001    package calhoun.analysis.crf.solver;
002    
003    import org.apache.commons.logging.Log;
004    import org.apache.commons.logging.LogFactory;
005    
006    import calhoun.analysis.crf.CRFInference;
007    import calhoun.analysis.crf.ModelManager;
008    import calhoun.analysis.crf.io.InputSequence;
009    import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
010    import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
011    import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
012    import calhoun.util.Assert;
013    import calhoun.util.ColtUtil;
014    
015    /** viterbi algorithm for semi-Markov CRFs.  Uses a {@link CacheProcessor} to handle the complexities of evaluation.
016     * This is similar to a markov viterbi search, except we have to search over all valid segments to fill in one entry 
017     * in the matrix instead of just looking at the last position.
018     */
019    public class SemiMarkovViterbi implements CRFInference {
020            private static final Log log = LogFactory.getLog(SemiMarkovViterbi.class);
021            boolean debug = log.isDebugEnabled();
022    
023            double[] lambda;
024    
025            private double[] bestScore;
026            private int[] backPointers;
027            int nStates;
028            
029            SolverSetup modelInfo;
030            CacheProcessor cacheProcessor;
031            FeatureEvaluation[] evals;
032            LengthFeatureEvaluation[][] lengthEvals;
033            boolean[] invalidTransitions;
034            int[] selfTransitions;
035    
036            RecyclingBuffer<double[]> stableStates;
037            double[] stableVector;
038            
039            /** gets the cache processor used to access feature evaluations
040             * @return the configured cache processor
041             */
042            public CacheProcessor getCacheProcessor() {
043                    return cacheProcessor;
044            }
045    
046            /** sets the cache processor used to access feature evaluations
047             * @param cacheProcessor the cache processor to use
048             */
049            public void setCacheProcessor(CacheProcessor cacheProcessor) {
050                    this.cacheProcessor = cacheProcessor;
051            }
052    
053            public InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
054                    this.lambda = lambda;
055    
056                    cacheProcessor.setInputData(fm, seq);
057                    modelInfo = cacheProcessor.getSolverSetup();
058                    nStates = modelInfo.nStates;
059                    Assert.a(modelInfo.maxStateLengths.length == nStates, "Maximum state lengths array was length ("+modelInfo.maxStateLengths.length+").  Must have one entry for each state "+modelInfo.nStates+")");
060                    evals = cacheProcessor.getFeatureEvaluations();
061                    lengthEvals = cacheProcessor.getLengthFeatureEvaluations();
062                    invalidTransitions = cacheProcessor.getInvalidTransitions();
063                    
064                    int len = seq.length();
065                    
066                    selfTransitions = new int[nStates];
067                    for(int i=0; i<nStates; ++i) {
068                            selfTransitions[i] = modelInfo.selfTransitions[i];
069                    }
070    
071                    // Circular buffers of the previous mi matrices.
072                    RecyclingBuffer<double[]> mis = new RecyclingBuffer<double[]>(new double[modelInfo.maxLookback][modelInfo.nTransitions]);
073                    double[] nextMi = new double[modelInfo.nTransitions];
074    
075                    // Circular buffers of the values of staying in a stable state for a given period of time
076                    stableStates = new RecyclingBuffer<double[]>(new double[modelInfo.maxLookback][nStates]);
077                    stableVector = new double[nStates];
078                    
079                    bestScore = new double[len*nStates];
080                    backPointers = new int[len*nStates];
081                    int[] backLengths = new int[len*nStates];
082                    
083                    for (int pos = 0; pos < len; pos++) {
084                            /* compute weighted features.  These are for transitions at the current base for non-length dependent features. */
085                                                    
086                            // Save the first Ri matrix, since it has the initial probabilies
087                            if(pos == 0) {
088                                    computeSparseMi(seq, pos, null, stableVector);
089                                    stableVector = stableStates.addFirst(stableVector);
090                            }
091                            else {
092                                    computeSparseMi(seq, pos, nextMi, null);
093    
094                                    // Use this transition matrix to update the stable vectors 
095                                    updateStableBuffer(nextMi);                             
096    
097                                    // Add this into the list of saved Mi matrices
098                                    nextMi = mis.addFirst(nextMi);                          
099                            }                               
100    
101                            double[] latestStable = stableStates.get(0);
102                            double[] latestMi = mis.get(0);
103    
104                            // Do states without lookback first.
105                            for(CacheProcessor.StatePotentials potentials : modelInfo.statesWithoutLookback) {
106                                    int state = potentials.state;
107                                    double max = Double.NEGATIVE_INFINITY;
108                                    int invalidIndex = pos*modelInfo.nPotentials;
109                                    int bestLookback = 0;
110                                    int bestPrevState = -2;
111                                    
112                                    // Check that the node is valid, otherwise put in a dummy entry. 
113                                    if(!invalidTransitions[invalidIndex + state]) {
114                                            if(pos == 0) {
115                                                    if(debug) 
116                                                            log.debug(String.format("Pos: %d State: %d %.2f", pos, state, latestStable[state]));
117                                                    // For the first position, we just store the starting potential
118                                                    max = latestStable[state];
119                                                    bestPrevState = -1;
120                                            }
121                                            else {
122                                                    for(byte edgePotential : potentials.potentials) {
123                                                            if(invalidTransitions[invalidIndex + state])
124                                                                    continue;
125    
126                                                            int transition = edgePotential - nStates;
127                                                            int prevState = modelInfo.transitionFrom[transition];
128                                                            double transitionCost = latestMi[transition];
129    
130                                                            if(Double.isInfinite(transitionCost)) {
131                                                                    // This transition was invalid at this location, ignore.
132                                                                    continue;
133                                                            }
134    
135                                                            double previous = bestScore[nStates *(pos-1) + prevState]; 
136                                                            double current = previous + transitionCost;
137                                                            if(debug)
138                                                                    log.debug(String.format("Pos: %d Trans: %d-%d %.2f (Prev: %.2f + Trans: %.2f)", pos, prevState, state, current, previous, transitionCost));
139                                                            if(current > max) {
140                                                                    max = current;
141                                                                    bestPrevState = prevState;
142                                                            }
143                                                    }
144                                            }
145                                    }
146                                    
147                                    // Fill in our best entry
148                                    int index= pos*nStates + state;
149                                    //log.debug(String.format("bestScore[%d] = %.2f", index, max));
150                                    bestScore[index] = max; 
151                                    backPointers[index] = bestPrevState;
152                                    backLengths[index] = bestLookback+1;
153                            }
154                            
155                            // Now repeat for states with lookback.
156                            for(int i=0; i<modelInfo.statesWithLookback.length; ++i) {
157                                    CacheProcessor.StatePotentials potentials = modelInfo.statesWithLookback[i];
158                                    LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
159    
160                                    int state = potentials.state;
161                                    double max = Double.NEGATIVE_INFINITY;
162                                    int bestLookback = -1;
163                                    int bestPrevState = -2;
164                                    
165                                    cacheProcessor.evaluateSegmentsEndingAt(0, pos);
166    
167                                    int lbIndex=0;
168                                    LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
169                                    int lookback = lengthEval.lookback;
170                                    while(lookback != -1) {
171                                            //log.info("Pos: "+pos+"\t State: "+modelInfo.statesWithLookback[i].state+"\t Lookback: "+lookback);
172    
173                                            double[] lookbackMi = mis.get(lookback); 
174                                            double[] lookbackStable = stableStates.get(lookback);
175    
176                                            // Handle evaluation of the node potentials
177                                            FeatureEvaluation nodeEvals = lengthEval.nodeEval;
178                                            short[] indices = nodeEvals.index;
179                                            float[] vals = nodeEvals.value;
180                                            int ix = 0;
181                                            short index = indices[ix];
182                                            double nodePotential = 0.0;
183                                            while(index >= 0) {
184                                                    nodePotential += vals[ix] * lambda[index];
185                                                    index = indices[++ix];
186                                            }
187                                            Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid");
188    
189                                            int prevPos = pos - lookback - 1;
190                                            if(prevPos < 0) {
191                                                    // Examine the case where this is the first segment in the sequence.
192                                                    // Score is the sum of non-length dependent fetaures plus length features.
193                                                    double current = latestStable[state] + nodePotential;
194                                                    if(debug)
195                                                            log.debug(String.format("Pos: %d Lb: %d State: %d %.2f (Stable: %.2f + Node: %.2f)", pos, lookback, state, current, latestStable[state], nodePotential));
196                                                    
197                                                    // Check if this is our best so far
198                                                    if(current > max) {
199                                                            max = current;
200                                                            bestLookback = lookback;
201                                                            bestPrevState = -1;
202                                                    }
203                                            }
204                                            else {
205                                                    // If this is not the first segment, we need to deal with edges coming into this segment
206                                                    FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
207                                                    int nEdges = potentials.potentials.length;
208                                                    for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
209                                                            int potential = potentials.potentials[edgeIx];
210                                                            int trans = potential - modelInfo.nStates;
211                                                            int fromNode = modelInfo.transitionFrom[trans];
212                                                            // Skip semi-Markov self transitions
213                                                            if(fromNode == state)
214                                                                    continue;
215    
216                                                            double edgeVal = 0.0;
217    
218                                                            if(edgeEvals == null) {
219                                                                    // If the cache processor does not have edge evaluations
220                                                                    // Just check if this transition is legal based on the invalid transitions matrix
221                                                                    if(invalidTransitions[(prevPos+1)*modelInfo.nPotentials + potential]) {
222                                                                            //log.info("Illegal transition: "+fromNode+"-"+toNode+" at pos: "+prevPos);
223                                                                            continue;
224                                                                    }
225                                                            }
226                                                            else {
227                                                                    // If the cache processor does have edge evaluations, then ignore the illegal transitions matrix
228                                                                    // and update the expval using the edge evaluations
229                                                                    FeatureEvaluation potEvals = edgeEvals[edgeIx];
230                                                                    indices = potEvals.index;
231                                                                    vals = potEvals.value;
232                                                                    ix = 0;
233                                                                    index = indices[i];
234                                                                    if(index == Short.MIN_VALUE) {
235                                                                            log.info("SHORT.MIN_VALUE");
236                                                                            continue;
237                                                                    }
238                                                                    while(index != -1) {
239                                                                            edgeVal += vals[ix] * lambda[index];
240                                                                            index = indices[++ix];
241                                                                    }
242                                                            }
243                                                            
244                                                            double prevBest = bestScore[nStates *(pos-(lookback+1)) + fromNode];
245                                                            double stable = latestStable[state]-lookbackStable[state];
246                                                            
247                                                            // Renormalize and update the exp value.
248                                                            double current = prevBest + nodePotential + edgeVal + stable + lookbackMi[trans];
249                                                            if(debug)
250                                                                    log.debug(String.format("Pos: %d Lb: %d Trans: %d-%d %.4f (Prev: %.4f + Stable: %.4f + Trans: %.4f + Node: %.4f + Edge: %.4f)", pos, lookback, fromNode, state, 
251                                                                                    current, prevBest, stable, lookbackMi[trans], nodePotential, edgeVal));
252                                                            
253                                                            if(current == Double.NEGATIVE_INFINITY) {
254                                                                    // The state was not valid at the previous position
255                                                                    continue;
256                                                            }
257    
258                                                            if(current > max) {
259                                                                    //log.info("Selected Pos: "+pos+" Edge: "+prevState+"-"+state+" "+current+" vs. "+max+" Prev. "+bestScore[nStates *(pos-(lookback+1)) + prevState]+" Trans. "+transitionCost);
260                                                                    max = current;
261                                                                    bestLookback = lookback;
262                                                                    bestPrevState = fromNode;
263                                                            }
264                                                    }
265                                            }
266                                            ++lbIndex;
267                                            lengthEval = lookbacksForState[lbIndex];
268                                            lookback = lengthEval.lookback;
269                                    }
270                                    
271                                    // Fill in our best entry
272                                    //log.info(String.format("Pos: %d State: %d BestScore: "+max+" BackPointer: %d", pos, state, bestPrevState));
273                                    int index= pos*nStates + state; 
274                                    bestScore[index] = max; 
275                                    backPointers[index] = bestPrevState;
276                                    backLengths[index] = bestLookback+1;
277                            }
278                    }
279                    //printForwardPass(bestScore, backPointers, backLengths, nStates, seq.length());
280                    //log.info(ColtUtil.format(bestScore));
281                    //log.info(StringUtils.join(backPointers," "));
282                    //log.info(backLengths);
283                                    
284                    // Now that we have the matrix, trace back to get the best path.
285                    int[] ret = new int[len];
286                    int pos = len-1;
287                    int state = ColtUtil.maxInColumn(bestScore, nStates, len-1);
288                    Assert.a(state != -2, "No valid paths");
289    
290                    while(pos >= 0) {
291                            int stateLen = backLengths[pos*nStates + state];
292                            int prevState = backPointers[pos*nStates + state];
293                            //log.info(String.format("State: %d, Len: %d, Ends At: %d", state, stateLen, pos));
294                            for(int i = 0; i < stateLen; ++i) {
295                                    ret[pos] = state;
296                                    pos--;
297                            }
298                            state = prevState; 
299                    }               
300                    Assert.a(pos == -1);
301                    InferenceResult inferenceResult = new InferenceResult();
302                    inferenceResult.hiddenStates = ret;
303                    inferenceResult.bestScores = new double[nStates];
304                    System.arraycopy(bestScore, nStates*(len-1), inferenceResult.bestScores, 0, nStates );
305                    return inferenceResult;
306            }
307    
308            /** Updates the stableStates list based on the current position.  The stableStates list contains the 
309             cost of the non-length dependent features for the duration of the lookback. 
310             */
311            void updateStableBuffer(double[] nextMi) {
312                    // Recycle matrices at the end
313                    double[] prevState = stableStates.get(0);
314                    for(int ix = 0; ix < nStates; ++ix) {
315                            if(modelInfo.maxStateLengths[ix] > 1) {
316                                    int trans = selfTransitions[ix];
317                                    if(trans != -1) {
318                                            if (Double.isInfinite(nextMi[trans]))
319                                                    stableVector[ix] = prevState[ix];
320                                            else
321                                                    stableVector[ix] = prevState[ix] + nextMi[trans];                               
322                                            //log.info(String.format("stableState[%d] = %f = %f + %f", ix, stableState[ix], prevState[ix], nextMi[trans]));
323                                    }
324                            }
325                    }
326    
327                    // Now add this to the beginning.
328                    stableVector = stableStates.addFirst(stableVector);
329            }
330    
331            void computeSparseMi(InputSequence seq, int pos, double[] mi, double[] ri) {
332                    cacheProcessor.evaluatePosition(0, pos);
333                    double nodeVal = Double.NaN;
334                    int invalidIndex = pos*modelInfo.nPotentials;
335                    for(short potential : modelInfo.orderedPotentials) {
336                            boolean invalid = invalidTransitions[invalidIndex + potential];
337                            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
338    
339                            // Add up all features for this potential.
340                            FeatureEvaluation potEvals = evals[potential];
341                            short[] indices = potEvals.index;
342                            float[] vals = potEvals.value;
343                            int i = 0;
344                            short index = indices[i];
345                            while(index != -1) {
346                                    // An invalid potential is indicated by a feature value of Short.MAX_VALUE
347                                    features += (index == Short.MIN_VALUE) ? Double.NEGATIVE_INFINITY : vals[i]*lambda[index]; 
348                                    index = indices[++i];
349                            }
350                            
351                            if(potential < modelInfo.nStates) {
352                                    nodeVal = features;
353                                    if(ri != null) {
354                                            ri[potential] = nodeVal;
355                                    }
356                            }
357                            else {
358                                    //log.debug(String.format("Mi[%d, %d] = %f, adding in %f to get %f", feat.yprev(), feat.y(), val, feat.value(), val*exp(feat.value()*param[feat.index()])));
359                                    int transition = potential - modelInfo.nStates;
360                                    double val = features + nodeVal;
361                                    if(mi != null)
362                                            mi[transition] = val;
363                            }
364                    }               
365            }
366    }