001    package calhoun.analysis.crf.features.tricycle13;
002    
003    import java.io.Serializable;
004    import java.util.ArrayList;
005    import java.util.List;
006    
007    import org.apache.commons.logging.Log;
008    import org.apache.commons.logging.LogFactory;
009    
010    import calhoun.analysis.crf.AbstractFeatureManager;
011    import calhoun.analysis.crf.CacheStrategySpec;
012    import calhoun.analysis.crf.FeatureList;
013    import calhoun.analysis.crf.FeatureManagerEdge;
014    import calhoun.analysis.crf.ModelManager;
015    import calhoun.analysis.crf.BeanModel.Node;
016    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
017    import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
018    import calhoun.analysis.crf.features.tricycle13.EmissionMarkovFeature.MarkovHistory;
019    import calhoun.analysis.crf.io.InputSequence;
020    import calhoun.analysis.crf.io.TrainingSequence;
021    import calhoun.seq.KmerHasher;
022    import calhoun.util.Assert;
023    import calhoun.util.DenseBooleanMatrix2D;
024    
025    public class PositionWeightMatrixFeatures extends AbstractFeatureManager<Character> implements FeatureManagerEdge<Character> {
026            private static final long serialVersionUID = -7659288739348604129L;
027            private static final Log log = LogFactory.getLog(PositionWeightMatrixFeatures.class);
028            boolean debug = log.isDebugEnabled();
029            
030            // Position weight matrices are used to model transitions between two extensive features,
031            // for example a donor model for the transition between a positive stranded exon and a
032            // positive stranded intron of the appropriate phase.
033            
034            // To specify a PWM, you specify its GEOMETRY, ie its span and where that span begins
035            // relative to the transition itself (the offset), and the two hidden states before and
036            // after the transition.
037            
038            // The features returned are that transition times the log probability of the observed
039            // sequence within the window/span being modeled.  This must then be trained using
040            // a TraningSequence<Character>.
041            
042            // Optionally, one may wish to subtract a double-counting correction if the observed
043            // sequence in the window would have been predicted by something else.  For this you 
044            // must also specify the predictor that would have been used by default for each base
045            // (depending on the hidden state), and the sequence of hidden states to which this
046            // predictor would have been applied over the span of the PWM.
047            
048            int startIx;  // The index of the first feature managed by this FeatureManager
049            ModelManager model;
050            
051            // Following block are things that depend explicitly on and are calculated directly
052            //  from geometry, included only for convenience. 
053            int nFeatures;
054            int[] span;
055            int[] offset;
056            int[] nTrans;
057            KmerHasher   h; // for a single letter
058            DenseBooleanMatrix2D[] transitions;
059            
060            // The variable geometry is the information that is needed to initialize
061            List<int[]> geometry; /* For each i, geometry[i] describes the geometry of one PWM feature.
062                                   * 0) geometry[i][0] is the span of the PWM
063                                   * 1) geometry[i][1] is the offset j of the PWM, so that the feature for
064                                   *     position i relates the following observable and hidden states:
065                                   *     y_(i-1), y_i, x_(i-j), x_(i-j+1), ... , x_(i-j+span-1)
066                                   * 2) geometry[i][2] is yprev
067                                   * 3) geometry[i][3] is y  */
068            
069            // These are the parameters that need to be trained empirically.
070            List<float[][]> logprob; /* logprob[i][j] is log probability of base j at position i, i=0..(span-1), j=0..3. */
071    
072            
073            
074            // Following block is something that is only meaningful
075            // if you're going to subtract the doublecounting correction
076            // If you use this correction, you are currently required that each
077            // Feature describes exactly one transition.
078            boolean dcflag;
079            List<int[]> dcc;  // DoubleCounting Correction
080            MarkovPredictorLogprob predictorlp;
081    
082            
083            transient InputSequence<? extends Character> lastSeq;
084            int lastPos;
085            float[] vals;
086            
087            boolean tieFlag = false;
088            
089            
090            int UVCount = 0;
091            
092            public static class Geometry implements Serializable {
093                    private static final long serialVersionUID = 4896358213027322167L;
094    
095                    int size;
096                    int transition;
097                    Node prev;
098                    Node current;
099    
100                    List<Node> overlapCorrections;
101    
102                    public Node getCurrent() {
103                            return current;
104                    }
105                    public void setCurrent(Node current) {
106                            this.current = current;
107                    }
108                    public List<Node> getOverlapCorrections() {
109                            return overlapCorrections;
110                    }
111                    public void setOverlapCorrections(List<Node> overlapCorrections) {
112                            this.overlapCorrections = overlapCorrections;
113                    }
114                    public Node getPrev() {
115                            return prev;
116                    }
117                    public void setPrev(Node prev) {
118                            this.prev = prev;
119                    }
120                    public int getSize() {
121                            return size;
122                    }
123                    public void setSize(int size) {
124                            this.size = size;
125                    }
126                    public int getTransition() {
127                            return transition;
128                    }
129                    public void setTransition(int transition) {
130                            this.transition = transition;
131                    }
132            }
133    
134            List<Geometry> pwmGeometry;
135            MarkovHistory markovHistory;
136            
137            ///////////////////////////////////// Class variables above, methods below //////////
138            public PositionWeightMatrixFeatures() { }
139            
140            public PositionWeightMatrixFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
141                    setThingsUp(geometry,dccorrection,markovhistory);
142            }
143    
144            public PositionWeightMatrixFeatures(List<int[]> geometry2, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> flags) {
145                    tieFlag = true;
146                    setThingsUp(geometry2,dccorrection,markovhistory);
147            }
148    
149            public void init() {
150                    List<int[]> geometry1 = new ArrayList(pwmGeometry.size());
151                    List<int[]> dccorrection = new ArrayList(pwmGeometry.size());
152                    for(Geometry g : pwmGeometry) {
153                            int[] params = new int[4];
154                            params[0] = g.getSize();
155                            params[1] = g.getTransition();
156                            params[2] = g.getPrev().getIndex();
157                            params[3] = g.getCurrent().getIndex();
158                            geometry1.add(params);
159                            int[] correction = new int[g.overlapCorrections.size()];
160                            for(int i=0; i<correction.length; ++i) {
161                                    correction[i] = g.overlapCorrections.get(i).getIndex();
162                            }
163                            dccorrection.add(correction);
164                    }
165                    setThingsUp(geometry1, dccorrection, markovHistory.convert());
166            }
167            
168            private void setThingsUp(List<int[]> geometry2, List<int[]> dccorrection, List<int[]> markovhistory) {
169                    this.predictorlp = new MarkovPredictorLogprob(markovhistory);
170                    
171                    setupGeometry(geometry2);       
172                    
173                    Assert.a(geometry.size()==nFeatures);
174                    Assert.a(dccorrection.size()==nFeatures);               
175                    for (int j=0; j<nFeatures; j++) {
176                            Assert.a(nTrans[j]==1);
177                            Assert.a(dccorrection.get(j).length == span[j]);
178                    }
179                    setupDoubleCountCorrections(dccorrection,predictorlp);          
180            }
181            
182            
183            private void setupDoubleCountCorrections(List<int[]> dccorrection, MarkovPredictorLogprob predictorlp) {
184                    dcflag = true;
185                    
186                    this.predictorlp = predictorlp;
187                    this.dcc = dccorrection;
188            }
189    
190            private void setupGeometry(List<int[]> geometry) {
191                    this.geometry = geometry;
192                    nFeatures = geometry.size();
193                    span = new int[nFeatures];
194                    offset = new int[nFeatures];
195                    nTrans = new int[nFeatures];
196                    vals = new float[nFeatures];
197                    h = new KmerHasher(KmerHasher.ACGTN, 1);
198                    logprob = new ArrayList<float[][]>();
199                    for (int i=0; i<nFeatures; i++) {
200                            nTrans[i] = (geometry.get(i).length - 2)/2;
201                            span[i] = geometry.get(i)[0];
202                            offset[i] = geometry.get(i)[1];
203                            Assert.a(offset[i]>=0); // So the span of the transition is WITHIN the span of the span of the PWM
204                            float[][] lp = new float[span[i]][h.range()];
205                            logprob.add(lp);
206                    }
207            }
208    
209            public int getNumFeatures() {
210                    if (tieFlag) { return 1; }
211                    return nFeatures;
212            }       
213            
214            public String getFeatureName(int featureIndex) {
215                    if (tieFlag) { return "tiedPwmFeature"; }
216                    
217                    int raw = featureIndex - startIx;
218                    int[] X = geometry.get(raw);
219                    
220                    String ret = "PWM.span" + X[0] + ".offset" + X[1];
221                    for (int j=2; j<X.length; j+=2) {
222                            ret = ret + ".(" + model.getStateName(X[j]) + "," + model.getStateName(X[j+1]) + ")"; 
223                    }       
224                    return ret;
225            }
226    
227    
228            
229            //static int count = 0;
230            public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int previousState, int state, FeatureList result) {
231                    if(pos == 0) {
232                            return;
233                    }
234                    
235                    if((seq != lastSeq) || (pos != lastPos)) {
236                    //if((pos != lastPos)) {
237                            lastSeq = seq;
238                            lastPos = pos;
239                            updateVals(seq, pos);
240                            //System.out.println("Pos: "+pos+" Seq: "+seq);
241                            //if(++count > 10) {
242                            //      throw new RuntimeException();
243                            //}
244                    }
245                    
246                    if (tieFlag) {
247                            for (int j=0; j<nFeatures; j++) {
248                                    if(transitions[j].getQuick(previousState, state)) {
249                                            result.addFeature(startIx, vals[j]);
250                                    }
251                            }               
252                    } else {
253                            for (int j=0; j<nFeatures; j++) {
254                                    if(transitions[j].getQuick(previousState, state)) {
255                                            result.addFeature(startIx + j, vals[j]);
256                                    }
257                            }               
258                    }
259            }
260    
261            void updateVals(InputSequence<? extends Character> seq, int ix) {
262                    UVCount++;
263                    for (int j=0; j<nFeatures; j++) {
264                            // Figure out what needed for Feature j, which might have several valid y pairs, at position i
265                            int[] geo = geometry.get(j);
266                            int spn = geo[0];
267                            int offset1 = geo[1];
268                            float val = 0;
269                            if ((ix>=offset1) && ((ix-offset1+spn)<=(seq.length())) ) {
270                                    for (int i=0; i<spn; i++) {
271                                            int pos = ix - offset1 + i;
272                                            char c = seq.getX(pos);
273                                            val = val + logprob.get(j)[i][h.hash(c)];
274                                    }
275                            
276                                    if (dcflag) {
277                                            Assert.a(nTrans[j]==1);
278                                            Assert.a(dcc.get(j).length == spn);
279                                            
280                                            for (int i=0; i<spn; i++) {
281                                                    val = val - predictorlp.logprob(dcc.get(j)[i],seq,ix-offset1+i);                                        
282                                            }
283                                    }
284                            }
285                            vals[j] = val;
286                    }                       
287            }
288    
289            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
290                    startIx = startingIndex;
291                    model = modelInfo;
292            
293                    for (int i=0; i<nFeatures; i++) {
294                            float[][] A = new float[span[i]][h.range()];
295                            logprob.add(A);
296                    }
297                    
298                    // Initialize an array to hold the Feature values which will be passed back:
299                    int nStates = model.getNumStates();
300                    transitions = new DenseBooleanMatrix2D[nFeatures];
301                    for (int i=0; i<nFeatures; i++) {
302                            transitions[i] = new DenseBooleanMatrix2D(nStates, nStates);
303                            for (int k=2; k<geometry.get(i).length; k+=2) {
304                                    transitions[i].setQuick(geometry.get(i)[k], geometry.get(i)[k+1], true);
305                            }
306                    }
307            
308                    for (int i=0; i<nFeatures; i++) {
309                            for (int j=0; j<span[i]; j++) {
310                                    for (int k=0; k<h.range(); k++) {
311                                            logprob.get(i)[j][k]=(float) 1.0;
312                                    }
313                            }       
314                    }
315    
316                    // In English, what I want to do is this.  Loop through all of the training data, once for each Feature.
317                    // While so doing, look for any positions where one of the allowed transitions for that feature occurs.
318                    // At such positions, increment the counts for logprob.
319                    for(TrainingSequence<? extends Character> seq : data) {
320                            int len = seq.length();
321                    
322                            for (int i=0; i<nFeatures; i++) {
323                                    for (int ix=0; ix<len; ix++) {
324                                            if ((ix>=offset[i]) && (ix-offset[i]+span[i] <= seq.length()) && (ix > 0) ) {
325                                                    int yprev = seq.getY(ix-1);
326                                                    int y = seq.getY(ix);
327                                                    for (int j=0; j<nTrans[i]; j++) {
328                                                            if ( (yprev == geometry.get(i)[2+2*j]) && (y == geometry.get(i)[2+2*j+1]) ) {
329                                                                    for (int pos=0; pos<span[i]; pos++) {
330                                                                            char c = seq.getX(ix - offset[i] + pos);
331                                                                            logprob.get(i)[pos][h.hash(c)] += 1.0;
332                                                                    }
333                                                            }
334                                                    }
335                                            }
336                                    }
337                            }       
338                    }
339                    
340                    // Above we got counts; we wanted logarithms; here is also where we normalize:          
341                    for (int i=0; i<nFeatures; i++) {                    
342                            for (int j=0; j<span[i]; j++) {
343                                    float norm = (float) 0.0;
344                                    for (int k=0; k<h.range(); k++) {
345                                            norm += (float) logprob.get(i)[j][k];
346                                    }
347                                    Assert.a(norm>0);
348                                    for (int k=0; k<h.range(); k++) {
349                                            logprob.get(i)[j][k] = (float) (Math.log(logprob.get(i)[j][k]) - Math.log(norm));
350                                    }
351                            }       
352                    }
353    
354                    if (dcflag) {
355                            predictorlp.train(data);
356                    }
357            }
358    
359            @Override
360            public CacheStrategySpec getCacheStrategy() {
361                    return new CacheStrategySpec(CacheStrategy.SPARSE);
362            }
363    
364            /**
365             * @return Returns the markovHistory.
366             */
367            public MarkovHistory getMarkovHistory() {
368                    return markovHistory;
369            }
370    
371            /**
372             * @param markovHistory The markovHistory to set.
373             */
374            public void setMarkovHistory(MarkovHistory markovHistory) {
375                    this.markovHistory = markovHistory;
376            }
377    
378            /**
379             * @return Returns the pwmGeometry.
380             */
381            public List<Geometry> getPwmGeometry() {
382                    return pwmGeometry;
383            }
384    
385            /**
386             * @param pwmGeometry The pwmGeometry to set.
387             */
388            public void setPwmGeometry(List<Geometry> pwmGeometry) {
389                    this.pwmGeometry = pwmGeometry;
390            }
391    
392            public boolean isTieFlag() {
393                    return tieFlag;
394            }
395    
396            public void setTieFlag(boolean tieFlag) {
397                    this.tieFlag = tieFlag;
398            }
399    }
400