001 package calhoun.analysis.crf.solver;
002
003 import java.util.ArrayList;
004 import java.util.Arrays;
005 import java.util.HashSet;
006 import java.util.List;
007
008 import org.apache.commons.logging.Log;
009 import org.apache.commons.logging.LogFactory;
010
011 import calhoun.analysis.crf.CacheStrategySpec;
012 import calhoun.analysis.crf.CompositeFeatureManager;
013 import calhoun.analysis.crf.FeatureManager;
014 import calhoun.analysis.crf.FeatureManagerEdge;
015 import calhoun.analysis.crf.FeatureManagerNode;
016 import calhoun.analysis.crf.FeatureManagerNodeBoundaries;
017 import calhoun.analysis.crf.FeatureManagerNodeExplicitLength;
018 import calhoun.analysis.crf.ModelManager;
019 import calhoun.analysis.crf.SemiMarkovSetup;
020 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
021 import calhoun.analysis.crf.io.InputSequence;
022 import calhoun.analysis.crf.io.TrainingSequence;
023 import calhoun.analysis.crf.solver.check.ArrayFeatureList;
024 import calhoun.util.Assert;
025 import calhoun.util.CheckException;
026
027 /** A policy based cache processor. This is the main cache processor used. It uses the cache
028 * policies specified by the feature managers to efficent cache all feature values.
029 * The distinct caching strategies recognized by this CacheProcessor are:
030 * <ul>
031 * <ol> <b>COMPOSITE</b> - must drill down further to a non-composite FeatureManager to learn the strategy
032 * <ol> <b>CONSTANT</b> - Triplets (potential, featureIndex, value) that do not depend on position
033 * <ol> <b>DENSE</b> - Triplets (potential, featureIndex, lookupTable), where lookupTable is of size totalPositions and different
034 * doublets (potential, featureIndex), if they always have the same value, can share the same lookupTable to save space.
035 * <ol> <b>SPARSE</b> - Quadruplets (cumulativePosition, potential, featureIndex, value) of nonzero evaluation stored in order first by cumulativePosition,
036 * and then arbitrarily, similar to the way the previous "FeatureCache" worked but without need to sub-order by potentials in "model order".
037 * <ol> <b>LENGTHFUNCTION</b> - Quadruplets (length, state, featureIndex, value).
038 * <ol> Otherwise, treated as SPARSE
039 * </ul>
040 */
041 public class CacheProcessorDeluxe extends CacheProcessorBasic {
042 private static final Log log = LogFactory.getLog(CacheProcessorDeluxe.class);
043
044 private boolean ignoreInvalidTrainingData;
045 private boolean allPaths;
046 private boolean validHiddenStates = false;
047
048 public boolean[] invalidTransitions; // of size totalLength * nPotentials
049
050 private int[] numFixedEvalIndices; // of length nPotentials
051
052 private boolean forceStrategy;
053 private CacheStrategy onlyStrategy;
054
055 int lookbackArraySize = -1;
056 int lookbackArrayFeatureSize = -1;
057
058 ///////////////////// Below holds raw information for Sparsely Cached Feature Evaluations
059 private ArrayList<FeatureManager> sparseFMList;
060 private ArrayList<FeatureManager> lengthFMList;
061 private int[] sparseStarts; // of length (totalPositions + 1)
062 private ArrayList<SparseEvaluation> sparseEvals;
063 private class SparseEvaluation {
064 short featureIndex;
065 int potential;
066 float value;
067 }
068 // This is a global variable used only by evaluation of sparse features from cache
069 private int[] currentEvalIndices; // of length nPotentials
070
071
072 //////////////////// Below holds raw information for the Constant Features
073 private boolean writtenYet;
074 private ArrayList<ConstantEvaluation> constEvals;
075 private class ConstantEvaluation {
076 short featureIndex;
077 int potential;
078 float value;
079 int evalIndex;
080 }
081
082
083 //////////////////////////////////////// Below holds raw information for DenseFeatureTables
084 private ArrayList<float[]> denseTables; // first dimension is which table; second dimension is of length modelInfo.totalPositions; faster to flip-flop?
085 // ********** Above variable is a memory hog
086 private ArrayList<DenseEvalIndices> denseEvalsList; // use this until finilization, at which time pack into array below with identical info
087 private class DenseEvalIndices {
088 public float[] lookupTable;
089 public int evalIndex;
090 public FeatureEvaluation evalTable; // a pointer to one of the tables in denseTables, several DenseEvalIndices can point to same table; the number of possible tables to choose from is nPotentials
091 public int potential; // evalTable happens to be equal to evals.get(potential)
092 public short featureIndex;
093
094 }
095
096
097 //////////////// Below holds info for DenseNodeBoundary feature caches
098 private ArrayList<float[]> denseBoundaryTables; // note that these tables need 1 extra unit of lengh for each sequence, so they are of length modelInfo.totalPositions + modelInfo.nSeqs
099 ArrayList<DenseNodeBoundaryIndices>[] denseNodeBoundaryEvalsList; // have a list for each potential
100
101 private short[] minStateLengths;
102
103 private boolean ignoreSemiMarkovSelfTransitions;
104 private class DenseNodeBoundaryIndices {
105 public int featureIndex;
106 public float[] lookupTable;
107 public int rightPad;
108 public int leftPad;
109 public int potential;
110 }
111
112 //////////////// Below holds info for Length feature caches
113 private LengthOnlyEvaluation[][] lengthTables;
114
115 private class LengthOnlyEvaluation {
116 public short[] featureIndex;
117 public float[] value;
118 }
119
120 ////////////////////////////////////////////
121 ////// MEMBER VARIABLES ABOVE, METHODS BELOW
122 ////////////////////////////////////////////
123
124 public CacheProcessorDeluxe() {
125 super();
126 log.debug("Calling Cache Processor Deluxe constructor");
127 forceStrategy = false;
128 }
129
130 public CacheProcessorDeluxe(CacheStrategy strategy) {
131 super();
132 log.warn("Calling CacheProcessorDeluxe constructor, specifying a strategy");
133 forceStrategy = true;
134 onlyStrategy = strategy;
135 log.info("Constructed a cache processor and requiring it to always use the following cache strategy:");
136 switch (strategy) {
137 case CONSTANT:
138 log.warn("CONSTANT");
139 break;
140 case DENSE:
141 log.warn("DENSE");
142 break;
143 case SPARSE:
144 log.warn("SPARSE");
145 break;
146 default:
147 Assert.a(false,"ERROR - case not dealt with yet.");
148 }
149 }
150
151 @Override
152 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
153 super.setTrainingData(fm, data);
154 basicInit(allPaths);
155
156 if (maxStateLengths == null) {
157 log.debug("maxStateLengths was not set, setting all to length 1");
158 maxStateLengths = new short[fm.getNumStates()];
159 Arrays.fill(maxStateLengths, (short) 1);
160 }
161
162 if (minStateLengths == null) {
163 log.info("minStateLengths was not set, setting all to length 1");
164 minStateLengths = new short[fm.getNumStates()];
165 Arrays.fill(maxStateLengths, (short) 1);
166 }
167
168 modelInfo.setup(fm, data, allPaths, maxStateLengths, ignoreSemiMarkovSelfTransitions);
169
170 modelInfo.maxLookback = 1;
171 for(int stateLen : maxStateLengths) {
172 modelInfo.maxLookback = (short) Math.max(stateLen, modelInfo.maxLookback);
173 }
174
175 if(lookbackArraySize == -1)
176 lookbackArraySize = modelInfo.maxLookback+2;
177 if(lookbackArrayFeatureSize == -1)
178 lookbackArrayFeatureSize = Math.max(5, modelInfo.nFeatures);
179 lengthEvals = LengthFeatureEvaluation.create(modelInfo.statesWithLookback, lookbackArraySize, lookbackArrayFeatureSize);
180
181 validHiddenStates = data.get(0).getY(0) >= 0;
182 if(validHiddenStates) {
183 // Verify that there are no invalid transitions in the training data.
184 List<TrainingSequence<?>> goodData = new ArrayList<TrainingSequence<?>>();
185 for (int i=0; i<data.size(); ++i) {
186 TrainingSequence seq = data.get(i);
187 short length = 1;
188 int lastState = -1;
189 boolean validSequence = true;
190 try {
191 for (int pos=0; pos < seq.length(); pos++) {
192 int trainingState = seq.getY(pos);
193 boolean segmentEnd = pos == seq.length()-1 || trainingState != seq.getY(pos+1);
194 if(segmentEnd) {
195 Assert.a(maxStateLengths[trainingState]==1 || (length <= maxStateLengths[trainingState] && length >= minStateLengths[trainingState]), "Seq #", i, " Pos ", pos, " Training segment length ", length, " state: "+trainingState+" outside the allowed length ", minStateLengths[trainingState], "-", maxStateLengths[trainingState]);
196 Assert.a(lastState == -1 || modelInfo.transitionIndex.getQuick(lastState,trainingState) != -1);
197 lastState = trainingState;
198 length = 1;
199 }
200 else {
201 ++length;
202 }
203 }
204 }
205 catch(CheckException ex) {
206 if(ignoreInvalidTrainingData) {
207 log.warn("Discarding sequence "+i+" "+ex.getMessage());
208 validSequence = false;
209 }
210 else {
211 throw ex;
212 }
213 }
214 if(validSequence) {
215 goodData.add(seq);
216 }
217 }
218 Assert.a(goodData.size() != 0, "All training sequences were invalid.");
219 int diff = data.size() - goodData.size();
220 if(diff != 0) {
221 log.warn("Using "+goodData.size()+" training sequences. Discarded "+diff+" because of state length or transition problems.");
222 // Redo the initialization with the new set of training data
223 setTrainingData(fm, goodData);
224 return;
225 }
226 }
227
228 initializeCacheProcessor();
229
230 updateCacheProcessor(fm);
231 cacheLengthFeatureManagers();
232 cacheSparseFeatureManagers();
233
234 evaluateConstantFeatures();
235
236 if(validHiddenStates) {
237 checkConstraintsInTrainingData();
238 computeFeatureSums();
239 }
240 }
241
242
243 private void initializeCacheProcessor() {
244 invalidTransitions = new boolean[modelInfo.nPotentials*modelInfo.totalPositions];
245
246 //featureSums = new double[fm.getNumFeatures()];
247
248 numFixedEvalIndices = new int[modelInfo.nPotentials];
249 currentEvalIndices = new int[modelInfo.nPotentials];
250
251 denseEvalsList = new ArrayList<DenseEvalIndices>();
252 denseTables = new ArrayList<float[]>();
253
254 //denseNodeBoundaryEvalsList = new ArrayList<DenseNodeBoundaryIndices>[modelInfo.nPotentials];
255 denseBoundaryTables = new ArrayList<float[]>();
256 denseNodeBoundaryEvalsList = new ArrayList[modelInfo.nPotentials];
257 for (int pot=0; pot<modelInfo.nPotentials; pot++) {
258 denseNodeBoundaryEvalsList[pot] = new ArrayList<DenseNodeBoundaryIndices>();
259 }
260
261 constEvals = new ArrayList<ConstantEvaluation>();
262 writtenYet = false;
263
264 sparseFMList = new ArrayList<FeatureManager>();
265 sparseStarts = new int[modelInfo.totalPositions + 1];
266 sparseEvals = new ArrayList<SparseEvaluation>();
267
268 lengthFMList = new ArrayList<FeatureManager>();
269 lengthTables = new LengthOnlyEvaluation[modelInfo.nStates][];
270 for(int i= 0; i<modelInfo.nStates; ++i) {
271 lengthTables[i] = new LengthOnlyEvaluation[modelInfo.maxStateLengths[i]];
272 for(int j=0; j<lengthTables[i].length; ++j) {
273 lengthTables[i][j] = new LengthOnlyEvaluation();
274 }
275 }
276 }
277
278
279 private void updateCacheProcessor(FeatureManager fm1) {
280 log.debug("Calling updateCacheProcessor");
281 CacheStrategySpec.CacheStrategy strategy;
282 if (forceStrategy) {
283 strategy = onlyStrategy;
284 } else {
285 strategy = fm1.getCacheStrategy().strategy;
286 }
287 switch (strategy) {
288 case COMPOSITE:
289 log.debug(" burrowing into a COMPOSITE feature manager");
290 CompositeFeatureManager cfm = (CompositeFeatureManager) fm1;
291 for (FeatureManager fmchild : cfm.getComponentFeatures()) {
292 updateCacheProcessor(fmchild);
293 }
294 break;
295 case DENSE:
296 log.debug(" processing a feature manager as DENSE");
297 cacheFeatureManagerAsDenseNode(fm1);
298 break;
299 case DENSE_NODE_BOUNDARY:
300 log.debug("Processing a node as DENSE_NODE_BOUNDARY");
301 cacheFeatureManagerAsDenseNodeBoundary(fm1);
302 break;
303 case CONSTANT:
304 log.debug(" processing a feature manager as CONSTANT");
305 cacheFeatureManagerAsConstant(fm1);
306 break;
307 case LENGTHFUNCTION:
308 log.debug(" processing a feature manager as LENGTH");
309 lengthFMList.add(fm1);
310 break;
311 case SPARSE:
312 case UNSPECIFIED:
313 default:
314 log.debug(" processing a feature manager as SPARSE");
315 sparseFMList.add(fm1);
316 if(fm1 instanceof FeatureManagerNodeExplicitLength) {
317 lengthFMList.add(fm1);
318 }
319 break;
320 }
321 }
322
323
324
325 InputSequence getSequence(InputSequence seq, FeatureManager argFm) {
326 return argFm.getInputComponent() == null ? seq : seq.getComponent(argFm.getInputComponent());
327 }
328
329 private void cacheSparseFeatureManagers() {
330 sparseStarts[0] = 0;
331
332 log.debug("We are caching this many feature managers (some may be composite) as sparse: " + sparseFMList.size());
333 log.debug("Number of feature in each of the FMs is");
334 for (FeatureManager fmb : sparseFMList) {
335 log.debug("Number of features is " + fmb.getNumFeatures());
336 if (FeatureManagerEdge.class.isInstance(fmb)) {log.debug(" Is edge featuremanager"); }
337 if (FeatureManagerNode.class.isInstance(fmb)) {log.debug(" Is node featuremanager"); }
338 }
339
340 ArrayFeatureList result = new ArrayFeatureList(fm);
341 for (int i=0; i<data.size(); i++) {
342 TrainingSequence seq = data.get(i);
343 for (int pos=0; pos<seq.length(); pos++) {
344 int cumulativePosition = modelInfo.seqOffsets[i] + pos;
345
346 for (int potential=0; potential< modelInfo.nPotentials; potential++) {
347 result.clear();
348
349 int state;
350 int prevState;
351
352
353 if (potential<modelInfo.nStates) {
354 state = potential;
355 for (FeatureManager fmb : sparseFMList) {
356 if (!FeatureManagerNode.class.isInstance(fmb)) { continue; }
357
358 FeatureManagerNode fmn = (FeatureManagerNode) fmb;
359 fmn.evaluateNode(getSequence(seq, fmb),pos,state,result);
360 }
361 } else if (pos>0) {
362 prevState = modelInfo.transitionFrom[potential - modelInfo.nStates];
363 state = modelInfo.transitionTo[potential - modelInfo.nStates];
364 for (FeatureManager fm1 : sparseFMList) {
365 if (!FeatureManagerEdge.class.isInstance(fm1)) {continue; }
366 FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
367
368 fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
369 }
370 }
371
372
373 if (result.isValid()) {
374 for (int j=0; j< result.size(); j++) {
375 float value = (float) result.getValue(j);
376 if(value != 0.0) {
377 SparseEvaluation se = new SparseEvaluation();
378 se.featureIndex = (short) result.getIndex(j);
379 se.potential = potential;
380 se.value = value;
381 sparseEvals.add(se);
382 }
383 }
384 } else {
385 int ind = cumulativePosition*modelInfo.nPotentials + potential;
386 invalidTransitions[ind] = true;
387 }
388 }
389 sparseStarts[cumulativePosition+1] = sparseEvals.size();
390 }
391 }
392 }
393
394 private void cacheFeatureManagerAsConstant(FeatureManager fm1) {
395 CacheStrategySpec s = (CacheStrategySpec) fm1.getCacheStrategy();
396 if (s.strategy != CacheStrategy.CONSTANT) {
397 log.warn("Caching a Feature Manager as CONSTANT even though it requested a different strategy");
398 }
399
400 InputSequence seq = data.get(0);
401 int pos = 0;
402
403 ArrayFeatureList result = new ArrayFeatureList(fm);
404 for (int pot=0; pot< modelInfo.nPotentials; pot++) {
405 result.clear();
406
407 int state;
408 int prevState;
409
410 if (pot<modelInfo.nStates) {
411 if (!FeatureManagerNode.class.isInstance(fm1)) { continue; }
412 state = pot;
413 FeatureManagerNode fmn = (FeatureManagerNode) fm1;
414 fmn.evaluateNode(getSequence(seq, fmn),pos,state,result);
415 } else {
416 if (!FeatureManagerEdge.class.isInstance(fm1)) { continue; }
417 FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
418 prevState = modelInfo.transitionFrom[pot - modelInfo.nStates];
419 state = modelInfo.transitionTo[pot - modelInfo.nStates];
420 fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
421 }
422
423 Assert.a(result.isValid(),"Constant features should not have to encounter invalid potentials");
424
425 for (int j=0; j<result.size(); j++) {
426 ConstantEvaluation ce = new ConstantEvaluation();
427 ce.featureIndex = (short) result.getIndex(j);
428 ce.potential = pot;
429 ce.value = (float) result.getValue(j);
430 ce.evalIndex = numFixedEvalIndices[pot];
431 numFixedEvalIndices[pot]++;
432 constEvals.add(ce);
433 }
434 }
435 }
436
437 class FeaturePotential {
438 int featureIndex;
439 int potential;
440 @Override
441 public int hashCode() {
442 return potential * modelInfo.nPotentials + featureIndex;
443 }
444 @Override
445 public boolean equals(Object rhs) {
446 FeaturePotential f2 = (FeaturePotential) rhs;
447 return featureIndex == f2.featureIndex && potential == f2.potential;
448 }
449 }
450
451 private void cacheFeatureManagerAsDenseNode(FeatureManager fm1) {
452 CacheStrategySpec s = (CacheStrategySpec) fm1.getCacheStrategy();
453 if (!(s.strategy == CacheStrategy.DENSE)) {
454 log.warn("The evaluations are being cached using the DENSE strategy even though that was not the strategy requested by the feature.");
455 }
456
457 CacheStrategySpec.DenseCachingDetails details = (CacheStrategySpec.DenseCachingDetails) s.details;
458
459 if (details == null) {
460 if(fm1.getNumFeatures() > 0) {
461 log.debug("No details specified for "+fm1+". Using evaluations to find the correct nodes");
462 }
463
464 // Iterate through all sequences and positions to find the feature/potential combinations used
465 // This replaces code that used to do it based on the first position of the first sequence with length > 2.
466 HashSet<FeaturePotential> usedPotentials = new HashSet();
467 ArrayFeatureList result = new ArrayFeatureList(fm);
468 FeaturePotential current = new FeaturePotential();
469 for (InputSequence seq : data) {
470 for(int pos = 0; pos < seq.length(); ++pos) {
471 for (int pot=0; pot< modelInfo.nPotentials; pot++) {
472 result.clear();
473 if (pot<modelInfo.nStates) {
474 if (!FeatureManagerNode.class.isInstance(fm1)) { continue; }
475 FeatureManagerNode fmn = (FeatureManagerNode) fm1;
476 fmn.evaluateNode(getSequence(seq, fmn),pos,pot,result);
477 } else {
478 if (!FeatureManagerEdge.class.isInstance(fm1)) { continue; }
479 FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
480 int prevState = modelInfo.transitionFrom[pot - modelInfo.nStates];
481 int state = modelInfo.transitionTo[pot - modelInfo.nStates];
482 fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
483 }
484 for (int j=0; j<result.size(); j++) {
485 current.featureIndex = result.getIndex(j);
486 current.potential = pot;
487 if(!usedPotentials.contains(current)) {
488 log.debug("Adding dense node - pot: "+current.potential+" Feat: "+current.featureIndex);
489 usedPotentials.add(current);
490 current = new FeaturePotential();
491 }
492 }
493 }
494 }
495 }
496
497 // Now that we have the list of potentials that will be used, go and create the details.
498 details = new CacheStrategySpec.DenseCachingDetails();
499 ArrayList<Integer> densePotList = new ArrayList<Integer>();
500 ArrayList<Integer> denseTabList = new ArrayList<Integer>();
501 ArrayList<Short> denseFiList = new ArrayList<Short>();
502 details.nEvals = usedPotentials.size();
503 details.nTables = 0;
504
505 for(FeaturePotential fp : usedPotentials) {
506 densePotList.add(fp.potential);
507 denseTabList.add(details.nTables);
508 denseFiList.add((short) fp.featureIndex);
509 details.nTables++;
510 }
511
512 details.featureIndex = new short[details.nEvals];
513 details.tableNum = new int[details.nEvals];
514 details.potential = new int[details.nEvals];
515
516 for (int j=0; j<details.nEvals; j++) {
517 details.featureIndex[j] = denseFiList.get(j);
518 details.potential[j] = densePotList.get(j);
519 details.tableNum[j] = denseTabList.get(j);
520 }
521 details.check();
522 }
523
524 int firstTableIndex = denseTables.size();
525 details.check();
526 for (int j=0; j<details.nTables; j++) {
527 float[] temp = new float[modelInfo.totalPositions];
528 for (int k=0; k<modelInfo.totalPositions; k++) {
529 temp[k] = 0;
530 }
531 denseTables.add(temp);
532 }
533
534 for (int j=0; j<details.nEvals; j++) {
535 DenseEvalIndices de = new DenseEvalIndices();
536
537 int potential = details.potential[j];
538 de.evalTable = evals[potential];
539 de.potential = potential;
540 de.evalIndex = numFixedEvalIndices[potential];
541 numFixedEvalIndices[potential]++;
542 de.lookupTable = denseTables.get(firstTableIndex + details.tableNum[j]);
543 de.featureIndex = details.featureIndex[j];
544 denseEvalsList.add(de);
545 }
546
547
548 ArrayFeatureList result = new ArrayFeatureList(fm);
549 for (int i=0; i<data.size(); i++) {
550 TrainingSequence seq = data.get(i);
551 int actualPreviousState = -1;
552 for (int pos=0; pos<seq.length(); pos++) {
553 int actualState = seq.getY(pos);
554
555 int cumulativePosition = modelInfo.seqOffsets[i] + pos;
556
557 for (int potential=0; potential< modelInfo.nPotentials; potential++) {
558 result.clear();
559
560 int state;
561 int prevState;
562
563
564 if (potential<modelInfo.nStates) {
565 if (!FeatureManagerNode.class.isInstance(fm1)) { continue; }
566 state = potential;
567 FeatureManagerNode fmn = (FeatureManagerNode) fm1;
568 fmn.evaluateNode(getSequence(seq, fmn),pos,state,result);
569 if (state==actualState) {
570 for (int j=0; j< result.size(); j++) {
571 // featureSums[result.getIndex(j)] += result.getValue(j);
572 }
573 }
574 } else if (pos>0) {
575 if (!FeatureManagerEdge.class.isInstance(fm1)) { continue; }
576 FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
577 prevState = modelInfo.transitionFrom[potential - modelInfo.nStates];
578 state = modelInfo.transitionTo[potential - modelInfo.nStates];
579 fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
580 if ((state==actualState) && (prevState==actualPreviousState)) {
581 for (int j=0; j< result.size(); j++) {
582 // featureSums[result.getIndex(j)] += result.getValue(j);
583 }
584 }
585 }
586
587
588 if (result.isValid()) {
589
590 for (int j=0; j<result.size(); j++) {
591 for (int ev=0; ev<details.nEvals; ev++) {
592 DenseEvalIndices de = denseEvalsList.get(denseEvalsList.size()-1-ev);
593 if ((potential==de.potential) && (result.getIndex(j)==de.featureIndex)) {
594 de.lookupTable[modelInfo.seqOffsets[i] + pos] = (float) result.getValue(j);
595 break;
596 }
597 }
598 }
599
600
601 } else {
602 int ind = cumulativePosition*modelInfo.nPotentials + potential;
603 invalidTransitions[ind] = true;
604 }
605 }
606 actualPreviousState = actualState;
607 }
608 }
609
610
611 }
612
613
614 private void cacheFeatureManagerAsDenseNodeBoundary(FeatureManager fm1) {
615 // NOTE: This does not calculate feature sums for the semi-Markov features. I think it would be best
616 // to have a separate function do this which calls the CacheProcessor but which is separate and not interwoven with each part of CacheProcessor.
617
618 CacheStrategySpec s = (CacheStrategySpec) fm1.getCacheStrategy();
619 Assert.a(s.strategy == CacheStrategy.DENSE_NODE_BOUNDARY,"ERROR: Cannot cache using thie DENSE_NODE_BOUNDARY strategy unless specifically requested by the feature.");
620
621 CacheStrategySpec.DenseBoundaryCachingDetails details = (CacheStrategySpec.DenseBoundaryCachingDetails) s.details;
622 Assert.a(details != null,"ERROR -- the Cache Strategy DenseNodeBoundary is too intricate to omit specifying details");
623
624 int firstTableIndex = denseBoundaryTables.size();
625 details.check();
626 for (int j=0; j<details.nTables; j++) {
627 float[] temp = new float[modelInfo.totalPositions+modelInfo.nSeqs];
628 for (int k=0; k<modelInfo.totalPositions+modelInfo.nSeqs; k++) {
629 temp[k] = 0;
630 }
631 denseBoundaryTables.add(temp);
632 }
633
634 ArrayList<DenseNodeBoundaryIndices> tempNodeBoundaries = new ArrayList<DenseNodeBoundaryIndices>();
635 //for (int j=0; j<details.entries.size(); j++) {
636 for (CacheStrategySpec.DenseBoundaryEntry be : details.entries) {
637 DenseNodeBoundaryIndices db = new DenseNodeBoundaryIndices();
638
639 int potential = be.potential; //details.potential[j];
640 Assert.a(potential<modelInfo.nStates,"ERROR - Can't cache potential " + potential + " , which is a transition using the DENSE_NODE_BOUNDARY caching strategy.");
641 //numFixedEvalIndices[potential]++;
642 db.lookupTable = denseBoundaryTables.get(firstTableIndex + be.tableNum); //details.tableNum[j]);
643 db.featureIndex = be.featureIndex; //details.featureIndex[j];
644 db.potential = potential;
645 db.rightPad = be.rightPad; //details.rightPad[j];
646 db.leftPad = be.leftPad; //details.leftPad[j];
647 denseNodeBoundaryEvalsList[potential].add(db);
648 log.debug("Node boundary feature: "+db.featureIndex+" pot: "+db.potential+" table: "+db.lookupTable);
649 tempNodeBoundaries.add(db);
650 }
651 Assert.a(tempNodeBoundaries.size() == details.entries.size());
652
653
654 for (int seqNum=0; seqNum<data.size(); seqNum++) {
655 TrainingSequence seq = data.get(seqNum);
656 for (int pos=0; pos<seq.length(); pos++) {
657
658 int cumulativePosition = modelInfo.seqOffsets[seqNum] + pos;
659
660 for (int state=0; state< modelInfo.nStates; state++) {
661 ArrayFeatureList result = new ArrayFeatureList(fm);
662 Assert.a(FeatureManagerNodeBoundaries.class.isInstance(fm1),"ERROR - to cache using DENSE NODE BOUNDARY, must be an instance of FeatureManagerNodeBoundaries, but this isn't.");
663 FeatureManagerNode fmn = (FeatureManagerNode) fm1;
664 fmn.evaluateNode(getSequence(seq, fmn),pos,state,result);
665
666 if (result.isValid()) {
667 for (int j=0; j<result.size(); j++) {
668 boolean found = false;
669 for (DenseNodeBoundaryIndices db : tempNodeBoundaries ) {
670 if ((state==db.potential) && (result.getIndex(j)==db.featureIndex)) {
671 int tx = modelInfo.seqOffsets[seqNum] + seqNum + pos;
672 db.lookupTable[tx + 1] = db.lookupTable[tx] + (float) result.getValue(j);
673 found = true;
674 break;
675 }
676 }
677 if(!found) {
678 Assert.a(false, String.format("Feature Ix: %d State: %d not found. ", result.getIndex(j), state));
679 }
680 }
681 }
682
683 if (!result.isValid()) {
684 int ind = cumulativePosition*modelInfo.nPotentials + state;
685 invalidTransitions[ind] = true;
686 }
687 }
688 }
689 }
690 }
691
692 private void cacheLengthFeatureManagers() {
693 // Iterate through each length for each state
694 ArrayFeatureList featureList = new ArrayFeatureList(fm);
695 for(CacheProcessor.StatePotentials stateInfo : modelInfo.statesWithLookback) {
696 LengthOnlyEvaluation[] lengthForState = lengthTables[stateInfo.state];
697
698 int maxLen = modelInfo.maxStateLengths[stateInfo.state];
699 for(int i=0; i<maxLen; ++i) {
700 LengthOnlyEvaluation eval = lengthForState[i];
701
702 // Fill the evaluation object with the correct info for the state.
703 featureList.clear();
704 for(FeatureManagerNodeExplicitLength fm : (List<FeatureManagerNodeExplicitLength>) (List) lengthFMList) {
705 fm.evaluateNodeLength(data.get(0),maxLen, i+1, stateInfo.state, featureList);
706 }
707 int size = featureList.currentSize;
708 eval.featureIndex = new short[size];
709 eval.value = new float[size];
710 for(int j=0; j<size; ++j) {
711 eval.featureIndex[j] = (short) featureList.indices[j];
712 eval.value[j] = (float) featureList.values[j];
713 }
714 }
715 }
716 }
717
718 public void checkConstraintsInTrainingData() {
719 // Verify that there are no invalid transitions in the training data.
720 List<TrainingSequence<?>> goodData = new ArrayList<TrainingSequence<?>>();
721
722 for (int i=0; i<data.size(); ++i) {
723 TrainingSequence seq = data.get(i);
724 int seqOffset = modelInfo.seqOffsets[i];
725 int lastState = -1;
726 boolean validSequence = true;
727 try {
728 for (int pos=0; pos < seq.length(); pos++) {
729 int trainingState = seq.getY(pos);
730 int index = (seqOffset + pos) * modelInfo.nPotentials;
731 Assert.a(!invalidTransitions[index + trainingState], "Seq: ",i," Pos: ", pos, " State: ",trainingState, " violates a constraint.");
732 if(lastState != -1)
733 Assert.a(!invalidTransitions[index + modelInfo.nStates + modelInfo.transitionIndex.getQuick(lastState, trainingState)], "Seq: ",i," Pos: ", pos, " Transition: ",lastState,"-",trainingState, " violates a constraint.");
734
735 lastState = trainingState;
736 }
737 }
738 catch(CheckException ex) {
739 if(ignoreInvalidTrainingData) {
740 log.warn("Discarding sequence "+i+" "+ex.getMessage());
741 validSequence = false;
742 }
743 else {
744 throw ex;
745 }
746 }
747 if(validSequence) {
748 goodData.add(seq);
749 }
750 }
751 Assert.a(goodData.size() != 0, "All training sequences were invalid.");
752 int diff = data.size() - goodData.size();
753 if(diff != 0) {
754 log.warn("Using "+goodData.size()+" training sequences. Discarded "+diff+" because of constraint problems.");
755 log.warn("Rebuilding the feature cache using the remaining good sequences.");
756 // Redo the initialization with the new set of training data
757 setTrainingData(fm, goodData);
758 return;
759 }
760 }
761
762 ///////////////////////////////////////////////////////////////////////////////
763 ///// ABOVE: build the cache structure BELOW: evaluate a position using the stored cache
764 ///////////////////////////////////////////////////////////////////////////////
765
766
767 public void evaluatePosition(int seq, int pos) {
768 // The job of this function is to update evals.index and evals.value based on all features and potentials at this location
769 // Recall below the comments from another class about the information in evals:
770
771 int cumulativePosition = modelInfo.seqOffsets[seq] + pos;
772 evaluateDenseFeatures(cumulativePosition);
773 evaluateSparseFeatures(cumulativePosition); // only variable length
774 }
775
776 private void evaluateConstantFeatures() {
777 Assert.a(!writtenYet, "Constant features can only be written once");
778 for ( ConstantEvaluation ce : constEvals) {
779 short fi = ce.featureIndex;
780 int pi = ce.potential;
781 int ci = ce.evalIndex;
782 float val = ce.value;
783
784 evals[pi].index[ci] = fi;
785 evals[pi].value[ci] = val;
786 }
787 writtenYet = true;
788 }
789
790
791 private void evaluateDenseFeatures(int cumulativePosition) {
792 for (DenseEvalIndices de : denseEvalsList) {
793 FeatureEvaluation fe = de.evalTable;// = densePotentialIndices[j];
794 int ei = de.evalIndex;//denseEvalIndices[j];
795 float[] lut = de.lookupTable;// = denseTableIndices[j];
796 fe.index[ei] = de.featureIndex;
797 fe.value[ei] = lut[cumulativePosition];
798 }
799 }
800
801
802 private void evaluateSparseFeatures(int cumulativePosition) {
803 for (int pot=0; pot<modelInfo.nPotentials; pot++) {
804 currentEvalIndices[pot] = numFixedEvalIndices[pot];
805 }
806 int start = sparseStarts[cumulativePosition];
807 int stop = sparseStarts[cumulativePosition+1];
808 for (int j=start; j<stop; j++) {
809 SparseEvaluation se = sparseEvals.get(j);
810 short fi = se.featureIndex;
811 int pi = se.potential;
812 int ci = currentEvalIndices[pi];
813 currentEvalIndices[pi]++;
814 float val = se.value;
815
816 evals[pi].index[ci] = fi;
817 evals[pi].value[ci] = val;
818 }
819
820 for (int pot=0; pot<modelInfo.nPotentials; pot++) {
821 evals[pot].index[currentEvalIndices[pot]] = -1;
822 }
823 }
824
825 /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
826 ////// ABOVE: Evaluate Markov features that don't depend on length BELOW: evaluate semi-Markov features that depend on interval length
827 /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
828
829 /**
830 * The job of this function is to update the table "LengthFeatureEvaluation[][] lengthEvals"
831 *
832 * The first dimension is index by those states which have at least one explicit length node feature,
833 * corresponding to the state of an interval which ends at position currently under consideration.
834 * The size of this dimension is predetermined and is equal to modelInfo.statesWithLookback.length
835 *
836 * The second dimension is the length of the lookback, i.e. how many bases is
837 * the length of interval currently being considered. The size of this dimension is variable, and
838 * after the last of which one inserts a LengthFeatureEvaluation whose lookback is -1
839 *
840 * A LengthFeatureEvaluation contains the lookback and a FeatureEvaluation for nodes, plus and
841 * edgeEvaluation (which for now we set to null). The FeatureEvaluations we have seen before;
842 * they are comprised of index and value arrays, which are of variable length,
843 * and after last entry you put a -1 in the index array.
844 *
845 */
846
847 public void evaluateSegmentsEndingAt(int seq, int pos) {
848
849 int seqOffset = modelInfo.seqOffsets[seq];
850 int overallPosition = seqOffset+pos;
851 int tx1 = modelInfo.seqOffsets[seq]+pos + seq + 1; // ending position on subtraction lookup table, remember each sequence needs one extra position of padding.
852
853 CacheProcessor.StatePotentials[] statesWithLookback = modelInfo.statesWithLookback;
854 int nSemiMarkovStates = statesWithLookback.length;
855
856 int seqLen = data.get(seq).length();
857 int invalidIndex = overallPosition*modelInfo.nPotentials;
858 for (int stateIx=0; stateIx < nSemiMarkovStates; stateIx++) {
859 CacheProcessor.StatePotentials statePotentials = statesWithLookback[stateIx];
860 LengthFeatureEvaluation[] lookbackEvals = lengthEvals[stateIx];
861
862 if(!checkExit(invalidIndex, pos, seqLen, statePotentials.state)) {
863 lookbackEvals[0].lookback = -1;
864 continue;
865 }
866
867 // For say an interval of exon1 ending at specified position, what lookback need we provide?
868
869 // Well, lets say for instance that exons have a minimum length of 50 and maximum length of 200.
870 // Then I want to start at current pposition (lookback=0) and walk backwards, breaking if I ever
871 // encounter an invalidated node or edge, or when I reach the maximum allowable length for exon.
872 // If not, then as soon as I get to lookback=50, I
873 // start taking notes for every valid beginning (a valid edge leading into beginning of current interval)
874
875 // Now for one of these, I look back, and for every valid transition into I record the lookback, the
876 // relevant interval node features, both index and value, by subtracting some numbers at an offset.
877
878 int nLookbacks = 0;
879 int minLength = minStateLengths[statesWithLookback[stateIx].state];
880 int maxLength = maxStateLengths[statesWithLookback[stateIx].state];
881
882 for (int lookback = 0; lookback < maxLength; lookback++) {
883 int firstPosIndex = (overallPosition-lookback)*modelInfo.nPotentials;
884 if (lookback > pos || invalidTransitions[firstPosIndex + statePotentials.state]) { break; }
885 if (lookback+1 < minLength) { continue; }
886
887 int prevPos = pos - lookback - 1;
888 boolean validEntry = false;
889 if(prevPos == -1) {
890 validEntry = true;
891 }
892 else {
893 for (int pot : statePotentials.potentials) {
894 int entryIndex = (seqOffset+prevPos+1)*modelInfo.nPotentials;
895 if (modelInfo.selfTransitions[statePotentials.state]+modelInfo.nStates != pot && !invalidTransitions[entryIndex + pot]) {
896 validEntry = true;
897 break;
898 }
899 }
900 }
901
902 if(validEntry) {
903 int nEvals = 0;
904 LengthFeatureEvaluation lengthEval = lookbackEvals[nLookbacks];
905 lengthEval.lookback = (short) lookback;
906 lengthEval.edgeEvals = null;
907
908 FeatureEvaluation nodeEval = lengthEval.nodeEval;
909
910 // for this potential, there is a lookup table. It includes the offsets you need for subtraction etc.
911 for (DenseNodeBoundaryIndices db : denseNodeBoundaryEvalsList[statePotentials.state]) {
912 int index = db.featureIndex;
913 float[] lut = db.lookupTable;
914 int rightPad = db.rightPad;
915 int leftPad = db.leftPad;
916
917 float val = lut[tx1 - rightPad ] - lut[tx1-lookback-1 + leftPad];
918
919 nodeEval.index[nEvals] = (short) index;
920 nodeEval.value[nEvals] = val;
921 nEvals++;
922 }
923
924 // Add in length evaluations
925 LengthOnlyEvaluation lengthOnlyEval = lengthTables[statePotentials.state][lookback];
926 int size = lengthOnlyEval.featureIndex.length;
927 System.arraycopy(lengthOnlyEval.featureIndex, 0, nodeEval.index, nEvals, size);
928 System.arraycopy(lengthOnlyEval.value, 0, nodeEval.value, nEvals, size);
929 nEvals += size;
930
931 nodeEval.index[nEvals] = -1;
932 nLookbacks++;
933 }
934 }
935 lookbackEvals[nLookbacks].lookback = -1;
936 }
937 }
938
939 /** Checks if there is a valid transition out of a node */
940 boolean checkExit(int positionIndex, int pos, int seqLen, int state) {
941 if(invalidTransitions[positionIndex+state])
942 return false;
943
944 // This requires that we be in the last position or that there is a valid transition out.
945 if(pos == seqLen-1) {
946 return true;
947 }
948
949 boolean wayOut = false;
950 int nextPosIndex = positionIndex + modelInfo.nPotentials;
951 for(byte pot : modelInfo.exitTransitions[state]) {
952 if(modelInfo.selfTransitions[state]+modelInfo.nStates != pot && !invalidTransitions[nextPosIndex + pot]) {
953 wayOut = true;
954 break;
955 }
956 }
957 return wayOut;
958 }
959
960 /////////////////////////////////////////////////////////////////////////////////////////////////
961
962
963 public boolean[] getInvalidTransitions() {
964 return invalidTransitions;
965 }
966
967 public void setSemiMarkovSetup(SemiMarkovSetup setup) {
968 maxStateLengths = setup.getMaxLengths();
969 minStateLengths = setup.getMinLengths();
970 ignoreSemiMarkovSelfTransitions = setup.isIgnoreSemiMarkovSelfTransitions();
971 }
972
973 public boolean isAllPaths() {
974 return allPaths;
975 }
976
977
978 public void setAllPaths(boolean allPaths) {
979 this.allPaths = allPaths;
980 }
981
982 /**
983 * @return Returns the lookbackArrayFeatureSize.
984 */
985 public int getLookbackArrayFeatureSize() {
986 return lookbackArrayFeatureSize;
987 }
988
989 /**
990 * @param lookbackArrayFeatureSize The lookbackArrayFeatureSize to set.
991 */
992 public void setLookbackArrayFeatureSize(int lookbackArrayFeatureSize) {
993 this.lookbackArrayFeatureSize = lookbackArrayFeatureSize;
994 }
995
996 /**
997 * @return Returns the lookbackArraySize.
998 */
999 public int getLookbackArraySize() {
1000 return lookbackArraySize;
1001 }
1002
1003 /**
1004 * @param lookbackArraySize The lookbackArraySize to set.
1005 */
1006 public void setLookbackArraySize(int lookbackArraySize) {
1007 this.lookbackArraySize = lookbackArraySize;
1008 }
1009
1010 public boolean isIgnoreInvalidTrainingData() {
1011 return ignoreInvalidTrainingData;
1012 }
1013
1014 public void setIgnoreInvalidTrainingData(boolean ignoreInvalidTrainingData) {
1015 this.ignoreInvalidTrainingData = ignoreInvalidTrainingData;
1016 }
1017 }