001    package calhoun.analysis.crf.features.supporting;
002    
003    import java.io.Serializable;
004    
005    import org.apache.commons.logging.Log;
006    import org.apache.commons.logging.LogFactory;
007    
008    import calhoun.analysis.crf.io.InputSequence;
009    import calhoun.seq.KmerHasher;
010    import calhoun.util.Assert;
011    
012    
013    public class PWMLookup implements Serializable {
014            private static final long serialVersionUID = 5353716408013134581L;
015    
016            private static final Log log = LogFactory.getLog(PWMLookup.class);
017    
018            final KmerHasher.CharacterHash hashForward = KmerHasher.ACGTN;
019            final KmerHasher.CharacterHash hashReverse = KmerHasher.ACGTNcomp;              
020    
021            boolean finalized = false;
022            
023            final int mult = 4;
024            
025            int left;
026            int right;
027            int span;
028            
029            double[] lookupTable;
030    
031            public PWMLookup(int lookLeft, int lookRight, double pseudoCount) {
032                    left  = lookLeft;
033                    right = lookRight;
034                    span  = left + right;
035                    
036                    Assert.a( left  >= 0);  // Perhaps all you really need is (lookLeft+lookRight) >= 0, but leave that for another day.
037                    Assert.a( right >= 0);
038                    Assert.a( span  < 30);  // No particular need for this, constraint can be relaxed, just want to be alerted if this happens for now
039    
040                    int tableSize = mult*span;
041            
042                    lookupTable = new double[tableSize];
043                    for (int i=0; i<tableSize; i++) {
044                            lookupTable[i] = pseudoCount;
045                    }
046            }
047    
048            
049            public void increment(InputSequence<? extends Character> seq, int pos, boolean isPlus) {
050                    Assert.a(!finalized);
051                    
052                    if (isPlus) {
053                            if (pos < left) { return; }
054                            if (pos + right > seq.length()) { return; }
055                            for (int j = pos - left; j<pos+right; j++) {
056                                    int h = hashForward.hash(seq.getX(j));
057                                    if (h<4) {
058                                            lookupTable[mult*(j-pos+left) + h] += 1.0;
059                                    }
060                            }
061                    } else {
062                            if (pos < right) { return; }
063                            if (pos + left > seq.length()) { return; }                   
064                            for (int j = pos + left - 1; j>=pos - right; j--) {
065                                    int h = hashReverse.hash(seq.getX(j));
066                                    if (h<4) {
067                                            lookupTable[mult*(pos+left-1-j) + h] += 1.0;
068                                    }
069                            }
070                    }
071            }
072    
073            public void completeCounts() {
074                    Assert.a(!finalized);
075                    log.debug("finalizing a PWMlookup, span="+span + "    mult=" + mult);
076                    
077                    for (int i=0; i<span; i++) {
078                            // System.out.println("finalizing i=" + i);
079                            double sum = 0;
080                            for (int j=mult*i; j<mult*(i+1); j++) {
081                                    sum += lookupTable[j];
082                            }
083                            for (int j=mult*i; j<mult*(i+1); j++) {
084                                    lookupTable[j] = Math.log(lookupTable[j]/sum);
085                                    Assert.a(lookupTable[j]<=0);
086                            }
087                    }
088                    
089                    finalized = true;
090            }
091            
092            public double lookup(InputSequence<? extends Character> seq, int pos, boolean isPlus) {
093                    Assert.a(finalized);
094                    
095                    double ret = 0.0;
096                    
097                    if (isPlus) {
098                            if (pos < left) { return 0.0; }
099                            if (pos + right > seq.length()) { return 0.0; }
100                            for (int j = pos - left; j<pos+right; j++) {
101                                    int h = hashForward.hash(seq.getX(j));
102                                    if (h<4) {
103                                            ret += lookupTable[mult*(j-pos+left) + h];
104                                    }
105                            }
106                    } else {
107                            if (pos < right) { return 0.0; }
108                            if (pos + left > seq.length()) { return 0.0; }                       
109                            for (int j = pos + left - 1; j>=pos - right; j--) {
110                                    int h = hashReverse.hash(seq.getX(j));
111                                    if (h<4) {
112                                            ret += lookupTable[mult*(pos+left-1-j) + h];
113                                    }
114                            }
115                    }
116                    Assert.a(ret<=0);
117                    return ret;
118            }
119    
120    }