001 package calhoun.analysis.crf.solver;
002
003 import java.io.BufferedWriter;
004 import java.util.Arrays;
005 import java.util.List;
006
007 import org.apache.commons.logging.Log;
008 import org.apache.commons.logging.LogFactory;
009
010 import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
011 import calhoun.analysis.crf.ModelManager;
012 import calhoun.analysis.crf.io.TrainingSequence;
013 import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
014 import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
015 import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
016 import calhoun.util.Assert;
017 import calhoun.util.ColtUtil;
018 import calhoun.util.FileUtil;
019
020
021 /** computes the likelihood of the true path for a semi-Markov CRF. The likelihood is normalized to a per label likelihood.
022 * <h2>Debugging output</h2>
023 * To get a better understanding of what the objective function is doing, several differn properties can be set that
024 * cause the objective function to write out trace files showing its calculations during training. Usually when turning
025 * these options on, you should set <code>maxIters = 1</code> and <code>requireConvergence = false</code> in your optimizer
026 * to do only a single training iteration, possibly setting the starts to some predetermined value. Each of these
027 * properties can be configured with a filename and each time {@link #apply} is called, the file will be overwritten with
028 * data from the current call. The logging options are:
029 * <ul>
030 * <li> <b><code>alphaFile</code></b> - computation of alpha values for Markov states, includes all nodes and edges.
031 * <li> <b><code>alphaLengthFile</code></b> - computation of alpha values for semi-Markov states , includes all segments
032 * <li> <b><code>expectFile</code></b> - computation of expected values for each Markov feature
033 * <li> <b><code>expectLengthFile</code></b> - computation of expected values for each semi-Markov feature
034 * <li> <b><code>nodeMarginalFile</code></b> - computation of marginal probability of each state at each position
035 * </ul>
036 * */
037 //public class MaximumLikelihoodSemiMarkovGradient extends CleanMaximumLikelihoodSemiMarkovGradient {
038 //}
039 public class MaximumLikelihoodSemiMarkovGradient implements CRFObjectiveFunctionGradient {
040 private static final Log log = LogFactory.getLog(MaximumLikelihoodSemiMarkovGradient.class);
041 private static final boolean debug = log.isDebugEnabled();
042 private static final double ASSERTION_TOLERANCE = 0.0001;
043
044 private static final int NORM_FACTOR = 50;
045 private static final double NORM_MIN = Math.exp(-NORM_FACTOR);
046 private static final double NORM_MAX = Math.exp(NORM_FACTOR);
047
048 String alphaFile = null;
049 String alphaLengthFile = null;
050 String betaLengthFile = null;
051 String expectFile = null;
052 String expectLengthFile = null;
053 String nodeMarginalFile = null;
054 BufferedWriter alphaWriter = null;
055 BufferedWriter alphaLengthWriter = null;
056 BufferedWriter betaLengthWriter = null;
057 BufferedWriter expectWriter = null;
058 BufferedWriter expectLengthWriter = null;
059 BufferedWriter nodeMarginalWriter = null;
060
061 SolverSetup modelInfo;
062 CacheProcessor cacheProcessor;
063 FeatureEvaluation[] evals;
064 LengthFeatureEvaluation[][] lengthEvals;
065 boolean[] invalidTransitions;
066
067 // / Cache feature information
068 // / Cached value of the Mi matrix for all of the features present at every position
069 // / Mi is stored as a sparse matrix
070 short maxLookback;
071 CacheProcessor.StatePotentials[] statesWithLookback;
072 CacheProcessor.StatePotentials[] statesWithoutLookback;
073 List<TrainingSequence> data;
074 int iter = 0;
075 double[][] alphas;
076 int[] alphaNorms;
077 double[] starterAlpha;
078
079 // At any given point, lookbackBuffer.get(x) returns the information about a lookback of x. Lookbacks start at 0.
080 RecyclingBuffer<LookbackBuffer> lookbackBuffer;
081 LookbackBuffer nextBuffer;
082
083 double[] lambda;
084 double logZ;
085 int zNorm;
086 double zInv;
087 double[] expects;
088
089 AlphaLengthFeatureProcessor alphaProcessor;
090 BetaLengthFeatureProcessor betaProcessor;
091
092 // We publish feature sums
093 private double[] featureSums;
094
095 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
096 cacheProcessor.setTrainingData(fm, data);
097 modelInfo = cacheProcessor.getSolverSetup();
098 Assert.a(modelInfo.maxStateLengths != null, "Maximum state lengths not set.");
099 Assert.a(modelInfo.maxStateLengths.length == modelInfo.nStates, "Maximum state lengths array was length ("+modelInfo.maxStateLengths.length+"). Must have one entry for each state "+modelInfo.nStates+")");
100 evals = cacheProcessor.getFeatureEvaluations();
101 lengthEvals = cacheProcessor.getLengthFeatureEvaluations();
102 invalidTransitions = cacheProcessor.getInvalidTransitions();
103
104 // Create local references to cache values
105 maxLookback = modelInfo.maxLookback;
106 statesWithLookback = modelInfo.statesWithLookback;
107 statesWithoutLookback = modelInfo.statesWithoutLookback;
108
109 // Initialize betas (for use later, in the gradient computation)
110 alphas = new double[modelInfo.longestSeq][modelInfo.nStates];
111 alphaNorms = new int[modelInfo.longestSeq];
112 expects = new double[modelInfo.nFeatures];
113
114 LookbackBuffer[] bufferContents = new LookbackBuffer[maxLookback+3];
115 for(int i = 0; i<maxLookback+3; ++i) {
116 bufferContents[i] = new LookbackBuffer();
117 }
118 lookbackBuffer = new RecyclingBuffer<LookbackBuffer>(bufferContents);
119 nextBuffer = new LookbackBuffer();
120
121 alphaProcessor = new AlphaLengthFeatureProcessor();
122 betaProcessor = new BetaLengthFeatureProcessor();
123
124 starterAlpha = new double[modelInfo.nStates];
125 }
126
127 public double apply(double[] param, double[] grad) {
128 log.debug(String.format("Beginning It: %d Weights: %s", iter, ColtUtil.format(param)));
129 alphaWriter = FileUtil.safeOpen(alphaFile);
130 alphaLengthWriter = FileUtil.safeOpen(alphaLengthFile);
131 betaLengthWriter = FileUtil.safeOpen(betaLengthFile);
132 expectWriter = FileUtil.safeOpen(expectFile);
133 expectLengthWriter = FileUtil.safeOpen(expectLengthFile);
134 nodeMarginalWriter = FileUtil.safeOpen(nodeMarginalFile);
135 lambda = param;
136 Arrays.fill(grad, 0);
137 double totalZ = 0.0;
138 double result = 0.0;
139
140 try {
141 // Iterate through sequences
142 Arrays.fill(expects, 0);
143 for (int i = 0; i < modelInfo.nSeqs; ++i) {
144 int len = modelInfo.seqOffsets[i + 1] - modelInfo.seqOffsets[i];
145 // Work forwards, computing alphas
146 alphaProcessor.computeAlpha(i, len);
147
148 // Since the final beta array is all ones, we can sum the alphas to get the Z
149 double sum = 0.0;
150 for (double val : alphas[len - 1]) {
151 sum += val;
152 }
153
154 logZ = log(sum) + NORM_FACTOR * (alphaNorms[len - 1]);
155 zNorm = ((int) logZ) / NORM_FACTOR;
156 zInv = exp(zNorm * NORM_FACTOR - logZ);
157 //log.debug("Seq: "+i+" Z: "+printNorm(1/zInv, zNorm));
158
159 // Work backwards, computing betas and expectations.
160 betaProcessor.computeBetasAndExpectations(i, len);
161
162 if(log.isDebugEnabled()) {
163 double[][] seqFeatureSums = cacheProcessor.getSequenceFeatureSums();
164 if(seqFeatureSums != null) {
165 double seqResult = 0.0;
166 for (int j = 0; j < modelInfo.nFeatures; ++j) {
167 seqResult += seqFeatureSums[i][j] * param[j];
168 }
169 log.debug(String.format("Seq: %d L: %g LL: %f Training path: %f Z: %f", i, exp(seqResult-logZ), seqResult-logZ, seqResult, logZ));
170 Assert.a(exp(seqResult-logZ) < 1.0);
171 }
172 }
173
174 // Update for the next sequence
175 totalZ += logZ;
176 }
177
178 // sum_j lambda_j F_j(xk, yk)
179 double[] featureSums = cacheProcessor.getFeatureSums();
180 this.featureSums = featureSums;
181 for (int j = 0; j < modelInfo.nFeatures; ++j) {
182 result += featureSums[j] * param[j];
183 grad[j] = featureSums[j] - expects[j];
184 }
185 log.debug("Path Value: "+result+" Norm: "+totalZ);
186 result -= totalZ;
187 if (log.isInfoEnabled()) {
188 if(log.isDebugEnabled()) {
189 log.debug(String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", iter, exp(result), result,
190 ColtUtil.norm(grad), ColtUtil.format(featureSums), ColtUtil.format(expects), ColtUtil.format(param), ColtUtil.format(grad)));
191 }
192 else {
193 log.info(String.format("It: %d LL=%f, norm(grad): %f", iter, exp(result), result, ColtUtil.norm(grad)));
194 }
195 }
196 Assert.a(exp(result) <= 1.0, "Likelihood is greater than 1.");
197 result = result/modelInfo.totalPositions;
198 for(int i=0; i<grad.length; ++i) {
199 grad[i] = grad[i]/modelInfo.totalPositions;
200 }
201 iter += 1;
202 }
203 finally {
204 FileUtil.safeClose(alphaWriter);
205 FileUtil.safeClose(alphaLengthWriter);
206 FileUtil.safeClose(betaLengthWriter);
207 FileUtil.safeClose(expectWriter);
208 FileUtil.safeClose(expectLengthWriter);
209 FileUtil.safeClose(nodeMarginalWriter);
210 }
211
212 return result;
213 }
214
215 public void clean() {
216 // Clean up as much as possible
217 modelInfo = null;
218 cacheProcessor = null;
219 evals = null;
220 lengthEvals = null;
221 invalidTransitions = null;
222 statesWithLookback = null;
223 statesWithoutLookback = null;
224 data = null;
225 alphas = null;
226 alphaNorms = null;
227 starterAlpha = null;
228 lookbackBuffer = null;
229 nextBuffer = null;
230 expects = null;
231 alphaProcessor = null;
232 betaProcessor = null;
233 }
234
235 private final class AlphaLengthFeatureProcessor {
236 int seqOffset;
237 int pos;
238 double[] alpha;
239 int alphaNorm;
240 double[] stableState;
241
242 /**
243 * In the forward pass we compute alpha values and expections. This is simpler than the backwards pass because
244 * the cache is set up for us so that we can always look at one position at a time. We have to cache previous
245 * values but we never have to look ahead.
246 */
247 final void computeAlpha(final int seqNum, final int len) {
248 // Result the alpha norms
249 Arrays.fill(alphaNorms, Integer.MIN_VALUE);
250
251 Arrays.fill(starterAlpha, 0.0);
252 double[] prevAlpha = null;
253
254 seqOffset = modelInfo.seqOffsets[seqNum];
255
256 for(pos = 0; pos < len; ++pos) {
257 prevAlpha = alpha;
258 alpha = alphas[pos];
259 Arrays.fill(alpha, 0.0);
260 if (pos == 0) {
261 alphaNorm = 0;
262 calcStartAlpha(alpha, seqNum);
263
264 // Put an empty entry in the lookback so the first base has 0's initialized.
265 Arrays.fill(nextBuffer.stableState, 0.0);
266 } else {
267 cacheMi(seqNum, nextBuffer.mi, stableState, nextBuffer.stableState, pos);
268 regularAlphaUpdate(pos, nextBuffer.mi, prevAlpha, alpha);
269 }
270
271 // Add the lookback into the array
272 stableState = nextBuffer.stableState;
273 nextBuffer = lookbackBuffer.addFirst(nextBuffer);
274
275 // Now we need to loop through for the length dependent cache
276 lengthAlpha(seqNum, pos);
277
278 int norm = normalize(alpha);
279 /*if(norm != 0) {
280 log.info("Pos: "+pos+" Renormalized alpha by "+norm+" to "+(alphaNorm+norm)+" : "+ColtUtil.format(alpha));
281 }*/
282 alphaNorm += norm;
283 alphaNorms[pos] = alphaNorm;
284 // if(alphaNorm > 0 && pos > 0 && alphaNorms[pos-1]<=0)
285 // log.info("Norm prob at pos: "+pos);
286 }
287 }
288
289 /**
290 * Updates the alpha vector for non-length dependent states. We don't have to worry about normalization here
291 * because the regular alpha update is done before the length dependent, so these are the first values that will
292 * be set.
293 */
294 private final void regularAlphaUpdate(final int pos, final double[] mi, final double[] lastAlpha, final double[] newAlpha) {
295 double nodeVal = 0.0;
296 int lastState = -1;
297 boolean lengthNode = false;
298 for (short potential : modelInfo.orderedPotentials) {
299 if (potential < modelInfo.nStates) {
300 if (lastState != -1) {
301 newAlpha[lastState] = nodeVal;
302 }
303 lastState = potential;
304 nodeVal = 0.0;
305 lengthNode = modelInfo.maxStateLengths[potential] > 1;
306 } else {
307 if (!lengthNode) {
308 int trans = potential - modelInfo.nStates;
309 double transVal = mi[trans];
310 if(!Double.isInfinite(transVal)) {
311 int from = modelInfo.transitionFrom[trans];
312 if(alphaWriter != null)
313 FileUtil.safeWrite(alphaWriter, String.format("alpha[%d][%d] = %s = %s + alpha[%d][%d] %s * %s exp(%f)\n", pos, lastState, printNorm(nodeVal + lastAlpha[from] * exp(mi[trans]), alphaNorm), printNorm(nodeVal, alphaNorm), pos-1, from, printNorm(lastAlpha[from], alphaNorm), printNorm(exp(mi[trans]), 0), mi[trans]));
314 nodeVal += lastAlpha[from] * exp(transVal);
315 }
316 }
317 }
318 }
319 newAlpha[lastState] = nodeVal;
320 }
321
322 /** Updates an alpha entry with a weighted sum of features values for a given potential */
323 private final void lengthAlpha(final int seqNum, final int pos) {
324 cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
325 /*
326 * Updates an existing alpha by adding in: potentialValue - The value of any length-dependent features for
327 * this node f(y, i, d) and edge f(y', y, i, d) stableValue - The value of the non-length dependent node
328 * features summed across the length of this segment mis - The value of the non-length dependent transition
329 * from the previous node to this one f(y', y, i-d)
330 */
331 int nSemiMarkovStates = modelInfo.statesWithLookback.length;
332 for(int i=0; i<nSemiMarkovStates; ++i) {
333 LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
334 CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
335 byte toNode = statePotentials.state;
336
337 int lbIndex=0;
338 LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
339 int lookback = lengthEval.lookback;
340 while(lookback != -1) {
341 //log.info("Pos: "+pos+"\t State: "+modelInfo.statesWithLookback[i].state+"\t Lookback: "+lookback);
342 int prevPos = pos - lookback - 1;
343 // For speed I hand inline RecyclingBuffer.get
344 LookbackBuffer buffer = lookbackBuffer.array[(lookbackBuffer.currentStart+lookback)%lookbackBuffer.length];
345
346 // Handle evaluation of the node potentials
347 FeatureEvaluation nodeEvals = lengthEval.nodeEval;
348 short[] indices = nodeEvals.index;
349 float[] vals = nodeEvals.value;
350 int ix = 0;
351 short index = indices[ix];
352 double stableValue = stableState[toNode] - buffer.stableState[toNode];
353 double nodePotential = stableValue;
354 while(index >= 0) {
355 nodePotential += vals[ix] * lambda[index];
356 index = indices[++ix];
357 }
358 if(debug)
359 Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid");
360
361 if(prevPos < 0) {
362 double nodeVal = nodePotential + starterAlpha[toNode];
363 // If this is the first segment, then we don't worry about edges and handle the node directly.
364 int norm = ((int) nodeVal) / NORM_FACTOR;
365 nodeVal -= norm * NORM_FACTOR;
366
367 if (norm > alphaNorm) {
368 renormalize(alpha, alphaNorm, norm);
369 //log.info("Renormalized alpha: "+ColtUtil.format(alpha));
370 alphaNorm = norm;
371 } else if (norm < alphaNorm) {
372 nodeVal += NORM_FACTOR * (norm - alphaNorm);
373 }
374 if(alphaLengthWriter != null) {
375 FileUtil.safeWrite(alphaLengthWriter, String.format("seq: %d alpha[%d][%d] = %s = %s + %s (Pot: %f Starter: %f)\n", seqNum, pos, toNode, printNorm(alpha[toNode] + exp(nodeVal), alphaNorm), printNorm(alpha[toNode], alphaNorm), printNorm(exp(nodeVal), alphaNorm), nodePotential, starterAlpha[toNode]));
376 }
377 /*if((pos == 499 && toNode == 0))
378 log.info(String.format("alpha[%d][%d] = %s = %s + %s (Pot: %f Starter: %f)", pos, toNode, printNorm(alpha[toNode] + exp(nodeVal), alphaNorm), printNorm(alpha[toNode], alphaNorm), printNorm(exp(nodeVal), alphaNorm), nodePotential, starterAlpha[toNode]));
379 */
380 alpha[toNode] += exp(nodeVal);
381 }
382 else {
383 // If this is not the first segment, we need to deal with edges coming into this segment
384 FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
385 int nEdges = statePotentials.potentials.length;
386 for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
387 int potential = statePotentials.potentials[edgeIx];
388 int trans = potential - modelInfo.nStates;
389 int fromNode = modelInfo.transitionFrom[trans];
390 // Skip semi-Markov self transitions
391 if(fromNode == toNode)
392 continue;
393
394 double edgeVal = 0.0;
395
396 if(edgeEvals == null) {
397 // If the cache processor does not have edge evaluations
398 // Just check if this transition is legal based on the invalid transitions matrix
399 int invalidIndex = (seqOffset+prevPos+1)*modelInfo.nPotentials;
400 if(invalidTransitions[invalidIndex + potential]) {
401 //log.info("Illegal transition: "+fromNode+"-"+toNode+" at pos: "+prevPos);
402 continue;
403 }
404 }
405 else {
406 // If the cache processor does have edge evaluations, then ignore the illegal transitions matrix
407 // and update the expval using the edge evaluations
408 FeatureEvaluation potEvals = edgeEvals[edgeIx];
409 indices = potEvals.index;
410 vals = potEvals.value;
411 ix = 0;
412 index = indices[i];
413 while(index >= 0) {
414 edgeVal += vals[ix] * lambda[index];
415 index = indices[++ix];
416 }
417 if(index == Short.MIN_VALUE) {
418 continue;
419 }
420 }
421
422 // Renormalize and update the exp value.
423 double expVal = edgeVal + buffer.mi[trans] + nodePotential;
424 int expNorm = ((int) expVal)/NORM_FACTOR;
425 expVal -= expNorm*NORM_FACTOR;
426
427 int prevNorm = alphaNorms[prevPos];
428 int updateNorm = expNorm + prevNorm;
429 if(updateNorm > alphaNorm) {
430 // Our updated value is larger than the existing alpha value, renormalize that alpha vector.
431 renormalize(alpha, alphaNorm, updateNorm);
432 //log.info("Renormalized alpha from "+alphaNorm+" to "+updateNorm +" : "+ColtUtil.format(alpha));
433 alphaNorm = updateNorm;
434 }
435 else if(alphaNorm > updateNorm) {
436 // Renormalize the expVal
437 int expShift = alphaNorm - updateNorm;
438 //log.info(String.format("Renormalize feature by %d from %d to %d",expShift, expNorm, expNorm+expShift));
439 expNorm += expShift;
440 expVal -= expShift*NORM_FACTOR;
441 }
442
443 double prevAlpha = alphas[prevPos][fromNode];
444 double update = exp(expVal) * prevAlpha;
445 if(alphaLengthWriter != null) {
446 FileUtil.safeWrite(alphaLengthWriter, String.format("seq: %d alpha[%d][%d] = %s = %s + %s (alpha[%d][%d]) * %s exp(EdgeLength: %f NodeLength: %f Edge: %f Node: %f )\n",
447 seqNum, pos, toNode, printNorm(alpha[toNode] + update, alphaNorm), printNorm(alpha[toNode], alphaNorm), printNorm(prevAlpha, alphaNorms[prevPos]),
448 prevPos, modelInfo.transitionFrom[trans], printNorm(exp(expVal), expNorm), edgeVal, nodePotential - stableValue, buffer.mi[trans], stableValue));
449 }
450
451 alpha[toNode] += update;
452 // Expensive assertion that can catch some normalization problems
453 if(debug)
454 Assert.a(expNorm + prevNorm == alphaNorm, "Norm problem. Exp: ", expNorm, " Prev alpha: ", prevNorm, " Alpha: ", alphaNorm);
455 }
456 }
457
458 ++lbIndex;
459 lengthEval = lookbacksForState[lbIndex];
460 lookback = lengthEval.lookback;
461 }
462 }
463 }
464
465 /**
466 * A specialized version of calcMi for the first position in a sequence. Has the special property that constant
467 * edge features are not included.
468 */
469 void calcStartAlpha(double[] currentAlpha, int seq) {
470 cacheProcessor.evaluatePosition(seq, 0);
471 int invalidIndex = seqOffset*modelInfo.nPotentials;
472 for(short potential : modelInfo.orderedPotentials) {
473 if(potential < modelInfo.nStates) {
474 boolean invalid = invalidTransitions[invalidIndex + potential];
475 double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
476
477 // Add up all features for this potential.
478 FeatureEvaluation potEvals = evals[potential];
479 short[] indices = potEvals.index;
480 float[] vals = potEvals.value;
481 int i = 0;
482 short index = indices[i];
483 while(index != -1) {
484 if(index == Short.MIN_VALUE) {
485 features = Double.NEGATIVE_INFINITY;
486 break;
487 }
488 features += vals[i]*lambda[index];
489 index = indices[++i];
490 }
491 if(modelInfo.maxStateLengths[potential]> 1) {
492 starterAlpha[potential] = features;
493 }
494 else {
495 currentAlpha[potential] = exp(features);
496 }
497 }
498 }
499 }
500 }
501
502 /** Updates a beta entry with a potential from a length dependent feature */
503 class BetaLengthFeatureProcessor {
504 int seqOffset;
505 int lengthPos;
506 double[] lengthStable;
507 int miPos;
508 double[] stableState;
509
510 double[] beta;
511 int betaNorm;
512
513 double prob;
514
515 LookbackBuffer posLookback;
516 LookbackBuffer prevLookback;
517
518 // This vector tracks the marginal probability of each node, accounting for length dependence
519 double[] nodeProb = new double[modelInfo.nStates];
520 double[] newNodeProb = new double[modelInfo.nStates];
521
522 // This vector tracks the marginal probability of each edge, accounting for length dependence
523 double[] edgeProb = new double[modelInfo.nTransitions];
524
525 /**
526 * Computing the betas is a bit tricky. The problem is that our cache associates each explicit
527 * length feature with the position at which it ends. For the betas, we are going backwards, and so we need to
528 * look at the positions where the length based feature starts.
529 *
530 * The way we handle this is by evaluating the length based features as they arise and incrementally adding
531 * their potential values into the beta positions as we go along. Therefore we aren't always filling one beta
532 * matrix at a time, we fill any beta entry that starts a state which ends the current position as we move
533 * backwards through the sequence.
534 *
535 * In order to do this we need to keep mi matrices and arrays of the potentials of staying in a stable state for
536 * the entire lookback period. These matrices let us compute betas anywhere we want.
537 *
538 * Unfortunately this process needs more exponentiation than the existing process.
539 *
540 * Normalization works as follows: First, we need to find the right normalization factor for the value we are
541 * adding. 1) We add all of the exponents we will need. 2) We integer divide by NORM_FACTOR to get the
542 * normalization constant for our beta update 3) The constant*NORM_FACTOR is subtracted from the total exponent
543 * to get the unnormalized exponent 4) The unnormalized exponent is exponentiated and multiplied by the
544 * following beta value 5) We compare the norm constant to the current constant for the beta we are about to
545 * update 5.1) If they are equal we are fine, the values can be added 5.2) If Our new norm constant has larger
546 * abs(), we renormalize the beta vector by dividing all elements by e^(NORM_FACTOR*diff) 5.3) If our new norm
547 * constant has smaller abs(), we renormalize it but dividing by e^(FACTOR*diff) 6) When we finally update the
548 * beta vector at the end we redo the normalization anyway.
549 *
550 * To optimize this for fastest execution it ends up being one big function. Putting function calls in the inner
551 * loops really slowed things down. I think the optimizer can't do as good a job if there are function calls
552 * there.
553 */
554 void computeBetasAndExpectations(int seqNum, int len) {
555 // Start by initializing an array of unexponentiated mi and stable state matrices for all of the possible
556 // lookback positions from the last base
557 // Also initialize the array of previous beta vectors
558 seqOffset = modelInfo.seqOffsets[seqNum];
559
560 // lastInitPos holds the position of the leftmost position which gets evaluated at the start of the run.
561 int lastInitPos = len - 2 - maxLookback;
562
563 // posLookback holds the lookback buffer for the position currently specified by pos
564 posLookback = nextBuffer;
565 // prevLookback holds the lookback buffer for the position pos+1
566 prevLookback = null;
567
568 Arrays.fill(nodeProb, 0.0);
569 Arrays.fill(newNodeProb, 0.0);
570
571 // Now work through the sequence backwards
572 miPos = len-1;
573 nextBuffer.clear();
574 for (int pos = len - 1; pos >= 0; --pos) {
575 // First, update the lookback, which caches mi and stable values if necessary
576 while(miPos >= 0 && miPos >= lastInitPos) {
577 if(miPos == len-1) {
578 // Initialize stable states
579 // Need to prime with a stable state of 0 since we don't have the init that we do for alphas.
580 Arrays.fill(nextBuffer.stableState, 0);
581 }
582 else {
583 // Update stable states given the previous transition
584 nextBuffer.clear();
585 cacheMi(seqNum, nextBuffer.mi, stableState, nextBuffer.stableState, miPos+1);
586 }
587 nextBuffer.pos = miPos;
588 stableState = nextBuffer.stableState;
589 nextBuffer = lookbackBuffer.addFirst(nextBuffer);
590 nextBuffer.clear();
591 --miPos;
592 }
593
594 // At this point, miPos contains the leftmost position for which markov features have been evaluated.
595 // The lookback buffer at this position has a stableState vector for each length dependent state and
596 // an mi matrix computed for that state.
597
598 // Retrieve the lookback information for the current position. The lookback buffer runs left to right, so the
599 // current position is not at the beginning.
600 posLookback = lookbackBuffer.get(pos - miPos -1);
601
602 // Check that posLookback is correct
603 if(debug)
604 Assert.a(posLookback.pos == pos, "Wrong lookback buffer: was ", posLookback.pos, " should be ", pos);
605
606 // Update betas (if necessary)
607 if(prevLookback == null) {
608 // For the last position in the sequence (first iteration in the for loop)
609 if(debug)
610 Assert.a(pos == len-1);
611
612 // Betas for the last position are all 1. exp(0)
613 Arrays.fill(posLookback.beta, 1);
614 posLookback.betaNorm = 0;
615
616 /* Initialize node marginals. Since all segments end at the last position, we can compute
617 this by multiplying the alpha and beta vectors and dividing by Z. The beta vector is all 1's though
618 so we just divide alpha by z */
619 double nodeNorm = exp((alphaNorms[pos] - zNorm) * NORM_FACTOR) * zInv;
620 for(int i=0; i<modelInfo.nStates; ++i) {
621 nodeProb[i] = nodeNorm * alphas[pos][i];
622 }
623 log.debug("Node marginals at seq "+seqNum+" last position ("+pos+"): "+ColtUtil.format(nodeProb));
624 }
625 else {
626 // For all but the last position, we update beta first for the markov states and then the semi-markov states.
627 if(debug)
628 Assert.a(prevLookback.pos == pos+1);
629 posLookback.betaNorm = regularBetaUpdate(pos+1, posLookback.beta, posLookback.betaNorm, prevLookback.beta, prevLookback.betaNorm, prevLookback.transitionProb, posLookback.mi);
630 //log.info(String.format("Pos: %d, Mipos: %d LbPos: %d", pos, miPos, posLookback.pos));
631
632 // At this point, the nodeProb and edgeProb vectors contain the marginal probabilities of markov nodes and edges into them.
633 // The beta is fully updated for the markov states.
634 // Semi markov has not been done.
635
636 // Now we need to loop through for the length dependent cache
637 beta = prevLookback.beta;
638 /*for(int ix=0; ix<beta.length; ++ix) {
639 Assert.a(!Double.isNaN(beta[ix]), prevLookback.pos+ " "+ ix);
640 }*/
641 betaNorm = prevLookback.betaNorm;
642 lengthStable = prevLookback.stableState;
643 lengthPos = pos + 1;
644 lengthBeta(seqNum, lengthPos);
645
646 System.arraycopy(nodeProb, 0, newNodeProb, 0, modelInfo.nStates);
647
648 // Now calculate edge marginals for staying in the same state
649 // Take the node probabilities, and subtract off each transition
650 for(CacheProcessor.StatePotentials lb: statesWithLookback) {
651 int state = lb.state;
652 int index = modelInfo.selfTransitions[state];
653 double transProb = nodeProb[state];
654 for(byte pot : lb.potentials) {
655 double lbTrans = posLookback.transitionProb[pot - modelInfo.nStates];
656 transProb -= lbTrans;
657 newNodeProb[state] -= lbTrans;
658 }
659 Assert.a(posLookback.transitionProb[index] == 0.0);
660 //log.debug("Self-trans marginal for ("+state+"): "+transProb );
661 posLookback.transitionProb[index] = transProb;
662 }
663 }
664
665 // As a check, we verify that the node marginals sum to one for each position.
666 double sum =0.0;
667 for(double x : nodeProb) {
668 if(x > 1.0+ASSERTION_TOLERANCE || x < -ASSERTION_TOLERANCE)
669 Assert.a(false, "Iter ",iter," Pos: "+pos+" Node marginals not valid "+x);
670 sum += x;
671 }
672 if(Math.abs(1.0-sum) > ASSERTION_TOLERANCE) {
673 Assert.a(false, "Iter ",iter," Pos: "+pos+" Node marginals sum to "+sum+" not 1: ", ColtUtil.format(nodeProb), " at ", seqNum," ",pos);
674 }
675
676 // Verify that edge marginals sum to 1.
677 if(prevLookback != null) {
678 if(debug) {
679 sum =0.0;
680 for(double x : edgeProb) {
681 if(x > 1+ASSERTION_TOLERANCE || x < -ASSERTION_TOLERANCE)
682 Assert.a(false, "Iter ",iter," Pos: "+pos+" Edge marginal not valid "+x);
683 sum += x;
684 }
685 for(double x : posLookback.transitionProb) {
686 if(x > 1+ASSERTION_TOLERANCE || x < -ASSERTION_TOLERANCE)
687 Assert.a(false, "Iter ",iter," Pos: "+pos+" Self-trans marginal not valid "+x);
688 sum += x;
689 }
690 /*
691 for(int i=0; i<modelInfo.nStates; ++i) {
692 log.debug("Nod "+i+" "+nodeProb[i]);
693 }*/
694 if(Math.abs(1.0-sum) > 0.001) {
695 /*for(int i = 0; i < posLookback.transitionProb.length; ++i) {
696 if(posLookback.transitionProb[i] != 0)
697 log.debug("Seg "+modelInfo.transitionFrom[i]+"-"+modelInfo.transitionTo[i]+" "+posLookback.transitionProb[i]);
698 }
699 for(int i = 0; i < edgeProb.length; ++i) {
700 if(edgeProb[i] != 0)
701 log.debug("Edg "+modelInfo.transitionFrom[i]+"-"+modelInfo.transitionTo[i]+" "+edgeProb[i]);
702 }*/
703 Assert.a(false, "Edge marginals don't sum to 1. Sum to: ", sum, " - ", ColtUtil.format(edgeProb), ColtUtil.format(posLookback.transitionProb));
704 }
705 }
706
707 // At this point the previous beta values are all updated.
708 // Update expectations for the transitions we just calculated.
709 updateExpectations(seqNum, pos+1, posLookback.transitionProb);
710 }
711
712 // Update the node probabilities to remove the length transitions
713 double[] temp = nodeProb;
714 nodeProb = newNodeProb;
715 newNodeProb = temp;
716
717 /*if (debug) {
718 if ((seqOffset == 0) && (pos < 2 || pos >= len - 2)) {
719 log.debug(String.format("Pos: %d expects: %s alphas: %s (norm %d) betas: %s (norm %d) MiPos: %d", pos, ColtUtil.format(expects), ColtUtil
720 .format(alphas[pos]), alphaNorms[pos], ColtUtil.format(posLookback.beta), posLookback.betaNorm, miPos+1));
721 }
722 }*/
723
724 prevLookback = posLookback;
725 --lastInitPos;
726 }
727 // Now update for the first position
728 //log.debug(String.format(ColtUtil.format(posLookback.beta)));
729 posLookback.betaNorm = regularBetaUpdate(0, null, posLookback.betaNorm, posLookback.beta, posLookback.betaNorm, null, null);
730 beta = posLookback.beta;
731 betaNorm = posLookback.betaNorm;
732 lengthStable = posLookback.stableState;
733 lengthPos = 0;
734 lengthBeta(seqNum, lengthPos);
735 //betaProcessor.lengthCache(seqStart);
736 updateExpectations(seqNum, 0, posLookback.transitionProb);
737 }
738
739 /**
740 * Does the update of the beta values for all of the regular, non-length dependent states. Also calculates all
741 * of the node and edge probabilities for the non-length nodes and the transitions into them.
742 */
743 private int regularBetaUpdate(int pos, double[] newBeta, int newNorm, double[] oldBeta, int oldNorm, double[] transitionProb, double[] mi) {
744 // Need to deal with possibly different normalization constants.
745 int norm = newNorm;
746 double normAdjust = 0.0;
747 if (oldNorm > newNorm) {
748 // We never make the constant smaller, so set the constant for the new factor
749 //log.info(String.format("Renormalizing beta[%d] from %d to %d", pos, newNorm, oldNorm));
750 renormalize(newBeta, newNorm, oldNorm);
751 norm = oldNorm;
752 newNorm = oldNorm;
753 } else {
754 // The case where beta(pos-1) already has a normalization constant larger than beta(pos)
755 // We need to adjust all of the updated beta values.
756 normAdjust = (oldNorm - newNorm) * NORM_FACTOR;
757 }
758
759 //log.info(String.format("Node norm pos: %d e^(alpha: %d + beta: %d - z: %d) * zinv: %f ", pos, alphaNorms[pos], norm, zNorm, zInv));
760 double[] nodeAlpha = alphas[pos];
761 double nodeNorm = exp((alphaNorms[pos] + oldNorm - zNorm) * NORM_FACTOR) * zInv;
762 double[] edgeAlpha = null;
763 double edgeNorm = Double.NaN;
764 if(pos > 0) {
765 edgeAlpha = alphas[pos-1];
766 // We add newNorm here because it ends up cancelling with normAdjust when we calc the edge prob.
767 edgeNorm = exp((alphaNorms[pos-1] + newNorm - zNorm) * NORM_FACTOR) * zInv;
768 }
769
770 for(CacheProcessor.StatePotentials potentials : statesWithoutLookback) {
771 byte node = potentials.state;
772 double nodePotential = 0.0;
773 double betaVal = oldBeta[node];
774 if(nodeMarginalWriter != null) {
775 FileUtil.safeWrite(nodeMarginalWriter, String.format("NodeMarg[%d][%d] = %f = %f * %f * %f (aN: %d bN: %d zN: %d 1/z: %f)\n", pos, node, nodeAlpha[node] * betaVal * nodeNorm, nodeAlpha[node], betaVal, nodeNorm, alphaNorms[pos], oldNorm, zNorm, zInv));
776 }
777 nodeProb[node] = nodeAlpha[node] * betaVal * nodeNorm;
778
779 // For regular states, we sum edge probabilities to get node probabilities.
780 if(pos > 0) {
781 for (short potential : potentials.potentials) {
782 int trans = potential - modelInfo.nStates;
783 double transVal = mi[trans];
784 if(!Double.isInfinite(transVal)) {
785 // Mi for beta is not exponentiated, so we do it here.
786 double potentialValue = exp(mi[trans] + normAdjust);
787 nodePotential += potentialValue;
788
789 int from = modelInfo.transitionFrom[trans];
790 /*if(pos > 2400) {
791 log.info(String.format("Beta[%d][%d] = %f = %f + Beta[%d][%d] %f * %f (Mi: %f)", pos, from, newBeta[from] + potentialValue * betaVal, newBeta[from], pos+1, node, betaVal, potentialValue, mi[trans]));
792 }*/
793 newBeta[from] += potentialValue * betaVal;
794 edgeProb[trans] = edgeAlpha[from] * potentialValue * betaVal * edgeNorm;
795 }
796 else {
797 edgeProb[trans] = 0.0;
798 }
799 }
800 }
801 }
802
803 // Now check to see if this new vector needs normalization again.
804 //log.debug("pos "+pos);
805 int ret = norm;
806 if(newBeta != null) {
807 try {
808 ret += normalize(newBeta);
809 }
810 catch(RuntimeException ex) {
811 log.warn("Normalization problem at " + pos + " "+ColtUtil.format(newBeta));
812 throw ex;
813 }
814 }
815 return ret;
816 }
817
818 /**
819 * Updates an existing beta by adding in: potentialValue - The value of any length-dependent features for this
820 * node f(y, i, d) and edge f(y', y, i, d) stableValue - The value of the non-length dependent node features
821 * summed across the length of this segment mis - The value of the non-length dependent transition from the
822 * previous node to this one f(y', y, i-d)
823 *
824 * In addition to computing the betas for length dependent features, it updates the probabilities vectors needed
825 * for the feature expectations. These are the marginal at each position or each edge.
826 */
827 private void lengthBeta(int seqNum, int pos) {
828 cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
829
830 int nSemiMarkovStates = modelInfo.statesWithLookback.length;
831 for(int i=0; i<nSemiMarkovStates; ++i) {
832 LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
833 CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
834 byte toNode = statePotentials.state;
835
836 int lbArrayIndex=0;
837 LengthFeatureEvaluation lengthEval = lookbacksForState[lbArrayIndex];
838 int lookback = lengthEval.lookback;
839 while(lookback != -1) {
840 int prevPos = lengthPos - lookback - 1;
841 int lbIndex = prevPos - miPos - 1;
842
843 //log.info("Pos: "+pos+"\t State: "+modelInfo.statesWithLookback[i].state+"\t Lookback: "+lookback);
844 LookbackBuffer segBegin = null;
845 if(prevPos >= 0) {
846 // For speed I hand inline RecyclingBuffer.get
847 segBegin = lookbackBuffer.array[(lookbackBuffer.currentStart+lbIndex)%lookbackBuffer.length];
848 }
849 LookbackBuffer stableBuffer = lookbackBuffer.array[(lookbackBuffer.currentStart+lbIndex+1)%lookbackBuffer.length];
850
851 // Handle evaluation of the node potentials
852 double stableValue = stableBuffer.stableState[toNode] - lengthStable[toNode];
853 double nodePotential = stableValue;
854
855 FeatureEvaluation nodeEvals = lengthEval.nodeEval;
856 short[] indices = nodeEvals.index;
857 float[] vals = nodeEvals.value;
858 int ix = 0;
859 short index = indices[ix];
860 while(index >= 0) {
861 nodePotential += vals[ix] * lambda[index];
862 index = indices[++ix];
863 }
864 if(debug)
865 Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid. They can be invalid because a node is invalid or a self-transition edge is invalid.");
866
867 if(prevPos < 0) {
868 // If this is the first segment, then we don't worry about edges and handle the node directly.
869 double expVal = nodePotential + starterAlpha[toNode];
870 lengthBetaHandling(seqNum, prevPos, pos, expVal, -1, toNode, 1.0, 0, nodeEvals);
871 }
872 else {
873 // If this is not the first segment, we need to deal with edges
874 FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
875 int nEdges = statePotentials.potentials.length;
876 for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
877 int potential = statePotentials.potentials[edgeIx];
878 int trans = potential - modelInfo.nStates;
879 int fromNode = modelInfo.transitionFrom[trans];
880 // Skip semi-Markov self transitions
881 if(fromNode == toNode)
882 continue;
883
884 double edgeVal = 0.0;
885
886 if(edgeEvals == null) {
887 // If the cache processor does not have edge evaluations
888 // Just check if this transition is legal based on the invalid transitions matrix
889 int invalidIndex = (seqOffset + prevPos+1)*modelInfo.nPotentials;
890 if(invalidTransitions[invalidIndex + potential])
891 continue;
892 }
893 else {
894 // If the cache processor does have edge evaluations, then ignore the illegal transitions matrix
895 // and update the expval using the edge evaluations
896 FeatureEvaluation potEvals = edgeEvals[edgeIx];
897 indices = potEvals.index;
898 vals = potEvals.value;
899 ix = 0;
900 index = indices[i];
901 while(index >= 0) {
902 edgeVal += vals[ix] * lambda[index];
903 index = indices[++ix];
904 }
905 if(index == Short.MIN_VALUE) {
906 continue;
907 }
908 }
909
910 //log.debug("Stab: "+ stableBuffer.stableState[toNode]+" - " + lengthStable[toNode]);
911 if(debug) {
912 Assert.a(prevLookback.pos == lengthPos, "Expected ",lengthPos, " was ",prevLookback.pos);
913 Assert.a(segBegin.pos == (lengthPos-lookback-1), "Expected ",(lengthPos-lookback-1), " was ",segBegin.pos);
914 }
915 double expVal = edgeVal + segBegin.mi[trans] + nodePotential;
916 double prevAlpha = alphas[prevPos][fromNode];
917 int prevAlphaNorm = alphaNorms[prevPos];
918 //log.debug("mi: "+buffer.mi[trans]+" s: "+nodePotential+" Base: "+(expVal - buffer.mi[trans] - nodePotential));
919 int expNorm = lengthBetaHandling(seqNum, prevPos, pos, expVal, fromNode, toNode, prevAlpha, prevAlphaNorm, nodeEvals);
920 expVal -= expNorm * NORM_FACTOR;
921
922 // To determine the normalization value, we compare the existing beta value to the value we are about to add
923 // Whichever is a larger number will dominate, and so we use that normalization value as the new value and ignore the
924 // previous one.
925
926 // Update the beta values
927 int updateNorm = expNorm + betaNorm;
928 if(updateNorm > segBegin.betaNorm) {
929 //log.info(String.format("Renormalize beta from %d to %d", segBegin.betaNorm, updateNorm));
930 renormalize(segBegin.beta, segBegin.betaNorm, updateNorm);
931 segBegin.betaNorm = updateNorm;
932 }
933 else if(segBegin.betaNorm > updateNorm) {
934 int expShift = updateNorm - segBegin.betaNorm;
935 //log.info(String.format("Renormalize feature from %d to %d",expNorm, expNorm+expShift));
936 expNorm += expShift;
937 expVal += expShift*NORM_FACTOR;
938 }
939 double transPotential = exp(expVal);
940 double update = transPotential * beta[toNode];
941 segBegin.beta[fromNode] += update;
942
943 // Update expectations
944 if(edgeEvals != null) {
945 FeatureEvaluation potEvals = edgeEvals[edgeIx];
946 indices = potEvals.index;
947 vals = potEvals.value;
948 ix = 0;
949 index = indices[i];
950 while(index != -1) {
951 if(expectLengthWriter != null)
952 FileUtil.safeWrite(expectLengthWriter, String.format("Seq %d Pos %d-%d Expect #%d: %e = %e + Prob: %e * EdgeVal: %e\n", seqNum, prevPos, pos, index, expects[index]+prob*vals[i], expects[index], prob, vals[i]));
953 expects[index] += prob * vals[ix];
954 index = indices[++ix];
955 }
956 }
957
958 // Updates the transition probabilities
959 //log.info(String.format("EdgeMarg[%d][%d] = %f = %f + %f", prevPos, toNode, buffer.transitionProb[trans] + prob, buffer.transitionProb[trans], prob));
960 segBegin.transitionProb[trans] += prob;
961
962 if(betaLengthWriter != null) {
963 FileUtil.safeWrite(betaLengthWriter, String.format(String.format("Beta[%d][%d] = %s = %s + %s beta[%d][%d] * %s exp(Edge: %f Node: %f Stable: %f Trans: %f)\n",
964 prevPos, fromNode, printNorm(segBegin.beta[fromNode], segBegin.betaNorm), printNorm(segBegin.beta[fromNode]-update, segBegin.betaNorm), printNorm(beta[toNode], betaNorm), lengthPos, toNode, printNorm(transPotential, expNorm),
965 edgeVal, nodePotential - stableValue, stableValue, segBegin.mi[trans])));
966 }
967 }
968 }
969
970 ++lbArrayIndex;
971 lengthEval = lookbacksForState[lbArrayIndex];
972 lookback = lengthEval.lookback;
973 }
974 }
975 }
976
977 int lengthBetaHandling(int seqNum, int prevPos, int pos, double expVal, int fromNode, int toNode, double prevAlpha, int prevAlphaNorm, FeatureEvaluation nodeEvals) {
978 int norm = ((int) expVal) / NORM_FACTOR;
979 expVal -= norm * NORM_FACTOR;
980
981 // In addition to updating the beta array, we need to calculate a probability for this segment so we can
982 // correctly calculate feature expectations
983 //log.info(String.format("PrevAlpha %f Beta: %f Exp: %f zInv: %f", prevAlpha, beta[toNode], expVal, zInv));
984 double afterExp = exp(expVal + NORM_FACTOR * (prevAlphaNorm + norm + betaNorm - zNorm));
985 // This if statement is here in case alpha or beta is 0, but the normalization is large. This could cause the exp value to go to infinity and result in a NaN probability instead of 0.
986 if(prevAlpha == 0 || beta[toNode] == 0)
987 prob = 0.0;
988 else
989 prob = prevAlpha * beta[toNode] * afterExp * zInv;
990 if(Double.isNaN(prob)) {
991 log.info(String.format("NaN = Alpha: %s * Beta: %s * Seg: %s / Z: %s",printNorm(prevAlpha, prevAlphaNorm), printNorm(beta[toNode], betaNorm), printNorm(exp(expVal), norm), printNorm(1/zInv, zNorm)));
992 Assert.a(false, String.format("Seq: %d Pos: %d-%d: Bad prob (NaN) = Alpha: %e * Beta[%d] %e * %e exp(%f Norm a:%d n:%d b:%d z:%d) * %e",
993 seqNum, prevPos, pos, prevAlpha, toNode, beta[toNode], afterExp, expVal, prevAlphaNorm, norm, betaNorm, zNorm, zInv));
994 }
995
996 // Now update expectations for all node features for this edge
997 short[] indices = nodeEvals.index;
998 float[] vals = nodeEvals.value;
999 int i = 0;
1000 short index = indices[i];
1001 while(index != -1) {
1002 if(expectLengthWriter != null)
1003 FileUtil.safeWrite(expectLengthWriter, String.format("Seq %d Pos %d-%d Expect #%d: %e = %e + Prob: %e * NodeVal: %e\n", seqNum, prevPos, pos, index, expects[index]+prob*vals[i], expects[index], prob, vals[i]));
1004 if(prob != 0.0)
1005 expects[index] += prob * vals[i];
1006 index = indices[++i];
1007 }
1008
1009 if(nodeMarginalWriter != null) {
1010 FileUtil.safeWrite(nodeMarginalWriter, String.format(
1011 "NodeMarg[%d][%d] = %f = %f + Alpha[%d][%d]: %s * Beta[%d][%d]: %s * seg: %s / Z: %s\n",
1012 lengthPos, toNode, nodeProb[toNode]+prob, nodeProb[toNode], prevPos, fromNode, printNorm(prevAlpha, prevAlphaNorm),
1013 pos, toNode, printNorm(beta[toNode], betaNorm),
1014 printNorm(expVal, norm), printNorm(1/zInv, zNorm)));
1015 }
1016 nodeProb[toNode] += prob;
1017
1018 return norm;
1019 }
1020
1021 void updateExpectations(int seqNum, int pos, double[] transitionProb) {
1022 cacheProcessor.evaluatePosition(seqNum, pos);
1023
1024 // First compute the expectations for the length-dependent states.
1025 int invalidIndex = (seqOffset+pos)* modelInfo.nPotentials;
1026
1027 boolean lengthNode = false;
1028 for (short potential : modelInfo.orderedPotentials) {
1029 boolean invalid = invalidTransitions[invalidIndex + potential];
1030
1031 double prob = Double.NaN;
1032 if(potential < modelInfo.nStates) {
1033 lengthNode = modelInfo.maxStateLengths[potential] > 1;
1034 prob = nodeProb[potential];
1035 }
1036 else {
1037 if(pos == 0)
1038 continue;
1039 prob = (lengthNode ? transitionProb : edgeProb)[potential-modelInfo.nStates];
1040 //prob = transitionProb[potential-nStates];
1041 }
1042
1043 if (!invalid) {
1044 // Iterate through features for this potential.
1045 FeatureEvaluation potEvals = evals[potential];
1046 short[] indices = potEvals.index;
1047 float[] vals = potEvals.value;
1048 int i = 0;
1049 short index = indices[i];
1050 while(index != -1) {
1051 if(expectWriter != null)
1052 FileUtil.safeWrite(expectWriter, String.format("Seq %d Pos %d Expect #%d: %e = %e + Prob: %e * Val: %e\n", seqNum, pos, index, expects[index]+prob*vals[i], expects[index], prob, vals[i]));
1053 expects[index] += prob*vals[i];
1054 index = indices[++i];
1055 }
1056 }
1057 }
1058 }
1059 }
1060
1061 void logBuf() {
1062 int l = lookbackBuffer.length;
1063 String s = "";
1064 for(int i=0; i < l; ++i) {
1065 s += lookbackBuffer.get(i).pos + " ";
1066 }
1067 log.info(s);
1068 }
1069
1070 void logBufBeta() {
1071 int l = lookbackBuffer.length;
1072 String s = "";
1073 for(int i=0; i < l; ++i) {
1074 s += ColtUtil.format(lookbackBuffer.get(i).beta) + " ";
1075 }
1076 log.info(s);
1077 }
1078
1079 /**
1080 * Computes an unexponentiated mi matrix and updates stable states. Used to create caches for lookback searches.
1081 *
1082 */
1083 void cacheMi(int seqNum, double[] mi, double[] prevStable, double [] newStable, int miPos) {
1084 if (miPos < 0)
1085 return;
1086 //calcMi(mi, overallPosition, cacheStart, cacheStop, false);
1087 calcMi(mi, seqNum, miPos, false);
1088 // Go through the mi matrix and for all states with length dependence compute stable values for self transitions
1089 for (int i = 0; i < modelInfo.nStates; ++i) {
1090 if (modelInfo.maxStateLengths[i] > 1) {
1091 // These are all log values so we add them
1092 newStable[i] = prevStable[i];
1093 double trans = mi[modelInfo.selfTransitions[i]];
1094 if(!Double.isInfinite(trans)) {
1095 //log.debug("Pos: "+miPos+" State: "+i+" Trans: "+trans+" Total: "+(newStable[i]+trans));
1096 newStable[i] += trans;
1097 }
1098 }
1099 }
1100 }
1101
1102 /**
1103 * This is one of the most time critical parts of the entire solver. The goal is to update the transition matrix.
1104 * This function makes a lot of assumptions in order to maximize performance.
1105 *
1106 * To maximize performance, we want to make one pass through the Mi matrix, setting each entry to its correct value.
1107 * The value for each entry is the exponent of the sum of the weighted feature values of the edge for that entry and
1108 * its corresponding node. The entry s0,s1 consists of the s0,s1 edge and the s1 node.
1109 */
1110 void calcMi(double[] mi, int seq, int pos, boolean doExp) {
1111 cacheProcessor.evaluatePosition(seq, pos);
1112 double nodeVal = Double.NaN;
1113 int overallPosition = modelInfo.seqOffsets[seq]+pos;
1114 int invalidIndex = overallPosition*modelInfo.nPotentials;
1115 for(short potential : modelInfo.orderedPotentials) {
1116 boolean invalid = invalidTransitions[invalidIndex + potential];
1117 double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
1118
1119 // Add up all features for this potential.
1120 FeatureEvaluation potEvals = evals[potential];
1121 short[] indices = potEvals.index;
1122 float[] vals = potEvals.value;
1123 int i = 0;
1124 short index = indices[i];
1125 while(index >= 0) {
1126 // An invalid potential is indicated by a feature value of Short.MAX_VALUE
1127 features += vals[i]*lambda[index];
1128 index = indices[++i];
1129 }
1130 if(index == Short.MIN_VALUE) {
1131 features = Double.NEGATIVE_INFINITY;
1132 }
1133
1134 if(potential < modelInfo.nStates) {
1135 nodeVal = features;
1136 }
1137 else {
1138 //log.debug(String.format("Mi[%d, %d] = %f, adding in %f to get %f", feat.yprev(), feat.y(), val, feat.value(), val*exp(feat.value()*param[feat.index()])));
1139 int transition = potential - modelInfo.nStates;
1140 double val = features + nodeVal;
1141 if(doExp)
1142 val = exp(val);
1143 mi[transition] = val;
1144 }
1145 }
1146 }
1147
1148 /**
1149 * Given a vector with an existing normalization factor, convert it to a new normalization factor by scaling the
1150 * entries.
1151 */
1152 private static final void renormalize(final double[] vec, final int currentNorm, final int newNorm) {
1153 // Instead of dividing by the different (new-current), we reverse the subtraction to negate the exponent and
1154 // then multiply.
1155 double factor = exp(NORM_FACTOR * (currentNorm - newNorm));
1156 //log.info(factor);
1157 //log.info(ColtUtil.format(vec));
1158 int len = vec.length;
1159 for (int i = 0; i < len; ++i) {
1160 if(vec[i] != 0.0)
1161 vec[i] *= factor;
1162 }
1163 //log.info(ColtUtil.format(vec));
1164 }
1165
1166 /** Given a vector, computes a normalization factor for the entries and scales them according to that factor. */
1167 private static final int normalize(final double[] vec) {
1168 double sum = 0.0;
1169 for(double val : vec) {
1170 sum += val;
1171 }
1172 if(sum == 0.0 || (sum > NORM_MIN && sum < NORM_MAX)) {
1173 // No normalization required, our vector is in range.
1174 return 0;
1175 }
1176 if(debug)
1177 Assert.a(!Double.isNaN(sum));
1178
1179 //log.info("performing normalization");
1180 double val = log(sum);
1181 int norm = (int) val / NORM_FACTOR;
1182 val = exp(NORM_FACTOR * norm);
1183 int len = vec.length;
1184 for (int i = 0; i < len; ++i) {
1185 vec[i] /= val;
1186 }
1187 return norm;
1188 }
1189
1190 private static final double exp(final double val) {
1191 return Math.exp(val);
1192 }
1193
1194 private static final double log(final double val) {
1195 return Math.log(val);
1196 }
1197
1198 /**
1199 * This object holds information about previous positions during the computation of betas and expectations. This
1200 * allows us to quickly access data about previous positions. These objects are kept in a recycling buffer that
1201 * keeps one buffer for each possible lookback.
1202 *
1203 * One tricky aspect of this is that the details change slightly between the forward and backwards pass. On the forward
1204 * pass, the lookback contains the information in the normal way. In the backwards pass, stable states and transitions are
1205 * shifted back one base compared to the betas.
1206 */
1207 final class LookbackBuffer {
1208 int pos;
1209
1210 // The mi matrix for transitioning from pos-lookback-1 to pos-lookback
1211 double[] mi = new double[modelInfo.nPotentials];
1212
1213 // The weighted sum of feature values for staying in this position from the beginning to pos-lookback
1214 double[] stableState = new double[modelInfo.nStates];
1215
1216 // Initial values of the beta vector for somelength dependent states.
1217 double[] beta = new double[modelInfo.nStates];
1218
1219 // Norm of prevBeta.
1220 int betaNorm;
1221
1222 // Stores the probability of all segments begining at this position using this transition.
1223 double[] transitionProb = new double[modelInfo.nTransitions];
1224
1225 /** mi and stableStates are cleared as new values are entered. This fixes the others */
1226 void clear()
1227 {
1228 pos = -1;
1229 Arrays.fill(beta, 0.0);
1230 betaNorm = Integer.MIN_VALUE;
1231 Arrays.fill(transitionProb, 0.0);
1232 }
1233 }
1234
1235 public static final String printNorm(final double value, final int norm) {
1236 if( value == 0.0)
1237 return "0 ("+norm+")";
1238 if( Double.isNaN(value))
1239 return "NaN ("+norm+")";
1240 int exponent = (int) log(value);
1241
1242 double eValue = value/exp(exponent);
1243 if(Double.isNaN(eValue)) {
1244 return String.format("NaN(%e n:%d)", value, norm);
1245 }
1246 //return String.format("%e(%d) %fe%d", value, norm, eValue, exponent+norm*NORM_FACTOR);
1247 return String.format("%fe%d", eValue, exponent+norm*NORM_FACTOR);
1248 }
1249
1250 /** gets the cache processor used to access feature evaluations
1251 * @return the configured cache processor
1252 */
1253 public CacheProcessor getCacheProcessor() {
1254 return cacheProcessor;
1255 }
1256
1257 /** sets the cache processor used to access feature evaluations
1258 * @param cacheProcessor the cache processor to use
1259 */
1260
1261 public void setCacheProcessor(CacheProcessor cacheProcessor) {
1262 this.cacheProcessor = cacheProcessor;
1263 }
1264
1265 public String getAlphaLengthFile() {
1266 return alphaLengthFile;
1267 }
1268
1269 public void setAlphaLengthFile(String alphaLengthFile) {
1270 this.alphaLengthFile = alphaLengthFile;
1271 }
1272
1273 public String getAlphaFile() {
1274 return alphaFile;
1275 }
1276
1277 public void setAlphaFile(String alphaFile) {
1278 this.alphaFile = alphaFile;
1279 }
1280
1281 public String getExpectFile() {
1282 return expectFile;
1283 }
1284
1285 public void setExpectFile(String expectFile) {
1286 this.expectFile = expectFile;
1287 }
1288
1289 public String getExpectLengthFile() {
1290 return expectLengthFile;
1291 }
1292
1293 public void setExpectLengthFile(String expectLengthFile) {
1294 this.expectLengthFile = expectLengthFile;
1295 }
1296
1297 public String getNodeMarginalFile() {
1298 return nodeMarginalFile;
1299 }
1300
1301 public void setNodeMarginalFile(String nodeMarginalFile) {
1302 this.nodeMarginalFile = nodeMarginalFile;
1303 }
1304
1305 public String getBetaLengthFile() {
1306 return betaLengthFile;
1307 }
1308
1309 public void setBetaLengthFile(String betaLengthFile) {
1310 this.betaLengthFile = betaLengthFile;
1311 }
1312
1313 public double[] getFeatureSums() {
1314 return this.featureSums.clone();
1315 }
1316 }