001    package calhoun.analysis.crf.test;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.Conrad;
010    import calhoun.analysis.crf.LocalPathSimilarityScore;
011    import calhoun.analysis.crf.ModelManager;
012    import calhoun.analysis.crf.SemiMarkovSetup;
013    import calhoun.analysis.crf.io.IntInput;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.analysis.crf.scoring.SimScoreMaxStateAgreement;
016    import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
017    import calhoun.analysis.crf.solver.MaximumExpectedAccuracySemiMarkovGradient;
018    import calhoun.analysis.crf.solver.StandardOptimizer;
019    import calhoun.analysis.crf.solver.check.CachedAOFGradient;
020    import calhoun.util.AbstractTestCase;
021    import calhoun.util.ColtUtil;
022    
023    public class MaximumExpectedAccuracyTest extends AbstractTestCase {
024            private static final Log log = LogFactory.getLog(MaximumExpectedAccuracyTest.class);
025    
026            public void testLittle() throws Exception {
027                    
028                    String config         = "test/input/aofreg_test/delta_conservation_aofreg.xml";
029                    String input          = "test/input/aofreg_test/testSeq.txt";
030    
031                    Conrad conrad = new Conrad(config);
032                    conrad.train(input);
033                    double[] regularWeights = conrad.getWeights();
034                    conrad.test(input);
035    
036                    config         = "test/input/aofreg_test/delta_conservation_aofreg_semi.xml";
037                    conrad = new Conrad(config);
038                    conrad.train(input);
039                    double[] semiMarkovWeights = conrad.getWeights();
040                    conrad.test(input);
041                    assertArrayEquals(regularWeights, semiMarkovWeights, 0.001);
042            }
043    
044            public void testGradEvals() throws Exception {
045                    doFuncEvalsSkewed(0, IntInput.prepareData("0\n0"));
046                    doFuncEvalsSkewed(0, IntInput.prepareData("00\n00"));
047                    doFuncEvalsSkewed(0, IntInput.prepareData("000\n000"));
048                    doFuncEvalsSkewed(1, IntInput.prepareData("1\n1"));
049                    doFuncEvalsSkewed(3, IntInput.prepareData("11\n11"));
050                    doFuncEvalsSkewed(1, IntInput.prepareData("10\n10"));
051                    doFuncEvalsSkewed(1, IntInput.prepareData("111\n111"));
052                    doFuncEvalsSkewed(3, IntInput.prepareData("101\n101"));
053                    doFuncEvalsSkewed(3, IntInput.prepareData("011\n011"));
054                    doFuncEvalsSkewed(1, IntInput.prepareData("1111\n1111"));
055                    doFuncEvalsSkewed(1, IntInput.prepareData("001111\n001111"));
056                    doFuncEvalsSkewed(2, IntInput.prepareData("00\n00"));
057                    doFuncEvalsSkewed(2, IntInput.prepareData("00\n00\n00\n00"));
058                    doFuncEvalsSkewed(1, IntInput.prepareData("00\n00\n0\n0\n"));
059                    doFuncEvalsSkewed(1, IntInput.prepareData("0\n0\n1111\n1111\n"));
060                    doFuncEvalsSkewed(2, IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"));
061            }
062    
063            void doFuncEvalsSkewed(int mmNum, List<? extends TrainingSequence<?>> data) throws Exception {
064                    //doFuncEvals(mmNum, false, false, data);
065                    //doFuncEvals(mmNum, true, false, data);
066                    
067                    if(mmNum == 1 || mmNum == 3) {
068                            //doFuncEvals(mmNum, false, true, data);
069                            doFuncEvals(mmNum, true, true, data);
070                    }
071            }
072            
073            void doFuncEvals(int mmNum, boolean skewedWeights, boolean fm3, List<? extends TrainingSequence<?>> data) throws Exception {
074                    double[] weights = skewedWeights ? new double[] { 2, 0.5, 1 } : new double[] { 1, 1, 1 };
075    
076                    MaximumExpectedAccuracySemiMarkovGradient gradFunc;
077                    CacheProcessorDeluxe cacheProcessor;
078                    short[] max = new short[2];
079                    Arrays.fill(max, (short) 1);
080    
081                    ModelManager m = fm3 ? new TestFeatureManager3(mmNum, false) : new TestFeatureManager2(mmNum, false);
082                    m.train(0, m, data);
083                    
084                    // Semi CRF no lookback
085                    gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
086                    cacheProcessor = new CacheProcessorDeluxe();
087                    cacheProcessor.setAllPaths(false);
088                    cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
089                    ((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
090                    gradFunc.setTrainingData(m, data); 
091                    double[] grad1 = new double[m.getNumFeatures()];
092                    double val1 = gradFunc.apply(weights, grad1);
093                    log.info("Grad1: " + ColtUtil.format(grad1));
094    
095                    // Semi CRF lookback
096                    Arrays.fill(max, (short) 20);
097                    m = fm3 ? new TestFeatureManager3(mmNum, false) : new TestFeatureManager2(mmNum, false);
098                    m.train(0, m, data);
099                    gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
100                    cacheProcessor = new CacheProcessorDeluxe();
101                    cacheProcessor.setAllPaths(false);
102                    cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
103                    ((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
104                    //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setBetaLengthFile("modeMarginal.txt");
105                    //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setNodeMarginalFile("nodeMarginal.txt");
106                    //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setExpectedProductFile("expectedProduct.txt");
107                    //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setScoreAlphaFile("scoreAlpha.txt");
108                    gradFunc.setTrainingData(m, data); 
109                    double[] grad2 = new double[m.getNumFeatures()];
110                    double val2 = gradFunc.apply(weights, grad2);
111                    log.info("Grad2: " + ColtUtil.format(grad2));
112                    
113                    assertEquals(val1, val2, 0.001);
114                    assertArrayEquals(grad1, grad2, 0.001);
115    
116                    // Semi CRF lookback with length features
117                    m = fm3 ? new TestFeatureManager3(mmNum, true) : new TestFeatureManager2(mmNum, true);
118                    m.train(0, m, data);
119                    gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
120                    cacheProcessor = new CacheProcessorDeluxe();
121                    cacheProcessor.setAllPaths(false);
122                    cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
123                    ((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
124                    //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setExpectedProductFile("expectedProductLen.txt");
125                    gradFunc.setTrainingData(m, data); 
126                    double[] grad3 = new double[m.getNumFeatures()];
127                    double val3 = gradFunc.apply(weights, grad3);
128                    log.info("Grad3: " + ColtUtil.format(grad3));
129                    
130                    assertEquals(val1, val3, 0.001);
131                    assertArrayEquals(grad1, grad3, 0.001);
132            }
133            
134            public void testGeneCallerLocalScore() throws Exception {
135                    Conrad nodeOnly = new Conrad("test/input/geneCallerLocal/baseline_aof.xml");
136                    Conrad length = new Conrad("test/input/geneCallerLocal/baseline_aof_length.xml");
137                    nodeOnly.train("test/input/geneCallerLocal");
138                    length.train("test/input/geneCallerLocal");
139                    assertArrayEquals(nodeOnly.getWeights(), length.getWeights(), 0.0001);
140            }
141            
142            public void testAlternateObjectiveFunction() throws Exception {
143                    LocalPathSimilarityScore s = new SimScoreMaxStateAgreement();
144                    // s measures the number of agreeing nucleotides between a hidden path
145                    // and the actual hidden path, not counting position zero.
146                    
147                    // This is a two feature example I worked out by hand
148                    ModelManager m2 = new TestFeatureManager(2);
149                    List<? extends TrainingSequence<?>> data = IntInput.prepareData("00\n00");          
150                    doAlternateObjectiveFunctionTest( m2, data, s, 0.2857142/2.0,-0.1838962/2.0);
151                    
152                    // below should be exactly double what is above (the
153                    // observations are not used by this feature manager)
154                    List<? extends TrainingSequence<?>> data2 = IntInput.prepareData("00\n11\n00\n00");
155                    doAlternateObjectiveFunctionTest( m2, data2, s,0.57142853/4.0,-0.367792/4.0);           
156                    
157                    // the following should grow linearly with length-1
158                    // The manager has just one node feature.  At weight 1.0, the probability of a zero,
159                    // and hence the expected value of S, is 1/3.
160                    ModelManager m0 = new TestFeatureManager(0);            // this has just node features
161                    List<? extends TrainingSequence<?>> data3 = IntInput.prepareData("00\n00");
162                    double t3 = -0.1540327;                 // I didn't check the gradient by hand but copied it from output
163                    doAlternateObjectiveFunctionTest( m0, data3, s,0.33333333/2.0,t3/2.0);                  
164    
165                    // since this grows linearly, it should be exactly three times the above
166                    List<? extends TrainingSequence<?>> data4 = IntInput.prepareData("0000\n0000");
167                    doAlternateObjectiveFunctionTest( m0, data4, s, 1.0/4.0,3*t3/4.0);      
168    
169                    // Expected value of S should be 0.5; gradient should be zero
170                    ModelManager m1 = new TestFeatureManager(1);            // this has just edge features
171                    doAlternateObjectiveFunctionTest( m1, data3, s, 0.5/2.0,0.0/2.0);       
172    
173                    // Expected value of S should be 1.5; gradient should be 0
174                    doAlternateObjectiveFunctionTest( m1, data4, s,1.5/4.0,0.0/4.0);
175            }
176    
177            // Similar logic to above bug on a slightly bigger example.
178            // If LBGFS can converge given my gradient/value function, then my function is
179            // probably self-consistent.
180            public void testSemiRealExample() throws Exception {
181                    String input          = "test/input/cryptoAOFUnittest/Tiny_1_Train_Test.txt";
182                    String config1        = "test/input/cryptoAOFUnittest/delta_aof_Model.xml";
183                    String config2         = "test/input/cryptoAOFUnittest/delta_aof_Model_semi_nolen.xml";
184                    String config3         = "test/input/cryptoAOFUnittest/delta_aof_Model_semi.xml";
185                    doSameWeightsTest(config1, config2, config3, input);
186            }
187    
188            void doSameWeightsTest(String config1, String config2, String config3, String input) throws Exception {
189                    double[] semiMarkovWeights;
190                    Conrad conrad;
191                    conrad = new Conrad(config1);
192                    conrad.train(input);
193                    double[] regularWeights = conrad.getWeights();
194                    conrad.test(input);
195    
196                    conrad = new Conrad(config2);
197                    conrad.train(input);
198                    double[] noLenWeights = conrad.getWeights();
199                    conrad.test(input);
200                    assertArrayEquals(regularWeights, noLenWeights, 0.001);
201    
202                    if(config3 != null) {
203                            conrad = new Conrad(config3);
204                            conrad.train(input);
205                            semiMarkovWeights = conrad.getWeights();
206                            conrad.test(input);
207                            assertArrayEquals(noLenWeights, semiMarkovWeights, 0.001);
208                    }
209            }
210            
211            // The idea of the test below is that of requiring LBGFS to converge to a point with zero gradient.
212            // Doesn't this require that the gradient equal zero at a point where the function is maximized.
213            // Isn't this very unlikely to succeed if either the function value or its
214            // gradient were computed incorrcetly?  I think actually this is a strong test;
215            // the gradient defines a direction for a line search, and the function value determines
216            // where on that line search to stop next.  It's hard to image this process
217            // converging in 2 or more dimensions unless the function and gradient are matched.
218            public void testCachedAOFGradient() throws Exception {
219                    List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111\n001111\n001111\n001111\n001111\n001111\n001111\n");
220                    ModelManager m = new TestFeatureManager(2);
221    
222                    StandardOptimizer opt = new StandardOptimizer();
223                    opt.setStarts(new double[] {0.1, 0.2});
224                    opt.setRequireConvergence(true);
225                    opt.setEpsForConvergence(0.0000005);
226                    opt.setObjectiveFunction(new CachedAOFGradient());
227                    opt.optimize(m, data);
228    
229                    MaximumExpectedAccuracySemiMarkovGradient semiAof = new MaximumExpectedAccuracySemiMarkovGradient();
230                    semiAof.setCacheProcessor(new CacheProcessorDeluxe());
231                    opt.setObjectiveFunction(semiAof);
232                    opt.optimize(m, data);
233            }
234    
235            void doAlternateObjectiveFunctionTest(ModelManager m, List<? extends TrainingSequence<?>> data, LocalPathSimilarityScore s, double hand_val, double hand_grad0 ) {
236    
237                    double[] weights = new double[m.getNumFeatures()];
238                    Arrays.fill(weights,1.0);
239    
240                    CachedAOFGradient gradFunc = new CachedAOFGradient();
241                    gradFunc.setAllPaths(true);
242                    gradFunc.setScoreAlphaFile("scoreAlphaOld.txt");
243                    gradFunc.setExpectedProductFile("expectedProductOld.txt");
244                    gradFunc.setLocalPathSimilarityScore(s);
245                    gradFunc.setTrainingData(m, data);
246    
247                    double val;
248                    double[] grad = new double[m.getNumFeatures()];
249                    
250                    val = gradFunc.apply(weights, grad);
251                    log.info("  val = " + val + "   grad[0] = " + grad[0]) ;
252                    log.info("Grad(Cache,Valid Paths): " + ColtUtil.format(grad));
253                    assertEquals(hand_val, val, 0.001);
254                    assertEquals(hand_grad0, grad[0], 0.001);               
255    
256                    MaximumExpectedAccuracySemiMarkovGradient semiAof = new MaximumExpectedAccuracySemiMarkovGradient();
257                    semiAof.setCacheProcessor(new CacheProcessorDeluxe());
258                    semiAof.setTrainingData(m, data);
259                    semiAof.setMarginalsFile("marginals.txt");
260                    semiAof.setScoreAlphaFile("scoreAlpha.txt");
261                    semiAof.setExpectedProductFile("expectedProduct.txt");
262                    val = semiAof.apply(weights, grad);
263                    log.info("  val = " + val + "   grad[0] = " + grad[0]) ;
264                    log.info("Grad(Cache,Valid Paths): " + ColtUtil.format(grad));
265                    assertEquals(hand_val, val, 0.001);
266                    assertEquals(hand_grad0, grad[0], 0.001);               
267            }
268    }