001    package calhoun.analysis.crf.features.tricycle13;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.AbstractFeatureManager;
010    import calhoun.analysis.crf.FeatureList;
011    import calhoun.analysis.crf.FeatureManagerEdge;
012    import calhoun.analysis.crf.ModelManager;
013    import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
014    import calhoun.analysis.crf.features.supporting.phylogenetic.ColumnConditionalLogProbability;
015    import calhoun.analysis.crf.features.supporting.phylogenetic.EvolutionaryModel;
016    import calhoun.analysis.crf.io.CompositeInput;
017    import calhoun.analysis.crf.io.InputSequence;
018    import calhoun.analysis.crf.io.TrainingSequence;
019    import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
020    import calhoun.seq.KmerHasher;
021    import calhoun.util.Assert;
022    import calhoun.util.DenseBooleanMatrix2D;
023    
024    public class PWM_evolution extends AbstractFeatureManager<CompositeInput> implements FeatureManagerEdge<CompositeInput> {
025            private static final long serialVersionUID = -7659288739348604129L;
026            private static final Log log = LogFactory.getLog(PWM_evolution.class);
027            boolean debug = log.isDebugEnabled();
028            
029            /* PWM evolution is like the position weight matrix fetaures for modeling the
030             * boundaries between two extensive states (eg a donor site separating exons from
031             * introns).  However, this feature does more because it not only does a PWM
032             * for the reference sequence but also trains a Kimura80 model for nucleotide
033             * evolution at each position of the feature (each position gets its own model).
034            
035             * Note that one must subtract double-counting corrections, since these bases would
036             * otherwise have been modeled using the exon states or the intron states.
037             * We'll use the same machinery that PWM uses for knowing what to subtract (ie what is
038             * being replaced.  However, here we must subtract out not only the prior assignment to the reference sequence (as we did with PWM)
039             * but also the prior assignemnt of conditional probability of the multiple alignment
040             * column given the reference sequence and evolutionary model.
041             * 
042             * Note similarities to both PositionWeightMatrixFeatures.java
043             *   and FelsensteinFeatures.java
044             */
045            
046            int startIx;  // The index of the first feature managed by this FeatureManager
047            ModelManager model;
048            
049            // Following block are things that depend explicitly on and are calculated directly
050            //  from geometry, included only for convenience. 
051            int nFeatures;
052            int[] span;
053            int[] offset;
054            int[] nTrans;
055    
056    
057            
058            DenseBooleanMatrix2D[] transitions;
059            
060            // New requirement: each Feture, or element of geometry, describes a single transition
061            // The variable geometry is the information that is needed to initialize
062            List<int[]> geometry; /* For each i, geometry[i] describes the geometry of one PWM feature.
063                                   * 0) geometry[i][0] is the span of the PWM
064                                   * 1) geometry[i][1] is the offset j of the PWM, so that the feature for
065                                   *     position i relates the following observable and hidden states:
066                                   *     y_(i-1), y_i, x_(i-j), x_(i-j+1), ... , x_(i-j+span-1)
067                                   * 2) geometry[i][2] is yprev
068                                   * 3) geometry[i][3] is y  */
069            
070            // These are the parameters that need to be trained empirically.
071            List<float[][]> logprob; /* logprob[i][j] is log probability of base j at position i, i=0..(span-1), j=0..3. */
072    
073            
074            
075            List<int[]> dcc;  // DoubleCounting Correction
076            MarkovPredictorLogprob predictorlp;
077            
078            // This stuff is similar to FelsensteinFeatures:
079            List<int[]> clusters;             // will be an input
080            List<EvolutionaryModel> emodels;  // this gets trained
081            int[] state2cluster;
082            static KmerHasher h = new KmerHasher(KmerHasher.ACGTother,1);
083            ColumnConditionalLogProbability mo;
084            boolean tieFlag;
085    
086            InputSequence<? extends CompositeInput> lastSeq;
087            int lastPos;
088            float[] vals;
089            private int nUpdate = 0;
090            
091    
092            public PWM_evolution(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters) {
093                    tieFlag = false;
094                    PWM_evolution_setup(geometry,dccorrection, markovhistory,clusters);
095            }
096    
097            public PWM_evolution(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters, List<int[]> flags) {
098                    tieFlag = true;
099                    PWM_evolution_setup(geometry,dccorrection, markovhistory,clusters);
100            }
101    
102            private void PWM_evolution_setup(List<int[]> geometry1, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters1) {
103                    dcc = dccorrection;
104                    predictorlp = new MarkovPredictorLogprob(markovhistory);
105                    mo = new ColumnConditionalLogProbability(clusters1,0); // Zero corresponds to default Kimura80 model
106    
107                    this.geometry = geometry1;
108                    this.clusters = clusters1;
109                    
110                    setupGeometry();        
111            }
112    
113            private void setupGeometry() {
114                    nFeatures = geometry.size();
115                    span = new int[nFeatures];
116                    offset = new int[nFeatures];
117                    nTrans = new int[nFeatures];
118                    h = new KmerHasher(KmerHasher.ACGTN, 1);
119                    logprob = new ArrayList<float[][]>();
120                    for (int i=0; i<nFeatures; i++) {
121                            nTrans[i] = (geometry.get(i).length - 2)/2;
122                            span[i] = geometry.get(i)[0];
123                            offset[i] = geometry.get(i)[1];
124                            Assert.a(offset[i]>=0); // So the span of the transition is WITHIN the span of the span of the PWM
125                            float[][] lp = new float[span[i]][h.range()];
126                            logprob.add(lp);
127                    }
128                    
129                    Assert.a(geometry.size()==nFeatures);
130                    Assert.a(dcc.size()==nFeatures);                
131                    for (int j=0; j<nFeatures; j++) {
132                            Assert.a(nTrans[j]==1);
133                            Assert.a(dcc.get(j).length == span[j]);
134                    }
135            }
136    
137            public int getNumFeatures() {
138                    if (tieFlag) { return 1; }
139                    return nFeatures;
140            }       
141            
142            public String getFeatureName(int featureIndex) {
143                    int raw = featureIndex - startIx;
144                    int[] X = geometry.get(raw);
145                    
146                    String ret = "PWM.span" + X[0] + ".offset" + X[1];
147                    for (int j=2; j<X.length; j+=2) {
148                            ret = ret + ".(" + model.getStateName(X[j]) + "," + model.getStateName(X[j+1]) + ")"; 
149                    }       
150                    return ret;
151            }
152    
153    
154            //static int count = 0;
155            public void evaluateEdge(InputSequence<? extends CompositeInput> seq, int pos, int previousState, int state, FeatureList result) {
156                    if(pos == 0) {
157                            return;
158                    }
159                    
160                    if(seq != lastSeq || pos != lastPos) {
161                            lastSeq = seq;
162                            lastPos = pos;
163                            //System.out.println("Pos: "+pos+" Seq: "+seq);
164                            //if(++count > 10) {
165                            //      throw new RuntimeException();
166                            //}
167                            updateVals(seq, pos);
168                    }
169                    
170                    for (int j=0; j<nFeatures; j++) {
171                            if(transitions[j].getQuick(previousState, state)) {
172                                    if (tieFlag) { 
173                                            result.addFeature(startIx, vals[j]);                                    
174                                    } else {
175                                            result.addFeature(startIx + j, vals[j]);
176                                    }
177                            }
178                    }               
179            }
180    
181            void updateVals(InputSequence<? extends CompositeInput> seq, int ix) {
182                    nUpdate ++;
183                    for (int j=0; j<nFeatures; j++) {
184                            // Figure out what needed for Feature j, which might have several valid y pairs, at position i
185                            int[] geo = geometry.get(j);
186                            int spn = geo[0];
187                            int offset1 = geo[1];
188                            float val = 0;
189                            if ((ix>=offset1) && ((ix-offset1+spn)<=(seq.length())) ) {
190                                    InputSequence<Character>                CIS = (InputSequence<Character>) seq.getComponent("ref");
191                                    InputSequence<MultipleAlignmentColumn>  MIS = (InputSequence<MultipleAlignmentColumn>) seq.getComponent("aln");
192                                    
193                                    for (int i=0; i<spn; i++) {
194                                            int pos = ix - offset1 + i;
195                                            char c = CIS.getX(pos);
196                                            val = val + logprob.get(j)[i][h.hash(c)];
197                                    }
198                                    
199                                    Assert.a(nTrans[j]==1);
200                                    Assert.a(dcc.get(j).length == spn);
201                                    
202                                    for (int i=0; i<spn; i++) {
203                                            // dcc.get(j)[i] is a state (a number 0-12); ix-offset+i is a position
204                                            // This is for subtracting to correct for what was added by the feature
205                                            // MarkovPredictorLogProb and is being replaced by the PWM
206                                            val = val - predictorlp.logprob(dcc.get(j)[i],CIS,ix-offset1+i);
207                                            // Below subtracts the correction for the feature FelsensteinFeatures
208                                            // which we are now replacing with the feature PWM_evolution at this position.
209                                            val = (float) (val - mo.condLogProb(MIS,ix-offset1+i,dcc.get(j)[i]));
210                                    }
211                            }
212                            vals[j] = val;
213                    }                       
214            }
215    
216            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends CompositeInput>> data) {
217                    startIx = startingIndex;
218                    model = modelInfo;
219    
220                    vals = new float[nFeatures];
221                    
222                    List<TrainingSequence<Character>> LTSC = new ArrayList<TrainingSequence<Character>>();
223                    List<TrainingSequence<MultipleAlignmentColumn>> LTSMA = new ArrayList<TrainingSequence<MultipleAlignmentColumn>>();             
224                    
225                    for (int j=0; j<data.size(); j++) {
226                            LTSC.add( data.get(j).getTrainingComponent("ref") );
227                            LTSMA.add( data.get(j).getTrainingComponent("aln") );                   
228                    }
229                    
230                    
231                    predictorlp.train(LTSC);
232                    mo.train(model,LTSMA);
233                    
234                    for (int i=0; i<nFeatures; i++) {
235                            float[][] A = new float[span[i]][h.range()];
236                            logprob.add(A);
237                    }
238                    
239                    // Initialize an array to hold the Feature values which will be passed back:
240                    int nStates = model.getNumStates();
241                    transitions = new DenseBooleanMatrix2D[nFeatures];
242                    for (int i=0; i<nFeatures; i++) {
243                            transitions[i] = new DenseBooleanMatrix2D(nStates, nStates);
244                            for (int k=2; k<geometry.get(i).length; k+=2) {
245                                    transitions[i].setQuick(geometry.get(i)[k], geometry.get(i)[k+1], true);
246                            }
247                    }
248            
249                    for (int i=0; i<nFeatures; i++) {
250                            for (int j=0; j<span[i]; j++) {
251                                    for (int k=0; k<h.range(); k++) {
252                                            logprob.get(i)[j][k]=(float) 1.0;
253                                    }
254                            }       
255                    }
256    
257                    // In English, what I want to do is this.  Loop through all of the training data, once for each Feature.
258                    // While so doing, look for any positions where one of the allowed transitions for that feature occurs.
259                    // At such positions, increment the counts for logprob.
260                    for(TrainingSequence<Character> seq : LTSC) {
261                            int len = seq.length();
262                    
263                            for (int i=0; i<nFeatures; i++) {
264                                    for (int ix=0; ix<len; ix++) {
265                                            if ((ix>=offset[i]) && (ix-offset[i]+span[i] <= seq.length()) && (ix > 0) ) {
266                                                    int yprev = seq.getY(ix-1);
267                                                    int y = seq.getY(ix);
268                                                    for (int j=0; j<nTrans[i]; j++) {
269                                                            if ( (yprev == geometry.get(i)[2+2*j]) && (y == geometry.get(i)[2+2*j+1]) ) {
270                                                                    for (int pos=0; pos<span[i]; pos++) {
271                                                                            char c = seq.getX(ix - offset[i] + pos);
272                                                                            logprob.get(i)[pos][h.hash(c)] += 1.0;
273                                                                    }
274                                                            }
275                                                    }
276                                            }
277                                    }
278                            }       
279                    }
280                    
281                    // Above we got counts; we wanted logarithms; here is also where we normalize:          
282                    for (int i=0; i<nFeatures; i++) {                    
283                            for (int j=0; j<span[i]; j++) {
284                                    float norm = (float) 0.0;
285                                    for (int k=0; k<h.range(); k++) {
286                                            norm += (float) logprob.get(i)[j][k];
287                                    }
288                                    Assert.a(norm>0);
289                                    for (int k=0; k<h.range(); k++) {
290                                            logprob.get(i)[j][k] = (float) (Math.log(logprob.get(i)[j][k]) - Math.log(norm));
291                                    }
292                            }       
293                    }
294    
295            }
296    }
297