001    package calhoun.analysis.crf.solver;
002    
003    import org.apache.commons.logging.Log;
004    import org.apache.commons.logging.LogFactory;
005    
006    import calhoun.analysis.crf.CRFInference;
007    import calhoun.analysis.crf.ModelManager;
008    import calhoun.analysis.crf.io.InputSequence;
009    import calhoun.analysis.crf.solver.check.FeatureCalculator;
010    import calhoun.analysis.crf.solver.check.TransitionInfo;
011    import calhoun.util.ColtUtil;
012    import calhoun.util.DenseBooleanMatrix2D;
013    
014    public class Viterbi implements CRFInference {
015            Log log = LogFactory.getLog(Viterbi.class);
016            boolean debug = false; //log.isInfoEnabled();
017            
018            private double[] bestScore;
019            private int[] backPointers;
020            private boolean allPaths;
021    
022            /** true if all paths (valid and invalid) are to be evaluated during the viterbi search.  Defaults to false.
023             * @return true if all paths are to be examined
024             */
025            public boolean isAllPaths() {
026                    return allPaths;
027            }
028    
029            /** sets whether all paths (valid and invalid) are to be evaluated during the viterbi search.  Defaults to false.
030             * @param allPaths allPath true if all paths are to be examined
031             */
032            public void setAllPaths(boolean allPaths) {
033                    this.allPaths = allPaths;
034            }
035    
036            public InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
037                    int numStates = fm.getNumStates();
038                    int len = seq.length();
039                    
040                    DenseBooleanMatrix2D transitions = fm.getLegalTransitions();
041                    if(allPaths || transitions == null) {
042                            transitions = new DenseBooleanMatrix2D(numStates, numStates);
043                            transitions.assign(true);
044                    }
045                    TransitionInfo t = new TransitionInfo(fm, false);
046                    FeatureCalculator calc = new FeatureCalculator(fm, lambda, t);
047    
048                    bestScore = new double[numStates * len];
049                    backPointers = new int[numStates * len];
050                    for (int pos = 0; pos < len; pos++) {
051                            int posIndex = pos*numStates;
052                            for(int state = 0; state<numStates; ++state) {
053                                    int index = posIndex + state;
054                                    double nodeVal = calc.calcNodeValue(seq, pos, state);
055                                    if(pos == 0) {
056                                            //log.debug(String.format("Pos: %d State: %d Total: %.2f", pos, state, nodeVal));
057                                            bestScore[index] = nodeVal;
058                                    }
059                                    else {
060                                            double max = Double.NEGATIVE_INFINITY;
061                                            int prevState = -1;
062                                            for(int k=0; k<numStates; ++k) {
063                                                    if(!transitions.getQuick(k, state)) {
064                                                            continue;
065                                                    }
066                                                    double previous = bestScore[posIndex-numStates+k];
067                                                    double edge = calc.calcEdgeValue(seq, pos, k, state);
068                                                    double current = previous + edge + nodeVal;
069                                                    if(debug)
070                                                            log.debug(String.format("Pos: %d Trans: %d-%d %.2f (Prev: %.2f + Edge: %.2f + Node: %.2f)", pos, k, state, current, previous, edge, nodeVal));
071                                                    if(current > max) {
072                                                            //log.info("Selected Pos: "+pos+" Edge: "+k+"-"+state+" "+current+" vs. "+max+" Prev. "+previous+" Node "+nodeVal + " Calc "+calc.calcEdgeValue(seq, pos, k, state));
073                                                            max = current;
074                                                            prevState = k;
075                                                    }
076                                            }
077                                            //Assert.a(prevState != -1, "No legal transitions found to state ", state, ".  Pos ", pos);
078                                            // With constraints it is legal to have states that are disallowed at a given position.
079                                            // Just put a -infinity in there for the score and a -1 for the backpointer.
080                                            bestScore[index] = max;
081                                            backPointers[index] = prevState;
082                                            //log.info(String.format("Pos: %d Prev: %d State: %d Total: "+max, pos, prevState, state));
083                                    }
084                            }
085                    }
086                    
087                    //log.info(ColtUtil.format(bestScore));
088                    //log.info(backPointers);
089                    
090                    int[] ret = new int[len];
091                    ret[len-1] = ColtUtil.maxInColumn(bestScore, numStates, len-1);
092                    for(int i = len-1; i>0; --i) {
093                            ret[i-1] = backPointers[numStates*i + ret[i]];
094                    }
095                    InferenceResult inferenceResult = new InferenceResult();
096                    inferenceResult.hiddenStates = ret;
097                    inferenceResult.bestScores = new double[numStates];
098                    System.arraycopy(bestScore, numStates*(len-1), inferenceResult.bestScores, 0, numStates );
099                    return inferenceResult;
100            }
101    
102            public int[] getBackPointers() {
103                    return backPointers;
104            }
105    
106            public void setBackPointers(int[] backPointers) {
107                    this.backPointers = backPointers;
108            }
109    
110            public double[] getBestScore() {
111                    return bestScore;
112            }
113    
114            public void setBestScore(double[] bestScore) {
115                    this.bestScore = bestScore;
116            }
117    
118    }