001 package calhoun.analysis.crf.solver.semimarkov;
002
003 import java.util.Arrays;
004 import java.util.List;
005
006 import calhoun.analysis.crf.LocalPathSimilarityScore;
007 import calhoun.analysis.crf.ModelManager;
008 import calhoun.analysis.crf.io.TrainingSequence;
009 import calhoun.analysis.crf.scoring.SimScoreMaxStateAgreement;
010 import calhoun.analysis.crf.solver.CacheProcessor;
011 import calhoun.analysis.crf.solver.LookbackBuffer;
012 import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
013 import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
014 import calhoun.util.Assert;
015 import calhoun.util.ColtUtil;
016 import calhoun.util.FileUtil;
017
018 /** computes an objective function which is the expected value of a local path similarity score on a
019 * semi-Markov model. Requires a {@link CacheProcessor} and a {@link LocalPathSimilarityScore} to be configured.<p>
020 * <h2>Debugging output</h2>
021 * To get a better understanding of what the objective function is doing, several different properties can be set that
022 * cause the objective function to write out trace files showing its calculations during training. Usually when turning
023 * these options on, you should set <code>maxIters = 1</code> and <code>requireConvergence = false</code> in your optimizer
024 * to do only a single training iteration, possibly setting the starts to some predetermined value. Each of these
025 * properties can be configured with a filename and each time {@link #apply} is called, the file will be overwritten with
026 * data from the current call. The logging options are:
027 *
028 * <ul>
029 * <li> <b><code>alphaFile</code></b> - computation of alpha values for Markov states, includes all nodes and edges.
030 * <li> <b><code>alphaLengthFile</code></b> - computation of alpha values for semi-Markov states , includes all segments
031 * <li> <b><code>betaLengthFile</code></b> - computation of beta values for semi-Markov states , includes all segments
032 * <li> <b><code>expectFile</code></b> - computation of expected values for each Markov feature
033 * <li> <b><code>expectLengthFile</code></b> - computation of expected values for each semi-Markov feature
034 * <li> <b><code>nodeMarginalFile</code></b> - computation of marginal probability of each state at each position
035 * </ul>
036
037 * <h4>Implementation Notes</h4>
038 * The general normalization scheme works as follows. When updating alpha values in the forward pass we compute segments
039 * of length 1 first and then work backwards.
040 * <p>
041 * Instead of always normalizing to 1 we discretize the normalization. We choose an arbitrary normalization factor w,
042 * such as 50. The normalization factor at any position is then an integer v, and all entries at that position are
043 * alpha[y]*e^(v*w).
044 * <p>
045 * The normalization can be computed at any position from 1) Elements of the alpha array are summed s 2) v = log(s)/w.
046 * By integer division v will always be an appropriate normalizer. It may be positive or negative. 3) All elements of
047 * the array are divided by e^(v*w)
048 *
049 */
050 public class CleanLocalScoreSemiMarkovGradient extends CleanMaximumLikelihoodSemiMarkovGradient {
051 LocalPathSimilarityScore score = new SimScoreMaxStateAgreement();
052
053 // Score cache
054 double[][] localScoreStableCache;
055 double[][] localScoreTransitionCache;
056
057 double[][] betas;
058 int[] betaNorms;
059 double[][] allEdgeProb;
060 double[][] allNodeProb;
061
062 double[][] scoreAlpha;
063 double[][] scoreBeta;
064 double[][] semiMarkovScoreAlpha;
065 double[][] semiMarkovScoreBeta;
066
067 boolean semiMarkov;
068
069 @Override
070 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
071 super.setTrainingData(fm, data);
072
073 // Initialize arrays to hold betas and marginals
074 betas = new double[modelInfo.nStates][modelInfo.longestSeq];
075 betaNorms = new int[modelInfo.longestSeq];
076 allEdgeProb = new double[modelInfo.nTransitions][modelInfo.longestSeq]; //[pos][yprev][y], not defined for pos=0
077 allNodeProb = new double[modelInfo.nStates][modelInfo.longestSeq];
078 scoreAlpha = new double[modelInfo.nStates][modelInfo.longestSeq];
079 scoreBeta = new double[modelInfo.nStates][modelInfo.longestSeq];
080
081 semiMarkov = modelInfo.maxLookback > 1;
082 if(semiMarkov) {
083 semiMarkovScoreAlpha = new double[modelInfo.nStates][modelInfo.longestSeq];
084 semiMarkovScoreBeta = new double[modelInfo.nStates][modelInfo.longestSeq];
085 }
086
087 betaProcessor.setGlobalArrays(betas, betaNorms, allNodeProb, allEdgeProb);
088 fillScoreCache(cacheProcessor.getData());
089 }
090
091 @Override
092 public double apply(double[] param, double[] grad) {
093 log.debug(String.format("Beginning It: %d Weights: %s", iter, ColtUtil.format(param)));
094 lambda = param;
095 double[] scoreFeatureProductExpectation = new double[grad.length];
096 Arrays.fill(grad, 0);
097 double result = 0.0;
098
099 try {
100 // Iterate through sequences
101 logs.open();
102 for (int i = 0; i < modelInfo.nSeqs; ++i) {
103 Arrays.fill(expects, 0);
104
105 int len = modelInfo.seqOffsets[i + 1] - modelInfo.seqOffsets[i];
106 alphaAndBetaPass(i, len);
107
108 writeMarginals(i, len);
109
110 scoreAlphaBeta(i, len);
111
112 Arrays.fill(scoreFeatureProductExpectation, 0.0);
113 double thisResult = scoreFeatureExpections(i, len, scoreFeatureProductExpectation);
114 // Combine the various terms to update the gradient.
115 for(int j = 0; j<modelInfo.nFeatures; ++j) {
116 grad[j] += scoreFeatureProductExpectation[j] - thisResult * expects[j];
117 Assert.a(!Double.isNaN(grad[j]));
118 }
119 if(debug) {
120 log.debug(String.format("Iter: %d Seq: %d Expected Score: %g Grad: %s Expected Features: %s Expected Product: %s", iter, i,
121 thisResult, ColtUtil.format(grad), ColtUtil.format(expects), ColtUtil.format(scoreFeatureProductExpectation)));
122 }
123 result += thisResult;
124 }
125
126 // Normalize by the number of positions
127 for(int j = 0; j<modelInfo.nFeatures; ++j) {
128 grad[j] = grad[j]/modelInfo.totalPositions;
129 }
130 result = result/modelInfo.totalPositions;
131 Assert.a(!Double.isNaN(result));
132
133 if(debug) {
134 log.debug(String.format("Iter: %d Val: %g Grad: %s Weights: %s", iter, result, ColtUtil.format(grad), ColtUtil.format(lambda)));
135 }
136 else if(log.isInfoEnabled()) {
137 log.debug(String.format("Iter: %d Val: %g Grad: %s", iter, result, ColtUtil.format(grad)));
138 }
139
140 iter += 1;
141 }
142 finally {
143 logs.close();
144 }
145 return result;
146 }
147
148 private final double scoreFeatureExpections(int seqNum, int len, double[] scoreFeatureProductExpectation) {
149 // Now we need to use those alpha and beta vectors to compute the expectations for product of score and features
150 // In the process we also need to compute the result
151 double result = 0.0;
152 int seqOffset = modelInfo.seqOffsets[seqNum];
153
154 scorePosZero(seqNum, scoreFeatureProductExpectation);
155
156 // Put an empty entry in the lookback so the first base has 0's initialized.
157 double[] stableState = nextBuffer.stableState;
158 Arrays.fill(stableState, 0.0);
159 nextBuffer = lookbackBuffer.addFirst(nextBuffer);
160
161 int prevPos = 0;
162 for(int pos = 1; pos < len; ++pos) {
163 int overallPosition = seqOffset + pos;
164 cacheProcessor.evaluatePosition(seqNum, pos);
165
166 cacheMi(seqNum, nextBuffer.mi, stableState, nextBuffer.stableState, pos);
167 stableState = nextBuffer.stableState;
168 nextBuffer = lookbackBuffer.addFirst(nextBuffer);
169
170 // Handle the markov features & update the expected value of the similarity score
171 FeatureEvaluation nodeEvals = null;
172 for(short potential : modelInfo.orderedPotentials) {
173 if(potential < modelInfo.nStates) {
174 nodeEvals = evals[potential];
175 }
176 else {
177 // This is an edge potential
178 int trans = potential - modelInfo.nStates;
179 int yprev = modelInfo.transitionFrom[trans];
180 int y = modelInfo.transitionTo[trans];
181 double ep = allEdgeProb[trans][prevPos];
182 double npPrev = allNodeProb[yprev][prevPos];
183 double np = allNodeProb[y][pos];
184 if(np > 0 && npPrev > 0 && ep >= 0) {
185 double scoreExpect = ep * localScoreTransitionCache[trans][overallPosition];
186 result += scoreExpect;
187
188 // Computing marginals from node and edge probabilities is not really safe. Clamp to a valid probability.
189 double conditionalIn = Math.min(1.0, ep / npPrev);
190 double conditionalOut = Math.min(1.0, ep / np);
191
192 double inner = scoreExpect;
193 inner += conditionalIn*scoreAlpha[yprev][prevPos];
194 inner += conditionalOut*scoreBeta[y][pos];
195
196 if(inner == 0.0)
197 continue;
198
199 // Edges
200 FeatureEvaluation potEvals = evals[potential];
201 short[] indices = potEvals.index;
202 float[] vals = potEvals.value;
203
204 int fCount = 0;
205 short index = indices[fCount];
206 while(index != -1) {
207 // FeatureValue * Score * Marginal Prob
208 scoreFeatureProductExpectation[index] += inner*vals[fCount];
209 if(logs.expectedProductWriter != null) {
210 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d Edge: %d-%d\tFeat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n",
211 seqNum, pos, yprev, y, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount],
212 localScoreTransitionCache[trans][overallPosition], ep, scoreAlpha[yprev][prevPos], conditionalIn,
213 scoreBeta[y][pos], conditionalOut));
214 }
215 index = indices[++fCount];
216 }
217
218 // Nodes
219 indices = nodeEvals.index;
220 vals = nodeEvals.value;
221
222 fCount = 0;
223 index = indices[fCount];
224 while(index != -1) {
225 // FeatureValue * Score * Marginal Prob
226 scoreFeatureProductExpectation[index] += inner*vals[fCount];
227 if(logs.expectedProductWriter != null) {
228 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d State: %d-%d\tNode Feat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n",
229 seqNum, pos, yprev, y, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount],
230 localScoreTransitionCache[trans][overallPosition], ep, scoreAlpha[yprev][prevPos], conditionalIn,
231 scoreBeta[y][pos], conditionalOut));
232 }
233 index = indices[++fCount];
234 }
235 }
236 }
237 }
238
239 // Handle the semi-markov features - We need to calculate the probability of each segment
240 if(semiMarkov) {
241 cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
242
243 int nSemiMarkovStates = modelInfo.statesWithLookback.length;
244 for(int i=0; i<nSemiMarkovStates; ++i) {
245 LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
246 CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
247 byte toNode = statePotentials.state;
248
249 int lbIndex=0;
250 LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
251 int lookback = lengthEval.lookback;
252 while(lookback != -1) {
253 int beginPos = pos - lookback - 1;
254 Assert.a(lengthEval.edgeEvals == null);
255
256 double stableScore = localScoreStableCache[toNode][overallPosition] - localScoreStableCache[toNode][overallPosition-lookback];
257 double beta = betas[toNode][pos];
258 int betaNorm = betaNorms[pos];
259
260 // For speed I hand inline RecyclingBuffer.get
261 LookbackBuffer segBegin = lookbackBuffer.array[(lookbackBuffer.currentStart+lookback)%lookbackBuffer.length];
262 double stableValue = stableState[toNode] - segBegin.stableState[toNode];
263
264 // Add in the length based features
265 // Handle evaluation of the node potentials
266 double nodeValue = stableValue;
267 nodeEvals = lengthEval.nodeEval;
268 short[] indices = nodeEvals.index;
269 float[] vals = nodeEvals.value;
270 int ix = 0;
271 short index = indices[ix];
272 while(index >= 0) {
273 nodeValue += vals[ix] * lambda[index];
274 index = indices[++ix];
275 }
276
277 // Look at all the transitions, calculate an inner value for each and sum. We multiply the whole some
278 // by our observed node length features
279 double inner = 0.0;
280 if(beginPos == -1) {
281 double segProb = beta*zInv*exp(starterAlpha[toNode] + nodeValue + NORM_FACTOR*(betaNorm - zNorm));
282 if(betaNorm == Integer.MIN_VALUE)
283 segProb = 0.0;
284
285 inner += segProb * stableScore;
286 double segEndMarg = allNodeProb[toNode][pos]-allEdgeProb[modelInfo.selfTransitions[toNode]][pos];
287 double conditionalOut = (segEndMarg > 0) ? minMax(segProb/segEndMarg) : 0.0;
288 inner += conditionalOut*semiMarkovScoreBeta[toNode][pos];
289 if(logs.expectedProductWriter != null) {
290 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: 0-%d State: %d\t Inner %g = cOut: %g * sb: %g + stabScore: %g * p: %g "+
291 "(b: %g * zInv %g * exp(nodeL:%g + stable: %g + FACTOR*(normB: %d + normZ: %d)))\n",
292 seqNum, pos, toNode, inner, conditionalOut, semiMarkovScoreBeta[toNode][pos], stableScore, segProb,
293 beta, zInv, nodeValue-stableValue, stableValue+starterAlpha[toNode], betaNorm, zNorm));
294 }
295 }
296 else {
297 int nEdges = statePotentials.potentials.length;
298 for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
299 int potential = statePotentials.potentials[edgeIx];
300 int trans = potential - modelInfo.nStates;
301 int fromNode = modelInfo.transitionFrom[trans];
302
303 Assert.a(lengthEval.edgeEvals == null, "Explicit length edge features not supported.");
304
305 // Skip semi-Markov self transitions
306 if(fromNode == toNode)
307 continue;
308
309 int invalidIndex = (seqOffset+beginPos+1)*modelInfo.nPotentials;
310 if(invalidTransitions[invalidIndex + potential]) {
311 continue;
312 }
313
314 double prevAlpha = alphas[beginPos][fromNode];
315 int prevAlphaNorm = alphaNorms[beginPos];
316
317 double transitionValue = segBegin.mi[trans];
318 double segProb = prevAlpha * beta * zInv * exp(nodeValue + transitionValue + NORM_FACTOR*(prevAlphaNorm + betaNorm - zNorm));
319 if(prevAlphaNorm == Integer.MIN_VALUE || betaNorm == Integer.MIN_VALUE)
320 segProb = 0.0;
321 if(Double.isNaN(segProb))
322 Assert.a(false, "Bad Segment Prob. Seq ", seqNum, " Pos ",prevPos, "-", pos);
323
324 double segmentScore = localScoreTransitionCache[trans][seqOffset + beginPos+1] + stableScore;
325 inner += segProb * segmentScore;
326
327 double prevSegMarg = allNodeProb[fromNode][beginPos] - ((modelInfo.maxStateLengths[fromNode]>1) ?
328 allEdgeProb[modelInfo.selfTransitions[fromNode]][beginPos] : 0);
329 double conditionalIn = (prevSegMarg > 0) ? minMax(segProb/prevSegMarg) : 0.0;
330 inner += conditionalIn*semiMarkovScoreAlpha[fromNode][beginPos];
331
332 double segEndMarg = allNodeProb[toNode][pos]-allEdgeProb[modelInfo.selfTransitions[toNode]][pos];
333 double conditionalOut = (segEndMarg > 0) ? minMax(segProb/segEndMarg) : 0.0;
334 inner += conditionalOut*semiMarkovScoreBeta[toNode][pos];
335 if(logs.expectedProductWriter != null) {
336 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d-%d State: %d-%d\t Inner %g = cIn: %g * sa: %g + cOut: %g * sb: %g + score: %g (trans: %g + stab: %s) * p: %g "+
337 "(a: %g * b: %g * zInv %g * exp(nodeL:%g + stable: %g + trans: %g + FACTOR*(normA: %d + normB: %d + normZ: %d)))\n",
338 seqNum, beginPos+1, pos, fromNode, toNode, inner, conditionalIn, semiMarkovScoreAlpha[fromNode][beginPos], conditionalOut, semiMarkovScoreBeta[toNode][pos], segmentScore, localScoreTransitionCache[trans][seqOffset + beginPos+1], stableScore, segProb,
339 prevAlpha, beta, zInv, nodeValue-stableValue, stableValue, transitionValue, prevAlphaNorm, betaNorm, zNorm));
340 }
341 }
342 }
343
344 // Once the inner value has been computed, multiply it by all of the observed feature values.
345 FeatureEvaluation lengthNodeEvals = lengthEval.nodeEval;
346 indices = lengthNodeEvals.index;
347 vals = lengthNodeEvals.value;
348
349 ix = 0;
350 index = indices[ix];
351 while(index >= 0) {
352 scoreFeatureProductExpectation[index] += inner*vals[ix];
353 if(logs.expectedProductWriter != null) {
354 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d-%d State: %d\tFeat: %d = %g = %g + Val: %g * Inner: %g\n",
355 seqNum, beginPos+1, pos, toNode, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[ix], vals[ix], inner));
356 }
357 index = indices[++ix];
358 }
359 Assert.a(lengthEval.edgeEvals == null, "Explicit length edges are not supported.");
360
361 ++lbIndex;
362 lengthEval = lookbacksForState[lbIndex];
363 lookback = lengthEval.lookback;
364 }
365 }
366 }
367 prevPos = pos;
368 }
369 return result;
370 }
371
372 void scorePosZero(int seqNum, double[] scoreFeatureProductExpectation) {
373 cacheProcessor.evaluatePosition(seqNum, 0);
374 for(int state = 0; state < modelInfo.nStates; ++state) {
375 // This is a node potential
376 double inner = scoreBeta[state][0];
377 FeatureEvaluation potEvals = evals[state];
378 short[] indices = potEvals.index;
379 float[] vals = potEvals.value;
380
381 int fCount = 0;
382 short index = indices[fCount];
383 while(index != -1) {
384 // FeatureValue * Score * Marginal Prob
385 scoreFeatureProductExpectation[index] += inner*vals[fCount];
386 if(logs.expectedProductWriter != null) {
387 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: 0 State: %d\tFeat: %d = %g = %g + Val: %g * Beta[%d][%d]: %g:\n",
388 seqNum, state, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount], 0, state, inner));
389 }
390 index = indices[++fCount];
391 }
392 }
393
394 // Handle the semi-markov features - We avoid having to calculate segment probabilities
395 if(semiMarkov) {
396 cacheProcessor.evaluateSegmentsEndingAt(seqNum, 0);
397
398 int nSemiMarkovStates = modelInfo.statesWithLookback.length;
399 for(int i=0; i<nSemiMarkovStates; ++i) {
400 LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
401 LengthFeatureEvaluation lengthEval = lookbacksForState[0];
402 if(lengthEval.lookback != -1) {
403 Assert.a(lengthEval.lookback == 0);
404 CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
405 byte state = statePotentials.state;
406
407 double inner = semiMarkovScoreBeta[state][0];
408 short[] indices = lengthEval.nodeEval.index;
409 float[] vals = lengthEval.nodeEval.value;
410 int fCount = 0;
411 short index = indices[fCount];
412 while(index != -1) {
413 // FeatureValue * Score * Marginal Prob
414 scoreFeatureProductExpectation[index] += inner*vals[fCount];
415 if(logs.expectedProductWriter != null) {
416 FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: 0 State: %d\tLen. Feat: %d = %g = %g + Val: %g * Beta[%d][%d]: %g:\n",
417 seqNum, state, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount], 0, state, inner));
418 }
419 index = indices[++fCount];
420 }
421 Assert.a(lookbacksForState[1].lookback == -1);
422 }
423 }
424 }
425 }
426
427 private final void scoreAlphaBeta(int seqNum, int len) {
428 int seqOffset = modelInfo.seqOffsets[seqNum];
429
430 // Do another backward & forward pass to compute the score alpha & betas
431 // scoreBeta is being defined from 0 to len-1
432 int pos = len-1;
433 int prevPos;
434 for (int y=0; y<modelInfo.nStates; y++) {
435 Arrays.fill(scoreBeta[y], 0.0);
436 if(semiMarkov)
437 Arrays.fill(semiMarkovScoreBeta[y], 0.0);
438 }
439 for (prevPos = len-2; prevPos >= 0; --prevPos) {
440 for (int trans=0; trans<modelInfo.nTransitions; trans++) {
441 int yprev = modelInfo.transitionFrom[trans];
442 int y = modelInfo.transitionTo[trans];
443 double ep = allEdgeProb[trans][prevPos];
444 double np = allNodeProb[y][pos];
445 if(np > 0 && ep >= 0) {
446 // Computing marginals from node and edge probabilities is not really safe. Clamp to a valid probability.
447 double conditional = Math.min(1.0, ep / np);
448 double update = ep*localScoreTransitionCache[trans][seqOffset+pos] + conditional*scoreBeta[y][pos];
449 scoreBeta[yprev][prevPos] += update;
450 if(modelInfo.maxStateLengths[y]>1 && y != yprev)
451 semiMarkovScoreBeta[yprev][prevPos] += update;
452 }
453 }
454 pos = prevPos;
455 }
456
457 // scoreAlpha is being defined from 0 to len-2 (we never use the last alpha position)
458 prevPos = 0;
459 for (int y=0; y<modelInfo.nStates; y++) {
460 Arrays.fill(scoreAlpha[y], 0.0);
461 if(semiMarkov)
462 Arrays.fill(semiMarkovScoreAlpha[y], 0.0);
463 }
464 for (pos=1; pos<len-1; pos++) {
465 for (int trans=0; trans<modelInfo.nTransitions; trans++) {
466 int yprev = modelInfo.transitionFrom[trans];
467 int y = modelInfo.transitionTo[trans];
468 double ep = allEdgeProb[trans][prevPos];
469 double np = allNodeProb[yprev][prevPos];
470 if(np > 0 && ep >= 0) {
471 // Computing marginals from node and edge probabilities is not really safe. Clamp to a valid probability.
472 double conditional = Math.min(1.0, ep / np);
473
474 double update = conditional*scoreAlpha[yprev][prevPos];
475 update += ep*localScoreTransitionCache[trans][seqOffset+pos];
476 scoreAlpha[y][pos] += update;
477 if(modelInfo.maxStateLengths[y]>1) {
478 double nodeMarg = allNodeProb[y][pos];
479 if(nodeMarg > 0) {
480 double outConditional = minMax((nodeMarg-allEdgeProb[modelInfo.selfTransitions[y]][pos])/nodeMarg);
481 semiMarkovScoreAlpha[y][pos] += update*outConditional;
482 }
483 }
484
485 if(logs.scoreAlphaWriter != null) {
486 FileUtil.safeWrite(logs.scoreAlphaWriter, String.format("Seq: %d alpha[%d][%d] = %g = %g + Pr: %g * alpha[%d][%d] %g + Pr: %g * Score: %g\n",
487 seqNum, pos, y, scoreAlpha[y][pos], scoreAlpha[y][pos]-update, ep/allNodeProb[yprev][prevPos], prevPos, yprev, scoreAlpha[yprev][prevPos], ep, localScoreTransitionCache[trans][seqOffset+pos]));
488 }
489 }
490 }
491 prevPos = pos;
492 }
493 }
494
495 void fillScoreCache(List<? extends TrainingSequence<?>> data) {
496 localScoreStableCache = new double[modelInfo.nStates][modelInfo.totalPositions];
497 localScoreTransitionCache = new double[modelInfo.nTransitions][modelInfo.totalPositions];
498
499 int overallPosition = 0;
500 for(int i=0; i<data.size(); ++i) {
501 TrainingSequence seq = data.get(i);
502 ++overallPosition;
503 for(int pos= 1; pos<seq.length(); ++pos ) {
504 for(int transition = 0; transition < modelInfo.nTransitions; ++transition) {
505 int from = modelInfo.transitionFrom[transition];
506 int to = modelInfo.transitionTo[transition];
507 double localScore = score.evaluate(from, to, seq, pos);
508 localScoreTransitionCache[transition][overallPosition] = localScore;
509 if(from == to)
510 localScoreStableCache[to][overallPosition] = localScoreStableCache[to][overallPosition-1] + localScore;
511 }
512 ++overallPosition;
513 }
514 }
515 }
516
517 private static final double minMax(final double val) {
518 double ret = (val <= 0.0) ? 0.0 : ((val >= 1.0) ? 1.0 : val);
519 if(Double.isNaN(ret)) {
520 Assert.a(false, "Min max called on "+val);
521 }
522 return ret;
523 }
524
525 private final void writeMarginals(int i, int len) {
526 if(logs.marginalsWriter != null) {
527 for(int pos=0; pos < len; ++pos) {
528 FileUtil.safeWrite(logs.marginalsWriter, String.format("Seq %d Pos %d -", i, pos));
529 for(short potential : modelInfo.orderedPotentials) {
530 if(potential < modelInfo.nStates) {
531 FileUtil.safeWrite(logs.marginalsWriter, String.format(" State %d: %e", potential, allNodeProb[potential][pos]));
532 }
533 else {
534 int trans = potential - modelInfo.nStates;
535 FileUtil.safeWrite(logs.marginalsWriter, String.format(" Edge %d-%d: %e", modelInfo.transitionFrom[trans], modelInfo.transitionTo[trans], allEdgeProb[trans][pos]));
536 }
537 }
538 FileUtil.safeWrite(logs.marginalsWriter, "\n");
539 }
540 }
541 }
542
543 /** gets the local similarity score function used to score each position in every path.
544 * @return the configured score function
545 */
546 public LocalPathSimilarityScore getScore() {
547 return score;
548 }
549
550 /** sets the local similarity score function used to score each position in every path. This
551 * is usually specified in the XML configuration file.
552 * @param score the score function to use
553 */
554 public void setScore(LocalPathSimilarityScore score) {
555 this.score = score;
556 }
557
558 public String getScoreAlphaFile() {
559 return logs.scoreAlphaFile;
560 }
561
562 public void setScoreAlphaFile(String scoreAlphaFile) {
563 logs.scoreAlphaFile = scoreAlphaFile;
564 }
565
566 public String getExpectedProductFile() {
567 return logs.expectedProductFile;
568 }
569
570 public void setExpectedProductFile(String expectedProductFile) {
571 logs.expectedProductFile = expectedProductFile;
572 }
573
574 public String getMarginalsFile() {
575 return logs.marginalsFile;
576 }
577
578 public void setMarginalsFile(String marginalsFile) {
579 logs.marginalsFile = marginalsFile;
580 }
581 }