001    package calhoun.analysis.crf.solver;
002    
003    import java.util.ArrayList;
004    import java.util.Arrays;
005    import java.util.HashSet;
006    import java.util.List;
007    
008    import org.apache.commons.logging.Log;
009    import org.apache.commons.logging.LogFactory;
010    
011    import calhoun.analysis.crf.CacheStrategySpec;
012    import calhoun.analysis.crf.CompositeFeatureManager;
013    import calhoun.analysis.crf.FeatureManager;
014    import calhoun.analysis.crf.FeatureManagerEdge;
015    import calhoun.analysis.crf.FeatureManagerNode;
016    import calhoun.analysis.crf.FeatureManagerNodeBoundaries;
017    import calhoun.analysis.crf.FeatureManagerNodeExplicitLength;
018    import calhoun.analysis.crf.ModelManager;
019    import calhoun.analysis.crf.SemiMarkovSetup;
020    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
021    import calhoun.analysis.crf.io.InputSequence;
022    import calhoun.analysis.crf.io.TrainingSequence;
023    import calhoun.analysis.crf.solver.check.ArrayFeatureList;
024    import calhoun.util.Assert;
025    import calhoun.util.CheckException;
026    
027    /** A policy based cache processor.  This is the main cache processor used.  It uses the cache
028     * policies specified by the feature managers to efficent cache all feature values.
029     * The distinct caching strategies recognized by this CacheProcessor are:
030     * <ul>
031     *  <ol> <b>COMPOSITE</b> - must drill down further to a non-composite FeatureManager to learn the strategy
032     *  <ol> <b>CONSTANT</b> - Triplets (potential, featureIndex, value) that do not depend on position
033     *  <ol> <b>DENSE</b> - Triplets (potential, featureIndex, lookupTable), where lookupTable is of size totalPositions and different
034     *    doublets (potential, featureIndex), if they always have the same value, can share the same lookupTable to save space.
035     *  <ol> <b>SPARSE</b> - Quadruplets (cumulativePosition, potential, featureIndex, value) of nonzero evaluation stored in order first by cumulativePosition,
036     *    and then arbitrarily, similar to the way the previous "FeatureCache" worked but without need to sub-order by potentials in "model order". 
037     *  <ol> <b>LENGTHFUNCTION</b> - Quadruplets (length, state, featureIndex, value). 
038     * <ol> Otherwise, treated as SPARSE
039     * </ul>
040     */
041    public class CacheProcessorDeluxe extends CacheProcessorBasic {
042            private static final Log log = LogFactory.getLog(CacheProcessorDeluxe.class);
043                    
044            private boolean ignoreInvalidTrainingData;
045            private boolean allPaths;
046            private boolean validHiddenStates = false;
047            
048            public boolean[] invalidTransitions;  // of size totalLength * nPotentials
049            
050            private int[] numFixedEvalIndices;  // of length nPotentials
051            
052            private boolean        forceStrategy;
053            private CacheStrategy  onlyStrategy;
054            
055            int lookbackArraySize = -1;
056            int lookbackArrayFeatureSize = -1;
057            
058            ///////////////////// Below holds raw information for Sparsely Cached Feature Evaluations
059            private ArrayList<FeatureManager> sparseFMList;
060            private ArrayList<FeatureManager> lengthFMList;
061            private int[] sparseStarts;      // of length (totalPositions + 1)
062            private ArrayList<SparseEvaluation> sparseEvals;
063            private class SparseEvaluation {
064                    short featureIndex;
065                    int   potential;
066                    float value;
067            }
068            // This is a global variable used only by evaluation of sparse features from cache
069            private int[] currentEvalIndices;   // of length nPotentials
070            
071            
072            //////////////////// Below holds raw information for the Constant Features
073            private boolean writtenYet;
074            private ArrayList<ConstantEvaluation> constEvals;
075            private class ConstantEvaluation {
076                    short featureIndex;
077                    int   potential;
078                    float value;
079                    int   evalIndex;
080            }
081            
082            
083            //////////////////////////////////////// Below holds raw information for DenseFeatureTables
084            private ArrayList<float[]> denseTables; // first dimension is which table; second dimension is of length modelInfo.totalPositions; faster to flip-flop?
085            // ********** Above variable is a memory hog
086            private ArrayList<DenseEvalIndices> denseEvalsList; // use this until finilization, at which time pack into array below with identical info
087            private class DenseEvalIndices {
088                    public float[]           lookupTable;
089                    public int               evalIndex;
090                    public FeatureEvaluation evalTable;   // a pointer to one of the tables in denseTables, several DenseEvalIndices can point to same table; the number of possible tables to choose from is nPotentials
091                    public int               potential;   // evalTable happens to be equal to evals.get(potential)
092                    public short             featureIndex;
093                    
094            }
095            
096            
097            //////////////// Below holds info for DenseNodeBoundary feature caches
098            private ArrayList<float[]> denseBoundaryTables;  // note that these tables need 1 extra unit of lengh for each sequence, so they are of length modelInfo.totalPositions + modelInfo.nSeqs
099            ArrayList<DenseNodeBoundaryIndices>[] denseNodeBoundaryEvalsList; // have a list for each potential
100    
101            private short[] minStateLengths;
102    
103            private boolean ignoreSemiMarkovSelfTransitions;
104            private class DenseNodeBoundaryIndices {
105                    public int featureIndex;
106                    public float[] lookupTable;
107                    public int rightPad;
108                    public int leftPad;
109                    public int potential;   
110            }
111            
112            //////////////// Below holds info for Length feature caches
113            private LengthOnlyEvaluation[][] lengthTables;
114            
115            private class LengthOnlyEvaluation {
116                    public short[] featureIndex;
117                    public float[] value;
118            }
119            
120            ////////////////////////////////////////////
121            ////// MEMBER VARIABLES ABOVE, METHODS BELOW
122            ////////////////////////////////////////////
123            
124            public CacheProcessorDeluxe() {
125                    super();
126                    log.debug("Calling Cache Processor Deluxe constructor");
127                    forceStrategy = false;
128            }
129            
130            public CacheProcessorDeluxe(CacheStrategy strategy) {
131                    super();
132                    log.warn("Calling CacheProcessorDeluxe constructor, specifying a strategy");
133                    forceStrategy = true;
134                    onlyStrategy = strategy;
135                    log.info("Constructed a cache processor and requiring it to always use the following cache strategy:");
136                    switch (strategy) {
137                    case CONSTANT:
138                            log.warn("CONSTANT");
139                    break;
140                    case DENSE:
141                            log.warn("DENSE");
142                    break;
143                    case SPARSE:
144                            log.warn("SPARSE");
145                    break;
146                    default:
147                            Assert.a(false,"ERROR - case not dealt with yet.");
148                    }
149            }
150            
151            @Override
152            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
153                    super.setTrainingData(fm, data);
154                    basicInit(allPaths);
155    
156                    if (maxStateLengths == null) {
157                            log.debug("maxStateLengths was not set, setting all to length 1");
158                            maxStateLengths = new short[fm.getNumStates()];
159                            Arrays.fill(maxStateLengths, (short) 1);
160                    }
161                    
162                    if (minStateLengths == null) {
163                            log.info("minStateLengths was not set, setting all to length 1");
164                            minStateLengths = new short[fm.getNumStates()];
165                            Arrays.fill(maxStateLengths, (short) 1);
166                    }
167                    
168                    modelInfo.setup(fm, data, allPaths, maxStateLengths, ignoreSemiMarkovSelfTransitions);
169                    
170                    modelInfo.maxLookback = 1;
171                    for(int stateLen : maxStateLengths) {
172                            modelInfo.maxLookback = (short) Math.max(stateLen, modelInfo.maxLookback);
173                    }
174    
175                    if(lookbackArraySize == -1)
176                            lookbackArraySize = modelInfo.maxLookback+2;
177                    if(lookbackArrayFeatureSize == -1)
178                            lookbackArrayFeatureSize = Math.max(5, modelInfo.nFeatures);
179                    lengthEvals = LengthFeatureEvaluation.create(modelInfo.statesWithLookback, lookbackArraySize, lookbackArrayFeatureSize);
180    
181                    validHiddenStates = data.get(0).getY(0) >= 0; 
182                    if(validHiddenStates) {
183                            // Verify that there are no invalid transitions in the training data.
184                            List<TrainingSequence<?>> goodData = new ArrayList<TrainingSequence<?>>();
185                            for (int i=0; i<data.size(); ++i) {
186                                    TrainingSequence seq = data.get(i);
187                                    short length = 1;                       
188                                    int lastState = -1;
189                                    boolean validSequence = true;
190                                    try {
191                                            for (int pos=0; pos < seq.length(); pos++) {
192                                                    int trainingState = seq.getY(pos);
193                                                    boolean segmentEnd = pos == seq.length()-1 || trainingState != seq.getY(pos+1);
194                                                    if(segmentEnd) {
195                                                            Assert.a(maxStateLengths[trainingState]==1 || (length <= maxStateLengths[trainingState] && length >= minStateLengths[trainingState]), "Seq #", i, " Pos ", pos, " Training segment length ", length, " state: "+trainingState+" outside the allowed length ", minStateLengths[trainingState], "-", maxStateLengths[trainingState]);
196                                                            Assert.a(lastState == -1 || modelInfo.transitionIndex.getQuick(lastState,trainingState) != -1);
197                                                            lastState = trainingState;
198                                                            length = 1;
199                                                    }
200                                                    else {
201                                                            ++length;
202                                                    }
203                                            }
204                                    }
205                                    catch(CheckException ex) {
206                                            if(ignoreInvalidTrainingData) {
207                                                    log.warn("Discarding sequence "+i+" "+ex.getMessage());
208                                                    validSequence = false;
209                                            }
210                                            else {
211                                                    throw ex;
212                                            }
213                                    }
214                                    if(validSequence) {
215                                            goodData.add(seq);
216                                    }
217                            }
218                            Assert.a(goodData.size() != 0, "All training sequences were invalid.");
219                            int diff = data.size() - goodData.size();
220                            if(diff != 0) {
221                                    log.warn("Using "+goodData.size()+" training sequences.  Discarded "+diff+" because of state length or transition problems.");
222                                    // Redo the initialization with the new set of training data
223                                    setTrainingData(fm, goodData);
224                                    return;
225                            }
226                    }
227                    
228                    initializeCacheProcessor();
229    
230                    updateCacheProcessor(fm);
231                    cacheLengthFeatureManagers();
232                    cacheSparseFeatureManagers();
233                    
234                    evaluateConstantFeatures();
235    
236                    if(validHiddenStates) {
237                            checkConstraintsInTrainingData();
238                            computeFeatureSums();
239                    }
240            }
241    
242    
243            private void initializeCacheProcessor() {
244                    invalidTransitions = new boolean[modelInfo.nPotentials*modelInfo.totalPositions];
245                    
246                    //featureSums = new double[fm.getNumFeatures()];
247                    
248                    numFixedEvalIndices = new int[modelInfo.nPotentials];
249                    currentEvalIndices = new int[modelInfo.nPotentials];
250                    
251                    denseEvalsList  = new ArrayList<DenseEvalIndices>();
252                    denseTables = new ArrayList<float[]>(); 
253                    
254                    //denseNodeBoundaryEvalsList = new ArrayList<DenseNodeBoundaryIndices>[modelInfo.nPotentials];
255                    denseBoundaryTables = new ArrayList<float[]>();
256                    denseNodeBoundaryEvalsList = new ArrayList[modelInfo.nPotentials];
257                    for (int pot=0; pot<modelInfo.nPotentials; pot++) {
258                            denseNodeBoundaryEvalsList[pot] = new   ArrayList<DenseNodeBoundaryIndices>();
259                    }
260                    
261                    constEvals      = new ArrayList<ConstantEvaluation>();
262                    writtenYet = false;
263                    
264                    sparseFMList = new ArrayList<FeatureManager>();
265                    sparseStarts = new int[modelInfo.totalPositions + 1];
266                    sparseEvals     = new ArrayList<SparseEvaluation>();
267                    
268                    lengthFMList = new ArrayList<FeatureManager>();
269                    lengthTables = new LengthOnlyEvaluation[modelInfo.nStates][];
270                    for(int i= 0; i<modelInfo.nStates; ++i) {
271                            lengthTables[i] = new LengthOnlyEvaluation[modelInfo.maxStateLengths[i]];
272                            for(int j=0; j<lengthTables[i].length; ++j) {
273                                    lengthTables[i][j] = new LengthOnlyEvaluation();
274                            }
275                    }
276            }
277            
278            
279            private void updateCacheProcessor(FeatureManager fm1) {
280                    log.debug("Calling updateCacheProcessor");
281                    CacheStrategySpec.CacheStrategy strategy;
282                    if (forceStrategy) {
283                            strategy = onlyStrategy;
284                    } else {
285                            strategy = fm1.getCacheStrategy().strategy;
286                    }
287                    switch (strategy) {
288                    case COMPOSITE:
289                            log.debug("  burrowing into a COMPOSITE feature manager");
290                    CompositeFeatureManager cfm = (CompositeFeatureManager) fm1;
291                    for (FeatureManager fmchild : cfm.getComponentFeatures()) {
292                            updateCacheProcessor(fmchild);
293                    }
294                    break;
295                    case DENSE:
296                            log.debug("  processing a feature manager as DENSE");
297                            cacheFeatureManagerAsDenseNode(fm1);
298                    break;
299                    case DENSE_NODE_BOUNDARY:
300                            log.debug("Processing a node as DENSE_NODE_BOUNDARY");
301                            cacheFeatureManagerAsDenseNodeBoundary(fm1);
302                    break;
303                    case CONSTANT:
304                            log.debug("  processing a feature manager as CONSTANT");
305                            cacheFeatureManagerAsConstant(fm1);
306                    break;
307                    case LENGTHFUNCTION:
308                            log.debug("  processing a feature manager as LENGTH");
309                            lengthFMList.add(fm1);
310                    break;
311                    case SPARSE:
312                    case UNSPECIFIED:
313                    default: 
314                            log.debug("  processing a feature manager as SPARSE");
315                            sparseFMList.add(fm1);
316                            if(fm1 instanceof FeatureManagerNodeExplicitLength) {
317                                    lengthFMList.add(fm1);
318                            }
319                    break;
320                    }
321            }
322            
323            
324            
325            InputSequence getSequence(InputSequence seq, FeatureManager argFm) {
326                    return argFm.getInputComponent() == null ? seq : seq.getComponent(argFm.getInputComponent());
327            }
328            
329            private void cacheSparseFeatureManagers() {
330                    sparseStarts[0] = 0;
331                    
332                    log.debug("We are caching this many feature managers (some may be composite) as sparse: " + sparseFMList.size());
333                    log.debug("Number of feature in each of the FMs is");
334                    for (FeatureManager fmb : sparseFMList) {
335                            log.debug("Number of features is " + fmb.getNumFeatures());
336                            if (FeatureManagerEdge.class.isInstance(fmb))  {log.debug("  Is edge featuremanager"); }
337                            if (FeatureManagerNode.class.isInstance(fmb))  {log.debug("  Is node featuremanager"); }
338                    }
339                    
340                    ArrayFeatureList result = new ArrayFeatureList(fm);                             
341                    for (int i=0; i<data.size(); i++) {
342                            TrainingSequence seq = data.get(i);
343                            for (int pos=0; pos<seq.length(); pos++) {
344                                    int cumulativePosition = modelInfo.seqOffsets[i] + pos;
345                                    
346                                    for (int potential=0; potential< modelInfo.nPotentials; potential++) {
347                                            result.clear();                         
348                                            
349                                            int state;
350                                            int prevState;
351                                            
352                                            
353                                            if (potential<modelInfo.nStates) {
354                                                    state = potential;
355                                                    for (FeatureManager fmb : sparseFMList) {
356                                                            if (!FeatureManagerNode.class.isInstance(fmb)) { continue; }                                                    
357                                                            
358                                                            FeatureManagerNode fmn = (FeatureManagerNode) fmb;
359                                                            fmn.evaluateNode(getSequence(seq, fmb),pos,state,result);
360                                                    }
361                                            } else if (pos>0) {
362                                                    prevState = modelInfo.transitionFrom[potential - modelInfo.nStates];
363                                                    state = modelInfo.transitionTo[potential - modelInfo.nStates];
364                                                    for (FeatureManager fm1 : sparseFMList) {
365                                                            if (!FeatureManagerEdge.class.isInstance(fm1))  {continue; }
366                                                            FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
367                                                            
368                                                            fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);                             
369                                                    }
370                                            }
371                                            
372                                            
373                                            if (result.isValid()) {
374                                                    for (int j=0; j< result.size(); j++) {
375                                                            float value = (float) result.getValue(j); 
376                                                            if(value != 0.0) {
377                                                                    SparseEvaluation se = new SparseEvaluation();
378                                                                    se.featureIndex = (short) result.getIndex(j);
379                                                                    se.potential = potential;
380                                                                    se.value = value;
381                                                                    sparseEvals.add(se);
382                                                            }
383                                                    }
384                                            } else {
385                                                    int ind = cumulativePosition*modelInfo.nPotentials + potential;
386                                                    invalidTransitions[ind] = true;
387                                            }
388                                    }
389                                    sparseStarts[cumulativePosition+1] = sparseEvals.size();
390                            }
391                    }
392            }
393            
394            private void cacheFeatureManagerAsConstant(FeatureManager fm1) {
395                    CacheStrategySpec s = (CacheStrategySpec) fm1.getCacheStrategy();
396                    if (s.strategy != CacheStrategy.CONSTANT) {
397                            log.warn("Caching a Feature Manager as CONSTANT even though it requested a different strategy");
398                    }
399                    
400                    InputSequence seq = data.get(0);
401                    int pos = 0;
402                    
403                    ArrayFeatureList result = new ArrayFeatureList(fm);
404                    for (int pot=0; pot< modelInfo.nPotentials; pot++) {
405                            result.clear();
406                            
407                            int state;
408                            int prevState;
409                            
410                            if (pot<modelInfo.nStates) {
411                                    if (!FeatureManagerNode.class.isInstance(fm1)) { continue; }
412                                    state = pot;
413                                    FeatureManagerNode fmn = (FeatureManagerNode) fm1;
414                                    fmn.evaluateNode(getSequence(seq, fmn),pos,state,result);                               
415                            } else {
416                                    if (!FeatureManagerEdge.class.isInstance(fm1)) { continue; }
417                                    FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
418                                    prevState = modelInfo.transitionFrom[pot - modelInfo.nStates];
419                                    state = modelInfo.transitionTo[pot - modelInfo.nStates];
420                                    fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
421                            }
422                            
423                            Assert.a(result.isValid(),"Constant features should not have to encounter invalid potentials");
424    
425                            for (int j=0; j<result.size(); j++) {
426                                    ConstantEvaluation ce = new ConstantEvaluation();
427                                    ce.featureIndex = (short) result.getIndex(j);
428                                    ce.potential = pot;
429                                    ce.value = (float) result.getValue(j);
430                                    ce.evalIndex = numFixedEvalIndices[pot];
431                                    numFixedEvalIndices[pot]++;
432                                    constEvals.add(ce);
433                            }
434                    }
435            }
436            
437            class FeaturePotential {
438                    int featureIndex;
439                    int potential;
440                    @Override
441                    public int hashCode() {
442                            return potential * modelInfo.nPotentials + featureIndex;
443                    }
444                    @Override
445                    public boolean equals(Object rhs) {
446                            FeaturePotential f2 = (FeaturePotential) rhs;
447                            return featureIndex == f2.featureIndex && potential == f2.potential;
448                    }
449            }
450            
451            private void cacheFeatureManagerAsDenseNode(FeatureManager fm1) {
452                    CacheStrategySpec s = (CacheStrategySpec) fm1.getCacheStrategy();
453                    if (!(s.strategy == CacheStrategy.DENSE)) {
454                            log.warn("The evaluations are being cached using the DENSE strategy even though that was not the strategy requested by the feature.");
455                    }
456                    
457                    CacheStrategySpec.DenseCachingDetails details = (CacheStrategySpec.DenseCachingDetails) s.details;
458                    
459                    if (details == null) { 
460                            if(fm1.getNumFeatures() > 0) {
461                                    log.debug("No details specified for "+fm1+".  Using evaluations to find the correct nodes");
462                            }
463                            
464                            // Iterate through all sequences and positions to find the feature/potential combinations used
465                            // This replaces code that used to do it based on the first position of the first sequence with length > 2.
466                            HashSet<FeaturePotential> usedPotentials = new HashSet();
467                            ArrayFeatureList result = new ArrayFeatureList(fm);
468                            FeaturePotential current = new FeaturePotential();
469                            for (InputSequence seq : data) {
470                                    for(int pos = 0; pos < seq.length(); ++pos) {
471                                            for (int pot=0; pot< modelInfo.nPotentials; pot++) {
472                                                    result.clear();
473                                                    if (pot<modelInfo.nStates) {
474                                                            if (!FeatureManagerNode.class.isInstance(fm1)) { continue; }
475                                                            FeatureManagerNode fmn = (FeatureManagerNode) fm1;
476                                                            fmn.evaluateNode(getSequence(seq, fmn),pos,pot,result);                         
477                                                    } else {
478                                                            if (!FeatureManagerEdge.class.isInstance(fm1)) { continue; }
479                                                            FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
480                                                            int prevState = modelInfo.transitionFrom[pot - modelInfo.nStates];
481                                                            int state = modelInfo.transitionTo[pot - modelInfo.nStates];
482                                                            fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
483                                                    }
484                                                    for (int j=0; j<result.size(); j++) {
485                                                            current.featureIndex = result.getIndex(j); 
486                                                            current.potential = pot;
487                                                            if(!usedPotentials.contains(current)) {
488                                                                    log.debug("Adding dense node - pot: "+current.potential+" Feat: "+current.featureIndex);
489                                                                    usedPotentials.add(current);
490                                                                    current = new FeaturePotential();
491                                                            }
492                                                    }
493                                            }
494                                    }
495                            }
496    
497                            // Now that we have the list of potentials that will be used, go and create the details.
498                            details = new CacheStrategySpec.DenseCachingDetails();
499                            ArrayList<Integer> densePotList = new ArrayList<Integer>();
500                            ArrayList<Integer> denseTabList = new ArrayList<Integer>();                 
501                            ArrayList<Short> denseFiList = new ArrayList<Short>();
502                            details.nEvals = usedPotentials.size();
503                            details.nTables = 0;
504    
505                            for(FeaturePotential fp : usedPotentials) {
506                                    densePotList.add(fp.potential);
507                                    denseTabList.add(details.nTables);
508                                    denseFiList.add((short) fp.featureIndex);
509                                    details.nTables++;
510                            }               
511    
512                            details.featureIndex = new short[details.nEvals];
513                            details.tableNum     = new int[details.nEvals];
514                            details.potential    = new int[details.nEvals];
515                            
516                            for (int j=0; j<details.nEvals; j++) {
517                                    details.featureIndex[j] = denseFiList.get(j);
518                                    details.potential[j] = densePotList.get(j);
519                                    details.tableNum[j] = denseTabList.get(j);
520                            }
521                            details.check();
522                    } 
523    
524                    int firstTableIndex = denseTables.size();
525                    details.check();
526                    for (int j=0; j<details.nTables; j++) {
527                            float[] temp = new float[modelInfo.totalPositions];
528                            for (int k=0; k<modelInfo.totalPositions; k++) {
529                                    temp[k] = 0;
530                            }
531                            denseTables.add(temp);
532                    }
533                    
534                    for (int j=0; j<details.nEvals; j++) {
535                            DenseEvalIndices de = new DenseEvalIndices();
536                            
537                            int potential = details.potential[j];
538                            de.evalTable = evals[potential];
539                            de.potential = potential;
540                            de.evalIndex = numFixedEvalIndices[potential];
541                            numFixedEvalIndices[potential]++;
542                            de.lookupTable = denseTables.get(firstTableIndex + details.tableNum[j]);
543                            de.featureIndex = details.featureIndex[j];
544                            denseEvalsList.add(de);
545                    }
546                    
547                    
548                    ArrayFeatureList result = new ArrayFeatureList(fm);                             
549                    for (int i=0; i<data.size(); i++) {
550                            TrainingSequence seq = data.get(i);
551                            int actualPreviousState = -1;
552                            for (int pos=0; pos<seq.length(); pos++) {
553                                    int actualState = seq.getY(pos);
554                                    
555                                    int cumulativePosition = modelInfo.seqOffsets[i] + pos;
556                                    
557                                    for (int potential=0; potential< modelInfo.nPotentials; potential++) {
558                                            result.clear();                         
559                                            
560                                            int state;
561                                            int prevState;
562                                            
563                                            
564                                            if (potential<modelInfo.nStates) {
565                                                    if (!FeatureManagerNode.class.isInstance(fm1)) { continue; }
566                                                    state = potential;
567                                                    FeatureManagerNode fmn = (FeatureManagerNode) fm1;
568                                                    fmn.evaluateNode(getSequence(seq, fmn),pos,state,result);
569                                                    if (state==actualState) {
570                                                            for (int j=0; j< result.size(); j++) {
571    //                                                              featureSums[result.getIndex(j)] += result.getValue(j); 
572                                                            }                                                               
573                                                    }
574                                            } else if (pos>0) {
575                                                    if (!FeatureManagerEdge.class.isInstance(fm1)) { continue; }
576                                                    FeatureManagerEdge fme = (FeatureManagerEdge) fm1;
577                                                    prevState = modelInfo.transitionFrom[potential - modelInfo.nStates];
578                                                    state = modelInfo.transitionTo[potential - modelInfo.nStates];
579                                                    fme.evaluateEdge(getSequence(seq, fme),pos,prevState,state,result);
580                                                    if ((state==actualState) && (prevState==actualPreviousState)) {
581                                                            for (int j=0; j< result.size(); j++) {
582    //                                                              featureSums[result.getIndex(j)] += result.getValue(j); 
583                                                            }                                                               
584                                                    }
585                                            }
586                                            
587                                            
588                                            if (result.isValid()) {
589                                                    
590                                                    for (int j=0; j<result.size(); j++) {
591                                                            for (int ev=0; ev<details.nEvals; ev++) {
592                                                                    DenseEvalIndices de = denseEvalsList.get(denseEvalsList.size()-1-ev);
593                                                                    if ((potential==de.potential) && (result.getIndex(j)==de.featureIndex)) {
594                                                                            de.lookupTable[modelInfo.seqOffsets[i] + pos] = (float) result.getValue(j);
595                                                                            break;
596                                                                    }
597                                                            }
598                                                    }
599                                                    
600                                                    
601                                            } else {
602                                                    int ind = cumulativePosition*modelInfo.nPotentials + potential;
603                                                    invalidTransitions[ind] = true;
604                                            }
605                                    }
606                                    actualPreviousState = actualState;
607                            }
608                    }
609                    
610                    
611            }
612            
613            
614            private void cacheFeatureManagerAsDenseNodeBoundary(FeatureManager fm1) {
615                    // NOTE: This does not calculate feature sums for the semi-Markov features.  I think it would be best
616                    // to have a separate function do this which calls the CacheProcessor but which is separate and not interwoven with each part of CacheProcessor.
617                    
618                    CacheStrategySpec s = (CacheStrategySpec) fm1.getCacheStrategy();
619                    Assert.a(s.strategy == CacheStrategy.DENSE_NODE_BOUNDARY,"ERROR: Cannot cache using thie DENSE_NODE_BOUNDARY strategy unless specifically requested by the feature.");
620                    
621                    CacheStrategySpec.DenseBoundaryCachingDetails details = (CacheStrategySpec.DenseBoundaryCachingDetails) s.details;      
622                    Assert.a(details != null,"ERROR -- the Cache Strategy DenseNodeBoundary is too intricate to omit specifying details");
623                    
624                    int firstTableIndex = denseBoundaryTables.size();
625                    details.check();
626                    for (int j=0; j<details.nTables; j++) {
627                            float[] temp = new float[modelInfo.totalPositions+modelInfo.nSeqs];
628                            for (int k=0; k<modelInfo.totalPositions+modelInfo.nSeqs; k++) {
629                                    temp[k] = 0;
630                            }
631                            denseBoundaryTables.add(temp);
632                    }
633                    
634                    ArrayList<DenseNodeBoundaryIndices> tempNodeBoundaries = new ArrayList<DenseNodeBoundaryIndices>();
635                    //for (int j=0; j<details.entries.size(); j++) {
636                    for (CacheStrategySpec.DenseBoundaryEntry be : details.entries) {
637                            DenseNodeBoundaryIndices db = new DenseNodeBoundaryIndices();
638                            
639                            int potential = be.potential; //details.potential[j];
640                            Assert.a(potential<modelInfo.nStates,"ERROR - Can't cache potential " + potential + " , which is a transition using the DENSE_NODE_BOUNDARY caching strategy.");
641                            //numFixedEvalIndices[potential]++;
642                            db.lookupTable = denseBoundaryTables.get(firstTableIndex + be.tableNum); //details.tableNum[j]);
643                            db.featureIndex = be.featureIndex; //details.featureIndex[j];
644                            db.potential = potential;
645                            db.rightPad = be.rightPad; //details.rightPad[j];
646                            db.leftPad = be.leftPad; //details.leftPad[j];
647                            denseNodeBoundaryEvalsList[potential].add(db);
648                            log.debug("Node boundary feature: "+db.featureIndex+" pot: "+db.potential+" table: "+db.lookupTable);
649                            tempNodeBoundaries.add(db);
650                    }
651                    Assert.a(tempNodeBoundaries.size() == details.entries.size());
652                    
653                    
654                    for (int seqNum=0; seqNum<data.size(); seqNum++) {
655                            TrainingSequence seq = data.get(seqNum);
656                            for (int pos=0; pos<seq.length(); pos++) {
657                                    
658                                    int cumulativePosition = modelInfo.seqOffsets[seqNum] + pos;
659                                    
660                                    for (int state=0; state< modelInfo.nStates; state++) {
661                                            ArrayFeatureList result = new ArrayFeatureList(fm);                             
662                                            Assert.a(FeatureManagerNodeBoundaries.class.isInstance(fm1),"ERROR - to cache using DENSE NODE BOUNDARY, must be an instance of FeatureManagerNodeBoundaries, but this isn't.");
663                                            FeatureManagerNode fmn = (FeatureManagerNode) fm1;
664                                            fmn.evaluateNode(getSequence(seq, fmn),pos,state,result);
665                                            
666                                            if (result.isValid()) { 
667                                                    for (int j=0; j<result.size(); j++) {
668                                                            boolean found = false;
669                                                            for (DenseNodeBoundaryIndices db : tempNodeBoundaries ) {       
670                                                                    if ((state==db.potential) && (result.getIndex(j)==db.featureIndex)) {
671                                                                            int tx = modelInfo.seqOffsets[seqNum] + seqNum + pos;
672                                                                            db.lookupTable[tx + 1] = db.lookupTable[tx] + (float) result.getValue(j);
673                                                                            found = true;
674                                                                            break;
675                                                                    }
676                                                            }
677                                                            if(!found) {
678                                                                    Assert.a(false, String.format("Feature Ix: %d State: %d not found. ", result.getIndex(j), state));
679                                                            }
680                                                    }
681                                            } 
682                                            
683                                            if (!result.isValid()) {
684                                                    int ind = cumulativePosition*modelInfo.nPotentials + state;
685                                                    invalidTransitions[ind] = true;
686                                            }
687                                    }
688                            }
689                    }       
690            }
691            
692            private void cacheLengthFeatureManagers() {
693                    // Iterate through each length for each state
694                    ArrayFeatureList featureList = new ArrayFeatureList(fm);
695                    for(CacheProcessor.StatePotentials stateInfo : modelInfo.statesWithLookback) {
696                            LengthOnlyEvaluation[] lengthForState = lengthTables[stateInfo.state];
697    
698                            int maxLen = modelInfo.maxStateLengths[stateInfo.state];
699                            for(int i=0; i<maxLen; ++i) {
700                                    LengthOnlyEvaluation eval = lengthForState[i];
701                                    
702                                    // Fill the evaluation object with the correct info for the state.
703                                    featureList.clear();
704                                    for(FeatureManagerNodeExplicitLength fm : (List<FeatureManagerNodeExplicitLength>) (List) lengthFMList) {
705                                            fm.evaluateNodeLength(data.get(0),maxLen, i+1, stateInfo.state, featureList);
706                                    }
707                                    int size = featureList.currentSize;
708                                    eval.featureIndex = new short[size]; 
709                                    eval.value = new float[size];
710                                    for(int j=0; j<size; ++j) {
711                                            eval.featureIndex[j] = (short) featureList.indices[j];
712                                            eval.value[j] = (float) featureList.values[j];
713                                    }
714                            }
715                    }
716            }
717            
718            public void checkConstraintsInTrainingData() {
719                    // Verify that there are no invalid transitions in the training data.
720                    List<TrainingSequence<?>> goodData = new ArrayList<TrainingSequence<?>>();
721    
722                    for (int i=0; i<data.size(); ++i) {
723                            TrainingSequence seq = data.get(i);
724                            int seqOffset = modelInfo.seqOffsets[i];
725                            int lastState = -1;
726                            boolean validSequence = true;
727                            try {
728                                    for (int pos=0; pos < seq.length(); pos++) {
729                                            int trainingState = seq.getY(pos);
730                                            int index = (seqOffset + pos) * modelInfo.nPotentials;
731                                            Assert.a(!invalidTransitions[index + trainingState], "Seq: ",i," Pos: ", pos, " State: ",trainingState, " violates a constraint.");
732                                            if(lastState != -1)
733                                                    Assert.a(!invalidTransitions[index + modelInfo.nStates + modelInfo.transitionIndex.getQuick(lastState, trainingState)], "Seq: ",i," Pos: ", pos, " Transition: ",lastState,"-",trainingState, " violates a constraint.");
734                                            
735                                            lastState = trainingState;
736                                    }
737                            }
738                            catch(CheckException ex) {
739                                    if(ignoreInvalidTrainingData) {
740                                            log.warn("Discarding sequence "+i+" "+ex.getMessage());
741                                            validSequence = false;
742                                    }
743                                    else {
744                                            throw ex;
745                                    }
746                            }
747                            if(validSequence) {
748                                    goodData.add(seq);
749                            }
750                    }
751                    Assert.a(goodData.size() != 0, "All training sequences were invalid.");
752                    int diff = data.size() - goodData.size();
753                    if(diff != 0) {
754                            log.warn("Using "+goodData.size()+" training sequences.  Discarded "+diff+" because of constraint problems.");
755                            log.warn("Rebuilding the feature cache using the remaining good sequences.");
756                            // Redo the initialization with the new set of training data
757                            setTrainingData(fm, goodData);
758                            return;
759                    }
760            }
761    
762            ///////////////////////////////////////////////////////////////////////////////
763            ///// ABOVE: build the cache structure    BELOW: evaluate a position using the stored cache
764            ///////////////////////////////////////////////////////////////////////////////
765            
766            
767            public void evaluatePosition(int seq, int pos) {
768                    // The job of this function is to update evals.index and evals.value based on all features and potentials at this location
769                    // Recall below the comments from another class about the information in evals:
770                    
771                    int cumulativePosition = modelInfo.seqOffsets[seq] + pos;               
772                    evaluateDenseFeatures(cumulativePosition);
773                    evaluateSparseFeatures(cumulativePosition); // only variable length
774            }
775            
776            private void evaluateConstantFeatures() {
777                    Assert.a(!writtenYet, "Constant features can only be written once");
778                    for ( ConstantEvaluation ce : constEvals) {
779                            short fi    = ce.featureIndex;
780                            int pi      = ce.potential;
781                            int ci      = ce.evalIndex;
782                            float val   = ce.value;
783                            
784                            evals[pi].index[ci] = fi;
785                            evals[pi].value[ci] = val;                      
786                    }
787                    writtenYet = true;
788            }
789            
790            
791            private void evaluateDenseFeatures(int cumulativePosition) {            
792                    for (DenseEvalIndices de : denseEvalsList) {                    
793                            FeatureEvaluation fe  = de.evalTable;//   = densePotentialIndices[j];
794                            int               ei  = de.evalIndex;//denseEvalIndices[j];
795                            float[]           lut = de.lookupTable;//    = denseTableIndices[j];
796                            fe.index[ei] = de.featureIndex;
797                            fe.value[ei] = lut[cumulativePosition];
798                    }
799            }
800            
801            
802            private void evaluateSparseFeatures(int cumulativePosition) {
803                    for (int pot=0; pot<modelInfo.nPotentials; pot++) {
804                            currentEvalIndices[pot] = numFixedEvalIndices[pot];
805                    }
806                    int start = sparseStarts[cumulativePosition];
807                    int stop = sparseStarts[cumulativePosition+1];
808                    for (int j=start; j<stop; j++) {
809                            SparseEvaluation se = sparseEvals.get(j);
810                            short fi   = se.featureIndex;
811                            int pi     = se.potential;
812                            int ci     = currentEvalIndices[pi];
813                            currentEvalIndices[pi]++;
814                            float val  = se.value;
815                            
816                            evals[pi].index[ci] = fi;
817                            evals[pi].value[ci] = val;                      
818                    }
819                    
820                    for (int pot=0; pot<modelInfo.nPotentials; pot++) {
821                            evals[pot].index[currentEvalIndices[pot]] = -1;
822                    }
823            }
824            
825            /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
826            //////  ABOVE: Evaluate Markov features that don't depend on length   BELOW: evaluate semi-Markov features that depend on interval length
827            /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
828            
829            /**
830             *  The job of this function is to update the table "LengthFeatureEvaluation[][] lengthEvals"
831             *  
832             *  The first dimension is index by those states which have at least one explicit length node feature,
833             *  corresponding to the state of an interval which ends at position currently under consideration.
834             *  The size of this dimension is predetermined and is equal to modelInfo.statesWithLookback.length
835             *  
836             *  The second dimension is the length of the lookback, i.e. how many bases is
837             *  the length of interval currently being considered.  The size of this dimension is variable, and
838             *  after the last of which one inserts a LengthFeatureEvaluation whose lookback is -1
839             * 
840             *  A LengthFeatureEvaluation contains the lookback and a FeatureEvaluation for nodes, plus and
841             *  edgeEvaluation (which for now we set to null).  The FeatureEvaluations we have seen before;
842             *  they are comprised of index and value arrays, which are of variable length,
843             *  and after last entry you put a -1 in the index array.
844             * 
845             */
846            
847            public void evaluateSegmentsEndingAt(int seq, int pos) {
848                    
849                    int seqOffset = modelInfo.seqOffsets[seq];
850                    int overallPosition = seqOffset+pos;
851                    int tx1 = modelInfo.seqOffsets[seq]+pos + seq + 1;  // ending position on subtraction lookup table, remember each sequence needs one extra position of padding.
852                    
853                    CacheProcessor.StatePotentials[] statesWithLookback = modelInfo.statesWithLookback;
854                    int nSemiMarkovStates = statesWithLookback.length;
855    
856                    int seqLen = data.get(seq).length();
857                    int invalidIndex = overallPosition*modelInfo.nPotentials; 
858                    for (int stateIx=0; stateIx < nSemiMarkovStates; stateIx++) {
859                            CacheProcessor.StatePotentials statePotentials = statesWithLookback[stateIx];
860                            LengthFeatureEvaluation[] lookbackEvals = lengthEvals[stateIx];
861    
862                            if(!checkExit(invalidIndex, pos, seqLen, statePotentials.state)) {
863                                    lookbackEvals[0].lookback = -1;
864                                    continue;
865                            }
866                            
867                            // For say an interval of exon1 ending at specified position, what lookback need we provide?
868                            
869                            // Well, lets say for instance that exons have a minimum length of 50 and maximum length of 200.
870                            // Then I want to start at current pposition (lookback=0) and walk backwards, breaking if I ever
871                            // encounter an invalidated node or edge, or when I reach the maximum allowable length for exon.
872                            // If not, then as soon as I get to lookback=50, I
873                            // start taking notes for every valid beginning (a valid edge leading into beginning of current interval)
874                            
875                            // Now for one of these, I look back, and for every valid transition into I record the lookback, the
876                            // relevant interval node features, both index and value, by subtracting some numbers at an offset.
877                            
878                            int nLookbacks = 0;
879                            int minLength = minStateLengths[statesWithLookback[stateIx].state];
880                            int maxLength = maxStateLengths[statesWithLookback[stateIx].state];
881    
882                            for (int lookback = 0; lookback < maxLength; lookback++) {
883                                    int firstPosIndex = (overallPosition-lookback)*modelInfo.nPotentials;
884                                    if (lookback > pos || invalidTransitions[firstPosIndex + statePotentials.state]) { break; }
885                                    if (lookback+1 < minLength) { continue; }
886                                    
887                                    int prevPos = pos - lookback - 1;
888                                    boolean validEntry = false;
889                                    if(prevPos == -1) {
890                                            validEntry = true;
891                                    }
892                                    else {
893                                            for (int pot : statePotentials.potentials) {
894                                                    int entryIndex = (seqOffset+prevPos+1)*modelInfo.nPotentials;
895                                                    if (modelInfo.selfTransitions[statePotentials.state]+modelInfo.nStates != pot && !invalidTransitions[entryIndex + pot]) {
896                                                            validEntry = true;
897                                                            break; 
898                                                    }
899                                            }
900                                    }
901                                    
902                                    if(validEntry) {
903                                            int nEvals = 0;
904                                            LengthFeatureEvaluation lengthEval = lookbackEvals[nLookbacks];
905                                            lengthEval.lookback  = (short) lookback;
906                                            lengthEval.edgeEvals = null;
907                                            
908                                            FeatureEvaluation nodeEval = lengthEval.nodeEval;
909    
910                                                    // for this potential, there is a lookup table.  It includes the offsets you need for subtraction etc.
911                                            for (DenseNodeBoundaryIndices db : denseNodeBoundaryEvalsList[statePotentials.state]) {
912                                                    int index    = db.featureIndex;
913                                                    float[] lut  = db.lookupTable;
914                                                    int rightPad = db.rightPad;
915                                                    int leftPad  = db.leftPad;
916                                                    
917                                                    float val = lut[tx1 - rightPad ] - lut[tx1-lookback-1 + leftPad];
918                                                    
919                                                    nodeEval.index[nEvals] = (short) index;
920                                                    nodeEval.value[nEvals] = val;
921                                                    nEvals++;
922                                            }
923    
924                                            // Add in length evaluations
925                                            LengthOnlyEvaluation lengthOnlyEval = lengthTables[statePotentials.state][lookback];
926                                            int size = lengthOnlyEval.featureIndex.length;
927                                            System.arraycopy(lengthOnlyEval.featureIndex, 0, nodeEval.index, nEvals, size);
928                                            System.arraycopy(lengthOnlyEval.value, 0, nodeEval.value, nEvals, size);
929                                            nEvals += size;
930                                            
931                                            nodeEval.index[nEvals] = -1;                                            
932                                            nLookbacks++;
933                                    }
934                            }
935                            lookbackEvals[nLookbacks].lookback = -1;
936                    }
937            }
938            
939            /** Checks if there is a valid transition out of a node */
940            boolean checkExit(int positionIndex, int pos, int seqLen, int state) {
941                    if(invalidTransitions[positionIndex+state])
942                            return false;
943    
944                    // This requires that we be in the last position or that there is a valid transition out.
945                    if(pos == seqLen-1) {
946                            return true;
947                    }
948                    
949                    boolean wayOut = false;
950                    int nextPosIndex = positionIndex +  modelInfo.nPotentials;
951                    for(byte pot : modelInfo.exitTransitions[state]) {
952                            if(modelInfo.selfTransitions[state]+modelInfo.nStates != pot && !invalidTransitions[nextPosIndex + pot]) {
953                                    wayOut = true;
954                                    break;
955                            }
956                    }
957                    return wayOut; 
958            }
959            
960            /////////////////////////////////////////////////////////////////////////////////////////////////
961            
962            
963            public boolean[] getInvalidTransitions() {
964                    return invalidTransitions;
965            }
966            
967            public void setSemiMarkovSetup(SemiMarkovSetup setup) {
968                    maxStateLengths = setup.getMaxLengths();
969                    minStateLengths = setup.getMinLengths();
970                    ignoreSemiMarkovSelfTransitions = setup.isIgnoreSemiMarkovSelfTransitions();
971            }
972                    
973            public boolean isAllPaths() {
974                    return allPaths;
975            }
976            
977            
978            public void setAllPaths(boolean allPaths) {
979                    this.allPaths = allPaths;
980            }
981    
982            /**
983             * @return Returns the lookbackArrayFeatureSize.
984             */
985            public int getLookbackArrayFeatureSize() {
986                    return lookbackArrayFeatureSize;
987            }
988    
989            /**
990             * @param lookbackArrayFeatureSize The lookbackArrayFeatureSize to set.
991             */
992            public void setLookbackArrayFeatureSize(int lookbackArrayFeatureSize) {
993                    this.lookbackArrayFeatureSize = lookbackArrayFeatureSize;
994            }
995    
996            /**
997             * @return Returns the lookbackArraySize.
998             */
999            public int getLookbackArraySize() {
1000                    return lookbackArraySize;
1001            }
1002    
1003            /**
1004             * @param lookbackArraySize The lookbackArraySize to set.
1005             */
1006            public void setLookbackArraySize(int lookbackArraySize) {
1007                    this.lookbackArraySize = lookbackArraySize;
1008            }
1009    
1010            public boolean isIgnoreInvalidTrainingData() {
1011                    return ignoreInvalidTrainingData;
1012            }
1013    
1014            public void setIgnoreInvalidTrainingData(boolean ignoreInvalidTrainingData) {
1015                    this.ignoreInvalidTrainingData = ignoreInvalidTrainingData;
1016            }
1017    }