001    package calhoun.analysis.crf;
002    
003    import java.io.IOException;
004    import java.io.Serializable;
005    import java.util.List;
006    
007    import org.apache.commons.logging.Log;
008    import org.apache.commons.logging.LogFactory;
009    import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
010    import org.springframework.context.ApplicationContext;
011    import org.springframework.context.support.GenericApplicationContext;
012    import org.springframework.core.io.ByteArrayResource;
013    
014    import calhoun.analysis.crf.CRFInference.InferenceResult;
015    import calhoun.analysis.crf.io.InputHandler;
016    import calhoun.analysis.crf.io.InputSequence;
017    import calhoun.analysis.crf.io.OutputHandler;
018    import calhoun.analysis.crf.io.OutputHandlerGeneCallPredict;
019    import calhoun.analysis.crf.io.OutputHandlerGeneCallStats;
020    import calhoun.analysis.crf.io.TrainingSequence;
021    import calhoun.analysis.crf.io.CompositeInput.LegacyInputHandler;
022    import calhoun.util.ErrorException;
023    import calhoun.util.FileUtil;
024    
025    /** the central class for the Conrad engine.  Has a <code>main</code> function for calling Conrad
026     * from the command line and a high-level public interface for programmatic operations.  This class
027     * is mostly just a container which delegates the real work to the various objects set up in the
028     * configuration.
029     * */
030    public class Conrad implements Serializable {
031            private static final long serialVersionUID = -5964610632818921236L;
032            private static final Log log = LogFactory.getLog(Conrad.class);
033            
034            byte[] configXml;
035            ModelManager model;
036            InputHandler inputHandler;
037            OutputHandler outputHandler;
038            transient CRFInference inference;
039            transient CRFTraining optimizer;
040            
041            double[] weights = null;
042            double trainingTime = 0;
043            
044            /** Command line entry point for running CRFs.  Can train, test, or predict depending on arguments.  
045             * @param args list of command line arguments.  See usage for details. */
046            public static void main(String[] args) throws Exception {
047                    if(args.length != 4) {
048                            usage();
049                    }
050                    if(args[0].startsWith("train")) {
051                            Conrad crf;
052                            if(args[0].equals("trainWeights")) {
053                                    crf = Conrad.read(args[1]);
054                                    crf.trainWeights(args[2]);
055                            }
056                            else {
057                                    crf = new Conrad(args[1]);
058                                    if(args[0].equals("trainFeatures")) {
059                                            crf.trainFeatures(args[2]);
060                                    }
061                                    else {
062                                            crf.train(args[2]);                     // train features and weights
063                                    }
064                            }
065                            crf.write(args[3]);
066                    }
067                    else if(args[0].equals("test")) {
068                            Conrad crf = Conrad.read(args[1]);
069                            crf.initSolver();
070                            crf.test(args[2], args[3]);
071                    }
072                    else if(args[0].equals("predict")) {
073                            Conrad crf = Conrad.read(args[1]);
074                            OutputHandlerGeneCallPredict predictOutputHandler = new OutputHandlerGeneCallPredict();
075                            predictOutputHandler.setWriteTrainingData(false);
076                            predictOutputHandler.setManager(crf.getModel());
077                            predictOutputHandler.setInputHandler(crf.getInputHandler());
078                            crf.setOutputHandler(predictOutputHandler);
079                            crf.initSolver();
080                            crf.testWithoutAnswers(args[2], args[3]);
081                    }
082                    else {
083                            usage();
084                    }
085            }
086    
087            /** creates a Conrad engine with no configuration.  All configuration must be done programmatically. 
088             */
089            public Conrad() {
090            }
091            
092            /** creates a Conrad engine based on configuration information from an XML model file.
093             * @param configFile string filename of the XML model file  */
094            @SuppressWarnings("deprecation")
095            public Conrad(String configFile) {
096                    try {
097                            configXml = FileUtil.readFileAsBytes(configFile);
098                    }
099                    catch(IOException ex) {
100                            throw new RuntimeException("Failed to load config file: "+configFile, ex);
101                    }
102                    ApplicationContext ctx = initSolver();
103                    model = (ModelManager) ctx.getBean("model");
104                    
105                    // For backwards compatibility, if we don't have an inputHandler but we do have an input format, use a legacy handler.
106                    if(ctx.containsBean("inputFormat") && !ctx.containsBean("inputHandler")) {
107                            inputHandler = new LegacyInputHandler(ctx.getBean("inputFormat"));
108                            
109                            if(ctx.containsBean("outputHandler")) {
110                                    outputHandler = (OutputHandler)ctx.getBean("outputHandler");                            
111                            } else {
112                                    OutputHandlerGeneCallStats legacyOutputHandler = new OutputHandlerGeneCallStats();
113                                    legacyOutputHandler.setWriteTrainingData(true);
114                                    legacyOutputHandler.setManager(model);
115                                    legacyOutputHandler.setInputHandler(inputHandler);      
116                                    outputHandler = legacyOutputHandler;
117                            }                       
118                    }
119                    else {
120                            inputHandler = (InputHandler) ctx.getBean("inputHandler");
121                            outputHandler = (OutputHandler) ctx.getBean("outputHandler");
122                    }
123            }
124    
125            /**  writes this Conrad engine to a file.  This is most often used to save a trained model file.
126             * 
127             * @param filename string name of the file that will contain the serialized model.
128             * @throws IOException if a problem occurs writing to the file
129             */
130            public void write(String filename) throws IOException {
131                    FileUtil.writeObject(filename, this);
132            }
133    
134            /** read in a Conrad engine from a file.  This file must have previously been created by calling {@link #write}
135             * 
136             * @param filename string name of the file containing the model.
137             * @return the Conrad engine which has been read from the file
138             * @throws IOException if there is a problem reading the file
139             */
140            public static Conrad read(String filename) throws IOException {
141                    try {
142                            Conrad ret = (Conrad) FileUtil.readObject(filename);
143                            ret.initSolver();
144                            return ret;
145                    } catch (ClassNotFoundException ex) {
146                            throw new ErrorException(ex);
147                    }
148            }
149    
150            /** fully trains this Conrad engine with this training data.  The training data is specified as a string location,
151             * which the configured {@link InputHandler } is responsible for converting into a list of training sequences
152             * of the appropriate type.
153             * 
154             * @param location string location of the data.  The exact meaning will be determined by the InputHandler.
155             * @throws IOException if there is a problem reading the training data.
156             */
157            public void train(String location) throws IOException {
158                    List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(location, false);
159                    train(t);
160            }
161    
162            
163            /** fully trains this Conrad engine with this training data.  The training data is specified as a list of 
164             * training sequences, and no DataInputHandler is used.
165             * @param data a list of training sequences to use for training
166             */
167            public void train(List<? extends TrainingSequence<?>> data) {
168                    trainFeatures(data);
169                    trainWeights(data);
170            }
171            
172            /** trains only the features in the current model with this training data.  {@link FeatureManager#train } is called
173             * for each feature in the model, but no optimization is performed and no feature weights are set.  This allows the
174             * features themselves to be parameterized on one set of training data, while using a different set for optimizing the
175             * feature weights.<p>
176             * The training data is specified as a string location,
177             * which the configured {@link InputHandler } is responsible for converting into a list of training sequences
178             * of the appropriate type.
179             * 
180             * @param location string location of the data.  The exact meaning will be determined by the InputHandler.
181             * @throws IOException if there is a problem reading the training data
182             */
183            public void trainFeatures(String location) throws IOException {
184                    List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(location, false);
185                    trainFeatures(t);
186            }
187    
188            /** trains only the features in the current model with this training data.  {@link FeatureManager#train } is called
189             * for each feature in the model, but no optimization is performed and no feature weights are set.  This allows the
190             * features themselves to be parameterized on one set of training data, while using a different set for optimizing the
191             * feature weights.<p>
192             * The training data is specified as a list of 
193             * training sequences, and no DataInputHandler is used.  
194             * @param data a list of training sequences to use for training
195             */
196            public void trainFeatures(List<? extends TrainingSequence<?>> data) {
197                    print("Training features");
198                    double timer = System.currentTimeMillis();
199    
200                    // Train features
201                    model.train(0, model, data);
202                    if (log.isDebugEnabled()) {
203                            log.debug("Features:");
204                            for (int i = 0; i < model.getNumFeatures(); ++i) {
205                                    log.debug(model.getFeatureName(i));
206                            }
207                    }
208                    
209                    trainingTime = (System.currentTimeMillis() - timer)/1000;
210                    print("Trained in "+trainingTime+" seconds.");
211            }
212    
213            /** optimizes the feature weights for the current model with this training data.  Assumes that {@link #trainFeatures } 
214             * has already been called to train the individual features.  <p>
215             * The training data is specified as a list of  training sequences, and no DataInputHandler is used.      
216             * @param location string location of the data.  The exact meaning will be determined by the InputHandler.
217             * @throws IOException if there is a problem reading the training data
218             */
219            public void trainWeights(String location) throws IOException {
220                    List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(location, false);
221                    trainWeights(t);
222            }
223    
224            /** optimizes the feature weights for the current model with this training data.  Assumes that {@link #trainFeatures } 
225             * has already been called to train the individual features.  <p>
226             * The training data is specified as a list of  training sequences, and no DataInputHandler is used.     
227             * @param data a list of training sequences to use for training
228             */
229            public void trainWeights(List<? extends TrainingSequence<?>> data) {
230                    print("Training weights");
231                    double timer = System.currentTimeMillis();
232    
233                    // Train weights
234                    weights = optimizer.optimize(model, data);
235    
236                    timer =  (System.currentTimeMillis() - timer)/1000;
237                    trainingTime += timer;
238                    print("Trained weights in "+timer+" seconds.  "+trainingTime+" total.");
239            }
240    
241            /** runs a trained model against a set of input data with known results and evaluates the performance.  Assumes that
242             * {@link #train} has already been called to train the model.  For convenience, the data is passed in as a training set,
243             * although the model is not trained.  The input is used to create a set of predictions and then those predictions are
244             * compared against the expected outputs.  The result of the prediction is passed to the output handler which can compare
245             * the predicted versus the expected values
246             * @param inputLocation string location of the data.  The exact meaning will be determined by the InputHandler.
247             * @throws IOException if there is a problem reading the training data
248             */
249            public void test(String inputLocation) throws IOException {
250                    List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(inputLocation, false);
251                    test(t);
252            }
253    
254            public void test(String inputLocation, String outputLocation) throws IOException {
255                    List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(inputLocation, false);
256                    test(t, outputLocation);
257            }
258    
259            public void test(List<? extends TrainingSequence<?>> data) throws IOException {
260                    test(data, null);
261            }
262            
263            public void testWithoutAnswers(String inputLocation, String outputLocation) throws IOException {
264                    List<? extends TrainingSequence<?>> t = inputHandler.readTrainingData(inputLocation, true);
265                    test(t, outputLocation);
266            }
267            
268            /** runs a trained model against a set of input data with known results and evaluates the performance.  Assumes that
269             * {@link #train} has already been called to train the model.  For convenience, the data is passed in as a training set,
270             * although the model is not trained.  The input is used to create a set of predictions and then those predictions are
271             * compared against the expected outputs.  The result of the prediction is passed to the output handler which can compare
272             * the predicted versus the expected values
273             * @param data a list of training sequences to use for training
274             */
275            public void test(List<? extends TrainingSequence<?>> data, String location) throws IOException {
276                    print("Beginning test");
277                    printWeights();
278                    outputHandler.setOutputLocation(location);
279                    for (TrainingSequence dr : data) {
280                            InferenceResult predictedHiddenSequence = predict(dr);
281                            outputHandler.writeTestOutput(dr.getInputSequence(), dr.getY(), predictedHiddenSequence.hiddenStates);
282                    }
283                    print("Testing complete");
284                    outputHandler.outputComplete();
285            }
286    
287            /** preforms inference on the input sequence and determines the best labeling for the sequence using
288             * the configured inference algorithm.
289             * @param data the input sequence the engine will use for inference
290             * @return an inference result containing the predicted hidden states
291             */
292            public InferenceResult predict(InputSequence data) {
293                    return inference.predict(model, data, weights);
294            }
295    
296            /** sets feature weights.  Usually these weights are determined during the training process, but they can be set directly.
297             * @param weights an array of doubles containing one weight for each feature.
298             */
299            public void setWeights(double[] weights) {
300                    this.weights = weights;
301            }
302            
303            /** looks up a feature's name given it's index
304             * @param index index of the feature
305             * @return name of the feature
306             */
307            public String getFeatureName(int index) {
308                    return model.getFeatureName(index);
309            }
310    
311            /** returns the the total number of seconds used in training.  This is the sum of the time to train the features and the time to train the weights.
312             * This is set when each phase of the training (features & weights) is completed.
313             * @return total training time
314             */
315            public double getTrainingTime() {
316                    return trainingTime;
317            }
318    
319            /** returns the number of individual features in the model.  This may differ from the number of <code>FeatureManager</code>s because each
320             * <code>FeatureManager</code> may have 0, 1, or many features associated with it.
321             * @return total number of features in the model.
322             */
323            public int getNumFeatures() {
324                    return model.getNumFeatures();
325            }
326    
327            /** returns the number of hidden states in the model
328             * @return number of hidden states in the model
329             */
330            public int getNumStates() {
331                    return model.getNumStates();
332            }
333    
334            /** looks up the name of a state given it's index.
335             * @return string name of the state with this index.
336             */
337            public String getStateName(int state) {
338                    return model.getStateName(state);
339            }
340    
341            /** returns the configured ModelManager object.
342             * @return the model manager which contains the features and hidden state configuration
343             */
344            public ModelManager getModel() {
345                    return model;
346            }
347            
348            /** returns the configured numerical optimizer which will be used to
349             * select the optimal feature weights during training.
350             * @return the configured objective function gradient
351             */
352            public CRFTraining getOptimizer() {
353                    return optimizer;
354            }
355            
356            /** returns the feature weights.  These will be valid once the modle is trained.
357             * @return an array of doubles containing the weight for each feature.  It will be the
358             * same length as returned by {@link #getNumFeatures()}
359             */
360            public double[] getWeights() {
361                    return weights;
362            }
363    
364            /** returns the configured inference algorithm which will be used 
365             * to predict hidden states for new inputs once the model is trained.
366             * @return the configured inference algorithm
367             */
368            public CRFInference getInference() {
369                    return inference;
370            }
371    
372            /** sets the inference algorithm.  Called automatically during configuration. */
373            public void setInference(CRFInference inference) {
374                    this.inference = inference;
375            }
376    
377            /** sets the model.  Called automatically during configuration. */
378            public void setModel(ModelManager model) {
379                    this.model = model;
380            }
381    
382            /** sets the numerical optimizer.  Called automatically during configuration. */
383            public void setOptimizer(CRFTraining optimizer) {
384                    this.optimizer = optimizer;
385            }
386    
387            /** Returns a formatted string listing the weights.  Useful for debugging. 
388             * @return a string containing a human readable list of the feature weights
389             * */
390            public String printWeights() {
391                    StringBuilder sb = new StringBuilder();
392                    for(int i=0; i<weights.length; ++i) {
393                            sb.append(String.format("%.5f\t%s\n", weights[i], getFeatureName(i)));
394                    }
395                    return sb.toString();
396            }
397    
398            private void print(String msg) {
399                    System.out.println(msg);
400            }
401    
402            private ApplicationContext initSolver() {
403                    GenericApplicationContext ctx = new GenericApplicationContext();
404                    XmlBeanDefinitionReader xmlReader = new XmlBeanDefinitionReader(ctx);
405                    xmlReader.loadBeanDefinitions(new ByteArrayResource(configXml));
406                    ctx.refresh();
407                    inference = (CRFInference) ctx.getBean("inference");
408                    optimizer = (CRFTraining) ctx.getBean("optimizer");
409                    return ctx;
410            }
411            
412            private static void usage() {
413                    System.out.println("       Conrad train(Features) configFile data modelFile");
414                    System.out.println(" or    Conrad trainWeights modelFileIn data modelFileOut");
415                    System.out.println(" or    Conrad test modelFile inputData outputData");
416                    System.out.println(" or    Conrad predict modelFile inputData outputData");
417                    System.exit(-1);
418            }
419    
420            /** returns the configured input handler.
421             * @return the input handler for this model
422             */
423            public InputHandler getInputHandler() {
424                    return inputHandler;
425            }
426    
427            /** sets the configured input handler.  Must be set before any train or test methods are called.  Usually called during config based on setup in the XML file.
428             * @param inputHandler the input handler for this model
429             */
430            public void setInputHandler(InputHandler inputHandler) {
431                    this.inputHandler = inputHandler;
432            }
433    
434            /** gets the configured output handler.  Must be set before any test methods are called.
435             * @return Returns the outputHandler.
436             */
437            public OutputHandler getOutputHandler() {
438                    return outputHandler;
439            }
440    
441            /** sets the configured output handler.  WillMust be set before any test methods are called.
442             * @param outputHandler the output handler for this model
443             */
444            public void setOutputHandler(OutputHandler outputHandler) {
445                    this.outputHandler = outputHandler;
446            }
447    }