001 package calhoun.analysis.crf.solver;
002
003 import java.util.ArrayList;
004 import java.util.Arrays;
005 import java.util.Iterator;
006 import java.util.LinkedList;
007 import java.util.List;
008
009 import org.apache.commons.logging.Log;
010 import org.apache.commons.logging.LogFactory;
011
012 import calhoun.analysis.crf.CRFInference;
013 import calhoun.analysis.crf.ModelManager;
014 import calhoun.analysis.crf.SemiMarkovSetup;
015 import calhoun.analysis.crf.io.InputSequence;
016 import calhoun.analysis.crf.solver.check.FeatureCalculator;
017 import calhoun.analysis.crf.solver.check.TransitionInfo;
018 import calhoun.util.Assert;
019 import calhoun.util.ColtUtil;
020 import calhoun.util.DenseIntMatrix2D;
021
022 /** viterbi algorithm for semi-Markov CRFs. Does not perform any caching, and <b>does not correctly handle
023 * NodeBoundary features</b>. Use only if you have no NodeBoundary features and are running into memory limitations.
024 * Requires that a {@link SemiMarkovSetup} be configured.
025 */
026 public class SemiMarkovViterbiNoCache implements CRFInference {
027 private static final Log log = LogFactory.getLog(SemiMarkovViterbi.class);
028 boolean debug = log.isDebugEnabled();
029
030 private boolean allPaths;
031 private double[] bestScore;
032 private int[] backPointers;
033 DenseIntMatrix2D backLengths;
034 short[] maxStateLengths;
035 boolean ignoreSemiMarkovSelfTransitions;
036
037 int maxLookback = 1;
038 int nStates;
039
040 TransitionInfo transitions;
041 int[] selfTransitions;
042
043 class LengthTransitionInfo extends TransitionInfo {
044 short[] localMaxStateLengths;
045 LengthTransitionInfo(ModelManager fm, boolean allPaths, short[] maxStateLengths, boolean ignoreSemiMarkovSelfTransitions) {
046 this.localMaxStateLengths = maxStateLengths;
047 this.ignoreSemiMarkovSelf = ignoreSemiMarkovSelfTransitions;
048 initTrans(fm, allPaths);
049 }
050 @Override
051 protected boolean allowSelf(int state) {
052 return localMaxStateLengths[state] > 1;
053 }
054 }
055
056 /** true if all paths (valid and invalid) are to be evaluated during the viterbi search. Defaults to false.
057 * @return true if all paths are to be examined
058 */
059 public boolean isAllPaths() {
060 return allPaths;
061 }
062
063 /** sets whether all paths (valid and invalid) are to be evaluated during the viterbi search. Defaults to false.
064 * @param allPaths allPath true if all paths are to be examined
065 */
066 public void setAllPaths(boolean allPaths) {
067 this.allPaths = allPaths;
068 }
069
070 /** sets up the semiMarkov parameters, such as maximum segment lengths. These should be the same as the model was trained with.
071 * @param setup the parameters to use during the semi-Markov viterbi search
072 */
073 public void setSemiMarkovSetup(SemiMarkovSetup setup) {
074 maxStateLengths = setup.getMaxLengths();
075 ignoreSemiMarkovSelfTransitions = setup.isIgnoreSemiMarkovSelfTransitions();
076 }
077
078 public InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
079 nStates = fm.getNumStates();
080 int len = seq.length();
081 if(maxStateLengths == null) {
082 log.info("No state lengths set - standard viterbi search.");
083 maxStateLengths = new short[nStates];
084 Arrays.fill(maxStateLengths, (short) 1);
085 }
086
087 // Determine the longest lookback we might have to do.
088 for(int i : maxStateLengths) {
089 maxLookback = Math.max(i, maxLookback);
090 }
091
092 transitions = new LengthTransitionInfo(fm, allPaths, maxStateLengths, ignoreSemiMarkovSelfTransitions);
093 int[][] statePotentials = getStatePotentials(transitions);
094 selfTransitions = new int[nStates];
095 for(int i=0; i<nStates; ++i) {
096 selfTransitions[i] = transitions.transitionIndex.getQuick(i, i);
097 }
098
099 FeatureCalculator calc = new FeatureCalculator(fm, lambda, transitions);
100
101 // Buffer of the previous mi matrices.
102 LinkedList<double[]> mis = new LinkedList<double[]>();
103 // Buffer of the values of staying in a stable state for a given period of time
104 LinkedList<double[]> stableStates = new LinkedList<double[]>();
105
106 // The next MI matrix we plan to use (may be recycled from the end of the list
107 double[] nextMi = new double[transitions.nTransitions];
108
109 double[] Ri = new double[nStates];
110
111 bestScore = new double[seq.length()*nStates];
112 backPointers = new int[seq.length()*nStates];
113 int[] localBackLengths = new int[seq.length()*nStates];
114
115 for (int pos = 0; pos < len; pos++) {
116 /* compute weighted features. These are for transitions at the current base for non-length dependent features. */
117 calc.computeSparseMi(seq, pos, nextMi, Ri);
118
119 // Save the first Ri matrix, since it has the initial probabilies
120 if(pos == 0) {
121 stableStates.add(Ri);
122 Ri = new double[nStates];
123 }
124 else {
125 // Use this transition matrix to update the stable vectors
126 updateStableBuffer(stableStates, nextMi);
127
128 // Add this into the list of saved Mi matrices
129 nextMi = updateMiBuffer(mis, nextMi);
130 }
131
132 double[] latestStable = stableStates.getFirst();
133 // Now investigate each state to determine the optimal path to this point.
134 for(int state = 0; state<nStates; ++state) {
135 int[] transitionPotentials = statePotentials[state];
136
137 // For each state, determine how far we should look back
138 int lookbackSize = maxStateLengths[state];
139
140 // Examine each length, starting with 1 and increasing.
141 double max = Double.NEGATIVE_INFINITY;
142 int bestLookback = -1;
143 int bestPrevState = -2;
144
145 // As we go through the lookbacks for the previous state, go through the stored Mi matrics and stable positions
146 Iterator<double[]> miIter = mis.iterator();
147 Iterator<double[]> stableIter = stableStates.iterator();
148 for(int lookback = 0; lookback < lookbackSize; ++lookback) {
149 // Find the starting position we are evaluating for this state
150 int startPos = pos - lookback;
151 //Assert.a(startPos >= 0);
152 if(startPos == 0) {
153 // In this case we have the extra check that we should have the same number of previous Mis as our length
154 Assert.a(mis.size() == lookback, "More Mi matrices in history than there are previous positions in the sequence.");
155
156 // Examine the case where this is the first position in the sequence.
157 // Score is the initial score plus the score of staying in that feature for this long.
158 double current = latestStable[state];
159 // Now add length dependent features
160 calc.result.evaluateNodeLength(seq, pos, lookback+1, state);
161 double lengthCost = calc.calcRet(false);
162 current += lengthCost;
163
164 // Check if this is our best so far
165 if(current > max) {
166 max = current;
167 bestLookback = lookback;
168 bestPrevState = -1;
169 }
170 // Don't bother looking back any farther
171 break;
172 }
173 else {
174 double nodeLengthCost = Double.NaN;
175 // This state starts after the beginning of the sequences.
176 // Need to check all legal transitions from previous states.
177 // This includes transitions to self although I'm not sure I like that for states with length models.
178 double[] lookbackMi = miIter.next();
179 double[] lookbackStable = stableIter.next();
180 for(int transition : transitionPotentials) {
181 int prevState = transitions.transitionFrom[transition];
182 if(prevState == state && lookbackSize > 1) {
183 // For states with explicit length distributions, ignore transition to self.
184 continue;
185 }
186 double transitionCost = lookbackMi[transition];
187 if(Double.isInfinite(transitionCost)) {
188 // This transition was invalid at this location, ignore.
189 continue;
190 }
191 if(Double.isNaN(nodeLengthCost)) {
192 // Only calculate this if you know you need it.
193 nodeLengthCost = calc.calcNodeLengthValue(seq, pos, lookback+1, state);
194 }
195 double current = bestScore[nStates *(pos-(lookback+1)) + prevState];
196 if(current == Double.NEGATIVE_INFINITY) {
197 // The state was not valid at the previous position
198 continue;
199 }
200 double stable = latestStable[state]-lookbackStable[state];
201 calc.result.evaluateEdgeLength(seq, pos, lookback+1, prevState, state);
202 double lengthCost = nodeLengthCost + calc.calcRet(false);
203
204 current += transitionCost + stable + lengthCost;
205 if(current > max) {
206 //log.info("Selected Pos: "+pos+" Edge: "+prevState+"-"+state+" "+current+" vs. "+max+" Prev. "+bestScore[nStates *(pos-(lookback+1)) + prevState]+" Trans. "+transitionCost);
207 max = current;
208 bestLookback = lookback;
209 bestPrevState = prevState;
210 }
211 }
212 }
213 }
214
215 // Fill in our best entry
216 //log.info(String.format("Pos: %d State: %d BestScore: "+max+" BackPointer: %d", pos, state, bestPrevState));
217 int index= pos*nStates + state;
218 bestScore[index] = max;
219 backPointers[index] = bestPrevState;
220 localBackLengths[index] = bestLookback+1;
221 }
222 }
223
224 //printForwardPass(bestScore, backPointers, backLengths, nStates, seq.length());
225 //log.info(ColtUtil.format(bestScore));
226 //log.info(StringUtils.join(backPointers," "));
227 //log.info(backLengths);
228
229 // Now that we have the matrix, trace back to get the best path.
230 int[] ret = new int[len];
231 int pos = len-1;
232 int state = ColtUtil.maxInColumn(bestScore, nStates, len-1);
233 Assert.a(state != -2, "No valid paths");
234
235 while(pos >= 0) {
236 int stateLen = localBackLengths[pos*nStates + state];
237 int prevState = backPointers[pos*nStates + state];
238 //log.info(String.format("State: %d, Len: %d, Ends At: %d", state, stateLen, pos));
239 for(int i = 0; i < stateLen; ++i) {
240 ret[pos] = state;
241 pos--;
242 }
243 state = prevState;
244 }
245 Assert.a(pos == -1);
246 InferenceResult inferenceResult = new InferenceResult();
247 inferenceResult.hiddenStates = ret;
248 inferenceResult.bestScores = new double[nStates];
249 System.arraycopy(bestScore, nStates*(len-1), inferenceResult.bestScores, 0, nStates );
250 return inferenceResult;
251 }
252
253 /** Updates the stableStates list based on the current position. The stableStates list contains the
254 cost of the non-length dependent features for the duration of the lookback.
255 */
256 void updateStableBuffer(LinkedList<double[]> stableStates, double[] nextMi) {
257 // Recycle matrices at the end
258 double[] stableState;
259 if(stableStates.size() > maxLookback) {
260 stableState = stableStates.removeLast();
261 }
262 else {
263 stableState = new double[nStates];
264 }
265
266 double[] prevState = stableStates.getFirst();
267 for(int ix = 0; ix < nStates; ++ix) {
268 if(maxStateLengths[ix] > 1) {
269 int trans = selfTransitions[ix];
270 if(trans != -1) {
271 if (Double.isInfinite(nextMi[trans]))
272 stableState[ix] = prevState[ix];
273 else
274 stableState[ix] = prevState[ix] + nextMi[trans];
275 //log.info(String.format("stableState[%d] = %f = %f + %f", ix, stableState[ix], prevState[ix], nextMi[trans]));
276 }
277 }
278 }
279
280 // Now add this to the beginning.
281 stableStates.addFirst(stableState);
282 }
283
284 double[] updateMiBuffer(LinkedList<double[]> mis, double[] nextMi) {
285 mis.addFirst(nextMi);
286 if(mis.size() > maxLookback) {
287 nextMi = mis.removeLast();
288 }
289 else {
290 nextMi = new double[transitions.nTransitions];
291 }
292 return nextMi;
293 }
294
295 // Used for debugging only.
296 /*
297 * private void printForwardPass(double[] bestScore, int[] backPointers, int[] backLengths, int numStates, int seqLen)
298 {
299 try {
300 //Writer fout = new BufferedWriter(new FileWriter("test/working/crf_forwardPass.txt"));
301 Writer fout = new BufferedWriter(new FileWriter("crf_forwardPass.txt"));
302
303 int i, pos, st;
304 //double[][] bestScores = new double[seqLen][numStates];
305 //int[][] bestIndices = new int[seqLen][numStates];
306
307 fout.write("Viterbi Map from Forward Pass\n");
308
309 fout.write("\n");
310
311 for (pos=0; pos<seqLen; pos++)
312 {
313 for (st=0; st<numStates; st++)
314 {
315 fout.write( (pos*numStates + st) + "\t" + pos + "\t" + st + "\t");
316 if (bestScore[pos*numStates + st] == 0)
317 fout.write(String.format("%1$11.0f", bestScore[pos*numStates + st]) + " \t" + "." + "\n");
318 else if (bestScore[pos*numStates + st] == Double.NEGATIVE_INFINITY)
319 fout.write(String.format("%1$11.0f", 0.0f) + " \t" + "." + "\n");
320 else {
321 //fout.write(String.format("%1$11.2f", bestScore[pos*numStates + st]) + " \t" + "--" + "\n");
322
323 if (backPointers[pos*numStates + st] == -1)
324 fout.write(String.format("%1$11.2f", bestScore[pos*numStates + st]) + " \t" + "--" + "\t" + backPointers[pos*numStates + st] + "\n");
325 else
326 fout.write(String.format("%1$11.2f", bestScore[pos*numStates + st]) + " \t" + (pos - backLengths[pos*numStates + st]) + "\t" + backPointers[pos*numStates + st] + "\n");
327
328 }
329 }
330 //fout.write("\n");
331 }
332 fout.close();
333
334 } catch (IOException e) {
335 throw new RuntimeException("Error writing alpha pass");
336 }
337 }
338 */
339
340 int[][] getStatePotentials(TransitionInfo transitions) {
341 int[][] statePotentials = new int[nStates][];
342 int currentState = -1;
343 List<Integer> currentList = null;
344 for(int potential : transitions.orderedPotentials) {
345 if(potential < nStates) {
346 if(currentState != -1) {
347 statePotentials[currentState] = toIntArray(currentList);
348 }
349 currentState = potential;
350 currentList = new ArrayList<Integer>();
351 }
352 else {
353 currentList.add(potential-nStates);
354 }
355 }
356 statePotentials[currentState] = toIntArray(currentList);
357 return statePotentials;
358 }
359
360 int[] toIntArray(List<Integer> list) {
361 int[] ret = new int[list.size()];
362 for(int i=0; i<ret.length; ++i) {
363 ret[i] = list.get(i);
364 }
365 return ret;
366 }
367 }