001 package calhoun.analysis.crf.solver;
002
003 import org.apache.commons.logging.Log;
004 import org.apache.commons.logging.LogFactory;
005
006 import calhoun.analysis.crf.CRFInference;
007 import calhoun.analysis.crf.ModelManager;
008 import calhoun.analysis.crf.io.InputSequence;
009 import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
010 import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
011 import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
012 import calhoun.util.Assert;
013 import calhoun.util.ColtUtil;
014
015 /** viterbi algorithm for semi-Markov CRFs. Uses a {@link CacheProcessor} to handle the complexities of evaluation.
016 * This is similar to a markov viterbi search, except we have to search over all valid segments to fill in one entry
017 * in the matrix instead of just looking at the last position.
018 */
019 public class SemiMarkovViterbi implements CRFInference {
020 private static final Log log = LogFactory.getLog(SemiMarkovViterbi.class);
021 boolean debug = log.isDebugEnabled();
022
023 double[] lambda;
024
025 private double[] bestScore;
026 private int[] backPointers;
027 int nStates;
028
029 SolverSetup modelInfo;
030 CacheProcessor cacheProcessor;
031 FeatureEvaluation[] evals;
032 LengthFeatureEvaluation[][] lengthEvals;
033 boolean[] invalidTransitions;
034 int[] selfTransitions;
035
036 RecyclingBuffer<double[]> stableStates;
037 double[] stableVector;
038
039 /** gets the cache processor used to access feature evaluations
040 * @return the configured cache processor
041 */
042 public CacheProcessor getCacheProcessor() {
043 return cacheProcessor;
044 }
045
046 /** sets the cache processor used to access feature evaluations
047 * @param cacheProcessor the cache processor to use
048 */
049 public void setCacheProcessor(CacheProcessor cacheProcessor) {
050 this.cacheProcessor = cacheProcessor;
051 }
052
053 public InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
054 this.lambda = lambda;
055
056 cacheProcessor.setInputData(fm, seq);
057 modelInfo = cacheProcessor.getSolverSetup();
058 nStates = modelInfo.nStates;
059 Assert.a(modelInfo.maxStateLengths.length == nStates, "Maximum state lengths array was length ("+modelInfo.maxStateLengths.length+"). Must have one entry for each state "+modelInfo.nStates+")");
060 evals = cacheProcessor.getFeatureEvaluations();
061 lengthEvals = cacheProcessor.getLengthFeatureEvaluations();
062 invalidTransitions = cacheProcessor.getInvalidTransitions();
063
064 int len = seq.length();
065
066 selfTransitions = new int[nStates];
067 for(int i=0; i<nStates; ++i) {
068 selfTransitions[i] = modelInfo.selfTransitions[i];
069 }
070
071 // Circular buffers of the previous mi matrices.
072 RecyclingBuffer<double[]> mis = new RecyclingBuffer<double[]>(new double[modelInfo.maxLookback][modelInfo.nTransitions]);
073 double[] nextMi = new double[modelInfo.nTransitions];
074
075 // Circular buffers of the values of staying in a stable state for a given period of time
076 stableStates = new RecyclingBuffer<double[]>(new double[modelInfo.maxLookback][nStates]);
077 stableVector = new double[nStates];
078
079 bestScore = new double[len*nStates];
080 backPointers = new int[len*nStates];
081 int[] backLengths = new int[len*nStates];
082
083 for (int pos = 0; pos < len; pos++) {
084 /* compute weighted features. These are for transitions at the current base for non-length dependent features. */
085
086 // Save the first Ri matrix, since it has the initial probabilies
087 if(pos == 0) {
088 computeSparseMi(seq, pos, null, stableVector);
089 stableVector = stableStates.addFirst(stableVector);
090 }
091 else {
092 computeSparseMi(seq, pos, nextMi, null);
093
094 // Use this transition matrix to update the stable vectors
095 updateStableBuffer(nextMi);
096
097 // Add this into the list of saved Mi matrices
098 nextMi = mis.addFirst(nextMi);
099 }
100
101 double[] latestStable = stableStates.get(0);
102 double[] latestMi = mis.get(0);
103
104 // Do states without lookback first.
105 for(CacheProcessor.StatePotentials potentials : modelInfo.statesWithoutLookback) {
106 int state = potentials.state;
107 double max = Double.NEGATIVE_INFINITY;
108 int invalidIndex = pos*modelInfo.nPotentials;
109 int bestLookback = 0;
110 int bestPrevState = -2;
111
112 // Check that the node is valid, otherwise put in a dummy entry.
113 if(!invalidTransitions[invalidIndex + state]) {
114 if(pos == 0) {
115 if(debug)
116 log.debug(String.format("Pos: %d State: %d %.2f", pos, state, latestStable[state]));
117 // For the first position, we just store the starting potential
118 max = latestStable[state];
119 bestPrevState = -1;
120 }
121 else {
122 for(byte edgePotential : potentials.potentials) {
123 if(invalidTransitions[invalidIndex + state])
124 continue;
125
126 int transition = edgePotential - nStates;
127 int prevState = modelInfo.transitionFrom[transition];
128 double transitionCost = latestMi[transition];
129
130 if(Double.isInfinite(transitionCost)) {
131 // This transition was invalid at this location, ignore.
132 continue;
133 }
134
135 double previous = bestScore[nStates *(pos-1) + prevState];
136 double current = previous + transitionCost;
137 if(debug)
138 log.debug(String.format("Pos: %d Trans: %d-%d %.2f (Prev: %.2f + Trans: %.2f)", pos, prevState, state, current, previous, transitionCost));
139 if(current > max) {
140 max = current;
141 bestPrevState = prevState;
142 }
143 }
144 }
145 }
146
147 // Fill in our best entry
148 int index= pos*nStates + state;
149 //log.debug(String.format("bestScore[%d] = %.2f", index, max));
150 bestScore[index] = max;
151 backPointers[index] = bestPrevState;
152 backLengths[index] = bestLookback+1;
153 }
154
155 // Now repeat for states with lookback.
156 for(int i=0; i<modelInfo.statesWithLookback.length; ++i) {
157 CacheProcessor.StatePotentials potentials = modelInfo.statesWithLookback[i];
158 LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
159
160 int state = potentials.state;
161 double max = Double.NEGATIVE_INFINITY;
162 int bestLookback = -1;
163 int bestPrevState = -2;
164
165 cacheProcessor.evaluateSegmentsEndingAt(0, pos);
166
167 int lbIndex=0;
168 LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
169 int lookback = lengthEval.lookback;
170 while(lookback != -1) {
171 //log.info("Pos: "+pos+"\t State: "+modelInfo.statesWithLookback[i].state+"\t Lookback: "+lookback);
172
173 double[] lookbackMi = mis.get(lookback);
174 double[] lookbackStable = stableStates.get(lookback);
175
176 // Handle evaluation of the node potentials
177 FeatureEvaluation nodeEvals = lengthEval.nodeEval;
178 short[] indices = nodeEvals.index;
179 float[] vals = nodeEvals.value;
180 int ix = 0;
181 short index = indices[ix];
182 double nodePotential = 0.0;
183 while(index >= 0) {
184 nodePotential += vals[ix] * lambda[index];
185 index = indices[++ix];
186 }
187 Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid");
188
189 int prevPos = pos - lookback - 1;
190 if(prevPos < 0) {
191 // Examine the case where this is the first segment in the sequence.
192 // Score is the sum of non-length dependent fetaures plus length features.
193 double current = latestStable[state] + nodePotential;
194 if(debug)
195 log.debug(String.format("Pos: %d Lb: %d State: %d %.2f (Stable: %.2f + Node: %.2f)", pos, lookback, state, current, latestStable[state], nodePotential));
196
197 // Check if this is our best so far
198 if(current > max) {
199 max = current;
200 bestLookback = lookback;
201 bestPrevState = -1;
202 }
203 }
204 else {
205 // If this is not the first segment, we need to deal with edges coming into this segment
206 FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
207 int nEdges = potentials.potentials.length;
208 for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
209 int potential = potentials.potentials[edgeIx];
210 int trans = potential - modelInfo.nStates;
211 int fromNode = modelInfo.transitionFrom[trans];
212 // Skip semi-Markov self transitions
213 if(fromNode == state)
214 continue;
215
216 double edgeVal = 0.0;
217
218 if(edgeEvals == null) {
219 // If the cache processor does not have edge evaluations
220 // Just check if this transition is legal based on the invalid transitions matrix
221 if(invalidTransitions[(prevPos+1)*modelInfo.nPotentials + potential]) {
222 //log.info("Illegal transition: "+fromNode+"-"+toNode+" at pos: "+prevPos);
223 continue;
224 }
225 }
226 else {
227 // If the cache processor does have edge evaluations, then ignore the illegal transitions matrix
228 // and update the expval using the edge evaluations
229 FeatureEvaluation potEvals = edgeEvals[edgeIx];
230 indices = potEvals.index;
231 vals = potEvals.value;
232 ix = 0;
233 index = indices[i];
234 if(index == Short.MIN_VALUE) {
235 log.info("SHORT.MIN_VALUE");
236 continue;
237 }
238 while(index != -1) {
239 edgeVal += vals[ix] * lambda[index];
240 index = indices[++ix];
241 }
242 }
243
244 double prevBest = bestScore[nStates *(pos-(lookback+1)) + fromNode];
245 double stable = latestStable[state]-lookbackStable[state];
246
247 // Renormalize and update the exp value.
248 double current = prevBest + nodePotential + edgeVal + stable + lookbackMi[trans];
249 if(debug)
250 log.debug(String.format("Pos: %d Lb: %d Trans: %d-%d %.4f (Prev: %.4f + Stable: %.4f + Trans: %.4f + Node: %.4f + Edge: %.4f)", pos, lookback, fromNode, state,
251 current, prevBest, stable, lookbackMi[trans], nodePotential, edgeVal));
252
253 if(current == Double.NEGATIVE_INFINITY) {
254 // The state was not valid at the previous position
255 continue;
256 }
257
258 if(current > max) {
259 //log.info("Selected Pos: "+pos+" Edge: "+prevState+"-"+state+" "+current+" vs. "+max+" Prev. "+bestScore[nStates *(pos-(lookback+1)) + prevState]+" Trans. "+transitionCost);
260 max = current;
261 bestLookback = lookback;
262 bestPrevState = fromNode;
263 }
264 }
265 }
266 ++lbIndex;
267 lengthEval = lookbacksForState[lbIndex];
268 lookback = lengthEval.lookback;
269 }
270
271 // Fill in our best entry
272 //log.info(String.format("Pos: %d State: %d BestScore: "+max+" BackPointer: %d", pos, state, bestPrevState));
273 int index= pos*nStates + state;
274 bestScore[index] = max;
275 backPointers[index] = bestPrevState;
276 backLengths[index] = bestLookback+1;
277 }
278 }
279 //printForwardPass(bestScore, backPointers, backLengths, nStates, seq.length());
280 //log.info(ColtUtil.format(bestScore));
281 //log.info(StringUtils.join(backPointers," "));
282 //log.info(backLengths);
283
284 // Now that we have the matrix, trace back to get the best path.
285 int[] ret = new int[len];
286 int pos = len-1;
287 int state = ColtUtil.maxInColumn(bestScore, nStates, len-1);
288 Assert.a(state != -2, "No valid paths");
289
290 while(pos >= 0) {
291 int stateLen = backLengths[pos*nStates + state];
292 int prevState = backPointers[pos*nStates + state];
293 //log.info(String.format("State: %d, Len: %d, Ends At: %d", state, stateLen, pos));
294 for(int i = 0; i < stateLen; ++i) {
295 ret[pos] = state;
296 pos--;
297 }
298 state = prevState;
299 }
300 Assert.a(pos == -1);
301 InferenceResult inferenceResult = new InferenceResult();
302 inferenceResult.hiddenStates = ret;
303 inferenceResult.bestScores = new double[nStates];
304 System.arraycopy(bestScore, nStates*(len-1), inferenceResult.bestScores, 0, nStates );
305 return inferenceResult;
306 }
307
308 /** Updates the stableStates list based on the current position. The stableStates list contains the
309 cost of the non-length dependent features for the duration of the lookback.
310 */
311 void updateStableBuffer(double[] nextMi) {
312 // Recycle matrices at the end
313 double[] prevState = stableStates.get(0);
314 for(int ix = 0; ix < nStates; ++ix) {
315 if(modelInfo.maxStateLengths[ix] > 1) {
316 int trans = selfTransitions[ix];
317 if(trans != -1) {
318 if (Double.isInfinite(nextMi[trans]))
319 stableVector[ix] = prevState[ix];
320 else
321 stableVector[ix] = prevState[ix] + nextMi[trans];
322 //log.info(String.format("stableState[%d] = %f = %f + %f", ix, stableState[ix], prevState[ix], nextMi[trans]));
323 }
324 }
325 }
326
327 // Now add this to the beginning.
328 stableVector = stableStates.addFirst(stableVector);
329 }
330
331 void computeSparseMi(InputSequence seq, int pos, double[] mi, double[] ri) {
332 cacheProcessor.evaluatePosition(0, pos);
333 double nodeVal = Double.NaN;
334 int invalidIndex = pos*modelInfo.nPotentials;
335 for(short potential : modelInfo.orderedPotentials) {
336 boolean invalid = invalidTransitions[invalidIndex + potential];
337 double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
338
339 // Add up all features for this potential.
340 FeatureEvaluation potEvals = evals[potential];
341 short[] indices = potEvals.index;
342 float[] vals = potEvals.value;
343 int i = 0;
344 short index = indices[i];
345 while(index != -1) {
346 // An invalid potential is indicated by a feature value of Short.MAX_VALUE
347 features += (index == Short.MIN_VALUE) ? Double.NEGATIVE_INFINITY : vals[i]*lambda[index];
348 index = indices[++i];
349 }
350
351 if(potential < modelInfo.nStates) {
352 nodeVal = features;
353 if(ri != null) {
354 ri[potential] = nodeVal;
355 }
356 }
357 else {
358 //log.debug(String.format("Mi[%d, %d] = %f, adding in %f to get %f", feat.yprev(), feat.y(), val, feat.value(), val*exp(feat.value()*param[feat.index()])));
359 int transition = potential - modelInfo.nStates;
360 double val = features + nodeVal;
361 if(mi != null)
362 mi[transition] = val;
363 }
364 }
365 }
366 }