001    package calhoun.analysis.crf.solver.semimarkov;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
010    import calhoun.analysis.crf.ModelManager;
011    import calhoun.analysis.crf.io.TrainingSequence;
012    import calhoun.analysis.crf.solver.CacheProcessor;
013    import calhoun.analysis.crf.solver.LogFiles;
014    import calhoun.analysis.crf.solver.LookbackBuffer;
015    import calhoun.analysis.crf.solver.RecyclingBuffer;
016    import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
017    import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
018    import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
019    import calhoun.analysis.crf.solver.CacheProcessor.StatePotentials;
020    import calhoun.util.Assert;
021    import calhoun.util.ColtUtil;
022    
023    /** computes the likelihood of the true path for a semi-Markov CRF.  The likelihood is normalized to a per label likelihood. 
024     * <h2>Debugging output</h2>
025     * To get a better understanding of what the objective function is doing, several differn properties can be set that
026     * cause the objective function to write out trace files showing its calculations during training.  Usually when turning
027     * these options on, you should set <code>maxIters = 1</code> and <code>requireConvergence = false</code> in your optimizer
028     * to do only a single training iteration, possibly setting the starts to some predetermined value.  Each of these
029     * properties can be configured with a filename and each time {@link #apply} is called, the file will be overwritten with 
030     * data from the current call.  The logging options are:
031     * <ul>
032     * <li> <b><code>alphaFile</code></b> - computation of alpha values for Markov states, includes all nodes and edges.
033     * <li> <b><code>alphaLengthFile</code></b> - computation of alpha values for semi-Markov states , includes all segments
034     * <li> <b><code>expectFile</code></b> - computation of expected values for each Markov feature 
035     * <li> <b><code>expectLengthFile</code></b> - computation of expected values for each semi-Markov feature  
036     * <li> <b><code>nodeMarginalFile</code></b> - computation of marginal probability of each state at each position 
037     * </ul>
038     * */
039    public class CleanMaximumLikelihoodSemiMarkovGradient implements CRFObjectiveFunctionGradient {
040            static final Log log = LogFactory.getLog(CleanMaximumLikelihoodSemiMarkovGradient.class);
041            public static final boolean debug = log.isDebugEnabled();
042            public static final double ASSERTION_TOLERANCE = 0.0001;
043            
044            public static final int NORM_FACTOR = 50;
045            public static final double NORM_MIN = Math.exp(-NORM_FACTOR);
046            public static final double NORM_MAX = Math.exp(NORM_FACTOR);
047    
048            final LogFiles logs = new LogFiles();
049            
050            SolverSetup modelInfo;
051            CacheProcessor cacheProcessor;
052            FeatureEvaluation[] evals;
053            LengthFeatureEvaluation[][] lengthEvals;
054            boolean[] invalidTransitions;
055            
056            // / Cache feature information
057            // / Cached value of the Mi matrix for all of the features present at every position
058            // / Mi is stored as a sparse matrix
059            short maxLookback;
060            StatePotentials[] statesWithLookback;
061            StatePotentials[] statesWithoutLookback;
062            int iter = 0;
063            double[][] alphas;
064            int[] alphaNorms;
065            double[] starterAlpha;
066            int nSemiMarkovStates;
067    
068            // At any given point, lookbackBuffer.get(x) returns the information about a lookback of x. Lookbacks start at 0.
069            RecyclingBuffer<LookbackBuffer> lookbackBuffer;
070            LookbackBuffer nextBuffer;
071    
072            double[] lambda;
073            double logZ;
074            int zNorm;
075            double zInv;
076            double[] expects;
077    
078            AlphaLengthFeatureProcessor alphaProcessor; 
079            BetaLengthFeatureProcessor betaProcessor;
080            
081            // We publish feature sums 
082            private double[] featureSums;
083                    
084            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
085                    cacheProcessor.setTrainingData(fm, data);
086    
087                    modelInfo = cacheProcessor.getSolverSetup();
088                    Assert.a(modelInfo.maxStateLengths != null, "Maximum state lengths not set.");
089                    Assert.a(modelInfo.maxStateLengths.length == modelInfo.nStates, "Maximum state lengths array was length ("+modelInfo.maxStateLengths.length+").  Must have one entry for each state "+modelInfo.nStates+")");
090                    evals = cacheProcessor.getFeatureEvaluations();
091                    lengthEvals = cacheProcessor.getLengthFeatureEvaluations();
092                    invalidTransitions = cacheProcessor.getInvalidTransitions();
093    
094                    // Create local references to cache values
095                    maxLookback = modelInfo.maxLookback;
096                    statesWithLookback = modelInfo.statesWithLookback;
097                    statesWithoutLookback = modelInfo.statesWithoutLookback;
098                    nSemiMarkovStates = modelInfo.statesWithLookback.length;
099    
100                    // Initialize betas (for use later, in the gradient computation)
101                    alphas = new double[modelInfo.longestSeq][modelInfo.nStates];
102                    alphaNorms = new int[modelInfo.longestSeq];
103                    expects = new double[modelInfo.nFeatures];
104    
105                    LookbackBuffer[] bufferContents = new LookbackBuffer[maxLookback+3];
106                    for(int i = 0; i<maxLookback+3; ++i) {
107                            bufferContents[i] = new LookbackBuffer(modelInfo.nStates, modelInfo.nTransitions);
108                    }
109                    lookbackBuffer = new RecyclingBuffer<LookbackBuffer>(bufferContents);
110                    nextBuffer = new LookbackBuffer(modelInfo.nStates, modelInfo.nTransitions);
111                    
112                    alphaProcessor = new AlphaLengthFeatureProcessor(this); 
113                    betaProcessor = new BetaLengthFeatureProcessor(this); 
114    
115                    starterAlpha = new double[modelInfo.nStates];
116            }
117    
118            public double apply(double[] param, double[] grad) {
119                    log.debug(String.format("Beginning It: %d Weights: %s", iter, ColtUtil.format(param)));
120                    logs.open();
121                    lambda = param;
122                    Arrays.fill(grad, 0);
123                    double totalZ = 0.0;
124                    double result = 0.0;
125    
126                    try {
127                            // Iterate through sequences
128                            Arrays.fill(expects, 0);
129                            for (int i = 0; i < modelInfo.nSeqs; ++i) {
130                                    int len = modelInfo.seqOffsets[i + 1] - modelInfo.seqOffsets[i];
131    
132                                    alphaAndBetaPass(i, len);
133                                    
134                                    // Update for the next sequence
135                                    totalZ += logZ;
136                            }
137                            
138                            // sum_j lambda_j F_j(xk, yk)
139                            double[] featureSums = cacheProcessor.getFeatureSums();
140                            this.featureSums = featureSums;
141                            for (int j = 0; j < modelInfo.nFeatures; ++j) {
142                                    result += featureSums[j] * param[j];
143                                    grad[j] = featureSums[j] - expects[j];
144                            }
145                            log.debug("Path Value: "+result+" Norm: "+totalZ);
146                            result -= totalZ;
147                            if (log.isInfoEnabled()) {
148                                    log.info(String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", iter, exp(result), result,
149                                                    ColtUtil.norm(grad), ColtUtil.format(featureSums), ColtUtil.format(expects), ColtUtil.format(param), ColtUtil.format(grad)));
150                            }
151                            Assert.a(exp(result) <= 1.0, "Likelihood is greater than 1.");
152    
153                            // Normalize by the length of the sequence
154                            result = result/modelInfo.totalPositions; 
155                            for(int i=0; i<grad.length; ++i) {
156                                    grad[i] = grad[i]/modelInfo.totalPositions;
157                            }
158    
159                            iter += 1;
160                    }
161                    finally {
162                            logs.close();
163                    }
164                    return result;
165            }
166    
167            public void clean() {
168            }
169                    
170            void alphaAndBetaPass(int i, int len) {
171                    // Work forwards, computing alphas
172                    alphaProcessor.computeAlpha(i, len);
173            
174                    // Since the final beta array is all ones, we can sum the alphas to get the Z
175                    double sum = 0.0;
176                    for (double val : alphas[len - 1]) {
177                            sum += val;
178                    }
179            
180                    logZ = log(sum) + NORM_FACTOR * (alphaNorms[len - 1]);
181                    zNorm = ((int) logZ) / NORM_FACTOR;
182                    zInv = exp(zNorm * NORM_FACTOR - logZ);
183                    log.debug("Seq: "+i+" Z: "+printNorm(1/zInv, zNorm));
184            
185                    // Work backwards, computing betas and expectations.
186                    betaProcessor.computeBetasAndExpectations(i, len);
187            
188                    //if(debug) {
189                    //      logFeatureSums(i);
190                    //}
191            }
192    
193            void logBuf() {
194                    int l = lookbackBuffer.length;
195                    String s = "";
196                    for(int i=0; i < l; ++i) {
197                            s += lookbackBuffer.get(i).pos + " ";
198                    }
199                    log.info(s);
200            }
201            
202            void logBufBeta() {
203                    int l = lookbackBuffer.length;
204                    String s = "";
205                    for(int i=0; i < l; ++i) {
206                            s += ColtUtil.format(lookbackBuffer.get(i).beta) + " ";
207                    }
208                    log.info(s);
209            }
210            
211            /**
212             * Computes an unexponentiated mi matrix and updates stable states. Used to create caches for lookback searches.
213             * 
214             */
215            void cacheMi(int seqNum, double[] mi, double[] prevStable, double [] newStable, int miPos) {
216                    if (miPos < 0)
217                            return;
218                    //calcMi(mi, overallPosition, cacheStart, cacheStop, false);
219                    calcMi(mi, seqNum, miPos, false);
220                    // Go through the mi matrix and for all states with length dependence compute stable values for self transitions
221                    for (int i = 0; i < modelInfo.nStates; ++i) {
222                            if (modelInfo.maxStateLengths[i] > 1) {
223                                    // These are all log values so we add them
224                                    newStable[i] = prevStable[i];
225                                    double trans = mi[modelInfo.selfTransitions[i]];
226                                    if(!Double.isInfinite(trans)) {
227                                            //log.debug("Pos: "+miPos+" State: "+i+" Trans: "+trans+" Total: "+(newStable[i]+trans));
228                                            newStable[i] += trans;
229                                    }
230                            }
231                    }
232            }
233            
234            /**
235             * This is one of the most time critical parts of the entire solver. The goal is to update the transition matrix.
236             * This function makes a lot of assumptions in order to maximize performance.
237             * 
238             * To maximize performance, we want to make one pass through the Mi matrix, setting each entry to its correct value.
239             * The value for each entry is the exponent of the sum of the weighted feature values of the edge for that entry and
240             * its corresponding node. The entry s0,s1 consists of the s0,s1 edge and the s1 node.
241             */
242            void calcMi(double[] mi, int seq, int pos, boolean doExp) {
243                    cacheProcessor.evaluatePosition(seq, pos);
244                    double nodeVal = Double.NaN;
245                    int overallPosition = modelInfo.seqOffsets[seq]+pos;
246                    int invalidIndex = overallPosition*modelInfo.nPotentials;
247                    for(short potential : modelInfo.orderedPotentials) {
248                            boolean invalid = invalidTransitions[invalidIndex + potential];
249                            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
250    
251                            // Add up all features for this potential.
252                            FeatureEvaluation potEvals = evals[potential];
253                            short[] indices = potEvals.index;
254                            float[] vals = potEvals.value;
255                            int i = 0;
256                            short index = indices[i];
257                            while(index >= 0) {
258                                    // An invalid potential is indicated by a feature value of Short.MAX_VALUE
259                                    features += vals[i]*lambda[index]; 
260                                    index = indices[++i];
261                            }
262                            if(index == Short.MIN_VALUE) {
263                                    features = Double.NEGATIVE_INFINITY; 
264                            }
265                                    
266                            if(potential < modelInfo.nStates) {
267                                    nodeVal = features;
268                            }
269                            else {
270                                    //log.debug(String.format("Mi[%d, %d] = %f, adding in %f to get %f", feat.yprev(), feat.y(), val, feat.value(), val*exp(feat.value()*param[feat.index()])));
271                                    int transition = potential - modelInfo.nStates;
272                                    double val = features + nodeVal;
273                                    if(doExp)
274                                            val = exp(val);
275                                    mi[transition] = val;
276                            }
277                    }               
278            }
279    
280            /**
281             * Given a vector with an existing normalization factor, convert it to a new normalization factor by scaling the
282             * entries.
283             */
284            static final void renormalize(final double[] vec, final int currentNorm, final int newNorm) {
285                    // Instead of dividing by the different (new-current), we reverse the subtraction to negate the exponent and
286                    // then multiply.
287                    double factor = exp(NORM_FACTOR * (currentNorm - newNorm));
288                    //log.info(factor);
289                    //log.info(ColtUtil.format(vec));
290                    int len = vec.length;
291                    for (int i = 0; i < len; ++i) {
292                            if(vec[i] != 0.0)
293                                    vec[i] *= factor;
294                    }
295                    //log.info(ColtUtil.format(vec));
296            }
297    
298            /** Given a vector, computes a normalization factor for the entries and scales them according to that factor. */
299            static final int normalize(final double[] vec) {
300                    double sum = 0.0;
301                    for(double val : vec) {
302                            sum += val;
303                    }
304                    if(sum == 0.0 || (sum > NORM_MIN && sum < NORM_MAX)) {
305                            // No normalization required, our vector is in range.
306                            return 0;
307                    }
308                    if(debug)
309                            Assert.a(!Double.isNaN(sum));
310    
311                    //log.info("performing normalization");
312                    double val = log(sum);
313                    int norm = (int) val / NORM_FACTOR;
314                    val = exp(NORM_FACTOR * norm);
315                    int len = vec.length;
316                    for (int i = 0; i < len; ++i) {
317                            vec[i] /= val;
318                    }
319                    return norm;
320            }
321    
322            static final double exp(final double val) {
323                    return Math.exp(val);
324            }
325    
326            static final double log(final double val) {
327                    return Math.log(val);
328            }
329    
330            final void logFeatureSums(int seqNum) {
331                    double[][] seqFeatureSums = cacheProcessor.getSequenceFeatureSums();
332                    if(seqFeatureSums != null) {
333                            double seqResult = 0.0;
334                            for (int j = 0; j < modelInfo.nFeatures; ++j) {
335                                    seqResult += seqFeatureSums[seqNum][j] * lambda[j];
336                            }
337                            log.debug(String.format("Seq: %d L: %g LL: %f Training path: %f Z: %f", seqNum, exp(seqResult-logZ), seqResult-logZ, seqResult, logZ));
338                            Assert.a(exp(seqResult-logZ) < 1.0);
339                    }
340            }
341    
342            public static final String printNorm(final double value, final int norm) {
343                    if( value == 0.0)
344                            return "0 ("+norm+")";
345                    if( Double.isNaN(value))
346                            return "NaN ("+norm+")";
347                    int exponent = (int) log(value);
348    
349                    double eValue = value/exp(exponent);
350                    if(Double.isNaN(eValue)) {
351                            return String.format("NaN(%e n:%d)", value, norm);
352                    }
353                    //return String.format("%e(%d) %fe%d", value, norm, eValue, exponent+norm*NORM_FACTOR);
354                    return String.format("%fe%d", eValue, exponent+norm*NORM_FACTOR);
355            }
356    
357            /** gets the cache processor used to access feature evaluations
358             * @return the configured cache processor
359             */
360            public CacheProcessor getCacheProcessor() {
361                    return cacheProcessor;
362            }
363    
364            /** sets the cache processor used to access feature evaluations
365             * @param cacheProcessor the cache processor to use
366             */
367    
368            public void setCacheProcessor(CacheProcessor cacheProcessor) {
369                    this.cacheProcessor = cacheProcessor;
370            }
371    
372            public String getAlphaLengthFile() {
373                    return logs.alphaLengthFile;
374            }
375    
376            public void setAlphaLengthFile(String alphaLengthFile) {
377                    logs.alphaLengthFile = alphaLengthFile;
378            }
379    
380            public String getAlphaFile() {
381                    return logs.alphaFile;
382            }
383    
384            public void setAlphaFile(String alphaFile) {
385                    logs.alphaFile = alphaFile;
386            }
387    
388            public String getExpectFile() {
389                    return logs.expectFile;
390            }
391    
392            public void setExpectFile(String expectFile) {
393                    logs.expectFile = expectFile;
394            }
395    
396            public String getExpectLengthFile() {
397                    return logs.expectLengthFile;
398            }
399    
400            public void setExpectLengthFile(String expectLengthFile) {
401                    logs.expectLengthFile = expectLengthFile;
402            }
403    
404            public String getNodeMarginalFile() {
405                    return logs.nodeMarginalFile;
406            }
407    
408            public void setNodeMarginalFile(String nodeMarginalFile) {
409                    logs.nodeMarginalFile = nodeMarginalFile;
410            }
411    
412            public String getBetaLengthFile() {
413                    return logs.betaLengthFile;
414            }
415    
416            public void setBetaLengthFile(String betaLengthFile) {
417                    logs.betaLengthFile = betaLengthFile;
418            }
419    }