001    package calhoun.analysis.crf.solver.semimarkov;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import calhoun.analysis.crf.LocalPathSimilarityScore;
007    import calhoun.analysis.crf.ModelManager;
008    import calhoun.analysis.crf.io.TrainingSequence;
009    import calhoun.analysis.crf.scoring.SimScoreMaxStateAgreement;
010    import calhoun.analysis.crf.solver.CacheProcessor;
011    import calhoun.analysis.crf.solver.LookbackBuffer;
012    import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
013    import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
014    import calhoun.util.Assert;
015    import calhoun.util.ColtUtil;
016    import calhoun.util.FileUtil;
017    
018    /** computes an objective function which is the expected value of a local path similarity score on a 
019     * semi-Markov model.  Requires a {@link CacheProcessor} and a {@link LocalPathSimilarityScore} to be configured.<p>
020     * <h2>Debugging output</h2>
021     * To get a better understanding of what the objective function is doing, several different properties can be set that
022     * cause the objective function to write out trace files showing its calculations during training.  Usually when turning
023     * these options on, you should set <code>maxIters = 1</code> and <code>requireConvergence = false</code> in your optimizer
024     * to do only a single training iteration, possibly setting the starts to some predetermined value.  Each of these
025     * properties can be configured with a filename and each time {@link #apply} is called, the file will be overwritten with 
026     * data from the current call.  The logging options are:
027     *
028     * <ul>
029     * <li> <b><code>alphaFile</code></b> - computation of alpha values for Markov states, includes all nodes and edges.
030     * <li> <b><code>alphaLengthFile</code></b> - computation of alpha values for semi-Markov states , includes all segments
031     * <li> <b><code>betaLengthFile</code></b> - computation of beta values for semi-Markov states , includes all segments
032     * <li> <b><code>expectFile</code></b> - computation of expected values for each Markov feature 
033     * <li> <b><code>expectLengthFile</code></b> - computation of expected values for each semi-Markov feature  
034     * <li> <b><code>nodeMarginalFile</code></b> - computation of marginal probability of each state at each position 
035     * </ul>
036    
037     * <h4>Implementation Notes</h4>
038     * The general normalization scheme works as follows. When updating alpha values in the forward pass we compute segments
039     * of length 1 first and then work backwards.
040     * <p>
041     * Instead of always normalizing to 1 we discretize the normalization. We choose an arbitrary normalization factor w,
042     * such as 50. The normalization factor at any position is then an integer v, and all entries at that position are
043     * alpha[y]*e^(v*w).
044     * <p>
045     * The normalization can be computed at any position from 1) Elements of the alpha array are summed s 2) v = log(s)/w.
046     * By integer division v will always be an appropriate normalizer. It may be positive or negative. 3) All elements of
047     * the array are divided by e^(v*w)
048     * 
049     */
050    public class CleanLocalScoreSemiMarkovGradient extends CleanMaximumLikelihoodSemiMarkovGradient {
051            LocalPathSimilarityScore score = new SimScoreMaxStateAgreement();
052    
053            // Score cache
054            double[][] localScoreStableCache;
055            double[][] localScoreTransitionCache;
056    
057            double[][] betas;
058            int[] betaNorms;
059            double[][] allEdgeProb;
060            double[][] allNodeProb;
061    
062            double[][] scoreAlpha;
063            double[][] scoreBeta;
064            double[][] semiMarkovScoreAlpha;
065            double[][] semiMarkovScoreBeta;
066    
067            boolean semiMarkov;
068            
069            @Override
070            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
071                    super.setTrainingData(fm, data);
072    
073                    // Initialize arrays to hold betas and marginals
074                    betas = new double[modelInfo.nStates][modelInfo.longestSeq];
075                    betaNorms = new int[modelInfo.longestSeq];
076                    allEdgeProb = new double[modelInfo.nTransitions][modelInfo.longestSeq]; //[pos][yprev][y], not defined for pos=0
077                    allNodeProb = new double[modelInfo.nStates][modelInfo.longestSeq];
078                    scoreAlpha = new double[modelInfo.nStates][modelInfo.longestSeq];
079                    scoreBeta = new double[modelInfo.nStates][modelInfo.longestSeq];
080    
081                    semiMarkov = modelInfo.maxLookback > 1;
082                    if(semiMarkov) {
083                            semiMarkovScoreAlpha = new double[modelInfo.nStates][modelInfo.longestSeq];
084                            semiMarkovScoreBeta = new double[modelInfo.nStates][modelInfo.longestSeq];
085                    }
086    
087                    betaProcessor.setGlobalArrays(betas, betaNorms, allNodeProb, allEdgeProb);
088                    fillScoreCache(cacheProcessor.getData());
089            }
090    
091            @Override
092            public double apply(double[] param, double[] grad) {
093                    log.debug(String.format("Beginning It: %d Weights: %s", iter, ColtUtil.format(param)));
094                    lambda = param;
095                    double[] scoreFeatureProductExpectation = new double[grad.length];
096                    Arrays.fill(grad, 0);
097                    double result = 0.0;
098    
099                    try {
100                            // Iterate through sequences
101                            logs.open();
102                            for (int i = 0; i < modelInfo.nSeqs; ++i) {
103                                    Arrays.fill(expects, 0);
104    
105                                    int len = modelInfo.seqOffsets[i + 1] - modelInfo.seqOffsets[i];
106                                    alphaAndBetaPass(i, len);
107    
108                                    writeMarginals(i, len);
109    
110                                    scoreAlphaBeta(i, len);
111    
112                                    Arrays.fill(scoreFeatureProductExpectation, 0.0);
113                                    double thisResult = scoreFeatureExpections(i, len, scoreFeatureProductExpectation);
114                                    // Combine the various terms to update the gradient.
115                                    for(int j = 0; j<modelInfo.nFeatures; ++j) {
116                                            grad[j] += scoreFeatureProductExpectation[j] - thisResult * expects[j]; 
117                                            Assert.a(!Double.isNaN(grad[j]));
118                                    }
119                                    if(debug) {
120                                            log.debug(String.format("Iter: %d Seq: %d Expected Score: %g Grad: %s Expected Features: %s Expected Product: %s", iter, i, 
121                                                            thisResult, ColtUtil.format(grad), ColtUtil.format(expects), ColtUtil.format(scoreFeatureProductExpectation)));
122                                    }
123                                    result += thisResult;
124                            }
125    
126                            // Normalize by the number of positions
127                            for(int j = 0; j<modelInfo.nFeatures; ++j) {
128                                    grad[j] = grad[j]/modelInfo.totalPositions; 
129                            }
130                            result = result/modelInfo.totalPositions; 
131                            Assert.a(!Double.isNaN(result));
132                            
133                            if(debug) {
134                                    log.debug(String.format("Iter: %d Val: %g Grad: %s Weights: %s", iter, result, ColtUtil.format(grad), ColtUtil.format(lambda)));
135                            }
136                            else if(log.isInfoEnabled()) {
137                                    log.debug(String.format("Iter: %d Val: %g Grad: %s", iter, result, ColtUtil.format(grad)));
138                            }
139    
140                            iter += 1;
141                    }
142                    finally {
143                            logs.close();
144                    }
145                    return result;
146            }
147    
148            private final double scoreFeatureExpections(int seqNum, int len, double[] scoreFeatureProductExpectation) {
149                    // Now we need to use those alpha and beta vectors to compute the expectations for product of score and features
150                    // In the process we also need to compute the result
151                    double result = 0.0;
152                    int seqOffset = modelInfo.seqOffsets[seqNum];
153    
154                    scorePosZero(seqNum, scoreFeatureProductExpectation);
155    
156                    // Put an empty entry in the lookback so the first base has 0's initialized.
157                    double[] stableState = nextBuffer.stableState;
158                    Arrays.fill(stableState, 0.0);
159                    nextBuffer = lookbackBuffer.addFirst(nextBuffer);
160                    
161                    int prevPos = 0;
162                    for(int pos = 1; pos < len; ++pos) {
163                            int overallPosition = seqOffset + pos;
164                            cacheProcessor.evaluatePosition(seqNum, pos);
165    
166                            cacheMi(seqNum, nextBuffer.mi, stableState, nextBuffer.stableState, pos);
167                            stableState = nextBuffer.stableState;
168                            nextBuffer = lookbackBuffer.addFirst(nextBuffer);
169                    
170                            // Handle the markov features & update the expected value of the similarity score
171                            FeatureEvaluation nodeEvals = null;
172                            for(short potential : modelInfo.orderedPotentials) {
173                                    if(potential < modelInfo.nStates) {
174                                            nodeEvals = evals[potential];
175                                    }
176                                    else {
177                                            // This is an edge potential
178                                            int trans = potential - modelInfo.nStates;
179                                            int yprev = modelInfo.transitionFrom[trans];
180                                            int y = modelInfo.transitionTo[trans];
181                                            double ep = allEdgeProb[trans][prevPos]; 
182                                            double npPrev = allNodeProb[yprev][prevPos]; 
183                                            double np = allNodeProb[y][pos]; 
184                                            if(np > 0 && npPrev > 0 && ep >= 0) {
185                                                    double scoreExpect = ep * localScoreTransitionCache[trans][overallPosition];    
186                                                    result += scoreExpect;
187    
188                                                    // Computing marginals from node and edge probabilities is not really safe.  Clamp to a valid probability.
189                                                    double conditionalIn = Math.min(1.0, ep / npPrev);
190                                                    double conditionalOut = Math.min(1.0, ep / np);
191    
192                                                    double inner = scoreExpect;
193                                                    inner += conditionalIn*scoreAlpha[yprev][prevPos];
194                                                    inner += conditionalOut*scoreBeta[y][pos];
195    
196                                                    if(inner == 0.0)
197                                                            continue;
198                                                    
199                                                    // Edges
200                                                    FeatureEvaluation potEvals = evals[potential];
201                                                    short[] indices = potEvals.index;
202                                                    float[] vals = potEvals.value;
203    
204                                                    int fCount = 0;
205                                                    short index = indices[fCount];
206                                                    while(index != -1) {
207                                                            // FeatureValue * Score * Marginal Prob
208                                                            scoreFeatureProductExpectation[index] += inner*vals[fCount]; 
209                                                            if(logs.expectedProductWriter != null) {
210                                                                    FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d Edge: %d-%d\tFeat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n",
211                                                                                    seqNum, pos, yprev, y, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount],
212                                                                                    localScoreTransitionCache[trans][overallPosition], ep, scoreAlpha[yprev][prevPos], conditionalIn, 
213                                                                                    scoreBeta[y][pos], conditionalOut));
214                                                            }
215                                                            index = indices[++fCount];
216                                                    }
217    
218                                                    // Nodes
219                                                    indices = nodeEvals.index;
220                                                    vals = nodeEvals.value;
221    
222                                                    fCount = 0;
223                                                    index = indices[fCount];
224                                                    while(index != -1) {
225                                                            // FeatureValue * Score * Marginal Prob
226                                                            scoreFeatureProductExpectation[index] += inner*vals[fCount]; 
227                                                            if(logs.expectedProductWriter != null) {
228                                                                    FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d State: %d-%d\tNode Feat: %d = %g = %g + Val: %g * (s: %g * ep: %g + a: %g * ms: %g + b: %g * me: %g)\n",
229                                                                                    seqNum, pos, yprev, y, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount], 
230                                                                                    localScoreTransitionCache[trans][overallPosition], ep, scoreAlpha[yprev][prevPos], conditionalIn, 
231                                                                                    scoreBeta[y][pos], conditionalOut));
232                                                            }
233                                                            index = indices[++fCount];
234                                                    }
235                                            }
236                                    }                                       
237                            }
238    
239                            // Handle the semi-markov features - We need to calculate the probability of each segment
240                            if(semiMarkov) {
241                                    cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
242            
243                                    int nSemiMarkovStates = modelInfo.statesWithLookback.length;
244                                    for(int i=0; i<nSemiMarkovStates; ++i) {
245                                            LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
246                                            CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
247                                            byte toNode = statePotentials.state;
248                                            
249                                            int lbIndex=0;
250                                            LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
251                                            int lookback = lengthEval.lookback;
252                                            while(lookback != -1) {
253                                                    int beginPos = pos - lookback - 1;
254                                                    Assert.a(lengthEval.edgeEvals == null);
255            
256                                                    double stableScore = localScoreStableCache[toNode][overallPosition] - localScoreStableCache[toNode][overallPosition-lookback];
257                                                    double beta = betas[toNode][pos];
258                                                    int betaNorm = betaNorms[pos];
259            
260                                                    // For speed I hand inline RecyclingBuffer.get
261                                                    LookbackBuffer segBegin = lookbackBuffer.array[(lookbackBuffer.currentStart+lookback)%lookbackBuffer.length];
262                                                    double stableValue = stableState[toNode] - segBegin.stableState[toNode];
263            
264                                                    // Add in the length based features
265                                                    // Handle evaluation of the node potentials
266                                                    double nodeValue = stableValue;
267                                                    nodeEvals = lengthEval.nodeEval;
268                                                    short[] indices = nodeEvals.index;
269                                                    float[] vals = nodeEvals.value;
270                                                    int ix = 0;
271                                                    short index = indices[ix];
272                                                    while(index >= 0) {
273                                                            nodeValue += vals[ix] * lambda[index];
274                                                            index = indices[++ix];
275                                                    }
276                                                    
277                                                    // Look at all the transitions, calculate an inner value for each and sum.  We multiply the whole some
278                                                    // by our observed node length features
279                                                    double inner = 0.0;
280                                                    if(beginPos == -1) {
281                                                            double segProb = beta*zInv*exp(starterAlpha[toNode] + nodeValue + NORM_FACTOR*(betaNorm - zNorm));
282                                                            if(betaNorm == Integer.MIN_VALUE)
283                                                                    segProb = 0.0;
284            
285                                                            inner += segProb * stableScore;
286                                                            double segEndMarg = allNodeProb[toNode][pos]-allEdgeProb[modelInfo.selfTransitions[toNode]][pos];
287                                                            double conditionalOut = (segEndMarg > 0) ? minMax(segProb/segEndMarg) : 0.0;
288                                                            inner += conditionalOut*semiMarkovScoreBeta[toNode][pos];
289                                                            if(logs.expectedProductWriter != null) {
290                                                                    FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: 0-%d State: %d\t Inner %g = cOut: %g * sb: %g + stabScore: %g * p: %g "+
291                                                                                    "(b: %g * zInv %g * exp(nodeL:%g + stable: %g + FACTOR*(normB: %d + normZ: %d)))\n",
292                                                                                    seqNum, pos, toNode, inner, conditionalOut, semiMarkovScoreBeta[toNode][pos], stableScore, segProb,
293                                                                                    beta, zInv, nodeValue-stableValue, stableValue+starterAlpha[toNode], betaNorm, zNorm));
294                                                            }
295                                                    }
296                                                    else {
297                                                            int nEdges = statePotentials.potentials.length;
298                                                            for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
299                                                                    int potential = statePotentials.potentials[edgeIx];
300                                                                    int trans = potential - modelInfo.nStates;
301                                                                    int fromNode = modelInfo.transitionFrom[trans];
302            
303                                                                    Assert.a(lengthEval.edgeEvals == null, "Explicit length edge features not supported.");
304                                                                    
305                                                                    // Skip semi-Markov self transitions
306                                                                    if(fromNode == toNode)
307                                                                            continue;
308                                                                    
309                                                                    int invalidIndex = (seqOffset+beginPos+1)*modelInfo.nPotentials;
310                                                                    if(invalidTransitions[invalidIndex + potential]) {
311                                                                            continue;
312                                                                    }
313            
314                                                                    double prevAlpha = alphas[beginPos][fromNode];
315                                                                    int prevAlphaNorm = alphaNorms[beginPos];
316                                                                    
317                                                                    double transitionValue = segBegin.mi[trans];
318                                                                    double segProb = prevAlpha * beta * zInv * exp(nodeValue + transitionValue + NORM_FACTOR*(prevAlphaNorm + betaNorm - zNorm));
319                                                                    if(prevAlphaNorm == Integer.MIN_VALUE || betaNorm == Integer.MIN_VALUE)
320                                                                            segProb = 0.0;
321                                                                    if(Double.isNaN(segProb))
322                                                                            Assert.a(false, "Bad Segment Prob. Seq ", seqNum, " Pos ",prevPos, "-", pos);
323            
324                                                                    double segmentScore = localScoreTransitionCache[trans][seqOffset + beginPos+1] + stableScore;
325                                                                    inner += segProb * segmentScore;
326    
327                                                                    double prevSegMarg = allNodeProb[fromNode][beginPos] - ((modelInfo.maxStateLengths[fromNode]>1) ?
328                                                                                    allEdgeProb[modelInfo.selfTransitions[fromNode]][beginPos] : 0);
329                                                                    double conditionalIn = (prevSegMarg > 0) ? minMax(segProb/prevSegMarg) : 0.0;
330                                                                    inner += conditionalIn*semiMarkovScoreAlpha[fromNode][beginPos];
331    
332                                                                    double segEndMarg = allNodeProb[toNode][pos]-allEdgeProb[modelInfo.selfTransitions[toNode]][pos];
333                                                                    double conditionalOut = (segEndMarg > 0) ? minMax(segProb/segEndMarg) : 0.0;
334                                                                    inner += conditionalOut*semiMarkovScoreBeta[toNode][pos];
335                                                                    if(logs.expectedProductWriter != null) {
336                                                                            FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d-%d State: %d-%d\t Inner %g = cIn: %g * sa: %g + cOut: %g * sb: %g + score: %g (trans: %g + stab: %s) * p: %g "+
337                                                                                            "(a: %g * b: %g * zInv %g * exp(nodeL:%g + stable: %g + trans: %g + FACTOR*(normA: %d + normB: %d + normZ: %d)))\n",
338                                                                                            seqNum, beginPos+1, pos, fromNode, toNode, inner, conditionalIn, semiMarkovScoreAlpha[fromNode][beginPos], conditionalOut, semiMarkovScoreBeta[toNode][pos], segmentScore, localScoreTransitionCache[trans][seqOffset + beginPos+1], stableScore, segProb,
339                                                                                            prevAlpha, beta, zInv, nodeValue-stableValue, stableValue, transitionValue, prevAlphaNorm, betaNorm, zNorm));
340                                                                    }
341                                                            }
342                                                    }
343                                                                                                                                    
344                                                    // Once the inner value has been computed, multiply it by all of the observed feature values.
345                                                    FeatureEvaluation lengthNodeEvals = lengthEval.nodeEval;
346                                                    indices = lengthNodeEvals.index;
347                                                    vals = lengthNodeEvals.value;
348                                                    
349                                                    ix = 0;
350                                                    index = indices[ix];
351                                                    while(index >= 0) {
352                                                            scoreFeatureProductExpectation[index] += inner*vals[ix]; 
353                                                            if(logs.expectedProductWriter != null) {
354                                                                    FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: %d-%d State: %d\tFeat: %d = %g = %g + Val: %g * Inner: %g\n",
355                                                                                    seqNum, beginPos+1, pos, toNode, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[ix], vals[ix], inner));
356                                                            }
357                                                            index = indices[++ix];
358                                                    }
359                                                    Assert.a(lengthEval.edgeEvals == null, "Explicit length edges are not supported.");
360                                                    
361                                                    ++lbIndex;
362                                                    lengthEval = lookbacksForState[lbIndex];
363                                                    lookback = lengthEval.lookback;
364                                            }
365                                    }
366                            }
367                            prevPos = pos;
368                    }
369                    return result;
370            }               
371                    
372            void scorePosZero(int seqNum, double[] scoreFeatureProductExpectation) {
373                    cacheProcessor.evaluatePosition(seqNum, 0);
374                    for(int state = 0; state < modelInfo.nStates; ++state) {
375                            // This is a node potential
376                            double inner = scoreBeta[state][0];
377                            FeatureEvaluation potEvals = evals[state];
378                            short[] indices = potEvals.index;
379                            float[] vals = potEvals.value;
380    
381                            int fCount = 0;
382                            short index = indices[fCount];
383                            while(index != -1) {
384                                    // FeatureValue * Score * Marginal Prob
385                                    scoreFeatureProductExpectation[index] += inner*vals[fCount]; 
386                                    if(logs.expectedProductWriter != null) {
387                                            FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: 0 State: %d\tFeat: %d = %g = %g + Val: %g * Beta[%d][%d]: %g:\n",
388                                                            seqNum, state, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount], 0, state, inner));
389                                    }
390                                    index = indices[++fCount];
391                            }
392                    }
393    
394                    // Handle the semi-markov features - We avoid having to calculate segment probabilities
395                    if(semiMarkov) {
396                            cacheProcessor.evaluateSegmentsEndingAt(seqNum, 0);
397            
398                            int nSemiMarkovStates = modelInfo.statesWithLookback.length;
399                            for(int i=0; i<nSemiMarkovStates; ++i) {
400                                    LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
401                                    LengthFeatureEvaluation lengthEval = lookbacksForState[0];
402                                    if(lengthEval.lookback != -1) {
403                                            Assert.a(lengthEval.lookback == 0);
404                                            CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
405                                            byte state = statePotentials.state;
406            
407                                            double inner = semiMarkovScoreBeta[state][0];
408                                            short[] indices = lengthEval.nodeEval.index;
409                                            float[] vals = lengthEval.nodeEval.value;
410                                            int fCount = 0;
411                                            short index = indices[fCount];
412                                            while(index != -1) {
413                                                    // FeatureValue * Score * Marginal Prob
414                                                    scoreFeatureProductExpectation[index] += inner*vals[fCount]; 
415                                                    if(logs.expectedProductWriter != null) {
416                                                            FileUtil.safeWrite(logs.expectedProductWriter, String.format("Seq: %d Pos: 0 State: %d\tLen. Feat: %d = %g = %g + Val: %g * Beta[%d][%d]: %g:\n",
417                                                                            seqNum, state, index, scoreFeatureProductExpectation[index], scoreFeatureProductExpectation[index]-inner*vals[fCount], vals[fCount], 0, state, inner));
418                                                    }
419                                                    index = indices[++fCount];
420                                            }
421                                            Assert.a(lookbacksForState[1].lookback == -1);
422                                    }
423                            }
424                    }
425            }
426    
427            private final void scoreAlphaBeta(int seqNum, int len) {
428                    int seqOffset = modelInfo.seqOffsets[seqNum];
429                    
430                    // Do another backward & forward pass to compute the score alpha & betas
431                    // scoreBeta is being defined from 0 to len-1
432                    int pos = len-1;
433                    int prevPos;
434                    for (int y=0; y<modelInfo.nStates; y++) {
435                            Arrays.fill(scoreBeta[y], 0.0);
436                            if(semiMarkov)
437                                    Arrays.fill(semiMarkovScoreBeta[y], 0.0);
438                    }
439                    for (prevPos = len-2; prevPos >= 0; --prevPos) {
440                            for (int trans=0; trans<modelInfo.nTransitions; trans++) {
441                                    int yprev = modelInfo.transitionFrom[trans];
442                                    int y = modelInfo.transitionTo[trans];
443                                    double ep = allEdgeProb[trans][prevPos];
444                                    double np = allNodeProb[y][pos];
445                                    if(np > 0 && ep >= 0) {
446                                            // Computing marginals from node and edge probabilities is not really safe.  Clamp to a valid probability.
447                                            double conditional = Math.min(1.0, ep / np);
448                                            double update = ep*localScoreTransitionCache[trans][seqOffset+pos] + conditional*scoreBeta[y][pos];
449                                            scoreBeta[yprev][prevPos] += update;
450                                            if(modelInfo.maxStateLengths[y]>1 && y != yprev)
451                                                    semiMarkovScoreBeta[yprev][prevPos] += update;
452                                    }                                       
453                            }
454                            pos = prevPos;
455                    }
456                    
457                    // scoreAlpha is being defined from 0 to len-2 (we never use the last alpha position)
458                    prevPos = 0;
459                    for (int y=0; y<modelInfo.nStates; y++) {
460                            Arrays.fill(scoreAlpha[y], 0.0);
461                            if(semiMarkov)
462                                    Arrays.fill(semiMarkovScoreAlpha[y], 0.0);
463                    }
464                    for (pos=1; pos<len-1; pos++) {
465                            for (int trans=0; trans<modelInfo.nTransitions; trans++) {
466                                    int yprev = modelInfo.transitionFrom[trans];
467                                    int y = modelInfo.transitionTo[trans];
468                                    double ep = allEdgeProb[trans][prevPos];
469                                    double np = allNodeProb[yprev][prevPos];
470                                    if(np > 0 && ep >= 0) {
471                                            // Computing marginals from node and edge probabilities is not really safe.  Clamp to a valid probability.
472                                            double conditional = Math.min(1.0, ep / np);
473    
474                                            double update = conditional*scoreAlpha[yprev][prevPos];
475                                            update += ep*localScoreTransitionCache[trans][seqOffset+pos];
476                                            scoreAlpha[y][pos] += update;
477                                            if(modelInfo.maxStateLengths[y]>1) {
478                                                    double nodeMarg = allNodeProb[y][pos];
479                                                    if(nodeMarg > 0) {
480                                                            double outConditional = minMax((nodeMarg-allEdgeProb[modelInfo.selfTransitions[y]][pos])/nodeMarg);
481                                                            semiMarkovScoreAlpha[y][pos] += update*outConditional;
482                                                    }
483                                            }
484                                                    
485                                            if(logs.scoreAlphaWriter != null) {
486                                                    FileUtil.safeWrite(logs.scoreAlphaWriter, String.format("Seq: %d alpha[%d][%d] = %g = %g + Pr: %g * alpha[%d][%d] %g + Pr: %g * Score: %g\n",
487                                                                    seqNum, pos, y, scoreAlpha[y][pos], scoreAlpha[y][pos]-update, ep/allNodeProb[yprev][prevPos], prevPos, yprev, scoreAlpha[yprev][prevPos], ep, localScoreTransitionCache[trans][seqOffset+pos]));
488                                            }
489                                    }
490                            }       
491                            prevPos = pos;
492                    }
493            }
494            
495            void fillScoreCache(List<? extends TrainingSequence<?>> data) {
496                    localScoreStableCache = new double[modelInfo.nStates][modelInfo.totalPositions];
497                    localScoreTransitionCache = new double[modelInfo.nTransitions][modelInfo.totalPositions];
498                    
499                    int overallPosition = 0;
500                    for(int i=0; i<data.size(); ++i) {
501                            TrainingSequence seq = data.get(i);
502                            ++overallPosition;
503                            for(int pos= 1; pos<seq.length(); ++pos ) {
504                                    for(int transition = 0; transition < modelInfo.nTransitions; ++transition) {
505                                            int from = modelInfo.transitionFrom[transition];
506                                            int to = modelInfo.transitionTo[transition];
507                                            double localScore = score.evaluate(from, to, seq, pos);
508                                            localScoreTransitionCache[transition][overallPosition] = localScore;
509                                            if(from == to)
510                                                    localScoreStableCache[to][overallPosition] = localScoreStableCache[to][overallPosition-1] + localScore;
511                                    }
512                                    ++overallPosition;
513                            }
514                    }
515            }
516            
517            private static final double minMax(final double val) {
518                    double ret = (val <= 0.0) ? 0.0 : ((val >= 1.0) ? 1.0 : val);
519                    if(Double.isNaN(ret)) {
520                            Assert.a(false, "Min max called on "+val);
521                    }
522                    return ret;
523            }
524            
525            private final void writeMarginals(int i, int len) {
526                    if(logs.marginalsWriter != null) {
527                            for(int pos=0; pos < len; ++pos) {
528                                    FileUtil.safeWrite(logs.marginalsWriter, String.format("Seq %d Pos %d -", i, pos));
529                                    for(short potential : modelInfo.orderedPotentials) {
530                                            if(potential < modelInfo.nStates) {
531                                                    FileUtil.safeWrite(logs.marginalsWriter, String.format(" State %d: %e", potential, allNodeProb[potential][pos]));
532                                            }
533                                            else {
534                                                    int trans = potential - modelInfo.nStates;
535                                                    FileUtil.safeWrite(logs.marginalsWriter, String.format(" Edge %d-%d: %e", modelInfo.transitionFrom[trans], modelInfo.transitionTo[trans], allEdgeProb[trans][pos]));
536                                            }
537                                    }
538                                    FileUtil.safeWrite(logs.marginalsWriter, "\n");
539                            }
540                    }
541            }
542            
543            /** gets the local similarity score function used to score each position in every path.
544             * @return the configured score function
545             */
546            public LocalPathSimilarityScore getScore() {
547                    return score;
548            }
549    
550            /** sets the local similarity score function used to score each position in every path.  This
551             * is usually specified in the XML configuration file.
552             * @param score the score function to use
553             */
554            public void setScore(LocalPathSimilarityScore score) {
555                    this.score = score;
556            }
557    
558            public String getScoreAlphaFile() {
559                    return logs.scoreAlphaFile;
560            }
561    
562            public void setScoreAlphaFile(String scoreAlphaFile) {
563                    logs.scoreAlphaFile = scoreAlphaFile;
564            }
565    
566            public String getExpectedProductFile() {
567                    return logs.expectedProductFile;
568            }
569    
570            public void setExpectedProductFile(String expectedProductFile) {
571                    logs.expectedProductFile = expectedProductFile;
572            }
573    
574            public String getMarginalsFile() {
575                    return logs.marginalsFile;
576            }
577            
578            public void setMarginalsFile(String marginalsFile) {
579                    logs.marginalsFile = marginalsFile;
580            }
581    }