001 package calhoun.analysis.crf.solver.semimarkov;
002
003 import java.util.Arrays;
004 import java.util.List;
005
006 import org.apache.commons.logging.Log;
007 import org.apache.commons.logging.LogFactory;
008
009 import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
010 import calhoun.analysis.crf.ModelManager;
011 import calhoun.analysis.crf.io.TrainingSequence;
012 import calhoun.analysis.crf.solver.CacheProcessor;
013 import calhoun.analysis.crf.solver.LogFiles;
014 import calhoun.analysis.crf.solver.LookbackBuffer;
015 import calhoun.analysis.crf.solver.RecyclingBuffer;
016 import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
017 import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
018 import calhoun.analysis.crf.solver.CacheProcessor.SolverSetup;
019 import calhoun.analysis.crf.solver.CacheProcessor.StatePotentials;
020 import calhoun.util.Assert;
021 import calhoun.util.ColtUtil;
022
023 /** computes the likelihood of the true path for a semi-Markov CRF. The likelihood is normalized to a per label likelihood.
024 * <h2>Debugging output</h2>
025 * To get a better understanding of what the objective function is doing, several differn properties can be set that
026 * cause the objective function to write out trace files showing its calculations during training. Usually when turning
027 * these options on, you should set <code>maxIters = 1</code> and <code>requireConvergence = false</code> in your optimizer
028 * to do only a single training iteration, possibly setting the starts to some predetermined value. Each of these
029 * properties can be configured with a filename and each time {@link #apply} is called, the file will be overwritten with
030 * data from the current call. The logging options are:
031 * <ul>
032 * <li> <b><code>alphaFile</code></b> - computation of alpha values for Markov states, includes all nodes and edges.
033 * <li> <b><code>alphaLengthFile</code></b> - computation of alpha values for semi-Markov states , includes all segments
034 * <li> <b><code>expectFile</code></b> - computation of expected values for each Markov feature
035 * <li> <b><code>expectLengthFile</code></b> - computation of expected values for each semi-Markov feature
036 * <li> <b><code>nodeMarginalFile</code></b> - computation of marginal probability of each state at each position
037 * </ul>
038 * */
039 public class CleanMaximumLikelihoodSemiMarkovGradient implements CRFObjectiveFunctionGradient {
040 static final Log log = LogFactory.getLog(CleanMaximumLikelihoodSemiMarkovGradient.class);
041 public static final boolean debug = log.isDebugEnabled();
042 public static final double ASSERTION_TOLERANCE = 0.0001;
043
044 public static final int NORM_FACTOR = 50;
045 public static final double NORM_MIN = Math.exp(-NORM_FACTOR);
046 public static final double NORM_MAX = Math.exp(NORM_FACTOR);
047
048 final LogFiles logs = new LogFiles();
049
050 SolverSetup modelInfo;
051 CacheProcessor cacheProcessor;
052 FeatureEvaluation[] evals;
053 LengthFeatureEvaluation[][] lengthEvals;
054 boolean[] invalidTransitions;
055
056 // / Cache feature information
057 // / Cached value of the Mi matrix for all of the features present at every position
058 // / Mi is stored as a sparse matrix
059 short maxLookback;
060 StatePotentials[] statesWithLookback;
061 StatePotentials[] statesWithoutLookback;
062 int iter = 0;
063 double[][] alphas;
064 int[] alphaNorms;
065 double[] starterAlpha;
066 int nSemiMarkovStates;
067
068 // At any given point, lookbackBuffer.get(x) returns the information about a lookback of x. Lookbacks start at 0.
069 RecyclingBuffer<LookbackBuffer> lookbackBuffer;
070 LookbackBuffer nextBuffer;
071
072 double[] lambda;
073 double logZ;
074 int zNorm;
075 double zInv;
076 double[] expects;
077
078 AlphaLengthFeatureProcessor alphaProcessor;
079 BetaLengthFeatureProcessor betaProcessor;
080
081 // We publish feature sums
082 private double[] featureSums;
083
084 public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
085 cacheProcessor.setTrainingData(fm, data);
086
087 modelInfo = cacheProcessor.getSolverSetup();
088 Assert.a(modelInfo.maxStateLengths != null, "Maximum state lengths not set.");
089 Assert.a(modelInfo.maxStateLengths.length == modelInfo.nStates, "Maximum state lengths array was length ("+modelInfo.maxStateLengths.length+"). Must have one entry for each state "+modelInfo.nStates+")");
090 evals = cacheProcessor.getFeatureEvaluations();
091 lengthEvals = cacheProcessor.getLengthFeatureEvaluations();
092 invalidTransitions = cacheProcessor.getInvalidTransitions();
093
094 // Create local references to cache values
095 maxLookback = modelInfo.maxLookback;
096 statesWithLookback = modelInfo.statesWithLookback;
097 statesWithoutLookback = modelInfo.statesWithoutLookback;
098 nSemiMarkovStates = modelInfo.statesWithLookback.length;
099
100 // Initialize betas (for use later, in the gradient computation)
101 alphas = new double[modelInfo.longestSeq][modelInfo.nStates];
102 alphaNorms = new int[modelInfo.longestSeq];
103 expects = new double[modelInfo.nFeatures];
104
105 LookbackBuffer[] bufferContents = new LookbackBuffer[maxLookback+3];
106 for(int i = 0; i<maxLookback+3; ++i) {
107 bufferContents[i] = new LookbackBuffer(modelInfo.nStates, modelInfo.nTransitions);
108 }
109 lookbackBuffer = new RecyclingBuffer<LookbackBuffer>(bufferContents);
110 nextBuffer = new LookbackBuffer(modelInfo.nStates, modelInfo.nTransitions);
111
112 alphaProcessor = new AlphaLengthFeatureProcessor(this);
113 betaProcessor = new BetaLengthFeatureProcessor(this);
114
115 starterAlpha = new double[modelInfo.nStates];
116 }
117
118 public double apply(double[] param, double[] grad) {
119 log.debug(String.format("Beginning It: %d Weights: %s", iter, ColtUtil.format(param)));
120 logs.open();
121 lambda = param;
122 Arrays.fill(grad, 0);
123 double totalZ = 0.0;
124 double result = 0.0;
125
126 try {
127 // Iterate through sequences
128 Arrays.fill(expects, 0);
129 for (int i = 0; i < modelInfo.nSeqs; ++i) {
130 int len = modelInfo.seqOffsets[i + 1] - modelInfo.seqOffsets[i];
131
132 alphaAndBetaPass(i, len);
133
134 // Update for the next sequence
135 totalZ += logZ;
136 }
137
138 // sum_j lambda_j F_j(xk, yk)
139 double[] featureSums = cacheProcessor.getFeatureSums();
140 this.featureSums = featureSums;
141 for (int j = 0; j < modelInfo.nFeatures; ++j) {
142 result += featureSums[j] * param[j];
143 grad[j] = featureSums[j] - expects[j];
144 }
145 log.debug("Path Value: "+result+" Norm: "+totalZ);
146 result -= totalZ;
147 if (log.isInfoEnabled()) {
148 log.info(String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", iter, exp(result), result,
149 ColtUtil.norm(grad), ColtUtil.format(featureSums), ColtUtil.format(expects), ColtUtil.format(param), ColtUtil.format(grad)));
150 }
151 Assert.a(exp(result) <= 1.0, "Likelihood is greater than 1.");
152
153 // Normalize by the length of the sequence
154 result = result/modelInfo.totalPositions;
155 for(int i=0; i<grad.length; ++i) {
156 grad[i] = grad[i]/modelInfo.totalPositions;
157 }
158
159 iter += 1;
160 }
161 finally {
162 logs.close();
163 }
164 return result;
165 }
166
167 public void clean() {
168 }
169
170 void alphaAndBetaPass(int i, int len) {
171 // Work forwards, computing alphas
172 alphaProcessor.computeAlpha(i, len);
173
174 // Since the final beta array is all ones, we can sum the alphas to get the Z
175 double sum = 0.0;
176 for (double val : alphas[len - 1]) {
177 sum += val;
178 }
179
180 logZ = log(sum) + NORM_FACTOR * (alphaNorms[len - 1]);
181 zNorm = ((int) logZ) / NORM_FACTOR;
182 zInv = exp(zNorm * NORM_FACTOR - logZ);
183 log.debug("Seq: "+i+" Z: "+printNorm(1/zInv, zNorm));
184
185 // Work backwards, computing betas and expectations.
186 betaProcessor.computeBetasAndExpectations(i, len);
187
188 //if(debug) {
189 // logFeatureSums(i);
190 //}
191 }
192
193 void logBuf() {
194 int l = lookbackBuffer.length;
195 String s = "";
196 for(int i=0; i < l; ++i) {
197 s += lookbackBuffer.get(i).pos + " ";
198 }
199 log.info(s);
200 }
201
202 void logBufBeta() {
203 int l = lookbackBuffer.length;
204 String s = "";
205 for(int i=0; i < l; ++i) {
206 s += ColtUtil.format(lookbackBuffer.get(i).beta) + " ";
207 }
208 log.info(s);
209 }
210
211 /**
212 * Computes an unexponentiated mi matrix and updates stable states. Used to create caches for lookback searches.
213 *
214 */
215 void cacheMi(int seqNum, double[] mi, double[] prevStable, double [] newStable, int miPos) {
216 if (miPos < 0)
217 return;
218 //calcMi(mi, overallPosition, cacheStart, cacheStop, false);
219 calcMi(mi, seqNum, miPos, false);
220 // Go through the mi matrix and for all states with length dependence compute stable values for self transitions
221 for (int i = 0; i < modelInfo.nStates; ++i) {
222 if (modelInfo.maxStateLengths[i] > 1) {
223 // These are all log values so we add them
224 newStable[i] = prevStable[i];
225 double trans = mi[modelInfo.selfTransitions[i]];
226 if(!Double.isInfinite(trans)) {
227 //log.debug("Pos: "+miPos+" State: "+i+" Trans: "+trans+" Total: "+(newStable[i]+trans));
228 newStable[i] += trans;
229 }
230 }
231 }
232 }
233
234 /**
235 * This is one of the most time critical parts of the entire solver. The goal is to update the transition matrix.
236 * This function makes a lot of assumptions in order to maximize performance.
237 *
238 * To maximize performance, we want to make one pass through the Mi matrix, setting each entry to its correct value.
239 * The value for each entry is the exponent of the sum of the weighted feature values of the edge for that entry and
240 * its corresponding node. The entry s0,s1 consists of the s0,s1 edge and the s1 node.
241 */
242 void calcMi(double[] mi, int seq, int pos, boolean doExp) {
243 cacheProcessor.evaluatePosition(seq, pos);
244 double nodeVal = Double.NaN;
245 int overallPosition = modelInfo.seqOffsets[seq]+pos;
246 int invalidIndex = overallPosition*modelInfo.nPotentials;
247 for(short potential : modelInfo.orderedPotentials) {
248 boolean invalid = invalidTransitions[invalidIndex + potential];
249 double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
250
251 // Add up all features for this potential.
252 FeatureEvaluation potEvals = evals[potential];
253 short[] indices = potEvals.index;
254 float[] vals = potEvals.value;
255 int i = 0;
256 short index = indices[i];
257 while(index >= 0) {
258 // An invalid potential is indicated by a feature value of Short.MAX_VALUE
259 features += vals[i]*lambda[index];
260 index = indices[++i];
261 }
262 if(index == Short.MIN_VALUE) {
263 features = Double.NEGATIVE_INFINITY;
264 }
265
266 if(potential < modelInfo.nStates) {
267 nodeVal = features;
268 }
269 else {
270 //log.debug(String.format("Mi[%d, %d] = %f, adding in %f to get %f", feat.yprev(), feat.y(), val, feat.value(), val*exp(feat.value()*param[feat.index()])));
271 int transition = potential - modelInfo.nStates;
272 double val = features + nodeVal;
273 if(doExp)
274 val = exp(val);
275 mi[transition] = val;
276 }
277 }
278 }
279
280 /**
281 * Given a vector with an existing normalization factor, convert it to a new normalization factor by scaling the
282 * entries.
283 */
284 static final void renormalize(final double[] vec, final int currentNorm, final int newNorm) {
285 // Instead of dividing by the different (new-current), we reverse the subtraction to negate the exponent and
286 // then multiply.
287 double factor = exp(NORM_FACTOR * (currentNorm - newNorm));
288 //log.info(factor);
289 //log.info(ColtUtil.format(vec));
290 int len = vec.length;
291 for (int i = 0; i < len; ++i) {
292 if(vec[i] != 0.0)
293 vec[i] *= factor;
294 }
295 //log.info(ColtUtil.format(vec));
296 }
297
298 /** Given a vector, computes a normalization factor for the entries and scales them according to that factor. */
299 static final int normalize(final double[] vec) {
300 double sum = 0.0;
301 for(double val : vec) {
302 sum += val;
303 }
304 if(sum == 0.0 || (sum > NORM_MIN && sum < NORM_MAX)) {
305 // No normalization required, our vector is in range.
306 return 0;
307 }
308 if(debug)
309 Assert.a(!Double.isNaN(sum));
310
311 //log.info("performing normalization");
312 double val = log(sum);
313 int norm = (int) val / NORM_FACTOR;
314 val = exp(NORM_FACTOR * norm);
315 int len = vec.length;
316 for (int i = 0; i < len; ++i) {
317 vec[i] /= val;
318 }
319 return norm;
320 }
321
322 static final double exp(final double val) {
323 return Math.exp(val);
324 }
325
326 static final double log(final double val) {
327 return Math.log(val);
328 }
329
330 final void logFeatureSums(int seqNum) {
331 double[][] seqFeatureSums = cacheProcessor.getSequenceFeatureSums();
332 if(seqFeatureSums != null) {
333 double seqResult = 0.0;
334 for (int j = 0; j < modelInfo.nFeatures; ++j) {
335 seqResult += seqFeatureSums[seqNum][j] * lambda[j];
336 }
337 log.debug(String.format("Seq: %d L: %g LL: %f Training path: %f Z: %f", seqNum, exp(seqResult-logZ), seqResult-logZ, seqResult, logZ));
338 Assert.a(exp(seqResult-logZ) < 1.0);
339 }
340 }
341
342 public static final String printNorm(final double value, final int norm) {
343 if( value == 0.0)
344 return "0 ("+norm+")";
345 if( Double.isNaN(value))
346 return "NaN ("+norm+")";
347 int exponent = (int) log(value);
348
349 double eValue = value/exp(exponent);
350 if(Double.isNaN(eValue)) {
351 return String.format("NaN(%e n:%d)", value, norm);
352 }
353 //return String.format("%e(%d) %fe%d", value, norm, eValue, exponent+norm*NORM_FACTOR);
354 return String.format("%fe%d", eValue, exponent+norm*NORM_FACTOR);
355 }
356
357 /** gets the cache processor used to access feature evaluations
358 * @return the configured cache processor
359 */
360 public CacheProcessor getCacheProcessor() {
361 return cacheProcessor;
362 }
363
364 /** sets the cache processor used to access feature evaluations
365 * @param cacheProcessor the cache processor to use
366 */
367
368 public void setCacheProcessor(CacheProcessor cacheProcessor) {
369 this.cacheProcessor = cacheProcessor;
370 }
371
372 public String getAlphaLengthFile() {
373 return logs.alphaLengthFile;
374 }
375
376 public void setAlphaLengthFile(String alphaLengthFile) {
377 logs.alphaLengthFile = alphaLengthFile;
378 }
379
380 public String getAlphaFile() {
381 return logs.alphaFile;
382 }
383
384 public void setAlphaFile(String alphaFile) {
385 logs.alphaFile = alphaFile;
386 }
387
388 public String getExpectFile() {
389 return logs.expectFile;
390 }
391
392 public void setExpectFile(String expectFile) {
393 logs.expectFile = expectFile;
394 }
395
396 public String getExpectLengthFile() {
397 return logs.expectLengthFile;
398 }
399
400 public void setExpectLengthFile(String expectLengthFile) {
401 logs.expectLengthFile = expectLengthFile;
402 }
403
404 public String getNodeMarginalFile() {
405 return logs.nodeMarginalFile;
406 }
407
408 public void setNodeMarginalFile(String nodeMarginalFile) {
409 logs.nodeMarginalFile = nodeMarginalFile;
410 }
411
412 public String getBetaLengthFile() {
413 return logs.betaLengthFile;
414 }
415
416 public void setBetaLengthFile(String betaLengthFile) {
417 logs.betaLengthFile = betaLengthFile;
418 }
419 }