001 /**
002 *
003 */
004 package calhoun.analysis.crf.solver;
005
006 import java.util.Arrays;
007 import java.util.List;
008
009 import org.apache.commons.logging.Log;
010 import org.apache.commons.logging.LogFactory;
011
012 import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
013 import calhoun.analysis.crf.ModelManager;
014 import calhoun.analysis.crf.io.TrainingSequence;
015 import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
016 import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
017 import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
018 import calhoun.util.Assert;
019 import calhoun.util.ColtUtil;
020
021 /** computes the likelihood of the true path for a Markov CRF. The likelihood is normalized to a per label likelihood so
022 * that likelihood of different length paths can be meaningfully compared and a single set of optimization tolerances can be used.
023 * Must be configured with a {@link CacheProcessor}. */
024 public class MaximumLikelihoodGradient implements CRFObjectiveFunctionGradient {
025 private static final Log log = LogFactory.getLog(MaximumLikelihoodGradient.class);
026 boolean debug = log.isDebugEnabled();
027
028 CacheProcessor cacheProcessor = new AllSparseLengthCacheProcessor();
029 SolverSetup modelInfo;
030 FeatureEvaluation[] evals;
031 boolean[] invalidTransitions;
032
033 /// Cached value of the Mi matrix for all of the features present at every position
034 /// Mi is stored as a sparse matrix
035 int miLength;
036 double[] mi;
037
038 int iter = 0;
039
040 double[] prevAlpha;
041 double[] alpha;
042 double[][] betas;
043 double[] betaNorms;
044 double[] expects;
045
046 // We publish feature sums
047 private double[] featureSums;
048
049 /** gets the cache processor used to access feature evaluations
050 * @return the configured cache processor
051 */
052 public CacheProcessor getCacheProcessor() {
053 return cacheProcessor;
054 }
055
056 /** sets the cache processor used to access feature evaluations
057 * @param cacheProcessor the cache processor to use
058 */
059 public void setCacheProcessor(CacheProcessor cacheProcessor) {
060 this.cacheProcessor = cacheProcessor;
061 }
062
063 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
064 cacheProcessor.setTrainingData(fm, data);
065 modelInfo = cacheProcessor.getSolverSetup();
066 evals = cacheProcessor.getFeatureEvaluations();
067 invalidTransitions = cacheProcessor.getInvalidTransitions();
068
069 miLength = modelInfo.nTransitions;
070 mi = new double[miLength];
071
072 expects = new double[modelInfo.nFeatures];
073 prevAlpha = new double[modelInfo.nStates];
074 alpha = new double[modelInfo.nStates];
075
076 // Initialize betas (for use later, in the gradient computation)
077 betas = new double[modelInfo.longestSeq][modelInfo.nStates];
078 betaNorms = new double[modelInfo.longestSeq];
079 }
080
081 public double apply(double[] param, double[] grad) {
082 // Initialize values
083 Arrays.fill(grad, 0);
084 double result = 0.0;
085 Arrays.fill(expects, 0);
086
087 // Iterate through sequences
088 int seqStart = 0;
089 for(int i = 0; i<modelInfo.nSeqs; ++i) {
090 int len = modelInfo.seqOffsets[i+1] - modelInfo.seqOffsets[i];
091
092 // Work backwards, computing betas.
093 Arrays.fill(betas[len-1], 1.0);
094 betaNorms[len-1] = 0;
095 for (int pos = len - 1; pos > 0; --pos) {
096 calcMi(i, pos, param);
097 quickBetaUpdate(betas[pos], betas[pos-1]);
098
099 double n = normalizePotential(betas[pos-1]);
100 betaNorms[pos-1] = betaNorms[pos] + log(n);
101 }
102
103 // Now work forwards
104 double logZ = Double.NEGATIVE_INFINITY; // This should always get initialized. Blow up if not.
105 double alphaNorm = 0;
106 double prevAlphaNorm = 0;
107 for(int pos=0; pos<len; ++pos) {
108 double[] beta = betas[pos];
109 double betaNorm = betaNorms[pos];
110 if(pos == 0) {
111 calcStartAlpha(i, param);
112 alphaNorm = log(normalizePotential(alpha));
113
114 // We now have everything needed to compute Z.
115 logZ = log(ColtUtil.dotProduct(alpha, beta)) + betaNorm + alphaNorm;
116 //log.info("Z = "+logZ);
117 }
118 else {
119 calcMi(i, pos, param);
120 //r * M (or M'*r)
121 //mi.zMult(prevAlpha, alpha, 1, 0, true);
122 quickAlphaUpdate(prevAlpha, alpha);
123 alphaNorm = prevAlphaNorm + log(normalizePotential(alpha));
124
125 // Verify our calculations by checking the Z.
126 // Expensive, so only enable as needed during debugging
127 double newZ = log(ColtUtil.dotProduct(alpha, beta)) + betaNorm + alphaNorm;
128 Assert.a(Math.abs(newZ-logZ) < 0.0000001*Math.abs(logZ), "New Z:",newZ," Old was: ", logZ);
129 }
130
131 // Iterate through the potentials to update feature expectations
132 double nodeNorm = exp(alphaNorm + betaNorm - logZ);
133 double edgeNorm = exp(prevAlphaNorm + betaNorm - logZ);
134 updateExpectations(i, pos, nodeNorm, edgeNorm, beta);
135
136 if(debug) {
137 if((i < 2 || i == modelInfo.nSeqs-1) && (pos < 2 || pos >= len-2)) {
138 log.debug(String.format("Pos: %d expects: %s alphas: %s (norm %f) betas: %s (norm %f)", pos, ColtUtil.format(expects), ColtUtil.format(alpha), alphaNorm, ColtUtil.format(beta), betaNorm));
139 }
140 }
141
142 // Recycle the arrays
143 double[] swap = prevAlpha;
144 prevAlpha = alpha;
145 alpha = swap;
146 prevAlphaNorm = alphaNorm;
147 }
148
149 result -= logZ;
150 seqStart += len;
151 }
152
153 // sum_j lambda_j F_j(xk, yk)
154 double[] featureSums = cacheProcessor.getFeatureSums();
155 for (int j = 0; j < modelInfo.nFeatures; ++j) {
156 result += featureSums[j] * param[j];
157 grad[j] = featureSums[j] - expects[j];
158 }
159
160 if(log.isInfoEnabled()) {
161 // Report average per-label numbers.
162 log.info(String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad (unnorm): %s", iter, exp(result/modelInfo.totalPositions), result/modelInfo.totalPositions, ColtUtil.norm(grad)/modelInfo.totalPositions, ColtUtil.format(featureSums), ColtUtil.format(expects), ColtUtil.format(param), ColtUtil.format(grad)));
163 }
164 iter += 1;
165 result = result/modelInfo.totalPositions;
166 for(int i=0; i<grad.length; ++i) {
167 grad[i] = grad[i]/modelInfo.totalPositions;
168 }
169
170 this.featureSums = featureSums;
171
172 return result;
173 }
174
175 public void clean() {
176 }
177
178 /** This is one of the most time critical parts of the entire solver. The goal is to update the transition matrix.
179 * This function makes a lot of assumptions in order to maximize performance.
180 *
181 * To maximize performance, we want to make one pass through the Mi matrix, setting each entry to its correct value.
182 * The value for each entry is the exponent of the sum of the weighted feature values of the edge for that entry and
183 * its corresponding node. The entry s0,s1 consists of the s0,s1 edge and the s1 node.
184 *
185 * Because node features are applied to more than 1 entry in the matrix, we use a sorting of all of the features where
186 * each node preceeds all its corresponding edges. This allows us to keep track of only 1 node value at a time and easily
187 * apply it to all its edge features.
188 *
189 * As we evaluate each potential, we check the cache to see if it is valid at this position and to get any features values.
190 * Note that this function very much depends on the fact that the entries in the cache will be in the correct order.
191 *
192 * The other wrinkle is that for features that always occur (constant features), we pull them from the constant mi array, not
193 * from the cache.
194 *
195 * This function is also used to calculate the Mi Matrix for the constant features.
196 * */
197 void calcMi(int seq, int pos, double[] lambda) {
198 cacheProcessor.evaluatePosition(seq, pos);
199 double nodeVal = Double.NaN;
200 int overallPosition = modelInfo.seqOffsets[seq]+pos;
201 int invalidIndex = overallPosition*modelInfo.nPotentials;
202 for(short potential : modelInfo.orderedPotentials) {
203 boolean invalid = invalidTransitions[invalidIndex + potential];
204 double features = 0.0;
205
206 if(invalid) {
207 features = Double.NEGATIVE_INFINITY;
208 }
209 else {
210 // Add up all features for this potential.
211 FeatureEvaluation potEvals = evals[potential];
212 short[] indices = potEvals.index;
213 float[] vals = potEvals.value;
214 int i = 0;
215 short index = indices[i];
216 while(index >= 0) {
217 features += vals[i]*lambda[index];
218 index = indices[++i];
219 }
220 if(index == Short.MIN_VALUE) {
221 // An invalid potential is indicated by a feature value of Short.MIN_VALUE
222 features = Double.NEGATIVE_INFINITY;
223 }
224 }
225
226 if(potential < modelInfo.nStates) {
227 nodeVal = features;
228 }
229 else {
230 int transition = potential - modelInfo.nStates;
231 mi[transition] = exp(features + nodeVal);
232 }
233 }
234 }
235
236 /** A specialized version of calcMi for the first position in a sequence. Has the special property that
237 * constant edge features are not included. Also optionally allows you to set a node value used to initialize the alphas. */
238 void calcStartAlpha(int seq, double[] lambda) {
239 cacheProcessor.evaluatePosition(seq, 0);
240 int overallPosition = modelInfo.seqOffsets[seq];
241 int invalidIndex = overallPosition*modelInfo.nPotentials;
242 for(short potential : modelInfo.orderedPotentials) {
243 if(potential < modelInfo.nStates) {
244 boolean invalid = invalidTransitions[invalidIndex + potential];
245 double features = 0.0;
246
247 if(invalid) {
248 features = Double.NEGATIVE_INFINITY;
249 }
250 else {
251 // Add up all features for this potential.
252 FeatureEvaluation potEvals = evals[potential];
253 short[] indices = potEvals.index;
254 float[] vals = potEvals.value;
255 int i = 0;
256 short index = indices[i];
257 while(index >= 0) {
258 features += vals[i]*lambda[index];
259 index = indices[++i];
260 }
261 if(index == Short.MIN_VALUE) {
262 features = Double.NEGATIVE_INFINITY;
263 }
264 }
265 alpha[potential] = exp(features);
266 }
267 }
268 }
269
270 private void quickBetaUpdate(double[] lastBeta, double[] newBeta) {
271 Arrays.fill(newBeta, 0);
272 double nodeVal = 0.0;
273 for(short potential : modelInfo.orderedPotentials) {
274 if(potential < modelInfo.nStates) {
275 nodeVal = lastBeta[potential];
276 }
277 else {
278 int trans = potential - modelInfo.nStates;
279 int from = modelInfo.transitionFrom[trans];
280 newBeta[from] += mi[trans]*nodeVal;
281 //log.debug(String.format("beta[%d] = %f = %f + mi: %f * last: %f",
282 // from, newBeta[from], newBeta[from]-mi[trans]*nodeVal, mi[trans], nodeVal));
283 }
284 }
285 }
286
287 private void quickAlphaUpdate(double[] lastAlpha, double[] newAlpha) {
288 double nodeVal = 0.0;
289 int lastState = -1;
290 for(short potential : modelInfo.orderedPotentials) {
291 if(potential < modelInfo.nStates) {
292 if(lastState != -1) {
293 newAlpha[lastState] = nodeVal;
294 }
295 lastState = potential;
296 nodeVal = 0.0;
297 }
298 else {
299 int trans = potential - modelInfo.nStates;
300 int from = modelInfo.transitionFrom[trans];
301 nodeVal += lastAlpha[from]*mi[trans];
302 }
303 }
304 newAlpha[lastState] = nodeVal;
305 }
306
307 void updateExpectations(int seq, int pos, double nodeNorm, double edgeNorm, double[] beta) {
308 int currentNode = -1;
309 double currentBeta = 0.0f;
310 int overallPos = modelInfo.seqOffsets[seq]+pos;
311 int invalidIndex = overallPos*modelInfo.nPotentials;
312 for(short potential : modelInfo.orderedPotentials) {
313 boolean invalid = invalidTransitions[invalidIndex + potential];
314 if(invalid)
315 continue;
316
317 double prob = 0.0;
318 if(potential < modelInfo.nStates) {
319 currentNode = potential;
320 currentBeta = beta[currentNode];
321 //log.info(String.format("State prob #%d: %f * %f * %f ", currentNode, alpha[currentNode], currentBeta, nodeNorm));
322 prob = alpha[currentNode] * currentBeta * nodeNorm;
323 }
324 else {
325 if(pos == 0)
326 continue;
327 int trans = potential - modelInfo.nStates;
328 int yprev = modelInfo.transitionFrom[trans];
329 prob = prevAlpha[yprev] * mi[trans] * currentBeta * edgeNorm;
330 }
331
332 // Iterate through features for this potential.
333 FeatureEvaluation potEvals = evals[potential];
334 short[] indices = potEvals.index;
335 float[] vals = potEvals.value;
336 int i = 0;
337 short index = indices[i];
338 if(index != Short.MIN_VALUE) {
339 while(index != -1) {
340 //log.info(String.format("Expect #%d: %f * %f ", index, prob, vals[i]));
341 expects[index] += prob*vals[i];
342 index = indices[++i];
343 }
344 }
345 }
346 }
347
348 private double normalizePotential(double[] vec) {
349 double norm = 0.0;
350 int len = vec.length;
351 for(int i=0; i<len; ++i) {
352 norm += vec[i];
353 }
354 double mult = 1/norm;
355 for(int i=0; i<len; ++i) {
356 vec[i] *= mult;
357 }
358 return norm;
359 }
360
361 static final double exp(double val) {
362 return Math.exp(val);
363 }
364
365 static final double log(double val) {
366 return Math.log(val);
367 }
368
369 public double[] getFeatureSums() {
370 return this.featureSums.clone();
371 }
372 }