001    package calhoun.analysis.crf.features.tricycle13;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.AbstractFeatureManager;
010    import calhoun.analysis.crf.FeatureList;
011    import calhoun.analysis.crf.FeatureManagerEdge;
012    import calhoun.analysis.crf.ModelManager;
013    import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
014    import calhoun.analysis.crf.features.supporting.MaxentMotifModel;
015    import calhoun.analysis.crf.io.InputSequence;
016    import calhoun.analysis.crf.io.TrainingSequence;
017    import calhoun.seq.KmerHasher;
018    import calhoun.util.Assert;
019    
020    public class MaxentMotifFeatures extends AbstractFeatureManager<Character> implements FeatureManagerEdge<Character> {
021            private static final long serialVersionUID = -7659288739348604129L;
022            private static final Log log = LogFactory.getLog(MaxentMotifFeatures.class);
023            boolean debug = log.isDebugEnabled();
024            
025            /* This is intended as an upgrade from PWM models of motifs such as splice sites, to 
026             * a more refined model callex maximum entropy model (MEM) that out to have greater
027             * sensitivity and specificity for finding splice sites, and hence also lead to more
028             * accurate gene predictons.  Relative to PWMs, MEMs can capture dependencies
029             * between postions, even non-adjacent positions.
030             * 
031             * GEOMETRY:
032             * To specify a PWM, you specify its GEOMETRY, ie its span and where that span begins
033             * relative to the transition itself (the offset), and the two hidden states before and
034             * after the transition.
035             * 
036             * DOUBLECOUNTING CORRECTION:
037             * Optionally, one may wish to subtract a double-counting correction if the observed
038             * sequence in the window would have been predicted by something else.  For this you 
039             * must also specify the predictor that would have been used by default for each base
040             * (depending on the hidden state), and the sequence of hidden states to which this
041             * predictor would have been applied over the span of the PWM.
042             * At the moment, the double-counting correction is no longer optional; it is now required.
043             * 
044             * Each motif model models a particular transition; e.g. exon3->intron3 has a model
045             * different than exon2->intron2.
046             * 
047             * The MaxEnt constraint definition and iterative scaling method are being ported from
048             * Matlab, where Jade Vinson first implemented this based on paper byBurge and Yeo.
049             */
050            
051            
052            // MEMBER VARIABLES //////////////////////////////////////////////////////
053            
054            // ADMINISTRATIVE OVERHEAD
055            int startIx;  // The index of the first feature managed by this FeatureManager
056            ModelManager model;
057            KmerHasher   h; // for a single letter
058            //DenseBooleanMatrix2D[] transitions;
059            boolean tieFlag = false;                
060            
061            // GEOMETRY
062            int nFeatures;  // derived from geometry
063            int spanLimit = 12;
064            int[] span;             // derived from geometry
065            int[] offset;   // derived from geometry
066            int[] tranfrom;
067            int[] tranto;
068            
069            // DOUBLECOUNTING CORRECTION
070            boolean dcflag;
071            List<int[]> dcc;  // DoubleCounting Correction
072            MarkovPredictorLogprob predictorlp;
073            
074            // OPTIMIZATION OVERHEAD
075            InputSequence<? extends Character> lastSeq;
076            int lastPos;
077            float[] vals;
078            
079            // DATA THAT GETS TRAINED
080            List<double[]> listprob;
081            
082            
083            
084            //      CONSTRUCTORS AND SUPPORT //////////////////////////////////////////////////////
085            
086            public MaxentMotifFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
087                    setThingsUp(geometry,dccorrection,markovhistory);
088            }
089            
090            public MaxentMotifFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> flags) {
091                    tieFlag = true;
092                    setThingsUp(geometry,dccorrection,markovhistory);
093            }
094            
095            private void setThingsUp(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
096                    
097                    // ADMINISTRATIVE
098                    /* geometry, one of the inputs, has the following interpretation:
099                     * For each i, geometry[i] describes the geometry of one PWM feature.
100                     * 0) geometry[i][0] is the span of the PWM
101                     * 1) geometry[i][1] is the offset j of the PWM, so that the feature for
102                     *     position i relates the following observable and hidden states:
103                     *     y_(i-1), y_i, x_(i-j), x_(i-j+1), ... , x_(i-j+span-1)
104                     * 2) geometry[i][2] is yprev
105                     * 3) geometry[i][3] is y  */
106                    
107                    
108                    
109                    nFeatures = geometry.size();
110                    vals = new float[nFeatures];
111                    h = new KmerHasher(KmerHasher.ACGTother, 1);    
112                    
113                    // DOUBLECOUNTING CORRECTION
114                    dcflag = true;
115                    this.predictorlp = new MarkovPredictorLogprob(markovhistory);           
116                    this.dcc = dccorrection;        
117                    
118                    // GEOMETRY
119                    span = new int[nFeatures];
120                    offset = new int[nFeatures];
121                    tranfrom = new int[nFeatures];
122                    tranto = new int[nFeatures];
123                    listprob = new ArrayList<double[]>();
124                    for (int i=0; i<nFeatures; i++) {
125                            span[i]     = geometry.get(i)[0];
126                            offset[i]   = geometry.get(i)[1];
127                            tranfrom[i] = geometry.get(i)[2];
128                            tranto[i]   = geometry.get(i)[3];
129                            Assert.a(span[i] <= spanLimit);
130                            int len=1; for (int j=0; j<span[i]; j++) { len *= 4; }
131                            double[] prob = new double[len];
132                            listprob.add(prob);
133                    }       
134                    
135                    //LOTS OF ASSERTIONS
136                    Assert.a(dccorrection.size()==nFeatures);               
137                    for (int i=0; i<nFeatures; i++) {
138                            Assert.a(  (offset[i]>=0) && (offset[i]<= span[i])  ); // So the span of the transition is WITHIN the span of the span of the PWM
139                            Assert.a(dccorrection.get(i).length == span[i]);
140                    }       
141            }
142            
143            
144            // BORING ADMINISTRATIVE FUNCTIONS //////////////////////////////////////////////////////
145            
146            public int getNumFeatures() {
147                    if (tieFlag) { return 1; }
148                    return nFeatures;
149            }       
150            
151            public String getFeatureName(int featureIndex) {
152                    if (tieFlag) { return "tiedMaxentMotifModels"; }
153                    
154                    int raw = featureIndex - startIx;
155                    
156                    String ret = "MaxentMotifModels.span" + span[raw] + ".offset" + offset[raw] + ".fromState." + model.getStateName(tranfrom[raw]) + ".toState."+ model.getStateName(tranto[raw]);
157                    
158                    return ret;
159            }
160            
161            // EVALUATION FUNCTION AND SUPPORT //////////////////////////////////////////////////////
162            
163            public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int previousState, int state, FeatureList result) {
164                    if(pos == 0) {
165                            return;
166                    }
167                    
168                    if((seq != lastSeq) || (pos != lastPos)) {
169                            lastSeq = seq;
170                            lastPos = pos;
171                            updateVals(seq, pos);
172                    }
173                    
174                    if (tieFlag) {
175                            for (int j=0; j<nFeatures; j++) {
176                                    if( (tranfrom[j]==previousState) && (tranto[j]==state)) {
177                                            result.addFeature(startIx, vals[j]);
178                                    }
179                            }               
180                    } else {
181                            for (int j=0; j<nFeatures; j++) {
182                                    if( (tranfrom[j]==previousState) && (tranto[j]==state)) {
183                                            result.addFeature(startIx + j, vals[j]);
184                                    }
185                            }               
186                    }
187            }
188            
189            
190            
191            
192            
193            void updateVals(InputSequence<? extends Character> seq, int ix) {
194                    for (int j=0; j<nFeatures; j++) {
195                            
196                            int spn = span[j];
197                            int offset1 = this.offset[j];
198                            
199                            // If there are no missing data in the relevant window, then return log( prob according to maxent / prob according to default )
200                            
201                            float val = 0;
202                            if ((ix>=offset1) && ((ix-offset1+spn)<=(seq.length())) ) {
203                                    boolean completeInformation = true;
204                                    for (int i=0; i<spn; i++) {
205                                            int pos = ix - offset1 + i;
206                                            char c = seq.getX(pos);
207                                            if (h.hash(c) == 4) { completeInformation=false; }
208                                    }                               
209                                    
210                                    if (completeInformation) {
211                                            int hash = 0;
212                                            for (int i=0; i<spn; i++) {
213                                                    int pos = ix - offset1 + i;
214                                                    char c = seq.getX(pos);
215                                                    hash = hash*4 + h.hash(c);
216                                            }
217                                            double exval = listprob.get(j)[hash];
218                                            if (exval>0) {
219                                                    val = (float) Math.log(exval);
220                                            } else {
221                                                    log.info("Refusing to take log of zero, returning a large penalty instead.");
222                                                    val = -4000;
223                                            }
224                                                    
225                                            if (dcflag) {                                           
226                                                    for (int i=0; i<spn; i++) {
227                                                            val = val - predictorlp.logprob(dcc.get(j)[i],seq,ix-offset1+i);                                        
228                                                    }
229                                            }
230                                    }
231                            }
232                            vals[j] = val;
233                    }                       
234            }
235            
236            
237            // TRAINING FUNCTION //////////////////////////////////////////////////////
238            
239            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
240                    startIx = startingIndex;
241                    model = modelInfo;
242                    
243                    log.debug("Training a maxent motif feature manager");
244                    
245                    for (int i=0; i<nFeatures; i++) {
246                            List<int[]> motifExamples = new ArrayList<int[]>();
247                            // Loop through all of the training data and record all fully present (ie no missing data) examples of the motif.
248                            // Use the hasher to get values between 0-3.
249                            
250                            log.debug("Training a maxent motif feature with span " + span[i]);
251                            
252                            for(TrainingSequence<? extends Character> seq : data) {
253                                    int len = seq.length();
254                                    
255                                    for (int ix=offset[i]; ix<(len-span[i]+offset[i]); ix++) {
256                                            if (ix<=0) continue;
257                                            if (ix>=len) continue;
258                                            int yprev = seq.getY(ix-1);
259                                            int y = seq.getY(ix);
260                                            
261                                            if ( (yprev == tranfrom[i]) && (y == tranto[i]) ) {
262                                                    int[] example = new int[span[i]];
263                                                    boolean completeInformation = true;
264                                                    for (int pos=0; pos<span[i]; pos++) {
265                                                            char c = seq.getX(ix - offset[i] + pos);
266                                                            example[pos] = h.hash(c);
267                                                            if (example[pos]==4) { completeInformation = false; } 
268                                                    }
269                                                    if (completeInformation) { motifExamples.add(example); }
270                                            }
271                                    }       
272                            }
273                                                    
274                            //List<Constraint> motifConstraints = makeAllPairwiseConstraints(motifExamples);
275                            
276                            //listprob.set( i , trainMaxentDistribution(motifConstraints,span[i]) );
277                            listprob.set( i , MaxentMotifModel.trainMaxentDistributionUsingAllPairwiseConstraints(motifExamples,span[i],1000,0.01) );
278                    }
279                                    
280                    if (dcflag) {
281                            predictorlp.train(data);
282                    }
283            }
284    }
285