001    /**
002     * 
003     */
004    package calhoun.analysis.crf.solver;
005    
006    import java.util.Arrays;
007    import java.util.List;
008    
009    import org.apache.commons.logging.Log;
010    import org.apache.commons.logging.LogFactory;
011    
012    import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
013    import calhoun.analysis.crf.ModelManager;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
016    import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
017    import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
018    import calhoun.util.Assert;
019    import calhoun.util.ColtUtil;
020    
021    /** computes the likelihood of the true path for a Markov CRF.  The likelihood is normalized to a per label likelihood so
022     * that likelihood of different length paths can be meaningfully compared and a single set of optimization tolerances can be used.
023     * Must be configured with a {@link CacheProcessor}. */
024    public class MaximumLikelihoodGradient implements CRFObjectiveFunctionGradient {
025            private static final Log log = LogFactory.getLog(MaximumLikelihoodGradient.class);
026            boolean debug = log.isDebugEnabled();
027    
028            CacheProcessor cacheProcessor = new AllSparseLengthCacheProcessor();
029            SolverSetup modelInfo;
030            FeatureEvaluation[] evals;
031            boolean[] invalidTransitions;
032            
033            /// Cached value of the Mi matrix for all of the features present at every position
034            /// Mi is stored as a sparse matrix
035            int miLength;
036            double[] mi;
037    
038            int iter = 0;
039    
040            double[] prevAlpha;
041            double[] alpha;
042            double[][] betas;
043            double[] betaNorms;
044            double[] expects;
045    
046            // We publish feature sums 
047            private double[] featureSums;
048            
049            /** gets the cache processor used to access feature evaluations
050             * @return the configured cache processor
051             */
052            public CacheProcessor getCacheProcessor() {
053                    return cacheProcessor;
054            }
055    
056            /** sets the cache processor used to access feature evaluations
057             * @param cacheProcessor the cache processor to use
058             */
059            public void setCacheProcessor(CacheProcessor cacheProcessor) {
060                    this.cacheProcessor = cacheProcessor;
061            }
062    
063            public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
064                    cacheProcessor.setTrainingData(fm, data);
065                    modelInfo = cacheProcessor.getSolverSetup();
066                    evals = cacheProcessor.getFeatureEvaluations();
067                    invalidTransitions = cacheProcessor.getInvalidTransitions();
068                    
069                    miLength = modelInfo.nTransitions;
070                    mi = new double[miLength];
071    
072                    expects = new double[modelInfo.nFeatures];
073                    prevAlpha = new double[modelInfo.nStates];
074                    alpha = new double[modelInfo.nStates];
075    
076                    // Initialize betas (for use later, in the gradient computation)
077                    betas = new double[modelInfo.longestSeq][modelInfo.nStates];
078                    betaNorms = new double[modelInfo.longestSeq];
079            }
080            
081            public double apply(double[] param, double[] grad) {
082                    // Initialize values
083                    Arrays.fill(grad, 0);
084                    double result = 0.0;
085                    Arrays.fill(expects, 0);
086    
087                    // Iterate through sequences
088                    int seqStart = 0;
089                    for(int i = 0; i<modelInfo.nSeqs; ++i) {
090                            int len = modelInfo.seqOffsets[i+1] - modelInfo.seqOffsets[i]; 
091                            
092                            // Work backwards, computing betas.
093                            Arrays.fill(betas[len-1], 1.0);
094                            betaNorms[len-1] = 0;
095                            for (int pos = len - 1; pos > 0; --pos) {
096                                    calcMi(i, pos, param);
097                                    quickBetaUpdate(betas[pos], betas[pos-1]);
098    
099                                    double n = normalizePotential(betas[pos-1]);
100                                    betaNorms[pos-1] = betaNorms[pos] + log(n);
101                            }
102                            
103                            // Now work forwards
104                            double logZ = Double.NEGATIVE_INFINITY; // This should always get initialized.  Blow up if not.
105                            double alphaNorm = 0;
106                            double prevAlphaNorm = 0;
107                            for(int pos=0; pos<len; ++pos) {
108                                    double[] beta = betas[pos];
109                                    double betaNorm = betaNorms[pos];
110                                    if(pos == 0) {
111                                            calcStartAlpha(i, param);
112                                            alphaNorm = log(normalizePotential(alpha));
113                                            
114                                            // We now have everything needed to compute Z.
115                                            logZ = log(ColtUtil.dotProduct(alpha, beta)) + betaNorm + alphaNorm;
116                                            //log.info("Z = "+logZ);
117                                    }
118                                    else {
119                                            calcMi(i, pos, param);
120                                            //r * M (or M'*r)
121                                            //mi.zMult(prevAlpha, alpha, 1, 0, true);
122                                            quickAlphaUpdate(prevAlpha, alpha);
123                                            alphaNorm = prevAlphaNorm + log(normalizePotential(alpha));
124                                            
125                                            // Verify our calculations by checking the Z.
126                                            // Expensive, so only enable as needed during debugging
127                                            double newZ = log(ColtUtil.dotProduct(alpha, beta)) + betaNorm + alphaNorm;
128                                            Assert.a(Math.abs(newZ-logZ) < 0.0000001*Math.abs(logZ), "New Z:",newZ," Old was: ", logZ);
129                                    }
130                                    
131                                    // Iterate through the potentials to update feature expectations
132                                    double nodeNorm = exp(alphaNorm + betaNorm - logZ);
133                                    double edgeNorm = exp(prevAlphaNorm + betaNorm - logZ);
134                                    updateExpectations(i, pos, nodeNorm, edgeNorm, beta);
135    
136                                    if(debug) {
137                                            if((i < 2 || i == modelInfo.nSeqs-1) && (pos < 2 || pos >= len-2)) {
138                                                    log.debug(String.format("Pos: %d expects: %s alphas: %s (norm %f) betas: %s (norm %f)", pos, ColtUtil.format(expects), ColtUtil.format(alpha), alphaNorm, ColtUtil.format(beta), betaNorm));
139                                            }
140                                    }
141    
142                                    // Recycle the arrays
143                                    double[] swap = prevAlpha;
144                                    prevAlpha = alpha;
145                                    alpha = swap;
146                                    prevAlphaNorm = alphaNorm;
147                            }
148    
149                            result -=  logZ;
150                            seqStart += len;
151                    }
152                    
153                    // sum_j lambda_j F_j(xk, yk)
154                    double[] featureSums = cacheProcessor.getFeatureSums();
155                    for (int j = 0; j < modelInfo.nFeatures; ++j) {
156                            result += featureSums[j] * param[j];
157                            grad[j] = featureSums[j] - expects[j];
158                    }
159    
160                    if(log.isInfoEnabled()) {
161                            // Report average per-label numbers.
162                            log.info(String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad (unnorm): %s", iter, exp(result/modelInfo.totalPositions), result/modelInfo.totalPositions, ColtUtil.norm(grad)/modelInfo.totalPositions, ColtUtil.format(featureSums), ColtUtil.format(expects), ColtUtil.format(param), ColtUtil.format(grad)));
163                    }
164                    iter += 1;
165                    result = result/modelInfo.totalPositions; 
166                    for(int i=0; i<grad.length; ++i) {
167                            grad[i] = grad[i]/modelInfo.totalPositions;
168                    }
169            
170                    this.featureSums = featureSums;
171    
172                    return result;
173            }
174    
175            public void clean() {
176            }
177            
178            /** This is one of the most time critical parts of the entire solver.  The goal is to update the transition matrix.
179             * This function makes a lot of assumptions in order to maximize performance.
180             * 
181             * To maximize performance, we want to make one pass through the Mi matrix, setting each entry to its correct value.
182             * The value for each entry is the exponent of the sum of the weighted feature values of the edge for that entry and 
183             * its corresponding node.  The entry s0,s1 consists of the s0,s1 edge and the s1 node. 
184             * 
185             * Because node features are applied to more than 1 entry in the matrix, we use a sorting of all of the features where
186             * each node preceeds all its corresponding edges.  This allows us to keep track of only 1 node value at a time and easily
187             * apply it to all its edge features.
188             * 
189             * As we evaluate each potential, we check the cache to see if it is valid at this position and to get any features values.
190             * Note that this function very much depends on the fact that the entries in the cache will be in the correct order.
191             * 
192             * The other wrinkle is that for features that always occur (constant features), we pull them from the constant mi array, not
193             * from the cache.
194             * 
195             *  This function is also used to calculate the Mi Matrix for the constant features.
196             *  */
197            void calcMi(int seq, int pos, double[] lambda) {
198                    cacheProcessor.evaluatePosition(seq, pos);
199                    double nodeVal = Double.NaN;
200                    int overallPosition = modelInfo.seqOffsets[seq]+pos;
201                    int invalidIndex = overallPosition*modelInfo.nPotentials;
202                    for(short potential : modelInfo.orderedPotentials) {
203                            boolean invalid = invalidTransitions[invalidIndex + potential];
204                            double features = 0.0;
205                            
206                            if(invalid) {
207                                    features = Double.NEGATIVE_INFINITY;
208                            }
209                            else {
210                                    // Add up all features for this potential.
211                                    FeatureEvaluation potEvals = evals[potential];
212                                    short[] indices = potEvals.index;
213                                    float[] vals = potEvals.value;
214                                    int i = 0;
215                                    short index = indices[i];
216                                    while(index >= 0) {
217                                            features += vals[i]*lambda[index]; 
218                                            index = indices[++i];
219                                    }
220                                    if(index == Short.MIN_VALUE) {
221                                            // An invalid potential is indicated by a feature value of Short.MIN_VALUE
222                                            features = Double.NEGATIVE_INFINITY; 
223                                    }
224                            }
225                            
226                            if(potential < modelInfo.nStates) {
227                                    nodeVal = features;
228                            }
229                            else {
230                                    int transition = potential - modelInfo.nStates;
231                                    mi[transition] = exp(features + nodeVal);
232                            }
233                    }               
234            }
235            
236            /** A specialized version of calcMi for the first position in a sequence. Has the special property that 
237             * constant edge features are not included.  Also optionally allows you to set a node value used to initialize the alphas. */
238            void calcStartAlpha(int seq, double[] lambda) {
239                    cacheProcessor.evaluatePosition(seq, 0);
240                    int overallPosition = modelInfo.seqOffsets[seq];
241                    int invalidIndex = overallPosition*modelInfo.nPotentials;
242                    for(short potential : modelInfo.orderedPotentials) {
243                            if(potential < modelInfo.nStates) {
244                                    boolean invalid = invalidTransitions[invalidIndex + potential];
245                                    double features = 0.0;
246    
247                                    if(invalid) {
248                                            features = Double.NEGATIVE_INFINITY;
249                                    }
250                                    else {
251                                            // Add up all features for this potential.
252                                            FeatureEvaluation potEvals = evals[potential];
253                                            short[] indices = potEvals.index;
254                                            float[] vals = potEvals.value;
255                                            int i = 0;
256                                            short index = indices[i];
257                                            while(index >= 0) {
258                                                    features += vals[i]*lambda[index];
259                                                    index = indices[++i];
260                                            }
261                                            if(index == Short.MIN_VALUE) {
262                                                    features = Double.NEGATIVE_INFINITY;
263                                            }
264                                    }
265                                    alpha[potential] = exp(features);
266                            }
267                    }               
268            }
269            
270            private void quickBetaUpdate(double[] lastBeta, double[] newBeta) {
271                    Arrays.fill(newBeta, 0);
272                    double nodeVal = 0.0;
273                    for(short potential : modelInfo.orderedPotentials) {
274                            if(potential < modelInfo.nStates) {
275                                    nodeVal = lastBeta[potential];
276                            }
277                            else {
278                                    int trans = potential - modelInfo.nStates;
279                                    int from = modelInfo.transitionFrom[trans];
280                                    newBeta[from] += mi[trans]*nodeVal;
281                                    //log.debug(String.format("beta[%d] = %f = %f + mi: %f * last: %f",
282                                    //                      from, newBeta[from], newBeta[from]-mi[trans]*nodeVal, mi[trans], nodeVal));
283                            }
284                    }
285            }
286    
287            private void quickAlphaUpdate(double[] lastAlpha, double[] newAlpha) {
288                    double nodeVal = 0.0;
289                    int lastState = -1;
290                    for(short potential : modelInfo.orderedPotentials) {
291                            if(potential < modelInfo.nStates) {
292                                    if(lastState != -1) {
293                                            newAlpha[lastState] = nodeVal;
294                                    }
295                                    lastState = potential;
296                                    nodeVal = 0.0;
297                            }
298                            else {
299                                    int trans = potential - modelInfo.nStates;
300                                    int from = modelInfo.transitionFrom[trans];
301                                    nodeVal += lastAlpha[from]*mi[trans];
302                            }
303                    }
304                    newAlpha[lastState] = nodeVal;
305            }
306    
307            void updateExpectations(int seq, int pos, double nodeNorm, double edgeNorm, double[] beta) {
308                    int currentNode = -1;
309                    double currentBeta = 0.0f;
310                    int overallPos = modelInfo.seqOffsets[seq]+pos;
311                    int invalidIndex = overallPos*modelInfo.nPotentials;
312                    for(short potential : modelInfo.orderedPotentials) {
313                            boolean invalid = invalidTransitions[invalidIndex + potential];
314                            if(invalid)
315                                    continue;
316                            
317                            double prob = 0.0;
318                            if(potential < modelInfo.nStates) {
319                                    currentNode = potential;
320                                    currentBeta = beta[currentNode];
321                                    //log.info(String.format("State prob #%d: %f * %f * %f ", currentNode, alpha[currentNode], currentBeta, nodeNorm));
322                                    prob = alpha[currentNode] * currentBeta * nodeNorm;
323                            }
324                            else {
325                                    if(pos == 0)
326                                            continue;
327                                    int trans = potential - modelInfo.nStates;
328                                    int yprev = modelInfo.transitionFrom[trans];
329                                    prob = prevAlpha[yprev] * mi[trans] * currentBeta * edgeNorm;
330                            }
331    
332                            // Iterate through features for this potential.
333                            FeatureEvaluation potEvals = evals[potential];
334                            short[] indices = potEvals.index;
335                            float[] vals = potEvals.value;
336                            int i = 0;
337                            short index = indices[i];
338                            if(index != Short.MIN_VALUE) {
339                                    while(index != -1) {
340                                            //log.info(String.format("Expect #%d: %f * %f ", index, prob, vals[i]));
341                                            expects[index] += prob*vals[i];
342                                            index = indices[++i];
343                                    }
344                            }
345                    }
346            }
347            
348            private double normalizePotential(double[] vec) {
349                    double norm = 0.0;
350                    int len = vec.length;
351                    for(int i=0; i<len; ++i) {
352                            norm += vec[i];
353                    }
354                    double mult = 1/norm;
355                    for(int i=0; i<len; ++i) {
356                            vec[i] *= mult;
357                    }
358                    return norm;
359            }
360    
361            static final double exp(double val) {
362                    return Math.exp(val);
363            }
364            
365            static final double log(double val) {
366                    return Math.log(val);
367            }
368    
369            public double[] getFeatureSums() {
370                    return this.featureSums.clone();
371            }
372    }