001 package calhoun.analysis.crf.solver;
002
003 import org.apache.commons.logging.Log;
004 import org.apache.commons.logging.LogFactory;
005
006 import calhoun.analysis.crf.CRFInference;
007 import calhoun.analysis.crf.ModelManager;
008 import calhoun.analysis.crf.io.InputSequence;
009 import calhoun.analysis.crf.solver.check.FeatureCalculator;
010 import calhoun.analysis.crf.solver.check.TransitionInfo;
011 import calhoun.util.ColtUtil;
012 import calhoun.util.DenseBooleanMatrix2D;
013
014 public class Viterbi implements CRFInference {
015 Log log = LogFactory.getLog(Viterbi.class);
016 boolean debug = false; //log.isInfoEnabled();
017
018 private double[] bestScore;
019 private int[] backPointers;
020 private boolean allPaths;
021
022 /** true if all paths (valid and invalid) are to be evaluated during the viterbi search. Defaults to false.
023 * @return true if all paths are to be examined
024 */
025 public boolean isAllPaths() {
026 return allPaths;
027 }
028
029 /** sets whether all paths (valid and invalid) are to be evaluated during the viterbi search. Defaults to false.
030 * @param allPaths allPath true if all paths are to be examined
031 */
032 public void setAllPaths(boolean allPaths) {
033 this.allPaths = allPaths;
034 }
035
036 public InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
037 int numStates = fm.getNumStates();
038 int len = seq.length();
039
040 DenseBooleanMatrix2D transitions = fm.getLegalTransitions();
041 if(allPaths || transitions == null) {
042 transitions = new DenseBooleanMatrix2D(numStates, numStates);
043 transitions.assign(true);
044 }
045 TransitionInfo t = new TransitionInfo(fm, false);
046 FeatureCalculator calc = new FeatureCalculator(fm, lambda, t);
047
048 bestScore = new double[numStates * len];
049 backPointers = new int[numStates * len];
050 for (int pos = 0; pos < len; pos++) {
051 int posIndex = pos*numStates;
052 for(int state = 0; state<numStates; ++state) {
053 int index = posIndex + state;
054 double nodeVal = calc.calcNodeValue(seq, pos, state);
055 if(pos == 0) {
056 //log.debug(String.format("Pos: %d State: %d Total: %.2f", pos, state, nodeVal));
057 bestScore[index] = nodeVal;
058 }
059 else {
060 double max = Double.NEGATIVE_INFINITY;
061 int prevState = -1;
062 for(int k=0; k<numStates; ++k) {
063 if(!transitions.getQuick(k, state)) {
064 continue;
065 }
066 double previous = bestScore[posIndex-numStates+k];
067 double edge = calc.calcEdgeValue(seq, pos, k, state);
068 double current = previous + edge + nodeVal;
069 if(debug)
070 log.debug(String.format("Pos: %d Trans: %d-%d %.2f (Prev: %.2f + Edge: %.2f + Node: %.2f)", pos, k, state, current, previous, edge, nodeVal));
071 if(current > max) {
072 //log.info("Selected Pos: "+pos+" Edge: "+k+"-"+state+" "+current+" vs. "+max+" Prev. "+previous+" Node "+nodeVal + " Calc "+calc.calcEdgeValue(seq, pos, k, state));
073 max = current;
074 prevState = k;
075 }
076 }
077 //Assert.a(prevState != -1, "No legal transitions found to state ", state, ". Pos ", pos);
078 // With constraints it is legal to have states that are disallowed at a given position.
079 // Just put a -infinity in there for the score and a -1 for the backpointer.
080 bestScore[index] = max;
081 backPointers[index] = prevState;
082 //log.info(String.format("Pos: %d Prev: %d State: %d Total: "+max, pos, prevState, state));
083 }
084 }
085 }
086
087 //log.info(ColtUtil.format(bestScore));
088 //log.info(backPointers);
089
090 int[] ret = new int[len];
091 ret[len-1] = ColtUtil.maxInColumn(bestScore, numStates, len-1);
092 for(int i = len-1; i>0; --i) {
093 ret[i-1] = backPointers[numStates*i + ret[i]];
094 }
095 InferenceResult inferenceResult = new InferenceResult();
096 inferenceResult.hiddenStates = ret;
097 inferenceResult.bestScores = new double[numStates];
098 System.arraycopy(bestScore, numStates*(len-1), inferenceResult.bestScores, 0, numStates );
099 return inferenceResult;
100 }
101
102 public int[] getBackPointers() {
103 return backPointers;
104 }
105
106 public void setBackPointers(int[] backPointers) {
107 this.backPointers = backPointers;
108 }
109
110 public double[] getBestScore() {
111 return bestScore;
112 }
113
114 public void setBestScore(double[] bestScore) {
115 this.bestScore = bestScore;
116 }
117
118 }