001 package calhoun.analysis.crf.features.tricycle13;
002
003 import java.io.Serializable;
004 import java.util.ArrayList;
005 import java.util.List;
006
007 import org.apache.commons.logging.Log;
008 import org.apache.commons.logging.LogFactory;
009
010 import calhoun.analysis.crf.AbstractFeatureManager;
011 import calhoun.analysis.crf.CacheStrategySpec;
012 import calhoun.analysis.crf.FeatureList;
013 import calhoun.analysis.crf.FeatureManagerEdge;
014 import calhoun.analysis.crf.ModelManager;
015 import calhoun.analysis.crf.BeanModel.Node;
016 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
017 import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
018 import calhoun.analysis.crf.features.tricycle13.EmissionMarkovFeature.MarkovHistory;
019 import calhoun.analysis.crf.io.InputSequence;
020 import calhoun.analysis.crf.io.TrainingSequence;
021 import calhoun.seq.KmerHasher;
022 import calhoun.util.Assert;
023 import calhoun.util.DenseBooleanMatrix2D;
024
025 public class PositionWeightMatrixFeatures extends AbstractFeatureManager<Character> implements FeatureManagerEdge<Character> {
026 private static final long serialVersionUID = -7659288739348604129L;
027 private static final Log log = LogFactory.getLog(PositionWeightMatrixFeatures.class);
028 boolean debug = log.isDebugEnabled();
029
030 // Position weight matrices are used to model transitions between two extensive features,
031 // for example a donor model for the transition between a positive stranded exon and a
032 // positive stranded intron of the appropriate phase.
033
034 // To specify a PWM, you specify its GEOMETRY, ie its span and where that span begins
035 // relative to the transition itself (the offset), and the two hidden states before and
036 // after the transition.
037
038 // The features returned are that transition times the log probability of the observed
039 // sequence within the window/span being modeled. This must then be trained using
040 // a TraningSequence<Character>.
041
042 // Optionally, one may wish to subtract a double-counting correction if the observed
043 // sequence in the window would have been predicted by something else. For this you
044 // must also specify the predictor that would have been used by default for each base
045 // (depending on the hidden state), and the sequence of hidden states to which this
046 // predictor would have been applied over the span of the PWM.
047
048 int startIx; // The index of the first feature managed by this FeatureManager
049 ModelManager model;
050
051 // Following block are things that depend explicitly on and are calculated directly
052 // from geometry, included only for convenience.
053 int nFeatures;
054 int[] span;
055 int[] offset;
056 int[] nTrans;
057 KmerHasher h; // for a single letter
058 DenseBooleanMatrix2D[] transitions;
059
060 // The variable geometry is the information that is needed to initialize
061 List<int[]> geometry; /* For each i, geometry[i] describes the geometry of one PWM feature.
062 * 0) geometry[i][0] is the span of the PWM
063 * 1) geometry[i][1] is the offset j of the PWM, so that the feature for
064 * position i relates the following observable and hidden states:
065 * y_(i-1), y_i, x_(i-j), x_(i-j+1), ... , x_(i-j+span-1)
066 * 2) geometry[i][2] is yprev
067 * 3) geometry[i][3] is y */
068
069 // These are the parameters that need to be trained empirically.
070 List<float[][]> logprob; /* logprob[i][j] is log probability of base j at position i, i=0..(span-1), j=0..3. */
071
072
073
074 // Following block is something that is only meaningful
075 // if you're going to subtract the doublecounting correction
076 // If you use this correction, you are currently required that each
077 // Feature describes exactly one transition.
078 boolean dcflag;
079 List<int[]> dcc; // DoubleCounting Correction
080 MarkovPredictorLogprob predictorlp;
081
082
083 transient InputSequence<? extends Character> lastSeq;
084 int lastPos;
085 float[] vals;
086
087 boolean tieFlag = false;
088
089
090 int UVCount = 0;
091
092 public static class Geometry implements Serializable {
093 private static final long serialVersionUID = 4896358213027322167L;
094
095 int size;
096 int transition;
097 Node prev;
098 Node current;
099
100 List<Node> overlapCorrections;
101
102 public Node getCurrent() {
103 return current;
104 }
105 public void setCurrent(Node current) {
106 this.current = current;
107 }
108 public List<Node> getOverlapCorrections() {
109 return overlapCorrections;
110 }
111 public void setOverlapCorrections(List<Node> overlapCorrections) {
112 this.overlapCorrections = overlapCorrections;
113 }
114 public Node getPrev() {
115 return prev;
116 }
117 public void setPrev(Node prev) {
118 this.prev = prev;
119 }
120 public int getSize() {
121 return size;
122 }
123 public void setSize(int size) {
124 this.size = size;
125 }
126 public int getTransition() {
127 return transition;
128 }
129 public void setTransition(int transition) {
130 this.transition = transition;
131 }
132 }
133
134 List<Geometry> pwmGeometry;
135 MarkovHistory markovHistory;
136
137 ///////////////////////////////////// Class variables above, methods below //////////
138 public PositionWeightMatrixFeatures() { }
139
140 public PositionWeightMatrixFeatures(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory) {
141 setThingsUp(geometry,dccorrection,markovhistory);
142 }
143
144 public PositionWeightMatrixFeatures(List<int[]> geometry2, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> flags) {
145 tieFlag = true;
146 setThingsUp(geometry2,dccorrection,markovhistory);
147 }
148
149 public void init() {
150 List<int[]> geometry1 = new ArrayList(pwmGeometry.size());
151 List<int[]> dccorrection = new ArrayList(pwmGeometry.size());
152 for(Geometry g : pwmGeometry) {
153 int[] params = new int[4];
154 params[0] = g.getSize();
155 params[1] = g.getTransition();
156 params[2] = g.getPrev().getIndex();
157 params[3] = g.getCurrent().getIndex();
158 geometry1.add(params);
159 int[] correction = new int[g.overlapCorrections.size()];
160 for(int i=0; i<correction.length; ++i) {
161 correction[i] = g.overlapCorrections.get(i).getIndex();
162 }
163 dccorrection.add(correction);
164 }
165 setThingsUp(geometry1, dccorrection, markovHistory.convert());
166 }
167
168 private void setThingsUp(List<int[]> geometry2, List<int[]> dccorrection, List<int[]> markovhistory) {
169 this.predictorlp = new MarkovPredictorLogprob(markovhistory);
170
171 setupGeometry(geometry2);
172
173 Assert.a(geometry.size()==nFeatures);
174 Assert.a(dccorrection.size()==nFeatures);
175 for (int j=0; j<nFeatures; j++) {
176 Assert.a(nTrans[j]==1);
177 Assert.a(dccorrection.get(j).length == span[j]);
178 }
179 setupDoubleCountCorrections(dccorrection,predictorlp);
180 }
181
182
183 private void setupDoubleCountCorrections(List<int[]> dccorrection, MarkovPredictorLogprob predictorlp) {
184 dcflag = true;
185
186 this.predictorlp = predictorlp;
187 this.dcc = dccorrection;
188 }
189
190 private void setupGeometry(List<int[]> geometry) {
191 this.geometry = geometry;
192 nFeatures = geometry.size();
193 span = new int[nFeatures];
194 offset = new int[nFeatures];
195 nTrans = new int[nFeatures];
196 vals = new float[nFeatures];
197 h = new KmerHasher(KmerHasher.ACGTN, 1);
198 logprob = new ArrayList<float[][]>();
199 for (int i=0; i<nFeatures; i++) {
200 nTrans[i] = (geometry.get(i).length - 2)/2;
201 span[i] = geometry.get(i)[0];
202 offset[i] = geometry.get(i)[1];
203 Assert.a(offset[i]>=0); // So the span of the transition is WITHIN the span of the span of the PWM
204 float[][] lp = new float[span[i]][h.range()];
205 logprob.add(lp);
206 }
207 }
208
209 public int getNumFeatures() {
210 if (tieFlag) { return 1; }
211 return nFeatures;
212 }
213
214 public String getFeatureName(int featureIndex) {
215 if (tieFlag) { return "tiedPwmFeature"; }
216
217 int raw = featureIndex - startIx;
218 int[] X = geometry.get(raw);
219
220 String ret = "PWM.span" + X[0] + ".offset" + X[1];
221 for (int j=2; j<X.length; j+=2) {
222 ret = ret + ".(" + model.getStateName(X[j]) + "," + model.getStateName(X[j+1]) + ")";
223 }
224 return ret;
225 }
226
227
228
229 //static int count = 0;
230 public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int previousState, int state, FeatureList result) {
231 if(pos == 0) {
232 return;
233 }
234
235 if((seq != lastSeq) || (pos != lastPos)) {
236 //if((pos != lastPos)) {
237 lastSeq = seq;
238 lastPos = pos;
239 updateVals(seq, pos);
240 //System.out.println("Pos: "+pos+" Seq: "+seq);
241 //if(++count > 10) {
242 // throw new RuntimeException();
243 //}
244 }
245
246 if (tieFlag) {
247 for (int j=0; j<nFeatures; j++) {
248 if(transitions[j].getQuick(previousState, state)) {
249 result.addFeature(startIx, vals[j]);
250 }
251 }
252 } else {
253 for (int j=0; j<nFeatures; j++) {
254 if(transitions[j].getQuick(previousState, state)) {
255 result.addFeature(startIx + j, vals[j]);
256 }
257 }
258 }
259 }
260
261 void updateVals(InputSequence<? extends Character> seq, int ix) {
262 UVCount++;
263 for (int j=0; j<nFeatures; j++) {
264 // Figure out what needed for Feature j, which might have several valid y pairs, at position i
265 int[] geo = geometry.get(j);
266 int spn = geo[0];
267 int offset1 = geo[1];
268 float val = 0;
269 if ((ix>=offset1) && ((ix-offset1+spn)<=(seq.length())) ) {
270 for (int i=0; i<spn; i++) {
271 int pos = ix - offset1 + i;
272 char c = seq.getX(pos);
273 val = val + logprob.get(j)[i][h.hash(c)];
274 }
275
276 if (dcflag) {
277 Assert.a(nTrans[j]==1);
278 Assert.a(dcc.get(j).length == spn);
279
280 for (int i=0; i<spn; i++) {
281 val = val - predictorlp.logprob(dcc.get(j)[i],seq,ix-offset1+i);
282 }
283 }
284 }
285 vals[j] = val;
286 }
287 }
288
289 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
290 startIx = startingIndex;
291 model = modelInfo;
292
293 for (int i=0; i<nFeatures; i++) {
294 float[][] A = new float[span[i]][h.range()];
295 logprob.add(A);
296 }
297
298 // Initialize an array to hold the Feature values which will be passed back:
299 int nStates = model.getNumStates();
300 transitions = new DenseBooleanMatrix2D[nFeatures];
301 for (int i=0; i<nFeatures; i++) {
302 transitions[i] = new DenseBooleanMatrix2D(nStates, nStates);
303 for (int k=2; k<geometry.get(i).length; k+=2) {
304 transitions[i].setQuick(geometry.get(i)[k], geometry.get(i)[k+1], true);
305 }
306 }
307
308 for (int i=0; i<nFeatures; i++) {
309 for (int j=0; j<span[i]; j++) {
310 for (int k=0; k<h.range(); k++) {
311 logprob.get(i)[j][k]=(float) 1.0;
312 }
313 }
314 }
315
316 // In English, what I want to do is this. Loop through all of the training data, once for each Feature.
317 // While so doing, look for any positions where one of the allowed transitions for that feature occurs.
318 // At such positions, increment the counts for logprob.
319 for(TrainingSequence<? extends Character> seq : data) {
320 int len = seq.length();
321
322 for (int i=0; i<nFeatures; i++) {
323 for (int ix=0; ix<len; ix++) {
324 if ((ix>=offset[i]) && (ix-offset[i]+span[i] <= seq.length()) && (ix > 0) ) {
325 int yprev = seq.getY(ix-1);
326 int y = seq.getY(ix);
327 for (int j=0; j<nTrans[i]; j++) {
328 if ( (yprev == geometry.get(i)[2+2*j]) && (y == geometry.get(i)[2+2*j+1]) ) {
329 for (int pos=0; pos<span[i]; pos++) {
330 char c = seq.getX(ix - offset[i] + pos);
331 logprob.get(i)[pos][h.hash(c)] += 1.0;
332 }
333 }
334 }
335 }
336 }
337 }
338 }
339
340 // Above we got counts; we wanted logarithms; here is also where we normalize:
341 for (int i=0; i<nFeatures; i++) {
342 for (int j=0; j<span[i]; j++) {
343 float norm = (float) 0.0;
344 for (int k=0; k<h.range(); k++) {
345 norm += (float) logprob.get(i)[j][k];
346 }
347 Assert.a(norm>0);
348 for (int k=0; k<h.range(); k++) {
349 logprob.get(i)[j][k] = (float) (Math.log(logprob.get(i)[j][k]) - Math.log(norm));
350 }
351 }
352 }
353
354 if (dcflag) {
355 predictorlp.train(data);
356 }
357 }
358
359 @Override
360 public CacheStrategySpec getCacheStrategy() {
361 return new CacheStrategySpec(CacheStrategy.SPARSE);
362 }
363
364 /**
365 * @return Returns the markovHistory.
366 */
367 public MarkovHistory getMarkovHistory() {
368 return markovHistory;
369 }
370
371 /**
372 * @param markovHistory The markovHistory to set.
373 */
374 public void setMarkovHistory(MarkovHistory markovHistory) {
375 this.markovHistory = markovHistory;
376 }
377
378 /**
379 * @return Returns the pwmGeometry.
380 */
381 public List<Geometry> getPwmGeometry() {
382 return pwmGeometry;
383 }
384
385 /**
386 * @param pwmGeometry The pwmGeometry to set.
387 */
388 public void setPwmGeometry(List<Geometry> pwmGeometry) {
389 this.pwmGeometry = pwmGeometry;
390 }
391
392 public boolean isTieFlag() {
393 return tieFlag;
394 }
395
396 public void setTieFlag(boolean tieFlag) {
397 this.tieFlag = tieFlag;
398 }
399 }
400