001 package calhoun.analysis.crf.test;
002
003 import java.util.List;
004
005 import org.apache.commons.logging.Log;
006 import org.apache.commons.logging.LogFactory;
007
008 import calhoun.analysis.crf.Conrad;
009 import calhoun.analysis.crf.CRFInference.InferenceResult;
010 import calhoun.analysis.crf.io.TrainingSequence;
011 import calhoun.util.AbstractTestCase;
012 import calhoun.util.Assert;
013 import cern.colt.matrix.impl.DenseDoubleMatrix2D;
014
015 public class SemiMarkovTest extends AbstractTestCase {
016 @SuppressWarnings("unused")
017 private static final Log log = LogFactory.getLog(SemiMarkovTest.class);
018
019 public void testSemiMarkovBadTraining() throws Exception {
020 // Assert that trying to train the semi-Markov fails when training data is used that contains segments longer than the maximum allowed length
021 String failureMessage = null;
022 try {
023 Conrad.main(new String[] {"train", "test/input/zeroOrderLBFGSCachedSemiMarkov.xml", "test/input/zeroOrderTest.txt", "test/working/zeroLBGFSModelCachedSemiMarkov.ser"});
024 }
025 catch(Exception ex) {
026 failureMessage = ex.getMessage();
027 }
028 assertEquals("Seq #0 Pos 150 Training segment 127 is longer than allowed length 20", failureMessage);
029 }
030
031 public void testSemiCRFViterbiCompareWithBaseClass() throws Exception {
032 Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengths.xml");
033 r.trainFeatures("test/input/zeroOrderTrivial.txt");
034 r.setWeights(new double[] {1,1,1});
035
036 Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
037 s.trainFeatures("test/input/zeroOrderTrivial.txt");
038 s.setWeights(new double[] {1,1,1});
039
040 doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
041 doViterbiComparison("test/input/zeroOrderTest.txt", r, s);
042 }
043
044 public void testSemiCRFViterbiCompareWithBaseClassWithLengths() throws Exception {
045 Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.xml");
046 r.trainFeatures("test/input/zeroOrderTrivial.txt");
047 r.setWeights(new double[] {1,1,1});
048
049 Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
050 s.trainFeatures("test/input/zeroOrderTrivial.txt");
051 s.setWeights(new double[] {1,1,1});
052
053 doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
054 }
055
056 public void testSemiCRFViterbiCompareWithBaseClassWithFeatures() throws Exception {
057 Conrad r = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
058 r.trainFeatures("test/input/zeroOrderTrivial.txt");
059 r.setWeights(new double[] {1,1,1});
060
061 Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
062 //Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.txt");
063 s.trainFeatures("test/input/zeroOrderTrivial.txt");
064 s.setWeights(new double[] {1,1,1});
065
066 doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
067 }
068
069 public void testSemiCRFCompareWithBaseClass() throws Exception {
070 Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
071 s.train("test/input/zeroOrderTrivial.txt");
072
073 Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengths.xml");
074 r.train("test/input/zeroOrderTrivial.txt");
075 assertEquals(r.getWeights()[0], s.getWeights()[0], 0.0001);
076
077 r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.xml");
078 r.train("test/input/zeroOrderTrivial.txt");
079 assertEquals(r.getWeights()[0], s.getWeights()[0], 0.0001);
080
081 r = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
082 r.train("test/input/zeroOrderTrivial.txt");
083 assertEquals(r.getWeights()[0], s.getWeights()[0], 0.001);
084 }
085
086 void doViterbiComparison(String file, Conrad a, Conrad b) throws Exception {
087 double[] M,N;
088
089 List<? extends TrainingSequence<?>> train = a.getInputHandler().readTrainingData(file);
090 for(TrainingSequence<?> seq : train) {
091 InferenceResult r1 = a.predict(seq.getInputSequence());
092 InferenceResult r2 = b.predict(seq.getInputSequence());
093 M = r1.bestScores;
094 N = r2.bestScores;
095 Assert.a(M.length == a.getNumStates());
096
097 // Scores should be identical
098 for(int r = 0; r<M.length; ++r) {
099 assertEquals("State: "+r+" ", N[r], M[r], 0.0001);
100 }
101
102 // Final paths are not identical due to floating point rounding errors
103 //for(int i=0; i < r1.hiddenStates.length; ++i) {
104 // assertEquals("Difference at pos "+i, r1.hiddenStates[i], r2.hiddenStates[i]);
105 //}
106 }
107 }
108
109 public void testColt() {
110 DenseDoubleMatrix2D f = new DenseDoubleMatrix2D(2,2);
111 DenseDoubleMatrix2D g = new DenseDoubleMatrix2D(2,2);
112 f.setQuick(1,1,3);
113 g.setQuick(1,1,3);
114 assertEquals(f,g);
115 }
116 }