001    package calhoun.analysis.crf.test;
002    
003    import java.util.List;
004    
005    import org.apache.commons.logging.Log;
006    import org.apache.commons.logging.LogFactory;
007    
008    import calhoun.analysis.crf.Conrad;
009    import calhoun.analysis.crf.ModelManager;
010    import calhoun.analysis.crf.SemiMarkovSetup;
011    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
012    import calhoun.analysis.crf.io.IntInput;
013    import calhoun.analysis.crf.io.StringInput;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.analysis.crf.solver.CacheProcessor;
016    import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
017    import calhoun.analysis.crf.solver.MaximumLikelihoodSemiMarkovGradient;
018    import calhoun.analysis.crf.solver.NoCachingCacheProcessor;
019    import calhoun.analysis.crf.solver.StandardOptimizer;
020    import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
021    import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
022    import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
023    import calhoun.util.AbstractTestCase;
024    import calhoun.util.Assert;
025    
026    public class CacheProcessorTest extends AbstractTestCase {
027            private static final Log log = LogFactory.getLog(CacheProcessorTest.class);
028    
029            public void testCPDRejectTrainingDataStatesTooLong() throws Exception {
030                    checkFailure("test/input/interval13/config/shortIntergenicCPD.xml", "test/input/interval13/data/tooLong.txt");
031            }
032            
033            public void testCPDRejectTrainingDataStatesTooShort() throws Exception {
034                    checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShort.txt");
035            }
036            
037            public void testAllSparseRejectTrainingDataStatesTooShort() throws Exception {
038                    checkFailure("test/input/interval13/config/lengthDependentAllSparse.xml", "test/input/interval13/data/tooShort.txt");
039            }
040            
041            public void testCPDRejectTrainingDataStatesTooShortStart() throws Exception {
042                    checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShortStart.txt");
043            }
044            
045            public void testCPDRejectTrainingDataStatesTooShortEnd() throws Exception {
046                    checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShortEnd.txt");
047            }
048            
049            public void testCPDRejectTrainingDataStatesViolatesConstraints() throws Exception {
050                    checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/badConstraints.txt");
051            }
052            
053            public void testCPDAllSequencesInvalid() throws Exception {
054                    checkFailure("test/input/interval13/config/lengthDependentCPDuncommentedDiscard.xml", "test/input/interval13/data/tooShortEnd.txt");
055            }
056            
057            public void testCPDAllDiscardInvalid() throws Exception {
058                    Conrad conrad = new Conrad("test/input/interval13/config/lengthDependentCPDuncommentedDiscard.xml");
059                    conrad.train("test/input/interval13/data/oneGoodOneBad.txt");
060            }
061            
062            public void testCPDAllDiscardInvalidLocalScore() throws Exception {
063                    Conrad conrad = new Conrad("test/input/interval13/config/lengthDependentCPDuncommentedDiscardLocalScore.xml");
064                    conrad.train("test/input/interval13/data/oneGoodOneBad.txt");
065            }
066            
067            void checkFailure(String configFile, String data) {
068                    Conrad conrad = new Conrad(configFile);
069                    boolean fail = false;
070                    try {
071                            conrad.trainFeatures(data);
072                            conrad.trainWeights(conrad.getInputHandler().readTrainingData(data));
073                    }
074                    catch(Exception ex) {
075                            log.warn(ex);
076                            fail = true;
077                    }
078                    assertTrue(fail);
079            }
080            
081            // Test edge features with 1 node, should be no evaluations
082            public void testEdgeTrivial() throws Exception {
083                    int[][] indices = new int[][] { {-1}, {-1}, {-1}, {-1}, {-1}, {-1} };
084                    float[][] vals = new float[0][0];
085                    double[] featureSums = new double[] { 0.0 };
086                    List<? extends TrainingSequence<?>> data = IntInput.prepareData("0\n0"); 
087                    System.out.println("number of data sequences is " + data.size());
088                    doTest(1, data, 0, 0, indices, vals, featureSums);
089            }
090    
091            // Test edge features with 2 and 3 positions.
092            public void testEdgeShort() throws Exception {
093                    int[][] indices = new int[][] { {-1}, {-1}, {0, -1}, {0, -1}, {0, -1}, {0, -1} };
094                    float[][] vals = new float[][] { {}, {}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f} };
095                    double[] featureSums = new double[] { -0.4054651f };
096                    doTest(1, IntInput.prepareData("00\n00"), 0, 1, indices, vals, featureSums);
097            }
098    
099            // Test node features
100            public void testNode() throws Exception {
101                    int[][] indices = new int[][] { {0, -1}, {0, -1}, {-1}, {-1}, {-1}, {-1} };
102                    float[][] vals1 = new float[][] { {-1.0986123f}, {-0.4054651f}, {}, {}, {}, {} };
103                    double[] featureSums = new double[] { -1.0986123f*2 + -0.4054651f*4 };
104                    doTest(0, IntInput.prepareData("001111\n001111"), 0, 0, indices, vals1, featureSums); // the DCP feature sums for this are wrong
105                    float[][] vals2 = new float[][] { {-1.0986123f}, {-0.4054651f}, {}, {}, {}, {} };
106                    doTest(0, IntInput.prepareData("001111\n001111"), 0, 4, indices, vals2, featureSums);
107            }
108    
109            // Test 2 features
110            public void testTwoFeaturesTrivial() throws Exception {
111                    int[][] indices1 = new int[][] { {0, -1}, {0, -1}, {-1}, {-1}, {-1}, {-1} };
112                    float[][] vals1 = new float[][] { {-1.0986123f}, {-0.4054651f}, {}, {}, {}, {} };
113                    double[] featureSums = new double[] { -1.0986123f*2, -0.4054651f };
114                    doTest(2, IntInput.prepareData("00\n00"), 0, 0, indices1, vals1, featureSums);
115                    int[][] indices2 = new int[][] { {0, -1}, {0, -1}, {1, -1}, {1, -1}, {1, -1}, {1, -1} };
116                    float[][] vals2 = new float[][] { {-1.0986123f}, {-0.4054651f}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f} };
117                    featureSums = new double[] { -1.0986123f*4, -0.4054651f*2 };
118                    doTest(2, IntInput.prepareData("00\n00\n00\n00"), 1, 1, indices2, vals2, featureSums);
119            }
120    
121            // Test 2 features non-trivial
122            public void testTwoFeaturesNonTrivial() throws Exception {
123                    int[][] indices = new int[][] { {0, -1}, {0, -1}, {1, -1}, {1, -1}, {1, -1}, {1, -1} };
124                    float[][] vals = new float[][] { {-1.0986123f}, {-0.4054651f}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f} };
125                    double[] featureSums = new double[] { -33.54728, -29.963765};
126                    doTest(2, IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"), 1, 4, indices, vals, featureSums);
127            }
128    
129            public void testLengthCacheDummy() throws Exception {
130                    int[][] lookbacks = new int[][] { {0, 1, -1}, {0, 1, -1} };
131                    int[][] nodeIndices = new int[][] { {0, 0, 0, -1}, {1, 0, 0, -1} };
132                    float[][] nodeValues = new float[0][4];
133                    ModelManager m = new TestFeatureManager(1);
134                    doLengthTest(m, IntInput.prepareData("00\n00"), 0, 1, 2, lookbacks, nodeIndices, nodeValues);
135    
136                    m = new TestFeatureManager(2);
137                    lookbacks = new int[][] { {0, 1, 2, 3, -1}, {0, 1, 2, 3, -1} };
138                    nodeIndices = new int[][] { {0, 0, 0, -1}, {0, 1, 0, -1}, {0, 2, 0, -1}, {0, 3, 0, -1}, {1, 0, 0, -1} };
139                    doLengthTest(m, IntInput.prepareData("0000\n0000"), 0, 3, 2, lookbacks, nodeIndices, nodeValues);
140            }
141    
142            public void testLengthCache() throws Exception {
143                    int[][] lookbacks = new int[][] { {0, 1, 2, 3, -1} };
144                    int[][] nodeIndices = new int[0][4];
145                    float[][] nodeValues = new float[][] { {0, 0, 0, -0.0f}, {0, 1, 0, -0.11157f} , {0, 2, 0, -0.11157f*2}, {0, 3, 0, -0.11157f*3}};
146                    Conrad c = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
147                    ModelManager m = c.getModel();
148                    List<? extends TrainingSequence<?>> data = StringInput.prepareData("00110\nATGCA");
149                    c.trainFeatures(data);
150                    CacheProcessor cp = ((MaximumLikelihoodSemiMarkovGradient) ((StandardOptimizer)c.getOptimizer()).getObjectiveFunction()).getCacheProcessor();
151                    cp.setTrainingData(m, data);
152                    cp.evaluateSegmentsEndingAt(0, 3);
153                    LengthFeatureEvaluation[][] lenEvals = cp.getLengthFeatureEvaluations();
154                    checkLengthEvals(lenEvals, 1, lookbacks, nodeIndices, nodeValues);
155            }
156            
157            void doLengthTest(ModelManager m, List<? extends TrainingSequence<?>> data, int seq, int pos, int nStates, int[][] lookback, int[][] nodeIndices, float[][] nodeValues) {
158                    AllSparseLengthCacheProcessor cp = new AllSparseLengthCacheProcessor();
159                    SemiMarkovSetup setup = new SemiMarkovSetup(new short[] {4, 4});
160                    setup.setIgnoreSemiMarkovSelfTransitions(true);
161                    cp.setSemiMarkovSetup(setup);
162                    cp.setTrainingData(m, data);
163                    cp.evaluateSegmentsEndingAt(seq, pos);
164                    LengthFeatureEvaluation[][] lenEvals = cp.getLengthFeatureEvaluations();
165                    checkLengthEvals(lenEvals, nStates, lookback, nodeIndices, nodeValues);
166            }
167    
168            void checkLengthEvals(LengthFeatureEvaluation[][] lenEvals, int nStates, int[][] lookback, int[][] nodeIndices, float[][] nodeValues) {
169                    assertEquals(nStates, lenEvals.length);
170                    for(int i=0; i<lookback.length; ++i) {
171                            for(int j=0; j<lookback[i].length; ++j) {
172                                    assertEquals(lookback[i][j], lenEvals[i][j].lookback);
173                            }
174                    }
175                    for(int[] entry : nodeIndices) {
176                            assertEquals(entry[3], lenEvals[entry[0]][entry[1]].nodeEval.index[entry[2]]);
177                    }
178                    
179                    for(float[] entry : nodeValues) {
180                            assertEquals(entry[3], lenEvals[(int)entry[0]][(int)entry[1]].nodeEval.value[(int)entry[2]], 0.0001);
181                    }
182            }
183            
184            void doTest(int mmNum, List<? extends TrainingSequence<?>> data, int seq, int pos, int[][] indices, float[][] vals, double[] featureSums) {
185                    ModelManager m = new TestFeatureManager(mmNum);
186                    
187                    AllSparseLengthCacheProcessor cp = new AllSparseLengthCacheProcessor();
188                    testOneCacheProcessor(cp,m,data,seq,pos,indices,vals,featureSums);
189                    
190                    NoCachingCacheProcessor ncp = new NoCachingCacheProcessor();
191                    testOneCacheProcessor(ncp,m,data,seq,pos,indices,vals,featureSums);
192            
193                    CacheProcessorDeluxe dcp = new CacheProcessorDeluxe();
194                    //dcp.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
195                    testOneCacheProcessor(dcp,m,data,seq,pos,indices,vals,featureSums);
196    
197                    CacheProcessorDeluxe dcp2 = new CacheProcessorDeluxe(CacheStrategy.CONSTANT);
198                    //dcp2.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
199                    testOneCacheProcessor(dcp2,m,data,seq,pos,indices,vals,featureSums);
200                    
201                    CacheProcessorDeluxe dcp3 = new CacheProcessorDeluxe(CacheStrategy.DENSE);
202                    //dcp3.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
203                    testOneCacheProcessor(dcp3,m,data,seq,pos,indices,vals,featureSums);
204                    
205                    CacheProcessorDeluxe dcp4 = new CacheProcessorDeluxe(CacheStrategy.SPARSE);
206                    //dcp4.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
207                    testOneCacheProcessor(dcp4,m,data,seq,pos,indices,vals,featureSums);            
208            }
209    
210            void testOneCacheProcessor(CacheProcessor dcp, ModelManager m, List<? extends TrainingSequence<?>> data,  int seq, int pos, int[][] indices, float[][] vals, double[] featureSums) {
211                    //dcp.setAllPaths(false);
212                    dcp.setTrainingData(m, data);
213                    dcp.evaluatePosition(seq, pos);
214                    if (pos > 0) {
215                            assertEvalEquals(dcp.getFeatureEvaluations(), indices, vals);
216                    } else {
217                            assertNonedgeEvalEquals(m,dcp.getFeatureEvaluations(), indices, vals);          
218                    }
219                    assertArrayEquals(featureSums, dcp.getFeatureSums(), 0.00001);
220            }
221            
222            
223            private void assertNonedgeEvalEquals(ModelManager m, FeatureEvaluation[] evals, int[][] indices, float[][] vals) {
224                    log.warn(evals);
225                    
226                    Assert.a(indices.length >= m.getNumStates());
227                    
228                    for(int i=0; i<m.getNumStates(); ++i) {
229                            for(int j=0; j<indices[i].length; ++j) {
230                                    assertEquals(indices[i][j], evals[i].index[j]);
231                            }
232                    }
233                    for(int i=0; i<vals.length; ++i) {
234                            for(int j=0; j<vals[i].length; ++j) {
235                                    assertEquals(vals[i][j], evals[i].value[j], .00001);
236                            }
237                    }
238            }
239            
240            private void assertEvalEquals(FeatureEvaluation[] evals, int[][] indices, float[][] vals) {
241                    log.warn(evals);
242                    for(int i=0; i<indices.length; ++i) {
243                            for(int j=0; j<indices[i].length; ++j) {
244                                    assertEquals(indices[i][j], evals[i].index[j]);
245                            }
246                    }
247                    for(int i=0; i<vals.length; ++i) {
248                            for(int j=0; j<vals[i].length; ++j) {
249                                    assertEquals(vals[i][j], evals[i].value[j], .00001);
250                            }
251                    }
252            }
253    }