001    package calhoun.analysis.crf.features.supporting;
002    
003    import java.io.Serializable;
004    import java.util.List;
005    
006    import calhoun.analysis.crf.io.InputSequence;
007    import calhoun.analysis.crf.io.TrainingSequence;
008    import calhoun.seq.KmerHasher;
009    import calhoun.util.Assert;
010    
011    public class MarkovPredictorLogprob implements Serializable {
012            private static final long serialVersionUID = -3801967331270040642L;
013    
014            // For each state S, this class is capable of returning
015            // log(p_S(x_i|x_i-h,...,x_i-2,x_i-1))    h = degree of predictor
016            
017            // Trained using a training sequence with hidden states labeled.
018            // For training, one only wants to look at positions for which not only
019            // the current hidden state but also the preceding h are as specified.
020            
021            // For example, for a fifth order predictor for the exon3 state, I might only
022            // want to train using positions in the training data for which the preceding
023            // five hidden states were (e1,e2,e3,12,e2).
024            
025            boolean trainedYet = false;
026            List<int[]> history;
027            int nStates;    // number of hidden states
028            int[] kmerLengths;  // order of Markov model (5th order means x_i predicted based on previous five).
029            KmerHasher[] hashers;  // One for each state (they might need different history lengths, so use different hashers)
030            float[][][] logProb;  // Log probabilities.  First dimension is the state, second is indexed by the hash of the history, third is the hash of the current position.
031            int maxLength;
032            KmerHasher h; // for a single letter.  Used to index the third dimension of the log prob.
033            
034            int[] currentHash;
035            char[] scratch;
036    
037            /** History is a list that contains one entry for each state.  Each entry in the list is in turn a list of the preceding states
038             * that we want to examine.
039             * 
040             * @param history
041             */
042            public MarkovPredictorLogprob(List<int[]> history) {
043                    this.history = history;
044    
045                    nStates = history.size();
046    
047                    kmerLengths = new int[nStates];
048                    hashers = new KmerHasher[nStates];
049                    logProb = new float[nStates][][];
050                    currentHash = new int[nStates];
051    
052                    h = new KmerHasher(KmerHasher.ACGTN, 1);
053                            
054                    maxLength = 0;
055                    for (int j=0; j<nStates; j++) {
056                            kmerLengths[j] = history.get(j).length;
057                            maxLength = Math.max(kmerLengths[j], maxLength);
058                            hashers[j] = new KmerHasher(KmerHasher.ACGTN, kmerLengths[j]);
059                            int historyRange = hashers[j].range();
060                            int range = h.range();
061                            logProb[j] = new float[historyRange][h.range()]; 
062                            for (int k=0; k<historyRange; k++) {
063                                    for (int l=0; l<range; l++) {
064                                            logProb[j][k][l] = 1.0f;   // Initializing the pseudocounts
065                                    }
066                            }
067                    }
068                    scratch = new char[maxLength];
069            }
070            
071            transient InputSequence<? extends Character> lastSeq = null;
072            int lastPos = -1;
073            char lastChar;
074            int singleHash;
075            public float logprob(int state, InputSequence<? extends Character> seq, int pos) {
076                    /* Returns the log probability of nucleotide at position pos in ISC given
077                     * that the hidden state at position pos is state and given the previous several
078                     * (about 4, depends on how trained) nucleotides in ISC.*/
079                    Assert.a(trainedYet);
080                    
081                    if (pos < maxLength || pos >= seq.length()) { 
082                            return (float) 0.0; 
083                    }               
084                    
085                    if(seq != lastSeq || pos != lastPos) {
086                            if(pos == lastPos + 1) {
087                                    updateHashes(lastChar);
088                            }
089                            else {
090                                    initHashes(seq, pos);
091                            }
092                            lastSeq = seq;
093                            lastPos = pos;
094                            lastChar = seq.getX(pos);
095                            singleHash = h.hash(lastChar);
096                    }
097                    
098                    return logProb[state][currentHash[state]][singleHash];
099            }
100    
101            
102            public void train(List<? extends TrainingSequence<? extends Character>> data) {
103                    if (trainedYet) { 
104                            return; 
105                    }
106                    
107                    // Loop through the data and increment all the counts
108                    for(TrainingSequence<? extends Character> seq : data) {
109                            int len = seq.length();
110    
111                            initHashes(seq, maxLength);
112                            for (int pos=maxLength; pos<len; pos++) {
113                                    int v = seq.getY(pos);
114                                    boolean okHistory = true;
115                                    for (int i=0; i<kmerLengths[v]; i++) {
116                                            if (history.get(v)[i] != seq.getY(pos-kmerLengths[v]+i)) { 
117                                                    okHistory = false;
118                                                    break;
119                                            }
120                                    }
121                                    
122                                    char c = seq.getX(pos);
123                                    if (okHistory) {
124                                            logProb[v][currentHash[v]][h.hash(c)] += 1.0f;
125                                    }
126                                    updateHashes(c);
127                            }       
128                    }
129                    
130                    // Then normalize the counts
131                    for (int v=0; v<nStates; v++) {
132                            for (int j=0; j<logProb[v].length; j++) {
133                                    float norm = (float) 0.0;
134                                    for (int k=0; k<logProb[v][j].length; k++) {
135                                            norm = norm + logProb[v][j][k];
136                                    }
137                                    double logNorm = Math.log(norm);
138                                    for (int k=0; k<h.range(); k++) {
139                                            logProb[v][j][k] = (float) (Math.log(logProb[v][j][k]) - logNorm);
140                                    }                               
141                            }
142                    }       
143                    trainedYet = true;
144            }
145    
146            void initHashes(InputSequence<? extends Character> seq, int start) {
147                    // Fill the scratch buffer with the first entries, from 0 to maxLength-1
148                    for (int pos=0; pos<maxLength; pos++) {
149                            scratch[pos] = seq.getX(pos-maxLength+start);
150                            // Note: if you crash on above line with class cast exception, then a possible
151                            // cause is that you are using a composite input, but in you config file you
152                            // do not specify which input component this feature should be using.
153                    }
154                    for(int i = 0; i<currentHash.length; ++i) {
155                            currentHash[i] = hashers[i].hash(scratch, maxLength - kmerLengths[i]);
156                    }
157            }
158    
159            void updateHashes(char c) {
160                    for(int i = 0; i<currentHash.length; ++i) {
161                            currentHash[i] = hashers[i].shiftHash(c, currentHash[i]);
162                    }
163            }
164    }