001    package calhoun.analysis.crf.solver;
002    
003    import java.util.ArrayList;
004    import java.util.Arrays;
005    import java.util.Iterator;
006    import java.util.LinkedList;
007    import java.util.List;
008    
009    import org.apache.commons.logging.Log;
010    import org.apache.commons.logging.LogFactory;
011    
012    import calhoun.analysis.crf.CRFInference;
013    import calhoun.analysis.crf.ModelManager;
014    import calhoun.analysis.crf.SemiMarkovSetup;
015    import calhoun.analysis.crf.io.InputSequence;
016    import calhoun.analysis.crf.solver.check.FeatureCalculator;
017    import calhoun.analysis.crf.solver.check.TransitionInfo;
018    import calhoun.util.Assert;
019    import calhoun.util.ColtUtil;
020    import calhoun.util.DenseIntMatrix2D;
021    
022    /** viterbi algorithm for semi-Markov CRFs.  Does not perform any caching, and <b>does not correctly handle
023     * NodeBoundary features</b>.  Use only if you have no NodeBoundary features and are running into memory limitations.
024     * Requires that a {@link SemiMarkovSetup} be configured.
025     */
026    public class SemiMarkovViterbiNoCache implements CRFInference {
027            private static final Log log = LogFactory.getLog(SemiMarkovViterbi.class);
028            boolean debug = log.isDebugEnabled();
029    
030            private boolean allPaths;
031            private double[] bestScore;
032            private int[] backPointers;
033            DenseIntMatrix2D backLengths;
034            short[] maxStateLengths;
035            boolean ignoreSemiMarkovSelfTransitions;
036            
037            int maxLookback = 1;
038            int nStates;
039    
040            TransitionInfo transitions;
041            int[] selfTransitions;
042            
043            class LengthTransitionInfo extends TransitionInfo {
044                    short[] localMaxStateLengths;
045                    LengthTransitionInfo(ModelManager fm, boolean allPaths, short[] maxStateLengths, boolean ignoreSemiMarkovSelfTransitions) {
046                            this.localMaxStateLengths = maxStateLengths;
047                            this.ignoreSemiMarkovSelf = ignoreSemiMarkovSelfTransitions;
048                            initTrans(fm, allPaths);
049                    }
050                    @Override
051                    protected boolean allowSelf(int state) {
052                            return localMaxStateLengths[state] > 1;
053                    }
054            }
055            
056            /** true if all paths (valid and invalid) are to be evaluated during the viterbi search.  Defaults to false.
057             * @return true if all paths are to be examined
058             */
059            public boolean isAllPaths() {
060                    return allPaths;
061            }
062    
063            /** sets whether all paths (valid and invalid) are to be evaluated during the viterbi search.  Defaults to false.
064             * @param allPaths allPath true if all paths are to be examined
065             */
066            public void setAllPaths(boolean allPaths) {
067                    this.allPaths = allPaths;
068            }
069    
070            /** sets up the semiMarkov parameters, such as maximum segment lengths.  These should be the same as the model was trained with.
071             * @param setup the parameters to use during the semi-Markov viterbi search
072             */
073            public void setSemiMarkovSetup(SemiMarkovSetup setup) {
074                    maxStateLengths = setup.getMaxLengths();
075                    ignoreSemiMarkovSelfTransitions = setup.isIgnoreSemiMarkovSelfTransitions();
076            }
077    
078            public InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {          
079                    nStates = fm.getNumStates();
080                    int len = seq.length();
081                    if(maxStateLengths == null) {
082                            log.info("No state lengths set - standard viterbi search.");
083                            maxStateLengths = new short[nStates];
084                            Arrays.fill(maxStateLengths, (short) 1);
085                    }
086                    
087                    // Determine the longest lookback we might have to do.
088                    for(int i : maxStateLengths) {
089                            maxLookback = Math.max(i, maxLookback);
090                    }
091    
092                    transitions = new LengthTransitionInfo(fm, allPaths, maxStateLengths, ignoreSemiMarkovSelfTransitions);
093                    int[][] statePotentials = getStatePotentials(transitions);
094                    selfTransitions = new int[nStates];
095                    for(int i=0; i<nStates; ++i) {
096                            selfTransitions[i] = transitions.transitionIndex.getQuick(i, i);
097                    }
098                    
099                    FeatureCalculator calc = new FeatureCalculator(fm, lambda, transitions);
100    
101                    // Buffer of the previous mi matrices.
102                    LinkedList<double[]> mis = new LinkedList<double[]>();
103                    // Buffer of the values of staying in a stable state for a given period of time
104                    LinkedList<double[]> stableStates = new LinkedList<double[]>();
105                    
106                    // The next MI matrix we plan to use (may be recycled from the end of the list
107                    double[] nextMi = new double[transitions.nTransitions];
108    
109                    double[] Ri = new double[nStates];
110                    
111                    bestScore = new double[seq.length()*nStates];
112                    backPointers = new int[seq.length()*nStates];
113                    int[] localBackLengths = new int[seq.length()*nStates];
114                    
115                    for (int pos = 0; pos < len; pos++) {
116                            /* compute weighted features.  These are for transitions at the current base for non-length dependent features. */
117                            calc.computeSparseMi(seq, pos, nextMi, Ri);
118                                                    
119                            // Save the first Ri matrix, since it has the initial probabilies
120                            if(pos == 0) {
121                                    stableStates.add(Ri);
122                                    Ri = new double[nStates];
123                            }
124                            else {
125                                    // Use this transition matrix to update the stable vectors 
126                                    updateStableBuffer(stableStates, nextMi);                               
127    
128                                    // Add this into the list of saved Mi matrices
129                                    nextMi = updateMiBuffer(mis, nextMi);                           
130                            }                               
131    
132                            double[] latestStable = stableStates.getFirst();
133                            // Now investigate each state to determine the optimal path to this point.
134                            for(int state = 0; state<nStates; ++state) {
135                                    int[] transitionPotentials = statePotentials[state];
136                                    
137                                    // For each state, determine how far we should look back 
138                                    int lookbackSize = maxStateLengths[state];
139                                    
140                                    // Examine each length, starting with 1 and increasing.
141                                    double max = Double.NEGATIVE_INFINITY;
142                                    int bestLookback = -1;
143                                    int bestPrevState = -2;
144                                    
145                                    // As we go through the lookbacks for the previous state, go through the stored Mi matrics and stable positions
146                                    Iterator<double[]> miIter = mis.iterator();
147                                    Iterator<double[]> stableIter = stableStates.iterator();
148                                    for(int lookback = 0; lookback < lookbackSize; ++lookback) {
149                                            // Find the starting position we are evaluating for this state
150                                            int startPos = pos - lookback;
151                                            //Assert.a(startPos >= 0);
152                                            if(startPos == 0) {
153                                                    // In this case we have the extra check that we should have the same number of previous Mis as our length
154                                                    Assert.a(mis.size() == lookback, "More Mi matrices in history than there are previous positions in the sequence.");
155                                                    
156                                                    // Examine the case where this is the first position in the sequence.
157                                                    // Score is the initial score plus the score of staying in that feature for this long.
158                                                    double current = latestStable[state];
159                                                    // Now add length dependent features
160                                                    calc.result.evaluateNodeLength(seq, pos, lookback+1, state);
161                                                    double lengthCost = calc.calcRet(false);
162                                                    current += lengthCost;
163                                                    
164                                                    // Check if this is our best so far
165                                                    if(current > max) {
166                                                            max = current;
167                                                            bestLookback = lookback;
168                                                            bestPrevState = -1;
169                                                    }
170                                                    // Don't bother looking back any farther
171                                                    break;
172                                            }
173                                            else {
174                                                    double nodeLengthCost = Double.NaN;
175                                                    // This state starts after the beginning of the sequences.
176                                                    // Need to check all legal transitions from previous states.
177                                                    // This includes transitions to self although I'm not sure I like that for states with length models.
178                                                    double[] lookbackMi = miIter.next(); 
179                                                    double[] lookbackStable = stableIter.next();
180                                                    for(int transition : transitionPotentials) {
181                                                            int prevState = transitions.transitionFrom[transition];
182                                                            if(prevState == state && lookbackSize > 1) {
183                                                                    // For states with explicit length distributions, ignore transition to self.
184                                                                    continue;
185                                                            }
186                                                            double transitionCost = lookbackMi[transition];
187                                                            if(Double.isInfinite(transitionCost)) {
188                                                                    // This transition was invalid at this location, ignore.
189                                                                    continue;
190                                                            }
191                                                            if(Double.isNaN(nodeLengthCost)) {
192                                                                    // Only calculate this if you know you need it.
193                                                                    nodeLengthCost = calc.calcNodeLengthValue(seq, pos, lookback+1, state);
194                                                            }
195                                                            double current = bestScore[nStates *(pos-(lookback+1)) + prevState];
196                                                            if(current == Double.NEGATIVE_INFINITY) {
197                                                                    // The state was not valid at the previous position
198                                                                    continue;
199                                                            }
200                                                            double stable = latestStable[state]-lookbackStable[state];
201                                                            calc.result.evaluateEdgeLength(seq, pos, lookback+1, prevState, state);
202                                                            double lengthCost = nodeLengthCost + calc.calcRet(false);
203                                                            
204                                                            current += transitionCost + stable + lengthCost;
205                                                            if(current > max) {
206                                                                    //log.info("Selected Pos: "+pos+" Edge: "+prevState+"-"+state+" "+current+" vs. "+max+" Prev. "+bestScore[nStates *(pos-(lookback+1)) + prevState]+" Trans. "+transitionCost);
207                                                                    max = current;
208                                                                    bestLookback = lookback;
209                                                                    bestPrevState = prevState;
210                                                            }
211                                                    }
212                                            }
213                                    }
214                                    
215                                    // Fill in our best entry
216                                    //log.info(String.format("Pos: %d State: %d BestScore: "+max+" BackPointer: %d", pos, state, bestPrevState));
217                                    int index= pos*nStates + state; 
218                                    bestScore[index] = max; 
219                                    backPointers[index] = bestPrevState;
220                                    localBackLengths[index] = bestLookback+1;
221                            }
222                    }
223                    
224                    //printForwardPass(bestScore, backPointers, backLengths, nStates, seq.length());
225                    //log.info(ColtUtil.format(bestScore));
226                    //log.info(StringUtils.join(backPointers," "));
227                    //log.info(backLengths);
228                                    
229                    // Now that we have the matrix, trace back to get the best path.
230                    int[] ret = new int[len];
231                    int pos = len-1;
232                    int state = ColtUtil.maxInColumn(bestScore, nStates, len-1);
233                    Assert.a(state != -2, "No valid paths");
234    
235                    while(pos >= 0) {
236                            int stateLen = localBackLengths[pos*nStates + state];
237                            int prevState = backPointers[pos*nStates + state];
238                            //log.info(String.format("State: %d, Len: %d, Ends At: %d", state, stateLen, pos));
239                            for(int i = 0; i < stateLen; ++i) {
240                                    ret[pos] = state;
241                                    pos--;
242                            }
243                            state = prevState; 
244                    }               
245                    Assert.a(pos == -1);
246                    InferenceResult inferenceResult = new InferenceResult();
247                    inferenceResult.hiddenStates = ret;
248                    inferenceResult.bestScores = new double[nStates];
249                    System.arraycopy(bestScore, nStates*(len-1), inferenceResult.bestScores, 0, nStates );
250                    return inferenceResult;
251            }
252    
253            /** Updates the stableStates list based on the current position.  The stableStates list contains the 
254             cost of the non-length dependent features for the duration of the lookback. 
255             */
256            void updateStableBuffer(LinkedList<double[]> stableStates, double[] nextMi) {
257                    // Recycle matrices at the end
258                    double[] stableState;
259                    if(stableStates.size() > maxLookback) {
260                            stableState = stableStates.removeLast();
261                    }
262                    else {
263                            stableState = new double[nStates];
264                    }
265                    
266                    double[] prevState = stableStates.getFirst();
267                    for(int ix = 0; ix < nStates; ++ix) {
268                            if(maxStateLengths[ix] > 1) {
269                                    int trans = selfTransitions[ix];
270                                    if(trans != -1) {
271                                            if (Double.isInfinite(nextMi[trans]))
272                                                    stableState[ix] = prevState[ix];
273                                            else
274                                                    stableState[ix] = prevState[ix] + nextMi[trans];                                
275                                            //log.info(String.format("stableState[%d] = %f = %f + %f", ix, stableState[ix], prevState[ix], nextMi[trans]));
276                                    }
277                            }
278                    }
279    
280                    // Now add this to the beginning.
281                    stableStates.addFirst(stableState);
282            }
283    
284            double[] updateMiBuffer(LinkedList<double[]> mis, double[] nextMi) {
285                    mis.addFirst(nextMi);
286                    if(mis.size() > maxLookback) {
287                            nextMi = mis.removeLast();
288                    }
289                    else {
290                            nextMi = new double[transitions.nTransitions];
291                    }
292                    return nextMi;
293            }
294            
295            // Used for debugging only.
296            /*
297             * private void printForwardPass(double[] bestScore, int[] backPointers, int[] backLengths, int numStates, int seqLen)
298            {
299                    try {
300                            //Writer fout = new BufferedWriter(new FileWriter("test/working/crf_forwardPass.txt"));
301                            Writer fout = new BufferedWriter(new FileWriter("crf_forwardPass.txt"));
302                            
303                            int i, pos, st;
304                            //double[][] bestScores = new double[seqLen][numStates];
305                            //int[][] bestIndices  = new int[seqLen][numStates];
306                            
307                            fout.write("Viterbi Map from Forward Pass\n");
308                            
309                            fout.write("\n");
310    
311                            for (pos=0; pos<seqLen; pos++)
312                            {
313                                    for (st=0; st<numStates; st++)
314                                    {
315                                            fout.write( (pos*numStates + st) + "\t" + pos + "\t" + st + "\t");
316                                            if (bestScore[pos*numStates + st] == 0)
317                                                    fout.write(String.format("%1$11.0f", bestScore[pos*numStates + st]) + "  \t" + "." + "\n");
318                                            else if (bestScore[pos*numStates + st] == Double.NEGATIVE_INFINITY)
319                                                    fout.write(String.format("%1$11.0f", 0.0f) + "  \t" +  "." + "\n");
320                                            else {
321                                                    //fout.write(String.format("%1$11.2f", bestScore[pos*numStates + st]) + "  \t" + "--" + "\n");
322    
323                                                    if (backPointers[pos*numStates + st] == -1)
324                                                            fout.write(String.format("%1$11.2f", bestScore[pos*numStates + st]) + "  \t" + "--" + "\t" +  backPointers[pos*numStates + st] + "\n");
325                                                    else 
326                                                            fout.write(String.format("%1$11.2f", bestScore[pos*numStates + st]) + "  \t" + (pos - backLengths[pos*numStates + st]) + "\t" +  backPointers[pos*numStates + st] + "\n");
327                                    
328                                            }
329                                    }
330                                    //fout.write("\n");
331                            }
332                            fout.close();
333                    
334                    } catch (IOException e) {
335                            throw new RuntimeException("Error writing alpha pass");
336                    }       
337            }
338            */
339    
340            int[][] getStatePotentials(TransitionInfo transitions) {
341                    int[][] statePotentials = new int[nStates][];
342                    int currentState = -1;
343                    List<Integer> currentList = null;
344                    for(int potential : transitions.orderedPotentials) {
345                            if(potential < nStates) {
346                                    if(currentState != -1) {
347                                            statePotentials[currentState] = toIntArray(currentList);
348                                    }
349                                    currentState = potential;
350                                    currentList = new ArrayList<Integer>();
351                            }
352                            else {
353                                    currentList.add(potential-nStates);
354                            }
355                    }
356                    statePotentials[currentState] = toIntArray(currentList);
357                    return statePotentials;
358            }
359            
360            int[] toIntArray(List<Integer> list) {
361                    int[] ret = new int[list.size()];
362                    for(int i=0; i<ret.length; ++i) {
363                            ret[i] = list.get(i);
364                    }
365                    return ret;
366            }
367    }