001    package calhoun.analysis.crf.test;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
010    import calhoun.analysis.crf.ModelManager;
011    import calhoun.analysis.crf.SemiMarkovSetup;
012    import calhoun.analysis.crf.io.IntInput;
013    import calhoun.analysis.crf.io.TrainingSequence;
014    import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
015    import calhoun.analysis.crf.solver.MaximumLikelihoodGradient;
016    import calhoun.analysis.crf.solver.MaximumLikelihoodSemiMarkovGradient;
017    import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
018    import calhoun.analysis.crf.solver.check.BasicCRFGradient;
019    import calhoun.analysis.crf.solver.check.NormalizedCRFGradient;
020    import calhoun.util.AbstractTestCase;
021    import calhoun.util.ColtUtil;
022    
023    public class FunctionEvalTest extends AbstractTestCase {
024            private static final Log log = LogFactory.getLog(FunctionEvalTest.class);
025    
026            // Test function and gradient calculation using just edge features
027            public void testGradientEdgeTrivial() throws Exception {
028                    doLikelihoodTest(1, IntInput.prepareData("0\n0"), -.6931, 0.000, -.6931, 0.000);
029            }
030    
031            // / Test function and gradient calculation using just edge features
032            public void testGradientEdgeShort() throws Exception {
033                    double a = Math.log(.5 * 2 / 3.0);
034                    double b = Math.log(.5 * 2 / 3.0 * 2 / 3.0);
035                    // grad = 1/3 ln 2 = .231
036                    doLikelihoodTest(1, IntInput.prepareData("00\n00"), a/2.0, 0.2310/2.0, -.9163/2.0, 0.1386/2.0);
037                    // grad = 2/3 ln 2 = .462
038                    doLikelihoodTest(1, IntInput.prepareData("000\n000"), b/3.0, 0.4620/3.0);
039            }
040    
041            // Test function and gradient calculation using just node features
042            public void testGradStateFunc() throws Exception {
043                    doLikelihoodTest(0, IntInput.prepareData("001111\n001111"), -3.8191/6.0, 0.000/6.0, -4.1115/6.0, -.55451/6.0);
044            }
045    
046            // Test 2 features
047            public void testTwoFeaturesTrivial() throws Exception {
048                    doLikelihoodTest(2, IntInput.prepareData("00\n00"), -1.9459/2.0, -.990/2.0);
049                    doLikelihoodTest(2, IntInput.prepareData("00\n00\n00\n00"), -3.8918/4.0, -1.9804/4.0);
050            }
051    
052            // Test 2 features non-trivial
053            public void testTwoFeaturesNonTrivial() throws Exception {
054                    doLikelihoodTest(2,
055                                    IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"), -39.0440/40.0, -11.2550/40.0);
056            }
057    
058            void doLikelihoodTest(int mmNum, List<? extends TrainingSequence<?>> data, double f, double g) {
059                    doLikelihoodTest(mmNum, data, f, g, false);
060            }
061    
062            void doLikelihoodTest(int mmNum, List<? extends TrainingSequence<?>>  data, double f, double g, double h, double i) {
063                    doLikelihoodTest(mmNum, data, f, g, false);
064                    doLikelihoodTest(mmNum, data, h, i, true);
065            }
066    
067            void doLikelihoodTest(int mmNum, List<? extends TrainingSequence<?>>  data, double f, double g, boolean skewedWeights) {
068                    ModelManager m = new TestFeatureManager(mmNum);
069                    double[] weights = skewedWeights ? new double[] { 2, 0.5 } : new double[] { 1, 1 };
070                    double[] grad = new double[2];
071    
072                    CRFObjectiveFunctionGradient gradFunc;
073                    double val;
074                    AllSparseLengthCacheProcessor cacheProcessor;
075                    short[] max = new short[2];
076                    
077                    gradFunc= new BasicCRFGradient();
078                    gradFunc.setTrainingData(m, data); 
079                    val = gradFunc.apply(weights, grad);
080                    log.info("Grad(Basic): " + ColtUtil.format(grad));
081                    assertEquals(f, val, 0.001);
082                    assertEquals(g, grad[0], 0.001);
083    
084                    gradFunc = new NormalizedCRFGradient();
085                    gradFunc.setTrainingData(m, data); 
086                    val = gradFunc.apply(weights, grad);
087                    log.info("Grad(Norm): " + ColtUtil.format(grad));
088                    assertEquals(f, val, 0.001);
089                    assertEquals(g, grad[0], 0.001);
090    
091                    cacheProcessor = new AllSparseLengthCacheProcessor();
092                    cacheProcessor.setAllPaths(true);
093                    
094                    gradFunc = new MaximumLikelihoodGradient();
095                    cacheProcessor = new AllSparseLengthCacheProcessor();
096                    cacheProcessor.setAllPaths(true);
097                    ((MaximumLikelihoodGradient) gradFunc).setCacheProcessor(cacheProcessor);
098                    gradFunc.setTrainingData(m, data); 
099                    val = gradFunc.apply(weights, grad);
100                    log.info("Grad(Cache,Valid Paths): " + ColtUtil.format(grad));
101                    assertEquals(f, val, 0.001);
102                    assertEquals(g, grad[0], 0.001);
103    
104                    gradFunc = new MaximumLikelihoodGradient();
105                    cacheProcessor = new AllSparseLengthCacheProcessor();
106                    cacheProcessor.setAllPaths(false);
107                    ((MaximumLikelihoodGradient) gradFunc).setCacheProcessor(cacheProcessor);
108                    gradFunc.setTrainingData(m, data); 
109                    val = gradFunc.apply(weights, grad);
110                    log.info("Grad(Cache): " + ColtUtil.format(grad));
111                    assertEquals(f, val, 0.001);
112                    assertEquals(g, grad[0], 0.001);
113    
114                    // Semi CRF no lookback
115                    Arrays.fill(max, (short) 1);
116                    gradFunc = new MaximumLikelihoodSemiMarkovGradient();
117                    cacheProcessor = new AllSparseLengthCacheProcessor();
118                    cacheProcessor.setAllPaths(false);
119                    cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max));
120                    ((MaximumLikelihoodSemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
121                    gradFunc.setTrainingData(m, data); 
122    
123                    val = gradFunc.apply(weights, grad);
124                    log.info("Grad(Semi-cache): " + ColtUtil.format(grad));
125                    assertEquals(f, val, 0.001);
126                    assertEquals(g, grad[0], 0.001);
127    
128                    // Semi CRF lookback without length features
129                    // We need a lookback of 20 because otherwise we can't have an equivalent test between semi-Markov and regular features here.
130                    max = new short[2];
131                    Arrays.fill(max, (short) 20);
132                    m = new TestFeatureManager(mmNum, 1);
133                    /*gradFunc = new CachedSemiCRFGradient(max, false);
134                    gradFunc.setTrainingData(m, data); 
135                    val = gradFunc.apply(weights, grad);
136                    log.info("Grad(Length): " + ColtUtil.format(grad));
137                    assertEquals(f, val, 0.001);
138                    assertEquals(g, grad[0], 0.001);*/
139    
140                    gradFunc = new MaximumLikelihoodSemiMarkovGradient();
141                    CacheProcessorDeluxe cacheProcessor1 = new CacheProcessorDeluxe();
142                    cacheProcessor1.setAllPaths(false);
143                    cacheProcessor1.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
144                    ((MaximumLikelihoodSemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor1);
145                    gradFunc.setTrainingData(m, data); 
146                    val = gradFunc.apply(weights, grad);
147                    log.info("Grad(Length): " + ColtUtil.format(grad));
148                    assertEquals(f, val, 0.001);
149                    assertEquals(g, grad[0], 0.001);
150    
151                    // Semi CRF lookback with length features
152                    m = new TestFeatureManager(mmNum, 2);
153                    Arrays.fill(max, (short) 20);
154                    gradFunc = new MaximumLikelihoodSemiMarkovGradient();
155                    cacheProcessor = new AllSparseLengthCacheProcessor();
156                    cacheProcessor.setAllPaths(false);
157                    cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
158                    //((MaximumLikelihoodSemiMarkovGradient) gradFunc).setBetaLengthFile("betaSM.txt");
159                    ((MaximumLikelihoodSemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
160                    gradFunc.setTrainingData(m, data); 
161                    val = gradFunc.apply(weights, grad);
162                    log.info("Grad(Length): " + ColtUtil.format(grad));
163                    assertEquals(f, val, 0.001);
164                    assertEquals(g, grad[0], 0.001);
165            }
166    
167    }