001 package calhoun.analysis.crf.test;
002
003 import java.util.List;
004
005 import org.apache.commons.logging.Log;
006 import org.apache.commons.logging.LogFactory;
007
008 import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
009 import calhoun.analysis.crf.Conrad;
010 import calhoun.analysis.crf.io.InputHandlerInterleaved;
011 import calhoun.analysis.crf.io.IntInput;
012 import calhoun.analysis.crf.io.OutputHandlerGeneCallStats;
013 import calhoun.analysis.crf.io.StringInput;
014 import calhoun.analysis.crf.io.TrainingSequence;
015 import calhoun.analysis.crf.solver.MaximumLikelihoodGradient;
016 import calhoun.analysis.crf.solver.SimplexOptimizer;
017 import calhoun.analysis.crf.solver.StandardOptimizer;
018 import calhoun.analysis.crf.solver.Viterbi;
019 import calhoun.util.AbstractTestCase;
020
021 public class NonGradientTest extends AbstractTestCase {
022 static final Log log = LogFactory.getLog(NonGradientTest.class);
023 boolean debug = log.isDebugEnabled();
024
025 public void testNoGradient() throws Exception {
026 SimplexOptimizer opt = new SimplexOptimizer();
027 opt.setStepSize(0.000001);
028 List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111");
029 opt.setObjectiveFunction(new MaximumLikelihoodGradient());
030 double[] weights = opt.optimize(new TestFeatureManager(0), data);
031 assertEquals(1.0, weights[0], 0.00001);
032 }
033
034 public void testConvergence() throws Exception {
035 SimplexOptimizer opt = new SimplexOptimizer();
036 opt.setStepSize(0.000001);
037 opt.setMaxIters(2);
038 opt.setRequireConvergence(true);
039 opt.setObjectiveFunction(new MaximumLikelihoodGradient());
040 List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111");
041 boolean ex = false;
042 try {
043 opt.optimize(new TestFeatureManager(0), data);
044 }
045 catch(Exception e) {
046 ex = true;
047 }
048 assertTrue(ex);
049 }
050
051 public void testGradientEdge() throws Exception {
052 List<? extends TrainingSequence<?>> data = IntInput.prepareData("00110\n00110");
053 CRFObjectiveFunctionGradient obj = new MaximumLikelihoodGradient();
054
055 SimplexOptimizer opt = new SimplexOptimizer();
056 opt.setObjectiveFunction(obj);
057 double[] weightsSimplex = opt.optimize(new TestFeatureManager(1), data);
058
059 StandardOptimizer stdOpt = new StandardOptimizer();
060 stdOpt.setObjectiveFunction(obj);
061 double[] weightsStandard = stdOpt.optimize(new TestFeatureManager(1), data);
062 assertArrayEquals(weightsStandard, weightsSimplex, 1e-5);
063 }
064
065 public void testViterbi() throws Exception {
066 double[] weights = new double[] { -9.639248, 9.659248, -9.982040e+00, 1.000204e+01, -1.534077e+01, 2.854132e+01, -1.301635e+01, -3.040209e+00, -3.383001e+00 , 1.947956e+01 };
067 Conrad crf = new Conrad();
068 crf.setModel(new ZeroOrderManager());
069 crf.setInputHandler(new InputHandlerInterleaved(new StringInput()));
070 OutputHandlerGeneCallStats stats = new OutputHandlerGeneCallStats(crf.getModel(), crf.getInputHandler());
071 stats.setWriteTrainingData(true);
072 crf.setOutputHandler(stats);
073 crf.setInference(new Viterbi());
074
075 crf.trainFeatures("test/input/zeroOrderTest.txt");
076 crf.setWeights(weights);
077 crf.test("test/input/zeroOrderTest.txt", "test/working/zeroOrderPredicted.txt");
078 assertFilesMatch("test/output/zeroOrderPredicted.txt", "test/working/zeroOrderPredicted.txt");
079 //log.info(results);
080 }
081 }