001    package calhoun.analysis.crf.test;
002    
003    import java.util.List;
004    
005    import org.apache.commons.logging.Log;
006    import org.apache.commons.logging.LogFactory;
007    
008    import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
009    import calhoun.analysis.crf.Conrad;
010    import calhoun.analysis.crf.io.InputHandlerInterleaved;
011    import calhoun.analysis.crf.io.IntInput;
012    import calhoun.analysis.crf.io.OutputHandlerGeneCallStats;
013    import calhoun.analysis.crf.io.StringInput;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.analysis.crf.solver.MaximumLikelihoodGradient;
016    import calhoun.analysis.crf.solver.SimplexOptimizer;
017    import calhoun.analysis.crf.solver.StandardOptimizer;
018    import calhoun.analysis.crf.solver.Viterbi;
019    import calhoun.util.AbstractTestCase;
020    
021    public class NonGradientTest extends AbstractTestCase {
022            static final Log log = LogFactory.getLog(NonGradientTest.class);
023            boolean debug = log.isDebugEnabled();
024            
025            public void testNoGradient() throws Exception {
026                    SimplexOptimizer opt = new SimplexOptimizer();
027                    opt.setStepSize(0.000001);
028                    List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111");
029                    opt.setObjectiveFunction(new MaximumLikelihoodGradient());
030                    double[] weights = opt.optimize(new TestFeatureManager(0), data);
031                    assertEquals(1.0, weights[0], 0.00001);
032            }
033    
034            public void testConvergence() throws Exception {
035                    SimplexOptimizer opt = new SimplexOptimizer();
036                    opt.setStepSize(0.000001);
037                    opt.setMaxIters(2);
038                    opt.setRequireConvergence(true);
039                    opt.setObjectiveFunction(new MaximumLikelihoodGradient());
040                    List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111");
041                    boolean ex = false;
042                    try {
043                            opt.optimize(new TestFeatureManager(0), data);
044                    }
045                    catch(Exception e) {
046                            ex = true;
047                    }
048                    assertTrue(ex);
049            }
050    
051            public void testGradientEdge() throws Exception {
052                    List<? extends TrainingSequence<?>> data = IntInput.prepareData("00110\n00110");
053                    CRFObjectiveFunctionGradient obj = new MaximumLikelihoodGradient();
054    
055                    SimplexOptimizer opt = new SimplexOptimizer();
056                    opt.setObjectiveFunction(obj);
057                    double[] weightsSimplex = opt.optimize(new TestFeatureManager(1), data);
058    
059                    StandardOptimizer stdOpt = new StandardOptimizer(); 
060                    stdOpt.setObjectiveFunction(obj);
061                    double[] weightsStandard = stdOpt.optimize(new TestFeatureManager(1), data);
062                    assertArrayEquals(weightsStandard, weightsSimplex, 1e-5);
063            }
064    
065            public void testViterbi() throws Exception {
066                    double[] weights = new double[] { -9.639248, 9.659248, -9.982040e+00, 1.000204e+01, -1.534077e+01, 2.854132e+01, -1.301635e+01, -3.040209e+00, -3.383001e+00 , 1.947956e+01 };
067                    Conrad crf = new Conrad();
068                    crf.setModel(new ZeroOrderManager());
069                    crf.setInputHandler(new InputHandlerInterleaved(new StringInput()));
070                    OutputHandlerGeneCallStats stats = new OutputHandlerGeneCallStats(crf.getModel(), crf.getInputHandler());
071                    stats.setWriteTrainingData(true);
072                    crf.setOutputHandler(stats);
073                    crf.setInference(new Viterbi());
074    
075                    crf.trainFeatures("test/input/zeroOrderTest.txt");
076                    crf.setWeights(weights);
077                    crf.test("test/input/zeroOrderTest.txt", "test/working/zeroOrderPredicted.txt");
078                    assertFilesMatch("test/output/zeroOrderPredicted.txt", "test/working/zeroOrderPredicted.txt");
079                    //log.info(results);
080            }
081    }