001    package calhoun.analysis.crf.solver;
002    
003    import java.io.BufferedWriter;
004    import java.util.Arrays;
005    import java.util.List;
006    
007    import org.apache.commons.logging.Log;
008    import org.apache.commons.logging.LogFactory;
009    
010    import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
011    import calhoun.analysis.crf.ModelManager;
012    import calhoun.analysis.crf.io.TrainingSequence;
013    import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
014    import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
015    import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
016    import calhoun.util.Assert;
017    import calhoun.util.ColtUtil;
018    import calhoun.util.FileUtil;
019    
020    
021    /** computes the likelihood of the true path for a semi-Markov CRF.  The likelihood is normalized to a per label likelihood. 
022     * <h2>Debugging output</h2>
023     * To get a better understanding of what the objective function is doing, several differn properties can be set that
024     * cause the objective function to write out trace files showing its calculations during training.  Usually when turning
025     * these options on, you should set <code>maxIters = 1</code> and <code>requireConvergence = false</code> in your optimizer
026     * to do only a single training iteration, possibly setting the starts to some predetermined value.  Each of these
027     * properties can be configured with a filename and each time {@link #apply} is called, the file will be overwritten with 
028     * data from the current call.  The logging options are:
029     * <ul>
030     * <li> <b><code>alphaFile</code></b> - computation of alpha values for Markov states, includes all nodes and edges.
031     * <li> <b><code>alphaLengthFile</code></b> - computation of alpha values for semi-Markov states , includes all segments
032     * <li> <b><code>expectFile</code></b> - computation of expected values for each Markov feature 
033     * <li> <b><code>expectLengthFile</code></b> - computation of expected values for each semi-Markov feature  
034     * <li> <b><code>nodeMarginalFile</code></b> - computation of marginal probability of each state at each position 
035     * </ul>
036     * */
037    //public class MaximumLikelihoodSemiMarkovGradient extends CleanMaximumLikelihoodSemiMarkovGradient {
038    //}
039    public class MaximumLikelihoodSemiMarkovGradient implements CRFObjectiveFunctionGradient {
040            private static final Log log = LogFactory.getLog(MaximumLikelihoodSemiMarkovGradient.class);
041            private static final boolean debug = log.isDebugEnabled();
042            private static final double ASSERTION_TOLERANCE = 0.0001;
043            
044            private static final int NORM_FACTOR = 50;
045            private static final double NORM_MIN = Math.exp(-NORM_FACTOR);
046            private static final double NORM_MAX = Math.exp(NORM_FACTOR);
047    
048            String alphaFile = null;
049            String alphaLengthFile = null;
050            String betaLengthFile = null;
051            String expectFile = null;
052            String expectLengthFile = null;
053            String nodeMarginalFile = null;
054            BufferedWriter alphaWriter = null;
055            BufferedWriter alphaLengthWriter = null;
056            BufferedWriter betaLengthWriter = null;
057            BufferedWriter expectWriter = null;
058            BufferedWriter expectLengthWriter = null;
059            BufferedWriter nodeMarginalWriter = null;
060            
061            SolverSetup modelInfo;
062            CacheProcessor cacheProcessor;
063            FeatureEvaluation[] evals;
064            LengthFeatureEvaluation[][] lengthEvals;
065            boolean[] invalidTransitions;
066            
067            // / Cache feature information
068            // / Cached value of the Mi matrix for all of the features present at every position
069            // / Mi is stored as a sparse matrix
070            short maxLookback;
071            CacheProcessor.StatePotentials[] statesWithLookback;
072            CacheProcessor.StatePotentials[] statesWithoutLookback;
073            List<TrainingSequence> data;
074            int iter = 0;
075            double[][] alphas;
076            int[] alphaNorms;
077            double[] starterAlpha;
078    
079            // At any given point, lookbackBuffer.get(x) returns the information about a lookback of x. Lookbacks start at 0.
080            RecyclingBuffer<LookbackBuffer> lookbackBuffer;
081            LookbackBuffer nextBuffer;
082    
083            double[] lambda;
084            double logZ;
085            int zNorm;
086            double zInv;
087            double[] expects;
088    
089            AlphaLengthFeatureProcessor alphaProcessor; 
090            BetaLengthFeatureProcessor betaProcessor; 
091            
092            // We publish feature sums 
093            private double[] featureSums;
094            
095            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
096                    cacheProcessor.setTrainingData(fm, data);
097                    modelInfo = cacheProcessor.getSolverSetup();
098                    Assert.a(modelInfo.maxStateLengths != null, "Maximum state lengths not set.");
099                    Assert.a(modelInfo.maxStateLengths.length == modelInfo.nStates, "Maximum state lengths array was length ("+modelInfo.maxStateLengths.length+").  Must have one entry for each state "+modelInfo.nStates+")");
100                    evals = cacheProcessor.getFeatureEvaluations();
101                    lengthEvals = cacheProcessor.getLengthFeatureEvaluations();
102                    invalidTransitions = cacheProcessor.getInvalidTransitions();
103    
104                    // Create local references to cache values
105                    maxLookback = modelInfo.maxLookback;
106                    statesWithLookback = modelInfo.statesWithLookback;
107                    statesWithoutLookback = modelInfo.statesWithoutLookback;
108    
109                    // Initialize betas (for use later, in the gradient computation)
110                    alphas = new double[modelInfo.longestSeq][modelInfo.nStates];
111                    alphaNorms = new int[modelInfo.longestSeq];
112                    expects = new double[modelInfo.nFeatures];
113    
114                    LookbackBuffer[] bufferContents = new LookbackBuffer[maxLookback+3];
115                    for(int i = 0; i<maxLookback+3; ++i) {
116                            bufferContents[i] = new LookbackBuffer();
117                    }
118                    lookbackBuffer = new RecyclingBuffer<LookbackBuffer>(bufferContents);
119                    nextBuffer = new LookbackBuffer();
120                    
121                    alphaProcessor = new AlphaLengthFeatureProcessor(); 
122                    betaProcessor = new BetaLengthFeatureProcessor(); 
123    
124                    starterAlpha = new double[modelInfo.nStates];
125            }
126    
127            public double apply(double[] param, double[] grad) {
128                    log.debug(String.format("Beginning It: %d Weights: %s", iter, ColtUtil.format(param)));
129                    alphaWriter = FileUtil.safeOpen(alphaFile);
130                    alphaLengthWriter = FileUtil.safeOpen(alphaLengthFile);
131                    betaLengthWriter = FileUtil.safeOpen(betaLengthFile);
132                    expectWriter = FileUtil.safeOpen(expectFile);
133                    expectLengthWriter = FileUtil.safeOpen(expectLengthFile);
134                    nodeMarginalWriter = FileUtil.safeOpen(nodeMarginalFile);
135                    lambda = param;
136                    Arrays.fill(grad, 0);
137                    double totalZ = 0.0;
138                    double result = 0.0;
139    
140                    try {
141                            // Iterate through sequences
142                            Arrays.fill(expects, 0);
143                            for (int i = 0; i < modelInfo.nSeqs; ++i) {
144                                    int len = modelInfo.seqOffsets[i + 1] - modelInfo.seqOffsets[i];
145                                    // Work forwards, computing alphas
146                                    alphaProcessor.computeAlpha(i, len);
147            
148                                    // Since the final beta array is all ones, we can sum the alphas to get the Z
149                                    double sum = 0.0;
150                                    for (double val : alphas[len - 1]) {
151                                            sum += val;
152                                    }
153            
154                                    logZ = log(sum) + NORM_FACTOR * (alphaNorms[len - 1]);
155                                    zNorm = ((int) logZ) / NORM_FACTOR;
156                                    zInv = exp(zNorm * NORM_FACTOR - logZ);
157                                    //log.debug("Seq: "+i+" Z: "+printNorm(1/zInv, zNorm));
158            
159                                    // Work backwards, computing betas and expectations.
160                                    betaProcessor.computeBetasAndExpectations(i, len);
161            
162                                    if(log.isDebugEnabled()) {
163                                            double[][] seqFeatureSums = cacheProcessor.getSequenceFeatureSums();
164                                            if(seqFeatureSums != null) {
165                                                    double seqResult = 0.0;
166                                                    for (int j = 0; j < modelInfo.nFeatures; ++j) {
167                                                            seqResult += seqFeatureSums[i][j] * param[j];
168                                                    }
169                                                    log.debug(String.format("Seq: %d L: %g LL: %f Training path: %f Z: %f", i, exp(seqResult-logZ), seqResult-logZ, seqResult, logZ));
170                                                    Assert.a(exp(seqResult-logZ) < 1.0);
171                                            }
172                                    }
173                                    
174                                    // Update for the next sequence
175                                    totalZ += logZ;
176                            }
177                            
178                            // sum_j lambda_j F_j(xk, yk)
179                            double[] featureSums = cacheProcessor.getFeatureSums();
180                            this.featureSums = featureSums;
181                            for (int j = 0; j < modelInfo.nFeatures; ++j) {
182                                    result += featureSums[j] * param[j];
183                                    grad[j] = featureSums[j] - expects[j];
184                            }
185                            log.debug("Path Value: "+result+" Norm: "+totalZ);
186                            result -= totalZ;
187                            if (log.isInfoEnabled()) {
188                                    if(log.isDebugEnabled()) {
189                                            log.debug(String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", iter, exp(result), result,
190                                                            ColtUtil.norm(grad), ColtUtil.format(featureSums), ColtUtil.format(expects), ColtUtil.format(param), ColtUtil.format(grad)));
191                                    }
192                                    else {
193                                            log.info(String.format("It: %d LL=%f, norm(grad): %f", iter, exp(result), result, ColtUtil.norm(grad)));
194                                    }
195                            }
196                            Assert.a(exp(result) <= 1.0, "Likelihood is greater than 1.");
197                            result = result/modelInfo.totalPositions; 
198                            for(int i=0; i<grad.length; ++i) {
199                                    grad[i] = grad[i]/modelInfo.totalPositions;
200                            }
201                            iter += 1;
202                    }
203                    finally {
204                            FileUtil.safeClose(alphaWriter);
205                            FileUtil.safeClose(alphaLengthWriter);
206                            FileUtil.safeClose(betaLengthWriter);
207                            FileUtil.safeClose(expectWriter);
208                            FileUtil.safeClose(expectLengthWriter);
209                            FileUtil.safeClose(nodeMarginalWriter);
210                    }
211                            
212                    return result;
213            }
214    
215            public void clean() {
216                    // Clean up as much as possible
217                    modelInfo = null;
218                    cacheProcessor = null;
219                    evals = null;
220                    lengthEvals = null;
221                    invalidTransitions = null;
222                    statesWithLookback = null;
223                    statesWithoutLookback = null;
224                    data = null;
225                    alphas = null;
226                    alphaNorms = null;
227                    starterAlpha = null;
228                    lookbackBuffer = null;
229                    nextBuffer = null;
230                    expects = null;
231                    alphaProcessor = null; 
232                    betaProcessor = null; 
233            }
234    
235            private final class AlphaLengthFeatureProcessor {
236                    int seqOffset;
237                    int pos;
238                    double[] alpha;
239                    int alphaNorm;
240                    double[] stableState;
241                    
242                    /**
243                     * In the forward pass we compute alpha values and expections. This is simpler than the backwards pass because
244                     * the cache is set up for us so that we can always look at one position at a time. We have to cache previous
245                     * values but we never have to look ahead.
246                     */
247                    final void computeAlpha(final int seqNum, final int len) {
248                            // Result the alpha norms
249                            Arrays.fill(alphaNorms, Integer.MIN_VALUE);
250    
251                            Arrays.fill(starterAlpha, 0.0);
252                            double[] prevAlpha = null;
253                            
254                            seqOffset = modelInfo.seqOffsets[seqNum];
255    
256                            for(pos = 0; pos < len; ++pos) {
257                                    prevAlpha = alpha;
258                                    alpha = alphas[pos];
259                                    Arrays.fill(alpha, 0.0);
260                                    if (pos == 0) {
261                                            alphaNorm = 0;
262                                            calcStartAlpha(alpha, seqNum);
263    
264                                            // Put an empty entry in the lookback so the first base has 0's initialized.
265                                            Arrays.fill(nextBuffer.stableState, 0.0);
266                                    } else {
267                                            cacheMi(seqNum, nextBuffer.mi, stableState, nextBuffer.stableState, pos);
268                                            regularAlphaUpdate(pos, nextBuffer.mi, prevAlpha, alpha);
269                                    }
270                                    
271                                    // Add the lookback into the array
272                                    stableState = nextBuffer.stableState;
273                                    nextBuffer = lookbackBuffer.addFirst(nextBuffer);
274                                            
275                                    // Now we need to loop through for the length dependent cache
276                                    lengthAlpha(seqNum, pos);
277    
278                                    int norm = normalize(alpha);
279                                    /*if(norm != 0) {
280                                            log.info("Pos: "+pos+" Renormalized alpha by "+norm+" to "+(alphaNorm+norm)+" : "+ColtUtil.format(alpha));
281                                    }*/
282                                    alphaNorm += norm; 
283                                    alphaNorms[pos] = alphaNorm;
284    //                              if(alphaNorm > 0 && pos > 0 && alphaNorms[pos-1]<=0)
285    //                                      log.info("Norm prob at pos: "+pos);
286                            }
287                    }
288    
289                    /**
290                     * Updates the alpha vector for non-length dependent states. We don't have to worry about normalization here
291                     * because the regular alpha update is done before the length dependent, so these are the first values that will
292                     * be set.
293                     */
294                    private final void regularAlphaUpdate(final int pos, final double[] mi, final double[] lastAlpha, final double[] newAlpha) {
295                            double nodeVal = 0.0;
296                            int lastState = -1;
297                            boolean lengthNode = false;
298                            for (short potential : modelInfo.orderedPotentials) {
299                                    if (potential < modelInfo.nStates) {
300                                            if (lastState != -1) {
301                                                    newAlpha[lastState] = nodeVal;
302                                            }
303                                            lastState = potential;
304                                            nodeVal = 0.0;
305                                            lengthNode = modelInfo.maxStateLengths[potential] > 1;
306                                    } else {
307                                            if (!lengthNode) {
308                                                    int trans = potential - modelInfo.nStates;
309                                                    double transVal = mi[trans];
310                                                    if(!Double.isInfinite(transVal)) {
311                                                            int from = modelInfo.transitionFrom[trans];
312                                                            if(alphaWriter != null)
313                                                                    FileUtil.safeWrite(alphaWriter, String.format("alpha[%d][%d] = %s = %s + alpha[%d][%d] %s * %s exp(%f)\n", pos, lastState, printNorm(nodeVal + lastAlpha[from] * exp(mi[trans]), alphaNorm), printNorm(nodeVal, alphaNorm), pos-1, from, printNorm(lastAlpha[from], alphaNorm), printNorm(exp(mi[trans]), 0), mi[trans]));
314                                                            nodeVal += lastAlpha[from] * exp(transVal);
315                                                    }
316                                            }
317                                    }
318                            }
319                            newAlpha[lastState] = nodeVal;
320                    }
321    
322                    /** Updates an alpha entry with a weighted sum of features values for a given potential */
323                    private final void lengthAlpha(final int seqNum, final int pos) {
324                            cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
325                            /*
326                             * Updates an existing alpha by adding in: potentialValue - The value of any length-dependent features for
327                             * this node f(y, i, d) and edge f(y', y, i, d) stableValue - The value of the non-length dependent node
328                             * features summed across the length of this segment mis - The value of the non-length dependent transition
329                             * from the previous node to this one f(y', y, i-d)
330                             */
331                            int nSemiMarkovStates = modelInfo.statesWithLookback.length;
332                            for(int i=0; i<nSemiMarkovStates; ++i) {
333                                    LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
334                                    CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
335                                    byte toNode = statePotentials.state;
336                                    
337                                    int lbIndex=0;
338                                    LengthFeatureEvaluation lengthEval = lookbacksForState[lbIndex];
339                                    int lookback = lengthEval.lookback;
340                                    while(lookback != -1) {
341                                            //log.info("Pos: "+pos+"\t State: "+modelInfo.statesWithLookback[i].state+"\t Lookback: "+lookback);
342                                            int prevPos = pos - lookback - 1;
343                                            // For speed I hand inline RecyclingBuffer.get
344                                            LookbackBuffer buffer = lookbackBuffer.array[(lookbackBuffer.currentStart+lookback)%lookbackBuffer.length];
345    
346                                            // Handle evaluation of the node potentials
347                                            FeatureEvaluation nodeEvals = lengthEval.nodeEval;
348                                            short[] indices = nodeEvals.index;
349                                            float[] vals = nodeEvals.value;
350                                            int ix = 0;
351                                            short index = indices[ix];
352                                            double stableValue = stableState[toNode] - buffer.stableState[toNode];
353                                            double nodePotential = stableValue;
354                                            while(index >= 0) {
355                                                    nodePotential += vals[ix] * lambda[index];
356                                                    index = indices[++ix];
357                                            }
358                                            if(debug)
359                                                    Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid");
360    
361                                            if(prevPos < 0) {
362                                                    double nodeVal = nodePotential + starterAlpha[toNode];
363                                                    // If this is the first segment, then we don't worry about edges and handle the node directly.
364                                                    int norm = ((int) nodeVal) / NORM_FACTOR;
365                                                    nodeVal -= norm * NORM_FACTOR;
366                            
367                                                    if (norm > alphaNorm) {
368                                                            renormalize(alpha, alphaNorm, norm);
369                                                            //log.info("Renormalized alpha: "+ColtUtil.format(alpha));
370                                                            alphaNorm = norm;
371                                                    } else if (norm < alphaNorm) {
372                                                            nodeVal += NORM_FACTOR * (norm - alphaNorm);
373                                                    }
374                                                    if(alphaLengthWriter != null) {
375                                                            FileUtil.safeWrite(alphaLengthWriter, String.format("seq: %d alpha[%d][%d] = %s = %s + %s (Pot: %f Starter: %f)\n", seqNum, pos, toNode, printNorm(alpha[toNode] + exp(nodeVal), alphaNorm), printNorm(alpha[toNode], alphaNorm), printNorm(exp(nodeVal), alphaNorm), nodePotential, starterAlpha[toNode])); 
376                                                    }
377                                                            /*if((pos == 499 && toNode == 0))
378                                                            log.info(String.format("alpha[%d][%d] = %s = %s + %s (Pot: %f Starter: %f)", pos, toNode, printNorm(alpha[toNode] + exp(nodeVal), alphaNorm), printNorm(alpha[toNode], alphaNorm), printNorm(exp(nodeVal), alphaNorm), nodePotential, starterAlpha[toNode]));
379                                                            */
380                                                    alpha[toNode] += exp(nodeVal);
381                                            }
382                                            else {
383                                                    // If this is not the first segment, we need to deal with edges coming into this segment
384                                                    FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
385                                                    int nEdges = statePotentials.potentials.length;
386                                                    for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
387                                                            int potential = statePotentials.potentials[edgeIx];
388                                                            int trans = potential - modelInfo.nStates;
389                                                            int fromNode = modelInfo.transitionFrom[trans];
390                                                            // Skip semi-Markov self transitions
391                                                            if(fromNode == toNode)
392                                                                    continue;
393    
394                                                            double edgeVal = 0.0;
395    
396                                                            if(edgeEvals == null) {
397                                                                    // If the cache processor does not have edge evaluations
398                                                                    // Just check if this transition is legal based on the invalid transitions matrix
399                                                                    int invalidIndex = (seqOffset+prevPos+1)*modelInfo.nPotentials;
400                                                                    if(invalidTransitions[invalidIndex + potential]) {
401                                                                            //log.info("Illegal transition: "+fromNode+"-"+toNode+" at pos: "+prevPos);
402                                                                            continue;
403                                                                    }
404                                                            }
405                                                            else {
406                                                                    // If the cache processor does have edge evaluations, then ignore the illegal transitions matrix
407                                                                    // and update the expval using the edge evaluations
408                                                                    FeatureEvaluation potEvals = edgeEvals[edgeIx];
409                                                                    indices = potEvals.index;
410                                                                    vals = potEvals.value;
411                                                                    ix = 0;
412                                                                    index = indices[i];
413                                                                    while(index >= 0) {
414                                                                            edgeVal += vals[ix] * lambda[index];
415                                                                            index = indices[++ix];
416                                                                    }
417                                                                    if(index == Short.MIN_VALUE) {
418                                                                            continue;
419                                                                    }
420                                                            }
421                                                            
422                                                            // Renormalize and update the exp value.
423                                                            double expVal = edgeVal + buffer.mi[trans] + nodePotential;
424                                                            int expNorm = ((int) expVal)/NORM_FACTOR;
425                                                            expVal -= expNorm*NORM_FACTOR;
426                                                            
427                                                            int prevNorm = alphaNorms[prevPos];
428                                                            int updateNorm = expNorm + prevNorm;
429                                                            if(updateNorm > alphaNorm) {
430                                                                    // Our updated value is larger than the existing alpha value, renormalize that alpha vector.
431                                                                    renormalize(alpha, alphaNorm, updateNorm);
432                                                                    //log.info("Renormalized alpha from "+alphaNorm+" to "+updateNorm +" : "+ColtUtil.format(alpha));
433                                                                    alphaNorm = updateNorm;
434                                                            }
435                                                            else if(alphaNorm > updateNorm) {
436                                                                    // Renormalize the expVal 
437                                                                    int expShift = alphaNorm - updateNorm;
438                                                                    //log.info(String.format("Renormalize feature by %d from %d to %d",expShift, expNorm, expNorm+expShift));
439                                                                    expNorm += expShift;
440                                                                    expVal -= expShift*NORM_FACTOR;
441                                                            }
442                                                            
443                                                            double prevAlpha = alphas[prevPos][fromNode];
444                                                            double update = exp(expVal) * prevAlpha;
445                                                            if(alphaLengthWriter != null) {
446                                                                    FileUtil.safeWrite(alphaLengthWriter, String.format("seq: %d alpha[%d][%d] = %s = %s + %s (alpha[%d][%d]) * %s exp(EdgeLength: %f NodeLength: %f Edge: %f Node: %f )\n", 
447                                                                                    seqNum, pos, toNode, printNorm(alpha[toNode] + update, alphaNorm), printNorm(alpha[toNode], alphaNorm), printNorm(prevAlpha, alphaNorms[prevPos]), 
448                                                                                    prevPos, modelInfo.transitionFrom[trans], printNorm(exp(expVal), expNorm), edgeVal, nodePotential - stableValue, buffer.mi[trans], stableValue));
449                                                            }
450                                                                                    
451                                                            alpha[toNode] += update;
452                                                            // Expensive assertion that can catch some normalization problems
453                                                            if(debug)
454                                                                    Assert.a(expNorm + prevNorm == alphaNorm, "Norm problem.  Exp: ", expNorm, " Prev alpha: ", prevNorm, " Alpha: ", alphaNorm);
455                                                    }
456                                            }
457                                            
458                                            ++lbIndex;
459                                            lengthEval = lookbacksForState[lbIndex];
460                                            lookback = lengthEval.lookback;
461                                    }
462                            }
463                    }
464    
465                    /**
466                     * A specialized version of calcMi for the first position in a sequence. Has the special property that constant
467                     * edge features are not included.
468                     */
469                    void calcStartAlpha(double[] currentAlpha, int seq) {
470                            cacheProcessor.evaluatePosition(seq, 0);
471                            int invalidIndex = seqOffset*modelInfo.nPotentials;
472                            for(short potential : modelInfo.orderedPotentials) {
473                                    if(potential < modelInfo.nStates) {
474                                            boolean invalid = invalidTransitions[invalidIndex + potential];
475                                            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
476    
477                                            // Add up all features for this potential.
478                                            FeatureEvaluation potEvals = evals[potential];
479                                            short[] indices = potEvals.index;
480                                            float[] vals = potEvals.value;
481                                            int i = 0;
482                                            short index = indices[i];
483                                            while(index != -1) {
484                                                    if(index == Short.MIN_VALUE) {
485                                                            features = Double.NEGATIVE_INFINITY;
486                                                            break;
487                                                    }
488                                                    features += vals[i]*lambda[index];
489                                                    index = indices[++i];
490                                            }
491                                            if(modelInfo.maxStateLengths[potential]> 1) {
492                                                    starterAlpha[potential] = features;
493                                            }
494                                            else {
495                                                    currentAlpha[potential] = exp(features);
496                                            }
497                                    }
498                            }               
499                    }
500            }
501    
502            /** Updates a beta entry with a potential from a length dependent feature */
503            class BetaLengthFeatureProcessor {
504                    int seqOffset;
505                    int lengthPos;
506                    double[] lengthStable;
507                    int miPos;
508                    double[] stableState;
509    
510                    double[] beta;
511                    int betaNorm;
512    
513                    double prob;
514                    
515                    LookbackBuffer posLookback; 
516                    LookbackBuffer prevLookback; 
517                    
518                    // This vector tracks the marginal probability of each node, accounting for length dependence
519                    double[] nodeProb = new double[modelInfo.nStates];
520                    double[] newNodeProb = new double[modelInfo.nStates];
521    
522                    // This vector tracks the marginal probability of each edge, accounting for length dependence
523                    double[] edgeProb = new double[modelInfo.nTransitions];
524                    
525                    /**
526                     * Computing the betas is a bit tricky. The problem is that our cache associates each explicit
527                     * length feature with the position at which it ends. For the betas, we are going backwards, and so we need to
528                     * look at the positions where the length based feature starts.
529                     * 
530                     * The way we handle this is by evaluating the length based features as they arise and incrementally adding
531                     * their potential values into the beta positions as we go along. Therefore we aren't always filling one beta
532                     * matrix at a time, we fill any beta entry that starts a state which ends the current position as we move
533                     * backwards through the sequence.
534                     * 
535                     * In order to do this we need to keep mi matrices and arrays of the potentials of staying in a stable state for
536                     * the entire lookback period. These matrices let us compute betas anywhere we want.
537                     * 
538                     * Unfortunately this process needs more exponentiation than the existing process.
539                     * 
540                     * Normalization works as follows: First, we need to find the right normalization factor for the value we are
541                     * adding. 1) We add all of the exponents we will need. 2) We integer divide by NORM_FACTOR to get the
542                     * normalization constant for our beta update 3) The constant*NORM_FACTOR is subtracted from the total exponent
543                     * to get the unnormalized exponent 4) The unnormalized exponent is exponentiated and multiplied by the
544                     * following beta value 5) We compare the norm constant to the current constant for the beta we are about to
545                     * update 5.1) If they are equal we are fine, the values can be added 5.2) If Our new norm constant has larger
546                     * abs(), we renormalize the beta vector by dividing all elements by e^(NORM_FACTOR*diff) 5.3) If our new norm
547                     * constant has smaller abs(), we renormalize it but dividing by e^(FACTOR*diff) 6) When we finally update the
548                     * beta vector at the end we redo the normalization anyway.
549                     * 
550                     * To optimize this for fastest execution it ends up being one big function. Putting function calls in the inner
551                     * loops really slowed things down. I think the optimizer can't do as good a job if there are function calls
552                     * there.
553                     */
554                    void computeBetasAndExpectations(int seqNum, int len) {
555                            // Start by initializing an array of unexponentiated mi and stable state matrices for all of the possible
556                            // lookback positions from the last base
557                            // Also initialize the array of previous beta vectors
558                            seqOffset = modelInfo.seqOffsets[seqNum];
559    
560                            // lastInitPos holds the position of the leftmost position which gets evaluated at the start of the run.
561                            int lastInitPos = len - 2 - maxLookback;
562    
563                            // posLookback holds the lookback buffer for the position currently specified by pos
564                            posLookback = nextBuffer;
565                            // prevLookback holds the lookback buffer for the position pos+1
566                            prevLookback = null;
567    
568                            Arrays.fill(nodeProb, 0.0);
569                            Arrays.fill(newNodeProb, 0.0);
570                            
571                            // Now work through the sequence backwards
572                            miPos = len-1;
573                            nextBuffer.clear();
574                            for (int pos = len - 1; pos >= 0; --pos) {
575                                    // First, update the lookback, which caches mi and stable values if necessary
576                                    while(miPos >= 0 && miPos >= lastInitPos) {
577                                            if(miPos == len-1) {
578                                                    // Initialize stable states
579                                                    // Need to prime with a stable state of 0 since we don't have the init that we do for alphas.
580                                                    Arrays.fill(nextBuffer.stableState, 0);
581                                            }
582                                            else {
583                                                    // Update stable states given the previous transition
584                                                    nextBuffer.clear();
585                                                    cacheMi(seqNum, nextBuffer.mi, stableState, nextBuffer.stableState, miPos+1);
586                                            }
587                                            nextBuffer.pos = miPos;
588                                            stableState = nextBuffer.stableState;
589                                            nextBuffer = lookbackBuffer.addFirst(nextBuffer);
590                                            nextBuffer.clear();
591                                            --miPos;
592                                    }
593    
594                                    // At this point, miPos contains the leftmost position for which markov features have been evaluated.
595                                    // The lookback buffer at this position has a stableState vector for each length dependent state and
596                                    // an mi matrix computed for that state.
597                                    
598                                    // Retrieve the lookback information for the current position.  The lookback buffer runs left to right, so the 
599                                    // current position is not at the beginning.
600                                    posLookback = lookbackBuffer.get(pos - miPos -1);
601    
602                                    // Check that posLookback is correct
603                                    if(debug)
604                                            Assert.a(posLookback.pos == pos, "Wrong lookback buffer: was ", posLookback.pos, " should be ", pos);
605    
606                                    // Update betas (if necessary)
607                                    if(prevLookback == null) {
608                                            // For the last position in the sequence (first iteration in the for loop)
609                                            if(debug)
610                                                    Assert.a(pos == len-1);
611    
612                                            // Betas for the last position are all 1.  exp(0)
613                                            Arrays.fill(posLookback.beta, 1);
614                                            posLookback.betaNorm = 0;
615    
616                                            /* Initialize node marginals.  Since all segments end at the last position, we can compute
617                                            this by multiplying the alpha and beta vectors and dividing by Z.  The beta vector is all 1's though
618                                            so we just divide alpha by z */ 
619                                            double nodeNorm = exp((alphaNorms[pos] - zNorm) * NORM_FACTOR) * zInv;
620                                            for(int i=0; i<modelInfo.nStates; ++i) {
621                                                    nodeProb[i] = nodeNorm * alphas[pos][i];
622                                            }
623                                            log.debug("Node marginals at seq "+seqNum+" last position ("+pos+"): "+ColtUtil.format(nodeProb));
624                                    }
625                                    else {
626                                            // For all but the last position, we update beta first for the markov states and then the semi-markov states.
627                                            if(debug)
628                                                    Assert.a(prevLookback.pos == pos+1);
629                                            posLookback.betaNorm = regularBetaUpdate(pos+1, posLookback.beta, posLookback.betaNorm, prevLookback.beta, prevLookback.betaNorm, prevLookback.transitionProb, posLookback.mi);
630                                            //log.info(String.format("Pos: %d, Mipos: %d LbPos: %d", pos, miPos, posLookback.pos));
631    
632                                            // At this point, the nodeProb and edgeProb vectors contain the marginal probabilities of markov nodes and edges into them.
633                                            // The beta is fully updated for the markov states.
634                                            // Semi markov has not been done.
635                                            
636                                            // Now we need to loop through for the length dependent cache
637                                            beta = prevLookback.beta;
638                                            /*for(int ix=0; ix<beta.length; ++ix) {
639                                                    Assert.a(!Double.isNaN(beta[ix]), prevLookback.pos+ " "+ ix);
640                                            }*/
641                                            betaNorm = prevLookback.betaNorm;
642                                            lengthStable = prevLookback.stableState;
643                                            lengthPos = pos + 1;
644                                            lengthBeta(seqNum, lengthPos);
645    
646                                            System.arraycopy(nodeProb, 0, newNodeProb, 0, modelInfo.nStates);
647                                            
648                                            // Now calculate edge marginals for staying in the same state
649                                            // Take the node probabilities, and subtract off each transition
650                                            for(CacheProcessor.StatePotentials lb: statesWithLookback) {
651                                                    int state = lb.state;
652                                                    int index = modelInfo.selfTransitions[state];
653                                                    double transProb = nodeProb[state];
654                                                    for(byte pot : lb.potentials) {
655                                                            double lbTrans = posLookback.transitionProb[pot - modelInfo.nStates];
656                                                            transProb -= lbTrans;
657                                                            newNodeProb[state] -= lbTrans;
658                                                    }
659                                                    Assert.a(posLookback.transitionProb[index] == 0.0);
660                                                    //log.debug("Self-trans marginal for ("+state+"): "+transProb );
661                                                    posLookback.transitionProb[index] = transProb;
662                                            }
663                                    }
664                                    
665                                    // As a check, we verify that the node marginals sum to one for each position.
666                                    double sum =0.0;
667                                    for(double x : nodeProb) {
668                                            if(x > 1.0+ASSERTION_TOLERANCE || x < -ASSERTION_TOLERANCE)
669                                                    Assert.a(false, "Iter ",iter," Pos: "+pos+" Node marginals not valid "+x);
670                                            sum += x;
671                                    }
672                                    if(Math.abs(1.0-sum) > ASSERTION_TOLERANCE) {
673                                            Assert.a(false, "Iter ",iter," Pos: "+pos+" Node marginals sum to "+sum+" not 1: ", ColtUtil.format(nodeProb), " at ", seqNum," ",pos);
674                                    }
675    
676                                    // Verify that edge marginals sum to 1.
677                                    if(prevLookback != null) {
678                                            if(debug) {
679                                                    sum =0.0;
680                                                    for(double x : edgeProb) {
681                                                            if(x > 1+ASSERTION_TOLERANCE || x < -ASSERTION_TOLERANCE)
682                                                                    Assert.a(false, "Iter ",iter," Pos: "+pos+" Edge marginal not valid "+x);
683                                                            sum += x;
684                                                    }
685                                                    for(double x : posLookback.transitionProb) {
686                                                            if(x > 1+ASSERTION_TOLERANCE || x < -ASSERTION_TOLERANCE)
687                                                                    Assert.a(false, "Iter ",iter," Pos: "+pos+" Self-trans marginal not valid "+x);
688                                                            sum += x;
689                                                    }
690                                                    /*
691                                                    for(int i=0; i<modelInfo.nStates; ++i) {
692                                                            log.debug("Nod "+i+" "+nodeProb[i]);
693                                                    }*/
694                                                    if(Math.abs(1.0-sum) > 0.001) {
695                                                            /*for(int i = 0; i < posLookback.transitionProb.length; ++i) {
696                                                                    if(posLookback.transitionProb[i] != 0)
697                                                                            log.debug("Seg "+modelInfo.transitionFrom[i]+"-"+modelInfo.transitionTo[i]+" "+posLookback.transitionProb[i]);
698                                                            }
699                                                            for(int i = 0; i < edgeProb.length; ++i) {
700                                                                    if(edgeProb[i] != 0)
701                                                                            log.debug("Edg "+modelInfo.transitionFrom[i]+"-"+modelInfo.transitionTo[i]+" "+edgeProb[i]);
702                                                            }*/
703                                                            Assert.a(false, "Edge marginals don't sum to 1.  Sum to: ", sum, " - ", ColtUtil.format(edgeProb), ColtUtil.format(posLookback.transitionProb));
704                                                    }
705                                            }
706                                            
707                                            // At this point the previous beta values are all updated.
708                                            // Update expectations for the transitions we just calculated.
709                                            updateExpectations(seqNum, pos+1, posLookback.transitionProb);
710                                    }
711    
712                                    // Update the node probabilities to remove the length transitions
713                                    double[] temp = nodeProb;
714                                    nodeProb = newNodeProb;
715                                    newNodeProb = temp;
716                                    
717                                    /*if (debug) {
718                                            if ((seqOffset == 0) && (pos < 2 || pos >= len - 2)) {
719                                                    log.debug(String.format("Pos: %d expects: %s alphas: %s (norm %d) betas: %s (norm %d) MiPos: %d", pos, ColtUtil.format(expects), ColtUtil
720                                                                    .format(alphas[pos]), alphaNorms[pos], ColtUtil.format(posLookback.beta), posLookback.betaNorm, miPos+1));
721                                            }
722                                    }*/
723    
724                                    prevLookback = posLookback;
725                                    --lastInitPos;
726                            }
727                            // Now update for the first position
728                            //log.debug(String.format(ColtUtil.format(posLookback.beta)));
729                            posLookback.betaNorm = regularBetaUpdate(0, null, posLookback.betaNorm, posLookback.beta, posLookback.betaNorm, null, null);
730                            beta = posLookback.beta;
731                            betaNorm = posLookback.betaNorm;
732                            lengthStable = posLookback.stableState;
733                            lengthPos = 0;
734                            lengthBeta(seqNum, lengthPos);
735                            //betaProcessor.lengthCache(seqStart);
736                            updateExpectations(seqNum, 0, posLookback.transitionProb);
737                    }
738    
739                    /**
740                     * Does the update of the beta values for all of the regular, non-length dependent states. Also calculates all
741                     * of the node and edge probabilities for the non-length nodes and the transitions into them.
742                     */
743                    private int regularBetaUpdate(int pos, double[] newBeta, int newNorm, double[] oldBeta, int oldNorm, double[] transitionProb, double[] mi) {
744                            // Need to deal with possibly different normalization constants.
745                            int norm = newNorm;
746                            double normAdjust = 0.0;
747                            if (oldNorm > newNorm) {
748                                    // We never make the constant smaller, so set the constant for the new factor
749                                    //log.info(String.format("Renormalizing beta[%d] from %d to %d", pos, newNorm, oldNorm));
750                                    renormalize(newBeta, newNorm, oldNorm);
751                                    norm = oldNorm;
752                                    newNorm = oldNorm;
753                            } else {
754                                    // The case where beta(pos-1) already has a normalization constant larger than beta(pos)
755                                    // We need to adjust all of the updated beta values.
756                                    normAdjust = (oldNorm - newNorm) * NORM_FACTOR;
757                            }
758    
759                            //log.info(String.format("Node norm pos: %d e^(alpha: %d + beta: %d - z: %d) * zinv: %f ", pos, alphaNorms[pos], norm, zNorm, zInv));
760                            double[] nodeAlpha = alphas[pos];
761                            double nodeNorm = exp((alphaNorms[pos] + oldNorm - zNorm) * NORM_FACTOR) * zInv;
762                            double[] edgeAlpha = null;
763                            double edgeNorm = Double.NaN;
764                            if(pos > 0) {
765                                    edgeAlpha = alphas[pos-1];
766                                    // We add newNorm here because it ends up cancelling with normAdjust when we calc the edge prob.
767                                    edgeNorm = exp((alphaNorms[pos-1] + newNorm - zNorm) * NORM_FACTOR) * zInv;
768                            }
769    
770                            for(CacheProcessor.StatePotentials potentials : statesWithoutLookback) {
771                                    byte node = potentials.state;
772                                    double nodePotential = 0.0;
773                                    double betaVal = oldBeta[node];
774                                    if(nodeMarginalWriter != null) {
775                                            FileUtil.safeWrite(nodeMarginalWriter, String.format("NodeMarg[%d][%d] = %f = %f * %f * %f (aN: %d bN: %d zN: %d 1/z: %f)\n", pos, node, nodeAlpha[node] * betaVal * nodeNorm, nodeAlpha[node], betaVal, nodeNorm, alphaNorms[pos], oldNorm, zNorm, zInv));
776                                    }
777                                    nodeProb[node] = nodeAlpha[node] * betaVal * nodeNorm;
778                                    
779                                    // For regular states, we sum edge probabilities to get node probabilities.
780                                    if(pos > 0) {
781                                            for (short potential : potentials.potentials) {
782                                                    int trans = potential - modelInfo.nStates;
783                                                    double transVal = mi[trans];
784                                                    if(!Double.isInfinite(transVal)) {
785                                                            // Mi for beta is not exponentiated, so we do it here.
786                                                            double potentialValue = exp(mi[trans] + normAdjust);
787                                                            nodePotential += potentialValue;
788    
789                                                            int from = modelInfo.transitionFrom[trans];
790                                                            /*if(pos > 2400) {
791                                                                    log.info(String.format("Beta[%d][%d] = %f = %f + Beta[%d][%d] %f * %f (Mi: %f)", pos, from, newBeta[from] + potentialValue * betaVal, newBeta[from], pos+1, node, betaVal, potentialValue, mi[trans]));
792                                                            }*/
793                                                            newBeta[from] += potentialValue * betaVal;
794                                                            edgeProb[trans] = edgeAlpha[from] * potentialValue * betaVal * edgeNorm;
795                                                    }
796                                                    else {
797                                                            edgeProb[trans] = 0.0;
798                                                    }
799                                            }
800                                    }
801                            }
802    
803                            // Now check to see if this new vector needs normalization again.
804                            //log.debug("pos "+pos);
805                            int ret = norm;
806                            if(newBeta != null) {
807                                    try {
808                                            ret += normalize(newBeta);
809                                    }
810                                    catch(RuntimeException ex) {
811                                            log.warn("Normalization problem at " + pos + " "+ColtUtil.format(newBeta));
812                                            throw ex;
813                                    }
814                            }
815                            return ret;
816                    }
817    
818                    /**
819                     * Updates an existing beta by adding in: potentialValue - The value of any length-dependent features for this
820                     * node f(y, i, d) and edge f(y', y, i, d) stableValue - The value of the non-length dependent node features
821                     * summed across the length of this segment mis - The value of the non-length dependent transition from the
822                     * previous node to this one f(y', y, i-d)
823                     * 
824                     * In addition to computing the betas for length dependent features, it updates the probabilities vectors needed
825                     * for the feature expectations. These are the marginal at each position or each edge.
826                     */
827                    private void lengthBeta(int seqNum, int pos) {
828                            cacheProcessor.evaluateSegmentsEndingAt(seqNum, pos);
829    
830                            int nSemiMarkovStates = modelInfo.statesWithLookback.length;
831                            for(int i=0; i<nSemiMarkovStates; ++i) {
832                                    LengthFeatureEvaluation[] lookbacksForState = lengthEvals[i];
833                                    CacheProcessor.StatePotentials statePotentials = modelInfo.statesWithLookback[i];
834                                    byte toNode = statePotentials.state;
835                                    
836                                    int lbArrayIndex=0;
837                                    LengthFeatureEvaluation lengthEval = lookbacksForState[lbArrayIndex];
838                                    int lookback = lengthEval.lookback;
839                                    while(lookback != -1) {
840                                            int prevPos = lengthPos - lookback - 1;
841                                            int lbIndex = prevPos - miPos - 1;
842    
843                                            //log.info("Pos: "+pos+"\t State: "+modelInfo.statesWithLookback[i].state+"\t Lookback: "+lookback);
844                                            LookbackBuffer segBegin = null;
845                                            if(prevPos >= 0) {
846                                                    // For speed I hand inline RecyclingBuffer.get
847                                                    segBegin = lookbackBuffer.array[(lookbackBuffer.currentStart+lbIndex)%lookbackBuffer.length];
848                                            }
849                                            LookbackBuffer stableBuffer = lookbackBuffer.array[(lookbackBuffer.currentStart+lbIndex+1)%lookbackBuffer.length];
850    
851                                            // Handle evaluation of the node potentials
852                                            double stableValue = stableBuffer.stableState[toNode] - lengthStable[toNode];
853                                            double nodePotential = stableValue;
854                                            
855                                            FeatureEvaluation nodeEvals = lengthEval.nodeEval;
856                                            short[] indices = nodeEvals.index;
857                                            float[] vals = nodeEvals.value;
858                                            int ix = 0;
859                                            short index = indices[ix];
860                                            while(index >= 0) {
861                                                    nodePotential += vals[ix] * lambda[index];
862                                                    index = indices[++ix];
863                                            }
864                                            if(debug)
865                                                    Assert.a(index != Short.MIN_VALUE, "Node lengths should only be returned in the cache if they are valid.  They can be invalid because a node is invalid or a self-transition edge is invalid.");
866    
867                                            if(prevPos < 0) {
868                                                    // If this is the first segment, then we don't worry about edges and handle the node directly.
869                                                    double expVal = nodePotential + starterAlpha[toNode];
870                                                    lengthBetaHandling(seqNum, prevPos, pos, expVal, -1, toNode, 1.0, 0, nodeEvals);
871                                            }
872                                            else {
873                                                    // If this is not the first segment, we need to deal with edges
874                                                    FeatureEvaluation[] edgeEvals = lengthEval.edgeEvals;
875                                                    int nEdges = statePotentials.potentials.length;
876                                                    for(int edgeIx=0; edgeIx < nEdges; ++edgeIx) {
877                                                            int potential = statePotentials.potentials[edgeIx];
878                                                            int trans = potential - modelInfo.nStates;
879                                                            int fromNode = modelInfo.transitionFrom[trans];
880                                                            // Skip semi-Markov self transitions
881                                                            if(fromNode == toNode)
882                                                                    continue;
883    
884                                                            double edgeVal = 0.0;
885    
886                                                            if(edgeEvals == null) {
887                                                                    // If the cache processor does not have edge evaluations
888                                                                    // Just check if this transition is legal based on the invalid transitions matrix
889                                                                    int invalidIndex = (seqOffset + prevPos+1)*modelInfo.nPotentials;
890                                                                    if(invalidTransitions[invalidIndex + potential])
891                                                                            continue;
892                                                            }
893                                                            else {
894                                                                    // If the cache processor does have edge evaluations, then ignore the illegal transitions matrix
895                                                                    // and update the expval using the edge evaluations
896                                                                    FeatureEvaluation potEvals = edgeEvals[edgeIx];
897                                                                    indices = potEvals.index;
898                                                                    vals = potEvals.value;
899                                                                    ix = 0;
900                                                                    index = indices[i];
901                                                                    while(index >= 0) {
902                                                                            edgeVal += vals[ix] * lambda[index];
903                                                                            index = indices[++ix];
904                                                                    }
905                                                                    if(index == Short.MIN_VALUE) {
906                                                                            continue;
907                                                                    }
908                                                            }
909    
910                                                            //log.debug("Stab: "+ stableBuffer.stableState[toNode]+" - " + lengthStable[toNode]);
911                                                            if(debug) {
912                                                                    Assert.a(prevLookback.pos == lengthPos, "Expected ",lengthPos, " was ",prevLookback.pos);
913                                                                    Assert.a(segBegin.pos == (lengthPos-lookback-1), "Expected ",(lengthPos-lookback-1), " was ",segBegin.pos);
914                                                            }
915                                                            double expVal = edgeVal + segBegin.mi[trans] + nodePotential;
916                                                            double prevAlpha = alphas[prevPos][fromNode];
917                                                            int prevAlphaNorm = alphaNorms[prevPos];
918                                                            //log.debug("mi: "+buffer.mi[trans]+" s: "+nodePotential+" Base: "+(expVal - buffer.mi[trans] - nodePotential));
919                                                            int expNorm = lengthBetaHandling(seqNum, prevPos, pos, expVal, fromNode, toNode, prevAlpha, prevAlphaNorm, nodeEvals);
920                                                            expVal -= expNorm * NORM_FACTOR;
921    
922                                                            // To determine the normalization value, we compare the existing beta value to the value we are about to add
923                                                            // Whichever is a larger number will dominate, and so we use that normalization value as the new value and ignore the
924                                                            // previous one.
925                                                            
926                                                            // Update the beta values
927                                                            int updateNorm = expNorm + betaNorm;
928                                                            if(updateNorm > segBegin.betaNorm) {
929                                                                    //log.info(String.format("Renormalize beta from %d to %d", segBegin.betaNorm, updateNorm));
930                                                                    renormalize(segBegin.beta, segBegin.betaNorm, updateNorm);
931                                                                    segBegin.betaNorm = updateNorm;
932                                                            }
933                                                            else if(segBegin.betaNorm > updateNorm) {
934                                                                    int expShift = updateNorm - segBegin.betaNorm;
935                                                                    //log.info(String.format("Renormalize feature from %d to %d",expNorm, expNorm+expShift));
936                                                                    expNorm += expShift;
937                                                                    expVal += expShift*NORM_FACTOR;
938                                                            }
939                                                            double transPotential = exp(expVal);
940                                                            double update = transPotential * beta[toNode];
941                                                            segBegin.beta[fromNode] += update;
942    
943                                                            // Update expectations
944                                                            if(edgeEvals != null) {
945                                                                    FeatureEvaluation potEvals = edgeEvals[edgeIx];
946                                                                    indices = potEvals.index;
947                                                                    vals = potEvals.value;
948                                                                    ix = 0;
949                                                                    index = indices[i];
950                                                                    while(index != -1) {
951                                                                            if(expectLengthWriter != null)
952                                                                                    FileUtil.safeWrite(expectLengthWriter, String.format("Seq %d Pos %d-%d Expect #%d: %e = %e + Prob: %e * EdgeVal: %e\n", seqNum, prevPos, pos, index, expects[index]+prob*vals[i], expects[index], prob, vals[i]));
953                                                                            expects[index] += prob * vals[ix];
954                                                                            index = indices[++ix];
955                                                                    }
956                                                            }
957                                                            
958                                                            // Updates the transition probabilities
959                                                            //log.info(String.format("EdgeMarg[%d][%d] = %f = %f + %f", prevPos, toNode, buffer.transitionProb[trans] + prob, buffer.transitionProb[trans], prob));
960                                                            segBegin.transitionProb[trans] += prob;
961    
962                                                            if(betaLengthWriter != null) {
963                                                                    FileUtil.safeWrite(betaLengthWriter, String.format(String.format("Beta[%d][%d] = %s = %s + %s beta[%d][%d] * %s exp(Edge: %f Node: %f Stable: %f Trans: %f)\n", 
964                                                                                    prevPos, fromNode, printNorm(segBegin.beta[fromNode], segBegin.betaNorm), printNorm(segBegin.beta[fromNode]-update, segBegin.betaNorm), printNorm(beta[toNode], betaNorm), lengthPos, toNode, printNorm(transPotential, expNorm), 
965                                                                                    edgeVal, nodePotential - stableValue, stableValue, segBegin.mi[trans]))); 
966                                                            }
967                                                    }
968                                            }
969                                            
970                                            ++lbArrayIndex;
971                                            lengthEval = lookbacksForState[lbArrayIndex];
972                                            lookback = lengthEval.lookback;
973                                    }
974                            }
975                    }
976    
977                    int lengthBetaHandling(int seqNum, int prevPos, int pos, double expVal, int fromNode, int toNode, double prevAlpha, int prevAlphaNorm, FeatureEvaluation nodeEvals) {
978                            int norm = ((int) expVal) / NORM_FACTOR;
979                            expVal -= norm * NORM_FACTOR;
980                            
981                            // In addition to updating the beta array, we need to calculate a probability for this segment so we can
982                            // correctly calculate feature expectations
983                            //log.info(String.format("PrevAlpha %f Beta: %f Exp: %f zInv: %f", prevAlpha, beta[toNode], expVal, zInv));
984                            double afterExp = exp(expVal + NORM_FACTOR * (prevAlphaNorm + norm + betaNorm - zNorm));
985                            // This if statement is here in case alpha or beta is 0, but the normalization is large.  This could cause the exp value to go to infinity and result in a NaN probability instead of 0.
986                            if(prevAlpha == 0 || beta[toNode] == 0)
987                                    prob = 0.0;
988                            else
989                                    prob = prevAlpha * beta[toNode] * afterExp * zInv;
990                            if(Double.isNaN(prob)) {
991                                    log.info(String.format("NaN = Alpha: %s * Beta: %s * Seg: %s / Z: %s",printNorm(prevAlpha, prevAlphaNorm), printNorm(beta[toNode], betaNorm), printNorm(exp(expVal), norm), printNorm(1/zInv, zNorm)));
992                                    Assert.a(false, String.format("Seq: %d Pos: %d-%d: Bad prob (NaN) = Alpha: %e * Beta[%d] %e * %e exp(%f Norm  a:%d n:%d b:%d z:%d) * %e",
993                                                    seqNum, prevPos, pos, prevAlpha, toNode, beta[toNode], afterExp, expVal, prevAlphaNorm, norm, betaNorm, zNorm, zInv));
994                            }
995                            
996                            // Now update expectations for all node features for this edge
997                            short[] indices = nodeEvals.index;
998                            float[] vals = nodeEvals.value;
999                            int i = 0;
1000                            short index = indices[i];
1001                            while(index != -1) {
1002                                    if(expectLengthWriter != null)
1003                                            FileUtil.safeWrite(expectLengthWriter, String.format("Seq %d Pos %d-%d Expect #%d: %e = %e + Prob: %e * NodeVal: %e\n", seqNum, prevPos, pos, index, expects[index]+prob*vals[i], expects[index], prob, vals[i]));
1004                                    if(prob != 0.0)
1005                                            expects[index] += prob * vals[i];
1006                                    index = indices[++i];
1007                            }
1008                            
1009                            if(nodeMarginalWriter != null) {
1010                                    FileUtil.safeWrite(nodeMarginalWriter, String.format(
1011                                                    "NodeMarg[%d][%d] = %f = %f + Alpha[%d][%d]: %s * Beta[%d][%d]: %s * seg: %s / Z: %s\n", 
1012                                                    lengthPos, toNode, nodeProb[toNode]+prob, nodeProb[toNode], prevPos, fromNode, printNorm(prevAlpha, prevAlphaNorm), 
1013                                                    pos, toNode, printNorm(beta[toNode], betaNorm), 
1014                                                    printNorm(expVal, norm), printNorm(1/zInv, zNorm)));
1015                            }
1016                            nodeProb[toNode] += prob;
1017                            
1018                            return norm;
1019                    }
1020                    
1021                    void updateExpectations(int seqNum, int pos, double[] transitionProb) {
1022                            cacheProcessor.evaluatePosition(seqNum, pos);
1023    
1024                            // First compute the expectations for the length-dependent states.
1025                            int invalidIndex = (seqOffset+pos)* modelInfo.nPotentials;
1026    
1027                            boolean lengthNode = false;
1028                            for (short potential : modelInfo.orderedPotentials) {
1029                                    boolean invalid = invalidTransitions[invalidIndex + potential];
1030                                    
1031                                    double prob = Double.NaN;
1032                                    if(potential < modelInfo.nStates) {
1033                                            lengthNode = modelInfo.maxStateLengths[potential] > 1;
1034                                            prob = nodeProb[potential];
1035                                    }
1036                                    else {
1037                                            if(pos == 0)
1038                                                    continue;
1039                                            prob = (lengthNode ? transitionProb : edgeProb)[potential-modelInfo.nStates]; 
1040                                            //prob = transitionProb[potential-nStates]; 
1041                                    }
1042                                    
1043                                    if (!invalid) {
1044                                            // Iterate through features for this potential.
1045                                            FeatureEvaluation potEvals = evals[potential];
1046                                            short[] indices = potEvals.index;
1047                                            float[] vals = potEvals.value;
1048                                            int i = 0;
1049                                            short index = indices[i];
1050                                            while(index != -1) {
1051                                                    if(expectWriter != null)
1052                                                            FileUtil.safeWrite(expectWriter, String.format("Seq %d Pos %d Expect #%d: %e = %e + Prob: %e * Val: %e\n", seqNum, pos, index, expects[index]+prob*vals[i], expects[index], prob, vals[i]));
1053                                                    expects[index] += prob*vals[i];
1054                                                    index = indices[++i];
1055                                            }
1056                                    }
1057                            }
1058                    }
1059            }
1060            
1061            void logBuf() {
1062                    int l = lookbackBuffer.length;
1063                    String s = "";
1064                    for(int i=0; i < l; ++i) {
1065                            s += lookbackBuffer.get(i).pos + " ";
1066                    }
1067                    log.info(s);
1068            }
1069            
1070            void logBufBeta() {
1071                    int l = lookbackBuffer.length;
1072                    String s = "";
1073                    for(int i=0; i < l; ++i) {
1074                            s += ColtUtil.format(lookbackBuffer.get(i).beta) + " ";
1075                    }
1076                    log.info(s);
1077            }
1078            
1079            /**
1080             * Computes an unexponentiated mi matrix and updates stable states. Used to create caches for lookback searches.
1081             * 
1082             */
1083            void cacheMi(int seqNum, double[] mi, double[] prevStable, double [] newStable, int miPos) {
1084                    if (miPos < 0)
1085                            return;
1086                    //calcMi(mi, overallPosition, cacheStart, cacheStop, false);
1087                    calcMi(mi, seqNum, miPos, false);
1088                    // Go through the mi matrix and for all states with length dependence compute stable values for self transitions
1089                    for (int i = 0; i < modelInfo.nStates; ++i) {
1090                            if (modelInfo.maxStateLengths[i] > 1) {
1091                                    // These are all log values so we add them
1092                                    newStable[i] = prevStable[i];
1093                                    double trans = mi[modelInfo.selfTransitions[i]];
1094                                    if(!Double.isInfinite(trans)) {
1095                                            //log.debug("Pos: "+miPos+" State: "+i+" Trans: "+trans+" Total: "+(newStable[i]+trans));
1096                                            newStable[i] += trans;
1097                                    }
1098                            }
1099                    }
1100            }
1101            
1102            /**
1103             * This is one of the most time critical parts of the entire solver. The goal is to update the transition matrix.
1104             * This function makes a lot of assumptions in order to maximize performance.
1105             * 
1106             * To maximize performance, we want to make one pass through the Mi matrix, setting each entry to its correct value.
1107             * The value for each entry is the exponent of the sum of the weighted feature values of the edge for that entry and
1108             * its corresponding node. The entry s0,s1 consists of the s0,s1 edge and the s1 node.
1109             */
1110            void calcMi(double[] mi, int seq, int pos, boolean doExp) {
1111                    cacheProcessor.evaluatePosition(seq, pos);
1112                    double nodeVal = Double.NaN;
1113                    int overallPosition = modelInfo.seqOffsets[seq]+pos;
1114                    int invalidIndex = overallPosition*modelInfo.nPotentials;
1115                    for(short potential : modelInfo.orderedPotentials) {
1116                            boolean invalid = invalidTransitions[invalidIndex + potential];
1117                            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
1118    
1119                            // Add up all features for this potential.
1120                            FeatureEvaluation potEvals = evals[potential];
1121                            short[] indices = potEvals.index;
1122                            float[] vals = potEvals.value;
1123                            int i = 0;
1124                            short index = indices[i];
1125                            while(index >= 0) {
1126                                    // An invalid potential is indicated by a feature value of Short.MAX_VALUE
1127                                    features += vals[i]*lambda[index]; 
1128                                    index = indices[++i];
1129                            }
1130                            if(index == Short.MIN_VALUE) {
1131                                    features = Double.NEGATIVE_INFINITY; 
1132                            }
1133                                    
1134                            if(potential < modelInfo.nStates) {
1135                                    nodeVal = features;
1136                            }
1137                            else {
1138                                    //log.debug(String.format("Mi[%d, %d] = %f, adding in %f to get %f", feat.yprev(), feat.y(), val, feat.value(), val*exp(feat.value()*param[feat.index()])));
1139                                    int transition = potential - modelInfo.nStates;
1140                                    double val = features + nodeVal;
1141                                    if(doExp)
1142                                            val = exp(val);
1143                                    mi[transition] = val;
1144                            }
1145                    }               
1146            }
1147    
1148            /**
1149             * Given a vector with an existing normalization factor, convert it to a new normalization factor by scaling the
1150             * entries.
1151             */
1152            private static final void renormalize(final double[] vec, final int currentNorm, final int newNorm) {
1153                    // Instead of dividing by the different (new-current), we reverse the subtraction to negate the exponent and
1154                    // then multiply.
1155                    double factor = exp(NORM_FACTOR * (currentNorm - newNorm));
1156                    //log.info(factor);
1157                    //log.info(ColtUtil.format(vec));
1158                    int len = vec.length;
1159                    for (int i = 0; i < len; ++i) {
1160                            if(vec[i] != 0.0)
1161                                    vec[i] *= factor;
1162                    }
1163                    //log.info(ColtUtil.format(vec));
1164            }
1165    
1166            /** Given a vector, computes a normalization factor for the entries and scales them according to that factor. */
1167            private static final int normalize(final double[] vec) {
1168                    double sum = 0.0;
1169                    for(double val : vec) {
1170                            sum += val;
1171                    }
1172                    if(sum == 0.0 || (sum > NORM_MIN && sum < NORM_MAX)) {
1173                            // No normalization required, our vector is in range.
1174                            return 0;
1175                    }
1176                    if(debug)
1177                            Assert.a(!Double.isNaN(sum));
1178    
1179                    //log.info("performing normalization");
1180                    double val = log(sum);
1181                    int norm = (int) val / NORM_FACTOR;
1182                    val = exp(NORM_FACTOR * norm);
1183                    int len = vec.length;
1184                    for (int i = 0; i < len; ++i) {
1185                            vec[i] /= val;
1186                    }
1187                    return norm;
1188            }
1189    
1190            private static final double exp(final double val) {
1191                    return Math.exp(val);
1192            }
1193    
1194            private static final double log(final double val) {
1195                    return Math.log(val);
1196            }
1197    
1198            /**
1199             * This object holds information about previous positions during the computation of betas and expectations. This
1200             * allows us to quickly access data about previous positions. These objects are kept in a recycling buffer that
1201             * keeps one buffer for each possible lookback.
1202             * 
1203             * One tricky aspect of this is that the details change slightly between the forward and backwards pass.  On the forward
1204             * pass, the lookback contains the information in the normal way.  In the backwards pass, stable states and transitions are 
1205             * shifted back one base compared to the betas.
1206             */
1207            final class LookbackBuffer {
1208                    int pos;
1209                    
1210                    // The mi matrix for transitioning from pos-lookback-1 to pos-lookback
1211                    double[] mi = new double[modelInfo.nPotentials];
1212    
1213                    // The weighted sum of feature values for staying in this position from the beginning to pos-lookback
1214                    double[] stableState = new double[modelInfo.nStates];
1215    
1216                    // Initial values of the beta vector for somelength dependent states.
1217                    double[] beta = new double[modelInfo.nStates];
1218                    
1219                    // Norm of prevBeta.
1220                    int betaNorm;
1221                    
1222                    // Stores the probability of all segments begining at this position using this transition.
1223                    double[] transitionProb = new double[modelInfo.nTransitions];
1224    
1225                    /** mi and stableStates are cleared as new values are entered. This fixes the others */
1226                    void clear()
1227                    {
1228                            pos = -1;
1229                            Arrays.fill(beta, 0.0);
1230                            betaNorm = Integer.MIN_VALUE;
1231                            Arrays.fill(transitionProb, 0.0);
1232                    }
1233            }
1234    
1235            public static final String printNorm(final double value, final int norm) {
1236                    if( value == 0.0)
1237                            return "0 ("+norm+")";
1238                    if( Double.isNaN(value))
1239                            return "NaN ("+norm+")";
1240                    int exponent = (int) log(value);
1241    
1242                    double eValue = value/exp(exponent);
1243                    if(Double.isNaN(eValue)) {
1244                            return String.format("NaN(%e n:%d)", value, norm);
1245                    }
1246                    //return String.format("%e(%d) %fe%d", value, norm, eValue, exponent+norm*NORM_FACTOR);
1247                    return String.format("%fe%d", eValue, exponent+norm*NORM_FACTOR);
1248            }
1249    
1250            /** gets the cache processor used to access feature evaluations
1251             * @return the configured cache processor
1252             */
1253            public CacheProcessor getCacheProcessor() {
1254                    return cacheProcessor;
1255            }
1256    
1257            /** sets the cache processor used to access feature evaluations
1258             * @param cacheProcessor the cache processor to use
1259             */
1260    
1261            public void setCacheProcessor(CacheProcessor cacheProcessor) {
1262                    this.cacheProcessor = cacheProcessor;
1263            }
1264    
1265            public String getAlphaLengthFile() {
1266                    return alphaLengthFile;
1267            }
1268    
1269            public void setAlphaLengthFile(String alphaLengthFile) {
1270                    this.alphaLengthFile = alphaLengthFile;
1271            }
1272    
1273            public String getAlphaFile() {
1274                    return alphaFile;
1275            }
1276    
1277            public void setAlphaFile(String alphaFile) {
1278                    this.alphaFile = alphaFile;
1279            }
1280    
1281            public String getExpectFile() {
1282                    return expectFile;
1283            }
1284    
1285            public void setExpectFile(String expectFile) {
1286                    this.expectFile = expectFile;
1287            }
1288    
1289            public String getExpectLengthFile() {
1290                    return expectLengthFile;
1291            }
1292    
1293            public void setExpectLengthFile(String expectLengthFile) {
1294                    this.expectLengthFile = expectLengthFile;
1295            }
1296    
1297            public String getNodeMarginalFile() {
1298                    return nodeMarginalFile;
1299            }
1300    
1301            public void setNodeMarginalFile(String nodeMarginalFile) {
1302                    this.nodeMarginalFile = nodeMarginalFile;
1303            }
1304    
1305            public String getBetaLengthFile() {
1306                    return betaLengthFile;
1307            }
1308    
1309            public void setBetaLengthFile(String betaLengthFile) {
1310                    this.betaLengthFile = betaLengthFile;
1311            }
1312    
1313            public double[] getFeatureSums() {
1314                    return this.featureSums.clone();
1315            }
1316    }