001 package calhoun.analysis.crf.features.tricycle13;
002
003 import java.util.ArrayList;
004 import java.util.List;
005
006 import org.apache.commons.logging.Log;
007 import org.apache.commons.logging.LogFactory;
008
009 import calhoun.analysis.crf.AbstractFeatureManager;
010 import calhoun.analysis.crf.FeatureList;
011 import calhoun.analysis.crf.FeatureManagerEdge;
012 import calhoun.analysis.crf.ModelManager;
013 import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
014 import calhoun.analysis.crf.features.supporting.MaxentMotifModel;
015 import calhoun.analysis.crf.io.InputSequence;
016 import calhoun.analysis.crf.io.TrainingSequence;
017 import calhoun.seq.KmerHasher;
018 import calhoun.util.Assert;
019
020 public class MaxentMotifFeatures extends AbstractFeatureManager<Character> implements FeatureManagerEdge<Character> {
021 private static final long serialVersionUID = -7659288739348604129L;
022 private static final Log log = LogFactory.getLog(MaxentMotifFeatures.class);
023 boolean debug = log.isDebugEnabled();
024
025 /* This is intended as an upgrade from PWM models of motifs such as splice sites, to
026 * a more refined model callex maximum entropy model (MEM) that out to have greater
027 * sensitivity and specificity for finding splice sites, and hence also lead to more
028 * accurate gene predictons. Relative to PWMs, MEMs can capture dependencies
029 * between postions, even non-adjacent positions.
030 *
031 * GEOMETRY:
032 * To specify a PWM, you specify its GEOMETRY, ie its span and where that span begins
033 * relative to the transition itself (the offset), and the two hidden states before and
034 * after the transition.
035 *
036 * DOUBLECOUNTING CORRECTION:
037 * Optionally, one may wish to subtract a double-counting correction if the observed
038 * sequence in the window would have been predicted by something else. For this you
039 * must also specify the predictor that would have been used by default for each base
040 * (depending on the hidden state), and the sequence of hidden states to which this
041 * predictor would have been applied over the span of the PWM.
042 * At the moment, the double-counting correction is no longer optional; it is now required.
043 *
044 * Each motif model models a particular transition; e.g. exon3->intron3 has a model
045 * different than exon2->intron2.
046 *
047 * The MaxEnt constraint definition and iterative scaling method are being ported from
048 * Matlab, where Jade Vinson first implemented this based on paper byBurge and Yeo.
049 */
050
051
052 // MEMBER VARIABLES //////////////////////////////////////////////////////
053
054 // ADMINISTRATIVE OVERHEAD
055 int startIx; // The index of the first feature managed by this FeatureManager
056 ModelManager model;
057 KmerHasher h; // for a single letter
058 //DenseBooleanMatrix2D[] transitions;
059 boolean tieFlag = false;
060
061 // GEOMETRY
062 int nFeatures; // derived from geometry
063 int spanLimit = 12;
064 int[] span; // derived from geometry
065 int[] offset; // derived from geometry
066 int[] tranfrom;
067 int[] tranto;
068
069 // DOUBLECOUNTING CORRECTION
070 boolean dcflag;
071 List<int[]> dcc; // DoubleCounting Correction
072 MarkovPredictorLogprob predictorlp;
073
074 // OPTIMIZATION OVERHEAD
075 InputSequence<? extends Character> lastSeq;
076 int lastPos;
077 float[] vals;
078
079 // DATA THAT GETS TRAINED
080 List<double[]> listprob;
081
082
083
084 // CONSTRUCTORS AND SUPPORT //////////////////////////////////////////////////////
085
086 public MaxentMotifFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
087 setThingsUp(geometry,dccorrection,markovhistory);
088 }
089
090 public MaxentMotifFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> flags) {
091 tieFlag = true;
092 setThingsUp(geometry,dccorrection,markovhistory);
093 }
094
095 private void setThingsUp(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
096
097 // ADMINISTRATIVE
098 /* geometry, one of the inputs, has the following interpretation:
099 * For each i, geometry[i] describes the geometry of one PWM feature.
100 * 0) geometry[i][0] is the span of the PWM
101 * 1) geometry[i][1] is the offset j of the PWM, so that the feature for
102 * position i relates the following observable and hidden states:
103 * y_(i-1), y_i, x_(i-j), x_(i-j+1), ... , x_(i-j+span-1)
104 * 2) geometry[i][2] is yprev
105 * 3) geometry[i][3] is y */
106
107
108
109 nFeatures = geometry.size();
110 vals = new float[nFeatures];
111 h = new KmerHasher(KmerHasher.ACGTother, 1);
112
113 // DOUBLECOUNTING CORRECTION
114 dcflag = true;
115 this.predictorlp = new MarkovPredictorLogprob(markovhistory);
116 this.dcc = dccorrection;
117
118 // GEOMETRY
119 span = new int[nFeatures];
120 offset = new int[nFeatures];
121 tranfrom = new int[nFeatures];
122 tranto = new int[nFeatures];
123 listprob = new ArrayList<double[]>();
124 for (int i=0; i<nFeatures; i++) {
125 span[i] = geometry.get(i)[0];
126 offset[i] = geometry.get(i)[1];
127 tranfrom[i] = geometry.get(i)[2];
128 tranto[i] = geometry.get(i)[3];
129 Assert.a(span[i] <= spanLimit);
130 int len=1; for (int j=0; j<span[i]; j++) { len *= 4; }
131 double[] prob = new double[len];
132 listprob.add(prob);
133 }
134
135 //LOTS OF ASSERTIONS
136 Assert.a(dccorrection.size()==nFeatures);
137 for (int i=0; i<nFeatures; i++) {
138 Assert.a( (offset[i]>=0) && (offset[i]<= span[i]) ); // So the span of the transition is WITHIN the span of the span of the PWM
139 Assert.a(dccorrection.get(i).length == span[i]);
140 }
141 }
142
143
144 // BORING ADMINISTRATIVE FUNCTIONS //////////////////////////////////////////////////////
145
146 public int getNumFeatures() {
147 if (tieFlag) { return 1; }
148 return nFeatures;
149 }
150
151 public String getFeatureName(int featureIndex) {
152 if (tieFlag) { return "tiedMaxentMotifModels"; }
153
154 int raw = featureIndex - startIx;
155
156 String ret = "MaxentMotifModels.span" + span[raw] + ".offset" + offset[raw] + ".fromState." + model.getStateName(tranfrom[raw]) + ".toState."+ model.getStateName(tranto[raw]);
157
158 return ret;
159 }
160
161 // EVALUATION FUNCTION AND SUPPORT //////////////////////////////////////////////////////
162
163 public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int previousState, int state, FeatureList result) {
164 if(pos == 0) {
165 return;
166 }
167
168 if((seq != lastSeq) || (pos != lastPos)) {
169 lastSeq = seq;
170 lastPos = pos;
171 updateVals(seq, pos);
172 }
173
174 if (tieFlag) {
175 for (int j=0; j<nFeatures; j++) {
176 if( (tranfrom[j]==previousState) && (tranto[j]==state)) {
177 result.addFeature(startIx, vals[j]);
178 }
179 }
180 } else {
181 for (int j=0; j<nFeatures; j++) {
182 if( (tranfrom[j]==previousState) && (tranto[j]==state)) {
183 result.addFeature(startIx + j, vals[j]);
184 }
185 }
186 }
187 }
188
189
190
191
192
193 void updateVals(InputSequence<? extends Character> seq, int ix) {
194 for (int j=0; j<nFeatures; j++) {
195
196 int spn = span[j];
197 int offset1 = this.offset[j];
198
199 // If there are no missing data in the relevant window, then return log( prob according to maxent / prob according to default )
200
201 float val = 0;
202 if ((ix>=offset1) && ((ix-offset1+spn)<=(seq.length())) ) {
203 boolean completeInformation = true;
204 for (int i=0; i<spn; i++) {
205 int pos = ix - offset1 + i;
206 char c = seq.getX(pos);
207 if (h.hash(c) == 4) { completeInformation=false; }
208 }
209
210 if (completeInformation) {
211 int hash = 0;
212 for (int i=0; i<spn; i++) {
213 int pos = ix - offset1 + i;
214 char c = seq.getX(pos);
215 hash = hash*4 + h.hash(c);
216 }
217 double exval = listprob.get(j)[hash];
218 if (exval>0) {
219 val = (float) Math.log(exval);
220 } else {
221 log.info("Refusing to take log of zero, returning a large penalty instead.");
222 val = -4000;
223 }
224
225 if (dcflag) {
226 for (int i=0; i<spn; i++) {
227 val = val - predictorlp.logprob(dcc.get(j)[i],seq,ix-offset1+i);
228 }
229 }
230 }
231 }
232 vals[j] = val;
233 }
234 }
235
236
237 // TRAINING FUNCTION //////////////////////////////////////////////////////
238
239 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
240 startIx = startingIndex;
241 model = modelInfo;
242
243 log.debug("Training a maxent motif feature manager");
244
245 for (int i=0; i<nFeatures; i++) {
246 List<int[]> motifExamples = new ArrayList<int[]>();
247 // Loop through all of the training data and record all fully present (ie no missing data) examples of the motif.
248 // Use the hasher to get values between 0-3.
249
250 log.debug("Training a maxent motif feature with span " + span[i]);
251
252 for(TrainingSequence<? extends Character> seq : data) {
253 int len = seq.length();
254
255 for (int ix=offset[i]; ix<(len-span[i]+offset[i]); ix++) {
256 if (ix<=0) continue;
257 if (ix>=len) continue;
258 int yprev = seq.getY(ix-1);
259 int y = seq.getY(ix);
260
261 if ( (yprev == tranfrom[i]) && (y == tranto[i]) ) {
262 int[] example = new int[span[i]];
263 boolean completeInformation = true;
264 for (int pos=0; pos<span[i]; pos++) {
265 char c = seq.getX(ix - offset[i] + pos);
266 example[pos] = h.hash(c);
267 if (example[pos]==4) { completeInformation = false; }
268 }
269 if (completeInformation) { motifExamples.add(example); }
270 }
271 }
272 }
273
274 //List<Constraint> motifConstraints = makeAllPairwiseConstraints(motifExamples);
275
276 //listprob.set( i , trainMaxentDistribution(motifConstraints,span[i]) );
277 listprob.set( i , MaxentMotifModel.trainMaxentDistributionUsingAllPairwiseConstraints(motifExamples,span[i],1000,0.01) );
278 }
279
280 if (dcflag) {
281 predictorlp.train(data);
282 }
283 }
284 }
285