001 package calhoun.analysis.crf.features.tricycle13;
002
003 import java.util.ArrayList;
004 import java.util.List;
005
006 import org.apache.commons.logging.Log;
007 import org.apache.commons.logging.LogFactory;
008
009 import calhoun.analysis.crf.AbstractFeatureManager;
010 import calhoun.analysis.crf.FeatureList;
011 import calhoun.analysis.crf.FeatureManagerEdge;
012 import calhoun.analysis.crf.ModelManager;
013 import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
014 import calhoun.analysis.crf.features.supporting.phylogenetic.ColumnConditionalLogProbability;
015 import calhoun.analysis.crf.features.supporting.phylogenetic.EvolutionaryModel;
016 import calhoun.analysis.crf.io.CompositeInput;
017 import calhoun.analysis.crf.io.InputSequence;
018 import calhoun.analysis.crf.io.TrainingSequence;
019 import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
020 import calhoun.seq.KmerHasher;
021 import calhoun.util.Assert;
022 import calhoun.util.DenseBooleanMatrix2D;
023
024 public class PWM_evolution extends AbstractFeatureManager<CompositeInput> implements FeatureManagerEdge<CompositeInput> {
025 private static final long serialVersionUID = -7659288739348604129L;
026 private static final Log log = LogFactory.getLog(PWM_evolution.class);
027 boolean debug = log.isDebugEnabled();
028
029 /* PWM evolution is like the position weight matrix fetaures for modeling the
030 * boundaries between two extensive states (eg a donor site separating exons from
031 * introns). However, this feature does more because it not only does a PWM
032 * for the reference sequence but also trains a Kimura80 model for nucleotide
033 * evolution at each position of the feature (each position gets its own model).
034
035 * Note that one must subtract double-counting corrections, since these bases would
036 * otherwise have been modeled using the exon states or the intron states.
037 * We'll use the same machinery that PWM uses for knowing what to subtract (ie what is
038 * being replaced. However, here we must subtract out not only the prior assignment to the reference sequence (as we did with PWM)
039 * but also the prior assignemnt of conditional probability of the multiple alignment
040 * column given the reference sequence and evolutionary model.
041 *
042 * Note similarities to both PositionWeightMatrixFeatures.java
043 * and FelsensteinFeatures.java
044 */
045
046 int startIx; // The index of the first feature managed by this FeatureManager
047 ModelManager model;
048
049 // Following block are things that depend explicitly on and are calculated directly
050 // from geometry, included only for convenience.
051 int nFeatures;
052 int[] span;
053 int[] offset;
054 int[] nTrans;
055
056
057
058 DenseBooleanMatrix2D[] transitions;
059
060 // New requirement: each Feture, or element of geometry, describes a single transition
061 // The variable geometry is the information that is needed to initialize
062 List<int[]> geometry; /* For each i, geometry[i] describes the geometry of one PWM feature.
063 * 0) geometry[i][0] is the span of the PWM
064 * 1) geometry[i][1] is the offset j of the PWM, so that the feature for
065 * position i relates the following observable and hidden states:
066 * y_(i-1), y_i, x_(i-j), x_(i-j+1), ... , x_(i-j+span-1)
067 * 2) geometry[i][2] is yprev
068 * 3) geometry[i][3] is y */
069
070 // These are the parameters that need to be trained empirically.
071 List<float[][]> logprob; /* logprob[i][j] is log probability of base j at position i, i=0..(span-1), j=0..3. */
072
073
074
075 List<int[]> dcc; // DoubleCounting Correction
076 MarkovPredictorLogprob predictorlp;
077
078 // This stuff is similar to FelsensteinFeatures:
079 List<int[]> clusters; // will be an input
080 List<EvolutionaryModel> emodels; // this gets trained
081 int[] state2cluster;
082 static KmerHasher h = new KmerHasher(KmerHasher.ACGTother,1);
083 ColumnConditionalLogProbability mo;
084 boolean tieFlag;
085
086 InputSequence<? extends CompositeInput> lastSeq;
087 int lastPos;
088 float[] vals;
089 private int nUpdate = 0;
090
091
092 public PWM_evolution(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters) {
093 tieFlag = false;
094 PWM_evolution_setup(geometry,dccorrection, markovhistory,clusters);
095 }
096
097 public PWM_evolution(List<int[]> geometry, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters, List<int[]> flags) {
098 tieFlag = true;
099 PWM_evolution_setup(geometry,dccorrection, markovhistory,clusters);
100 }
101
102 private void PWM_evolution_setup(List<int[]> geometry1, List<int[]> dccorrection, List<int[]> markovhistory, List<int[]> clusters1) {
103 dcc = dccorrection;
104 predictorlp = new MarkovPredictorLogprob(markovhistory);
105 mo = new ColumnConditionalLogProbability(clusters1,0); // Zero corresponds to default Kimura80 model
106
107 this.geometry = geometry1;
108 this.clusters = clusters1;
109
110 setupGeometry();
111 }
112
113 private void setupGeometry() {
114 nFeatures = geometry.size();
115 span = new int[nFeatures];
116 offset = new int[nFeatures];
117 nTrans = new int[nFeatures];
118 h = new KmerHasher(KmerHasher.ACGTN, 1);
119 logprob = new ArrayList<float[][]>();
120 for (int i=0; i<nFeatures; i++) {
121 nTrans[i] = (geometry.get(i).length - 2)/2;
122 span[i] = geometry.get(i)[0];
123 offset[i] = geometry.get(i)[1];
124 Assert.a(offset[i]>=0); // So the span of the transition is WITHIN the span of the span of the PWM
125 float[][] lp = new float[span[i]][h.range()];
126 logprob.add(lp);
127 }
128
129 Assert.a(geometry.size()==nFeatures);
130 Assert.a(dcc.size()==nFeatures);
131 for (int j=0; j<nFeatures; j++) {
132 Assert.a(nTrans[j]==1);
133 Assert.a(dcc.get(j).length == span[j]);
134 }
135 }
136
137 public int getNumFeatures() {
138 if (tieFlag) { return 1; }
139 return nFeatures;
140 }
141
142 public String getFeatureName(int featureIndex) {
143 int raw = featureIndex - startIx;
144 int[] X = geometry.get(raw);
145
146 String ret = "PWM.span" + X[0] + ".offset" + X[1];
147 for (int j=2; j<X.length; j+=2) {
148 ret = ret + ".(" + model.getStateName(X[j]) + "," + model.getStateName(X[j+1]) + ")";
149 }
150 return ret;
151 }
152
153
154 //static int count = 0;
155 public void evaluateEdge(InputSequence<? extends CompositeInput> seq, int pos, int previousState, int state, FeatureList result) {
156 if(pos == 0) {
157 return;
158 }
159
160 if(seq != lastSeq || pos != lastPos) {
161 lastSeq = seq;
162 lastPos = pos;
163 //System.out.println("Pos: "+pos+" Seq: "+seq);
164 //if(++count > 10) {
165 // throw new RuntimeException();
166 //}
167 updateVals(seq, pos);
168 }
169
170 for (int j=0; j<nFeatures; j++) {
171 if(transitions[j].getQuick(previousState, state)) {
172 if (tieFlag) {
173 result.addFeature(startIx, vals[j]);
174 } else {
175 result.addFeature(startIx + j, vals[j]);
176 }
177 }
178 }
179 }
180
181 void updateVals(InputSequence<? extends CompositeInput> seq, int ix) {
182 nUpdate ++;
183 for (int j=0; j<nFeatures; j++) {
184 // Figure out what needed for Feature j, which might have several valid y pairs, at position i
185 int[] geo = geometry.get(j);
186 int spn = geo[0];
187 int offset1 = geo[1];
188 float val = 0;
189 if ((ix>=offset1) && ((ix-offset1+spn)<=(seq.length())) ) {
190 InputSequence<Character> CIS = (InputSequence<Character>) seq.getComponent("ref");
191 InputSequence<MultipleAlignmentColumn> MIS = (InputSequence<MultipleAlignmentColumn>) seq.getComponent("aln");
192
193 for (int i=0; i<spn; i++) {
194 int pos = ix - offset1 + i;
195 char c = CIS.getX(pos);
196 val = val + logprob.get(j)[i][h.hash(c)];
197 }
198
199 Assert.a(nTrans[j]==1);
200 Assert.a(dcc.get(j).length == spn);
201
202 for (int i=0; i<spn; i++) {
203 // dcc.get(j)[i] is a state (a number 0-12); ix-offset+i is a position
204 // This is for subtracting to correct for what was added by the feature
205 // MarkovPredictorLogProb and is being replaced by the PWM
206 val = val - predictorlp.logprob(dcc.get(j)[i],CIS,ix-offset1+i);
207 // Below subtracts the correction for the feature FelsensteinFeatures
208 // which we are now replacing with the feature PWM_evolution at this position.
209 val = (float) (val - mo.condLogProb(MIS,ix-offset1+i,dcc.get(j)[i]));
210 }
211 }
212 vals[j] = val;
213 }
214 }
215
216 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends CompositeInput>> data) {
217 startIx = startingIndex;
218 model = modelInfo;
219
220 vals = new float[nFeatures];
221
222 List<TrainingSequence<Character>> LTSC = new ArrayList<TrainingSequence<Character>>();
223 List<TrainingSequence<MultipleAlignmentColumn>> LTSMA = new ArrayList<TrainingSequence<MultipleAlignmentColumn>>();
224
225 for (int j=0; j<data.size(); j++) {
226 LTSC.add( data.get(j).getTrainingComponent("ref") );
227 LTSMA.add( data.get(j).getTrainingComponent("aln") );
228 }
229
230
231 predictorlp.train(LTSC);
232 mo.train(model,LTSMA);
233
234 for (int i=0; i<nFeatures; i++) {
235 float[][] A = new float[span[i]][h.range()];
236 logprob.add(A);
237 }
238
239 // Initialize an array to hold the Feature values which will be passed back:
240 int nStates = model.getNumStates();
241 transitions = new DenseBooleanMatrix2D[nFeatures];
242 for (int i=0; i<nFeatures; i++) {
243 transitions[i] = new DenseBooleanMatrix2D(nStates, nStates);
244 for (int k=2; k<geometry.get(i).length; k+=2) {
245 transitions[i].setQuick(geometry.get(i)[k], geometry.get(i)[k+1], true);
246 }
247 }
248
249 for (int i=0; i<nFeatures; i++) {
250 for (int j=0; j<span[i]; j++) {
251 for (int k=0; k<h.range(); k++) {
252 logprob.get(i)[j][k]=(float) 1.0;
253 }
254 }
255 }
256
257 // In English, what I want to do is this. Loop through all of the training data, once for each Feature.
258 // While so doing, look for any positions where one of the allowed transitions for that feature occurs.
259 // At such positions, increment the counts for logprob.
260 for(TrainingSequence<Character> seq : LTSC) {
261 int len = seq.length();
262
263 for (int i=0; i<nFeatures; i++) {
264 for (int ix=0; ix<len; ix++) {
265 if ((ix>=offset[i]) && (ix-offset[i]+span[i] <= seq.length()) && (ix > 0) ) {
266 int yprev = seq.getY(ix-1);
267 int y = seq.getY(ix);
268 for (int j=0; j<nTrans[i]; j++) {
269 if ( (yprev == geometry.get(i)[2+2*j]) && (y == geometry.get(i)[2+2*j+1]) ) {
270 for (int pos=0; pos<span[i]; pos++) {
271 char c = seq.getX(ix - offset[i] + pos);
272 logprob.get(i)[pos][h.hash(c)] += 1.0;
273 }
274 }
275 }
276 }
277 }
278 }
279 }
280
281 // Above we got counts; we wanted logarithms; here is also where we normalize:
282 for (int i=0; i<nFeatures; i++) {
283 for (int j=0; j<span[i]; j++) {
284 float norm = (float) 0.0;
285 for (int k=0; k<h.range(); k++) {
286 norm += (float) logprob.get(i)[j][k];
287 }
288 Assert.a(norm>0);
289 for (int k=0; k<h.range(); k++) {
290 logprob.get(i)[j][k] = (float) (Math.log(logprob.get(i)[j][k]) - Math.log(norm));
291 }
292 }
293 }
294
295 }
296 }
297