001    package calhoun.analysis.crf.test;
002    
003    import java.util.List;
004    
005    import org.apache.commons.logging.Log;
006    import org.apache.commons.logging.LogFactory;
007    
008    import calhoun.analysis.crf.Conrad;
009    import calhoun.analysis.crf.CRFInference.InferenceResult;
010    import calhoun.analysis.crf.io.TrainingSequence;
011    import calhoun.util.AbstractTestCase;
012    import calhoun.util.Assert;
013    import cern.colt.matrix.impl.DenseDoubleMatrix2D;
014    
015    public class SemiMarkovTest extends AbstractTestCase {
016            @SuppressWarnings("unused")
017            private static final Log log = LogFactory.getLog(SemiMarkovTest.class);
018            
019            public void testSemiMarkovBadTraining() throws Exception {
020                    // Assert that trying to train the semi-Markov fails when training data is used that contains segments longer than the maximum allowed length
021                    String failureMessage = null;
022                    try {
023                            Conrad.main(new String[] {"train", "test/input/zeroOrderLBFGSCachedSemiMarkov.xml", "test/input/zeroOrderTest.txt", "test/working/zeroLBGFSModelCachedSemiMarkov.ser"});
024                    }
025                    catch(Exception ex) {
026                            failureMessage = ex.getMessage();
027                    }
028                    assertEquals("Seq #0 Pos 150 Training segment 127 is longer than allowed length 20", failureMessage);
029            }
030    
031            public void testSemiCRFViterbiCompareWithBaseClass() throws Exception {
032                    Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengths.xml");
033                    r.trainFeatures("test/input/zeroOrderTrivial.txt");
034                    r.setWeights(new double[] {1,1,1});
035    
036                    Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
037                    s.trainFeatures("test/input/zeroOrderTrivial.txt");
038                    s.setWeights(new double[] {1,1,1});
039                    
040                    doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
041                    doViterbiComparison("test/input/zeroOrderTest.txt", r, s);
042            }
043            
044            public void testSemiCRFViterbiCompareWithBaseClassWithLengths() throws Exception {
045                    Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.xml");
046                    r.trainFeatures("test/input/zeroOrderTrivial.txt");
047                    r.setWeights(new double[] {1,1,1});
048    
049                    Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
050                    s.trainFeatures("test/input/zeroOrderTrivial.txt");
051                    s.setWeights(new double[] {1,1,1});
052                    
053                    doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
054            }
055            
056            public void testSemiCRFViterbiCompareWithBaseClassWithFeatures() throws Exception {
057                    Conrad r = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
058                    r.trainFeatures("test/input/zeroOrderTrivial.txt");
059                    r.setWeights(new double[] {1,1,1});
060    
061                    Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
062                    //Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.txt");
063                    s.trainFeatures("test/input/zeroOrderTrivial.txt");
064                    s.setWeights(new double[] {1,1,1});
065                    
066                    doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
067            }
068    
069            public void testSemiCRFCompareWithBaseClass() throws Exception {
070                    Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
071                    s.train("test/input/zeroOrderTrivial.txt");
072                    
073                    Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengths.xml");
074                    r.train("test/input/zeroOrderTrivial.txt");
075                    assertEquals(r.getWeights()[0], s.getWeights()[0], 0.0001);
076    
077                    r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.xml");
078                    r.train("test/input/zeroOrderTrivial.txt");
079                    assertEquals(r.getWeights()[0], s.getWeights()[0], 0.0001);
080    
081                    r = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
082                    r.train("test/input/zeroOrderTrivial.txt");
083                    assertEquals(r.getWeights()[0], s.getWeights()[0], 0.001);
084            }
085            
086            void doViterbiComparison(String file, Conrad a, Conrad b) throws Exception {
087                    double[] M,N;
088                    
089                    List<? extends TrainingSequence<?>> train = a.getInputHandler().readTrainingData(file);
090                    for(TrainingSequence<?> seq : train) {
091                            InferenceResult r1 = a.predict(seq.getInputSequence());
092                            InferenceResult r2 = b.predict(seq.getInputSequence());
093                            M = r1.bestScores;
094                            N = r2.bestScores;
095                            Assert.a(M.length == a.getNumStates());
096                            
097                            // Scores should be identical
098                            for(int r = 0; r<M.length; ++r) {
099                                            assertEquals("State: "+r+" ", N[r], M[r], 0.0001);
100                            }
101                            
102                            // Final paths are not identical due to floating point rounding errors
103                            //for(int i=0; i < r1.hiddenStates.length; ++i) {
104                            //      assertEquals("Difference at pos "+i, r1.hiddenStates[i], r2.hiddenStates[i]);
105                            //}
106                    }
107            }
108            
109            public void testColt() {
110                    DenseDoubleMatrix2D f = new DenseDoubleMatrix2D(2,2);
111                    DenseDoubleMatrix2D g = new DenseDoubleMatrix2D(2,2);
112                    f.setQuick(1,1,3);
113                    g.setQuick(1,1,3);
114                    assertEquals(f,g);
115            }
116    }