001 package calhoun.analysis.crf.features.supporting;
002
003 import java.io.Serializable;
004
005 import org.apache.commons.logging.Log;
006 import org.apache.commons.logging.LogFactory;
007
008 import calhoun.analysis.crf.io.InputSequence;
009 import calhoun.seq.KmerHasher;
010 import calhoun.util.Assert;
011
012
013 public class PWMLookup implements Serializable {
014 private static final long serialVersionUID = 5353716408013134581L;
015
016 private static final Log log = LogFactory.getLog(PWMLookup.class);
017
018 final KmerHasher.CharacterHash hashForward = KmerHasher.ACGTN;
019 final KmerHasher.CharacterHash hashReverse = KmerHasher.ACGTNcomp;
020
021 boolean finalized = false;
022
023 final int mult = 4;
024
025 int left;
026 int right;
027 int span;
028
029 double[] lookupTable;
030
031 public PWMLookup(int lookLeft, int lookRight, double pseudoCount) {
032 left = lookLeft;
033 right = lookRight;
034 span = left + right;
035
036 Assert.a( left >= 0); // Perhaps all you really need is (lookLeft+lookRight) >= 0, but leave that for another day.
037 Assert.a( right >= 0);
038 Assert.a( span < 30); // No particular need for this, constraint can be relaxed, just want to be alerted if this happens for now
039
040 int tableSize = mult*span;
041
042 lookupTable = new double[tableSize];
043 for (int i=0; i<tableSize; i++) {
044 lookupTable[i] = pseudoCount;
045 }
046 }
047
048
049 public void increment(InputSequence<? extends Character> seq, int pos, boolean isPlus) {
050 Assert.a(!finalized);
051
052 if (isPlus) {
053 if (pos < left) { return; }
054 if (pos + right > seq.length()) { return; }
055 for (int j = pos - left; j<pos+right; j++) {
056 int h = hashForward.hash(seq.getX(j));
057 if (h<4) {
058 lookupTable[mult*(j-pos+left) + h] += 1.0;
059 }
060 }
061 } else {
062 if (pos < right) { return; }
063 if (pos + left > seq.length()) { return; }
064 for (int j = pos + left - 1; j>=pos - right; j--) {
065 int h = hashReverse.hash(seq.getX(j));
066 if (h<4) {
067 lookupTable[mult*(pos+left-1-j) + h] += 1.0;
068 }
069 }
070 }
071 }
072
073 public void completeCounts() {
074 Assert.a(!finalized);
075 log.debug("finalizing a PWMlookup, span="+span + " mult=" + mult);
076
077 for (int i=0; i<span; i++) {
078 // System.out.println("finalizing i=" + i);
079 double sum = 0;
080 for (int j=mult*i; j<mult*(i+1); j++) {
081 sum += lookupTable[j];
082 }
083 for (int j=mult*i; j<mult*(i+1); j++) {
084 lookupTable[j] = Math.log(lookupTable[j]/sum);
085 Assert.a(lookupTable[j]<=0);
086 }
087 }
088
089 finalized = true;
090 }
091
092 public double lookup(InputSequence<? extends Character> seq, int pos, boolean isPlus) {
093 Assert.a(finalized);
094
095 double ret = 0.0;
096
097 if (isPlus) {
098 if (pos < left) { return 0.0; }
099 if (pos + right > seq.length()) { return 0.0; }
100 for (int j = pos - left; j<pos+right; j++) {
101 int h = hashForward.hash(seq.getX(j));
102 if (h<4) {
103 ret += lookupTable[mult*(j-pos+left) + h];
104 }
105 }
106 } else {
107 if (pos < right) { return 0.0; }
108 if (pos + left > seq.length()) { return 0.0; }
109 for (int j = pos + left - 1; j>=pos - right; j--) {
110 int h = hashReverse.hash(seq.getX(j));
111 if (h<4) {
112 ret += lookupTable[mult*(pos+left-1-j) + h];
113 }
114 }
115 }
116 Assert.a(ret<=0);
117 return ret;
118 }
119
120 }