001 package calhoun.analysis.crf;
002
003 import java.io.IOException;
004 import java.io.Serializable;
005 import java.util.List;
006
007 import org.apache.commons.logging.Log;
008 import org.apache.commons.logging.LogFactory;
009 import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
010 import org.springframework.context.ApplicationContext;
011 import org.springframework.context.support.GenericApplicationContext;
012 import org.springframework.core.io.ByteArrayResource;
013
014 import calhoun.analysis.crf.CRFInference.InferenceResult;
015 import calhoun.analysis.crf.io.InputHandler;
016 import calhoun.analysis.crf.io.InputSequence;
017 import calhoun.analysis.crf.io.OutputHandler;
018 import calhoun.analysis.crf.io.OutputHandlerGeneCallPredict;
019 import calhoun.analysis.crf.io.OutputHandlerGeneCallStats;
020 import calhoun.analysis.crf.io.TrainingSequence;
021 import calhoun.analysis.crf.io.CompositeInput.LegacyInputHandler;
022 import calhoun.util.ErrorException;
023 import calhoun.util.FileUtil;
024
025 /** the central class for the Conrad engine. Has a <code>main</code> function for calling Conrad
026 * from the command line and a high-level public interface for programmatic operations. This class
027 * is mostly just a container which delegates the real work to the various objects set up in the
028 * configuration.
029 * */
030 public class Conrad implements Serializable {
031 private static final long serialVersionUID = -5964610632818921236L;
032 private static final Log log = LogFactory.getLog(Conrad.class);
033
034 byte[] configXml;
035 ModelManager model;
036 InputHandler inputHandler;
037 OutputHandler outputHandler;
038 transient CRFInference inference;
039 transient CRFTraining optimizer;
040
041 double[] weights = null;
042 double trainingTime = 0;
043
044 /** Command line entry point for running CRFs. Can train, test, or predict depending on arguments.
045 * @param args list of command line arguments. See usage for details. */
046 public static void main(String[] args) throws Exception {
047 if(args.length != 4) {
048 usage();
049 }
050 if(args[0].startsWith("train")) {
051 Conrad crf;
052 if(args[0].equals("trainWeights")) {
053 crf = Conrad.read(args[1]);
054 crf.trainWeights(args[2]);
055 }
056 else {
057 crf = new Conrad(args[1]);
058 if(args[0].equals("trainFeatures")) {
059 crf.trainFeatures(args[2]);
060 }
061 else {
062 crf.train(args[2]); // train features and weights
063 }
064 }
065 crf.write(args[3]);
066 }
067 else if(args[0].equals("test")) {
068 Conrad crf = Conrad.read(args[1]);
069 crf.initSolver();
070 crf.test(args[2], args[3]);
071 }
072 else if(args[0].equals("predict")) {
073 Conrad crf = Conrad.read(args[1]);
074 OutputHandlerGeneCallPredict predictOutputHandler = new OutputHandlerGeneCallPredict();
075 predictOutputHandler.setWriteTrainingData(false);
076 predictOutputHandler.setManager(crf.getModel());
077 predictOutputHandler.setInputHandler(crf.getInputHandler());
078 crf.setOutputHandler(predictOutputHandler);
079 crf.initSolver();
080 crf.testWithoutAnswers(args[2], args[3]);
081 }
082 else {
083 usage();
084 }
085 }
086
087 /** creates a Conrad engine with no configuration. All configuration must be done programmatically.
088 */
089 public Conrad() {
090 }
091
092 /** creates a Conrad engine based on configuration information from an XML model file.
093 * @param configFile string filename of the XML model file */
094 @SuppressWarnings("deprecation")
095 public Conrad(String configFile) {
096 try {
097 configXml = FileUtil.readFileAsBytes(configFile);
098 }
099 catch(IOException ex) {
100 throw new RuntimeException("Failed to load config file: "+configFile, ex);
101 }
102 ApplicationContext ctx = initSolver();
103 model = (ModelManager) ctx.getBean("model");
104
105 // For backwards compatibility, if we don't have an inputHandler but we do have an input format, use a legacy handler.
106 if(ctx.containsBean("inputFormat") && !ctx.containsBean("inputHandler")) {
107 inputHandler = new LegacyInputHandler(ctx.getBean("inputFormat"));
108
109 if(ctx.containsBean("outputHandler")) {
110 outputHandler = (OutputHandler)ctx.getBean("outputHandler");
111 } else {
112 OutputHandlerGeneCallStats legacyOutputHandler = new OutputHandlerGeneCallStats();
113 legacyOutputHandler.setWriteTrainingData(true);
114 legacyOutputHandler.setManager(model);
115 legacyOutputHandler.setInputHandler(inputHandler);
116 outputHandler = legacyOutputHandler;
117 }
118 }
119 else {
120 inputHandler = (InputHandler) ctx.getBean("inputHandler");
121 outputHandler = (OutputHandler) ctx.getBean("outputHandler");
122 }
123 }
124
125 /** writes this Conrad engine to a file. This is most often used to save a trained model file.
126 *
127 * @param filename string name of the file that will contain the serialized model.
128 * @throws IOException if a problem occurs writing to the file
129 */
130 public void write(String filename) throws IOException {
131 FileUtil.writeObject(filename, this);
132 }
133
134 /** read in a Conrad engine from a file. This file must have previously been created by calling {@link #write}
135 *
136 * @param filename string name of the file containing the model.
137 * @return the Conrad engine which has been read from the file
138 * @throws IOException if there is a problem reading the file
139 */
140 public static Conrad read(String filename) throws IOException {
141 try {
142 Conrad ret = (Conrad) FileUtil.readObject(filename);
143 ret.initSolver();
144 return ret;
145 } catch (ClassNotFoundException ex) {
146 throw new ErrorException(ex);
147 }
148 }
149
150 /** fully trains this Conrad engine with this training data. The training data is specified as a string location,
151 * which the configured {@link InputHandler } is responsible for converting into a list of training sequences
152 * of the appropriate type.
153 *
154 * @param location string location of the data. The exact meaning will be determined by the InputHandler.
155 * @throws IOException if there is a problem reading the training data.
156 */
157 public void train(String location) throws IOException {
158 List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(location, false);
159 train(t);
160 }
161
162
163 /** fully trains this Conrad engine with this training data. The training data is specified as a list of
164 * training sequences, and no DataInputHandler is used.
165 * @param data a list of training sequences to use for training
166 */
167 public void train(List<? extends TrainingSequence<?>> data) {
168 trainFeatures(data);
169 trainWeights(data);
170 }
171
172 /** trains only the features in the current model with this training data. {@link FeatureManager#train } is called
173 * for each feature in the model, but no optimization is performed and no feature weights are set. This allows the
174 * features themselves to be parameterized on one set of training data, while using a different set for optimizing the
175 * feature weights.<p>
176 * The training data is specified as a string location,
177 * which the configured {@link InputHandler } is responsible for converting into a list of training sequences
178 * of the appropriate type.
179 *
180 * @param location string location of the data. The exact meaning will be determined by the InputHandler.
181 * @throws IOException if there is a problem reading the training data
182 */
183 public void trainFeatures(String location) throws IOException {
184 List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(location, false);
185 trainFeatures(t);
186 }
187
188 /** trains only the features in the current model with this training data. {@link FeatureManager#train } is called
189 * for each feature in the model, but no optimization is performed and no feature weights are set. This allows the
190 * features themselves to be parameterized on one set of training data, while using a different set for optimizing the
191 * feature weights.<p>
192 * The training data is specified as a list of
193 * training sequences, and no DataInputHandler is used.
194 * @param data a list of training sequences to use for training
195 */
196 public void trainFeatures(List<? extends TrainingSequence<?>> data) {
197 print("Training features");
198 double timer = System.currentTimeMillis();
199
200 // Train features
201 model.train(0, model, data);
202 if (log.isDebugEnabled()) {
203 log.debug("Features:");
204 for (int i = 0; i < model.getNumFeatures(); ++i) {
205 log.debug(model.getFeatureName(i));
206 }
207 }
208
209 trainingTime = (System.currentTimeMillis() - timer)/1000;
210 print("Trained in "+trainingTime+" seconds.");
211 }
212
213 /** optimizes the feature weights for the current model with this training data. Assumes that {@link #trainFeatures }
214 * has already been called to train the individual features. <p>
215 * The training data is specified as a list of training sequences, and no DataInputHandler is used.
216 * @param location string location of the data. The exact meaning will be determined by the InputHandler.
217 * @throws IOException if there is a problem reading the training data
218 */
219 public void trainWeights(String location) throws IOException {
220 List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(location, false);
221 trainWeights(t);
222 }
223
224 /** optimizes the feature weights for the current model with this training data. Assumes that {@link #trainFeatures }
225 * has already been called to train the individual features. <p>
226 * The training data is specified as a list of training sequences, and no DataInputHandler is used.
227 * @param data a list of training sequences to use for training
228 */
229 public void trainWeights(List<? extends TrainingSequence<?>> data) {
230 print("Training weights");
231 double timer = System.currentTimeMillis();
232
233 // Train weights
234 weights = optimizer.optimize(model, data);
235
236 timer = (System.currentTimeMillis() - timer)/1000;
237 trainingTime += timer;
238 print("Trained weights in "+timer+" seconds. "+trainingTime+" total.");
239 }
240
241 /** runs a trained model against a set of input data with known results and evaluates the performance. Assumes that
242 * {@link #train} has already been called to train the model. For convenience, the data is passed in as a training set,
243 * although the model is not trained. The input is used to create a set of predictions and then those predictions are
244 * compared against the expected outputs. The result of the prediction is passed to the output handler which can compare
245 * the predicted versus the expected values
246 * @param inputLocation string location of the data. The exact meaning will be determined by the InputHandler.
247 * @throws IOException if there is a problem reading the training data
248 */
249 public void test(String inputLocation) throws IOException {
250 List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(inputLocation, false);
251 test(t);
252 }
253
254 public void test(String inputLocation, String outputLocation) throws IOException {
255 List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(inputLocation, false);
256 test(t, outputLocation);
257 }
258
259 public void test(List<? extends TrainingSequence<?>> data) throws IOException {
260 test(data, null);
261 }
262
263 public void testWithoutAnswers(String inputLocation, String outputLocation) throws IOException {
264 List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(inputLocation, true);
265 test(t, outputLocation);
266 }
267
268 /** runs a trained model against a set of input data with known results and evaluates the performance. Assumes that
269 * {@link #train} has already been called to train the model. For convenience, the data is passed in as a training set,
270 * although the model is not trained. The input is used to create a set of predictions and then those predictions are
271 * compared against the expected outputs. The result of the prediction is passed to the output handler which can compare
272 * the predicted versus the expected values
273 * @param data a list of training sequences to use for training
274 */
275 public void test(List<? extends TrainingSequence<?>> data, String location) throws IOException {
276 print("Beginning test");
277 printWeights();
278 outputHandler.setOutputLocation(location);
279 for (TrainingSequence dr : data) {
280 InferenceResult predictedHiddenSequence = predict(dr);
281 outputHandler.writeTestOutput(dr.getInputSequence(), dr.getY(), predictedHiddenSequence.hiddenStates);
282 }
283 print("Testing complete");
284 outputHandler.outputComplete();
285 }
286
287 /** preforms inference on the input sequence and determines the best labeling for the sequence using
288 * the configured inference algorithm.
289 * @param data the input sequence the engine will use for inference
290 * @return an inference result containing the predicted hidden states
291 */
292 public InferenceResult predict(InputSequence data) {
293 return inference.predict(model, data, weights);
294 }
295
296 /** sets feature weights. Usually these weights are determined during the training process, but they can be set directly.
297 * @param weights an array of doubles containing one weight for each feature.
298 */
299 public void setWeights(double[] weights) {
300 this.weights = weights;
301 }
302
303 /** looks up a feature's name given it's index
304 * @param index index of the feature
305 * @return name of the feature
306 */
307 public String getFeatureName(int index) {
308 return model.getFeatureName(index);
309 }
310
311 /** returns the the total number of seconds used in training. This is the sum of the time to train the features and the time to train the weights.
312 * This is set when each phase of the training (features & weights) is completed.
313 * @return total training time
314 */
315 public double getTrainingTime() {
316 return trainingTime;
317 }
318
319 /** returns the number of individual features in the model. This may differ from the number of <code>FeatureManager</code>s because each
320 * <code>FeatureManager</code> may have 0, 1, or many features associated with it.
321 * @return total number of features in the model.
322 */
323 public int getNumFeatures() {
324 return model.getNumFeatures();
325 }
326
327 /** returns the number of hidden states in the model
328 * @return number of hidden states in the model
329 */
330 public int getNumStates() {
331 return model.getNumStates();
332 }
333
334 /** looks up the name of a state given it's index.
335 * @return string name of the state with this index.
336 */
337 public String getStateName(int state) {
338 return model.getStateName(state);
339 }
340
341 /** returns the configured ModelManager object.
342 * @return the model manager which contains the features and hidden state configuration
343 */
344 public ModelManager getModel() {
345 return model;
346 }
347
348 /** returns the configured numerical optimizer which will be used to
349 * select the optimal feature weights during training.
350 * @return the configured objective function gradient
351 */
352 public CRFTraining getOptimizer() {
353 return optimizer;
354 }
355
356 /** returns the feature weights. These will be valid once the modle is trained.
357 * @return an array of doubles containing the weight for each feature. It will be the
358 * same length as returned by {@link #getNumFeatures()}
359 */
360 public double[] getWeights() {
361 return weights;
362 }
363
364 /** returns the configured inference algorithm which will be used
365 * to predict hidden states for new inputs once the model is trained.
366 * @return the configured inference algorithm
367 */
368 public CRFInference getInference() {
369 return inference;
370 }
371
372 /** sets the inference algorithm. Called automatically during configuration. */
373 public void setInference(CRFInference inference) {
374 this.inference = inference;
375 }
376
377 /** sets the model. Called automatically during configuration. */
378 public void setModel(ModelManager model) {
379 this.model = model;
380 }
381
382 /** sets the numerical optimizer. Called automatically during configuration. */
383 public void setOptimizer(CRFTraining optimizer) {
384 this.optimizer = optimizer;
385 }
386
387 /** Returns a formatted string listing the weights. Useful for debugging.
388 * @return a string containing a human readable list of the feature weights
389 * */
390 public String printWeights() {
391 StringBuilder sb = new StringBuilder();
392 for(int i=0; i<weights.length; ++i) {
393 sb.append(String.format("%.5f\t%s\n", weights[i], getFeatureName(i)));
394 }
395 return sb.toString();
396 }
397
398 private void print(String msg) {
399 System.out.println(msg);
400 }
401
402 private ApplicationContext initSolver() {
403 GenericApplicationContext ctx = new GenericApplicationContext();
404 XmlBeanDefinitionReader xmlReader = new XmlBeanDefinitionReader(ctx);
405 xmlReader.loadBeanDefinitions(new ByteArrayResource(configXml));
406 ctx.refresh();
407 inference = (CRFInference) ctx.getBean("inference");
408 optimizer = (CRFTraining) ctx.getBean("optimizer");
409 return ctx;
410 }
411
412 private static void usage() {
413 System.out.println(" Conrad train(Features) configFile data modelFile");
414 System.out.println(" or Conrad trainWeights modelFileIn data modelFileOut");
415 System.out.println(" or Conrad test modelFile inputData outputData");
416 System.out.println(" or Conrad predict modelFile inputData outputData");
417 System.exit(-1);
418 }
419
420 /** returns the configured input handler.
421 * @return the input handler for this model
422 */
423 public InputHandler getInputHandler() {
424 return inputHandler;
425 }
426
427 /** sets the configured input handler. Must be set before any train or test methods are called. Usually called during config based on setup in the XML file.
428 * @param inputHandler the input handler for this model
429 */
430 public void setInputHandler(InputHandler inputHandler) {
431 this.inputHandler = inputHandler;
432 }
433
434 /** gets the configured output handler. Must be set before any test methods are called.
435 * @return Returns the outputHandler.
436 */
437 public OutputHandler getOutputHandler() {
438 return outputHandler;
439 }
440
441 /** sets the configured output handler. WillMust be set before any test methods are called.
442 * @param outputHandler the output handler for this model
443 */
444 public void setOutputHandler(OutputHandler outputHandler) {
445 this.outputHandler = outputHandler;
446 }
447 }