001    package calhoun.analysis.crf.solver;
002    
003    import java.util.List;
004    
005    import calhoun.analysis.crf.FeatureList;
006    import calhoun.analysis.crf.ModelManager;
007    import calhoun.analysis.crf.io.InputSequence;
008    import calhoun.analysis.crf.io.TrainingSequence;
009    import calhoun.util.Assert;
010    
011    /** a dummy cache processor that fulfills the interface but doesn't cache.  It always retrieves values
012     * by calling the {@link calhoun.analysis.crf.FeatureManager} <code>evlaute</code> functions.  <b>Does not work with 
013     * NodeBoundary features.</b>
014     */
015    public class NoCachingCacheProcessor extends CacheProcessorBasic {
016            //private static final Log log = LogFactory.getLog(NoCachingCacheProcessor.class);
017    
018            boolean allPaths;
019            
020            /// Cached values of the sums of each feature value through the whole training set.
021            double[] featureSumsLocal;
022            boolean[] invalidTransitions;
023            DirectFeatureList result;
024    
025            class DirectFeatureList implements FeatureList {
026                    FeatureEvaluation evals1;
027                    public int position;
028                    boolean valid;
029                    
030                    public DirectFeatureList() {
031                    }
032                    
033                    public void addFeature(int index, double val) {
034                            evals1.index[position] = (short) index;
035                            evals1.value[position++] = (float) val;
036                    }
037    
038                    /** Returns the invalid flag. */
039                    public boolean isValid() {
040                            return valid;
041                    }
042    
043                    /** Invalidates results. */
044                    public void invalidate() {
045                            valid = false;
046                    }
047            }
048            
049            /** true if all paths (valid and invalid) are to be evaluated during the viterbi search.  Defaults to false.
050             * @return true if all paths are to be examined
051             */
052            public boolean isAllPaths() {
053                    return allPaths;
054            }
055    
056            /** sets whether all paths (valid and invalid) are to be evaluated during the viterbi search.  Defaults to false.
057             * @param allPaths allPath true if all paths are to be examined
058             */
059            public void setAllPaths(boolean allPaths) {
060                    this.allPaths = allPaths;
061            }
062    
063            @Override
064            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
065                    super.setTrainingData(fm, data);
066                    basicInit(allPaths);
067                    result = new DirectFeatureList();
068                    
069                    invalidTransitions = new boolean[modelInfo.nPotentials*modelInfo.totalPositions];
070                    calcFeatureSums();
071            }
072    
073            public boolean[] getInvalidTransitions() {
074                    return invalidTransitions;
075            }
076    
077            @Override
078            public double[] getFeatureSums() {
079                    return featureSumsLocal;
080            }
081    
082            void calcFeatureSums() {
083                    featureSumsLocal = new double[modelInfo.nFeatures];
084                    for(int seqNum = 0; seqNum < modelInfo.nSeqs; ++seqNum) {
085                            TrainingSequence train = (TrainingSequence) data.get(seqNum);
086                            int len = train.length();
087                            int previousState = -1;
088                            for(int pos = 0; pos < len; ++pos) {
089                                    evaluatePosition(seqNum, pos);
090                                    int state = train.getY(pos);
091                                    int i = 0;
092                                    FeatureEvaluation potEval = evals[state];
093                                    int index = potEval.index[i];
094                                    while(index != -1) {
095                                            float val = potEval.value[i];
096                                            Assert.a(!Float.isNaN(val));
097                                            featureSumsLocal[index] += val;
098                                            index = potEval.index[++i];
099                                    }
100                                    
101                                    for(int trans = 0; trans < modelInfo.nTransitions; ++trans) {
102                                            if(modelInfo.transitionTo[trans] == state && modelInfo.transitionFrom[trans] == previousState) {
103                                                    i = 0;
104                                                    int pot = modelInfo.nStates+trans;
105                                                    potEval = evals[pot];
106                                                    index = potEval.index[i];
107                                                    while(index != -1) {
108                                                            float val = potEval.value[i];
109                                                            Assert.a(!Float.isNaN(val));
110                                                            featureSumsLocal[index] += val;
111                                                            index = potEval.index[++i];
112                                                    }
113                                            }
114                                    }
115                                    previousState = state;
116                            }
117                    }
118            }
119    
120            public void evaluatePosition(int seqNum, int pos) {
121                    InputSequence seq = data.get(seqNum);
122                    for(int pot=0; pot<modelInfo.nPotentials; ++pot) {
123                            result.evals1 = evals[pot];
124                            result.position = 0;
125                            result.valid = true;
126                            if(pot < modelInfo.nStates) {
127                                    fm.evaluateNode(seq, pos, pot, result);
128                            }
129                            else {
130                                    if(pos == 0) {
131                                            result.evals1.index[0] = -1;
132                                            continue;
133                                    }
134                                    int trans = pot - modelInfo.nStates;
135                                    fm.evaluateEdge(seq, pos, modelInfo.transitionFrom[trans], modelInfo.transitionTo[trans], result);
136                            }
137                            if(result.isValid()) {
138                                    result.evals1.index[result.position] = -1;
139                            }
140                            else {
141                                    result.evals1.value[0] = Float.NaN;
142                                    result.evals1.index[0] = Short.MIN_VALUE;
143                                    result.evals1.index[1] = -1;
144                            }
145                    }
146            }
147    
148            public void evaluateSegmentsEndingAt(int seq, int pos) {
149                    throw new UnsupportedOperationException();
150            }
151    
152    }
153