001 package calhoun.analysis.crf.features.supporting;
002
003 import java.io.Serializable;
004 import java.util.List;
005
006 import calhoun.analysis.crf.io.InputSequence;
007 import calhoun.analysis.crf.io.TrainingSequence;
008 import calhoun.seq.KmerHasher;
009 import calhoun.util.Assert;
010
011 public class MarkovPredictorLogprob implements Serializable {
012 private static final long serialVersionUID = -3801967331270040642L;
013
014 // For each state S, this class is capable of returning
015 // log(p_S(x_i|x_i-h,...,x_i-2,x_i-1)) h = degree of predictor
016
017 // Trained using a training sequence with hidden states labeled.
018 // For training, one only wants to look at positions for which not only
019 // the current hidden state but also the preceding h are as specified.
020
021 // For example, for a fifth order predictor for the exon3 state, I might only
022 // want to train using positions in the training data for which the preceding
023 // five hidden states were (e1,e2,e3,12,e2).
024
025 boolean trainedYet = false;
026 List<int[]> history;
027 int nStates; // number of hidden states
028 int[] kmerLengths; // order of Markov model (5th order means x_i predicted based on previous five).
029 KmerHasher[] hashers; // One for each state (they might need different history lengths, so use different hashers)
030 float[][][] logProb; // Log probabilities. First dimension is the state, second is indexed by the hash of the history, third is the hash of the current position.
031 int maxLength;
032 KmerHasher h; // for a single letter. Used to index the third dimension of the log prob.
033
034 int[] currentHash;
035 char[] scratch;
036
037 /** History is a list that contains one entry for each state. Each entry in the list is in turn a list of the preceding states
038 * that we want to examine.
039 *
040 * @param history
041 */
042 public MarkovPredictorLogprob(List<int[]> history) {
043 this.history = history;
044
045 nStates = history.size();
046
047 kmerLengths = new int[nStates];
048 hashers = new KmerHasher[nStates];
049 logProb = new float[nStates][][];
050 currentHash = new int[nStates];
051
052 h = new KmerHasher(KmerHasher.ACGTN, 1);
053
054 maxLength = 0;
055 for (int j=0; j<nStates; j++) {
056 kmerLengths[j] = history.get(j).length;
057 maxLength = Math.max(kmerLengths[j], maxLength);
058 hashers[j] = new KmerHasher(KmerHasher.ACGTN, kmerLengths[j]);
059 int historyRange = hashers[j].range();
060 int range = h.range();
061 logProb[j] = new float[historyRange][h.range()];
062 for (int k=0; k<historyRange; k++) {
063 for (int l=0; l<range; l++) {
064 logProb[j][k][l] = 1.0f; // Initializing the pseudocounts
065 }
066 }
067 }
068 scratch = new char[maxLength];
069 }
070
071 transient InputSequence<? extends Character> lastSeq = null;
072 int lastPos = -1;
073 char lastChar;
074 int singleHash;
075 public float logprob(int state, InputSequence<? extends Character> seq, int pos) {
076 /* Returns the log probability of nucleotide at position pos in ISC given
077 * that the hidden state at position pos is state and given the previous several
078 * (about 4, depends on how trained) nucleotides in ISC.*/
079 Assert.a(trainedYet);
080
081 if (pos < maxLength || pos >= seq.length()) {
082 return (float) 0.0;
083 }
084
085 if(seq != lastSeq || pos != lastPos) {
086 if(pos == lastPos + 1) {
087 updateHashes(lastChar);
088 }
089 else {
090 initHashes(seq, pos);
091 }
092 lastSeq = seq;
093 lastPos = pos;
094 lastChar = seq.getX(pos);
095 singleHash = h.hash(lastChar);
096 }
097
098 return logProb[state][currentHash[state]][singleHash];
099 }
100
101
102 public void train(List<? extends TrainingSequence<? extends Character>> data) {
103 if (trainedYet) {
104 return;
105 }
106
107 // Loop through the data and increment all the counts
108 for(TrainingSequence<? extends Character> seq : data) {
109 int len = seq.length();
110
111 initHashes(seq, maxLength);
112 for (int pos=maxLength; pos<len; pos++) {
113 int v = seq.getY(pos);
114 boolean okHistory = true;
115 for (int i=0; i<kmerLengths[v]; i++) {
116 if (history.get(v)[i] != seq.getY(pos-kmerLengths[v]+i)) {
117 okHistory = false;
118 break;
119 }
120 }
121
122 char c = seq.getX(pos);
123 if (okHistory) {
124 logProb[v][currentHash[v]][h.hash(c)] += 1.0f;
125 }
126 updateHashes(c);
127 }
128 }
129
130 // Then normalize the counts
131 for (int v=0; v<nStates; v++) {
132 for (int j=0; j<logProb[v].length; j++) {
133 float norm = (float) 0.0;
134 for (int k=0; k<logProb[v][j].length; k++) {
135 norm = norm + logProb[v][j][k];
136 }
137 double logNorm = Math.log(norm);
138 for (int k=0; k<h.range(); k++) {
139 logProb[v][j][k] = (float) (Math.log(logProb[v][j][k]) - logNorm);
140 }
141 }
142 }
143 trainedYet = true;
144 }
145
146 void initHashes(InputSequence<? extends Character> seq, int start) {
147 // Fill the scratch buffer with the first entries, from 0 to maxLength-1
148 for (int pos=0; pos<maxLength; pos++) {
149 scratch[pos] = seq.getX(pos-maxLength+start);
150 // Note: if you crash on above line with class cast exception, then a possible
151 // cause is that you are using a composite input, but in you config file you
152 // do not specify which input component this feature should be using.
153 }
154 for(int i = 0; i<currentHash.length; ++i) {
155 currentHash[i] = hashers[i].hash(scratch, maxLength - kmerLengths[i]);
156 }
157 }
158
159 void updateHashes(char c) {
160 for(int i = 0; i<currentHash.length; ++i) {
161 currentHash[i] = hashers[i].shiftHash(c, currentHash[i]);
162 }
163 }
164 }