001    package calhoun.analysis.crf.test;
002    
003    import java.util.Arrays;
004    import java.util.Collections;
005    import java.util.List;
006    
007    import org.apache.commons.logging.Log;
008    import org.apache.commons.logging.LogFactory;
009    
010    import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
011    import calhoun.analysis.crf.Conrad;
012    import calhoun.analysis.crf.io.StringInput;
013    import calhoun.analysis.crf.io.TrainingSequence;
014    import calhoun.analysis.crf.solver.StandardOptimizer;
015    import calhoun.util.AbstractTestCase;
016    
017    /** Tests that CRF is working with valid probabilities - the sum of all possible labelings is 1.
018     *  
019     * Test that the code to walk through only the valid paths works correctly.
020     * Uses a two state model that disallows transitions to self.  010101... or 101010... are the only allowed paths. */
021    public class ValidPathsTest extends AbstractTestCase {
022            private static final Log log = LogFactory.getLog(ValidPathsTest.class);
023    
024            static List<? extends TrainingSequence<?>> train1 = StringInput.prepareData("000001\nATAGAG\n101010\nAGAGAG\n");
025            static String[] test1 = new String[] {
026                            "1\nA\n0\nT\n",
027                            "10\nAT\n01\nAA\n00\nTA\n11\nAA\n"
028            };
029    
030            static String[] test2 = new String[] {
031                    "1\nA\n0\nT\n",
032                    "10\nAT\n01\nAA\n"
033            };
034    
035            /** Tests all 2-state combinations with a restricted model and onlyValidTransitions = false */
036            public void testSumPathsAllPaths() throws Exception {
037                    sumLikelihoodSets("test/input/validPathsModelAllPathsBasic.xml", train1, test1);
038            }
039            
040            /** Tests all 2-state combinations with a restricted model and onlyValidTransitions = false */
041            public void testSumPathsAllPathsCached() throws Exception {
042                    sumLikelihoodSets("test/input/validPathsModelAllPaths.xml", train1, test1);
043            }
044    
045            /** Tests that the sum of the likelihoods of the valid paths is 1 */
046            public void testSumPaths() throws Exception {
047                    List<? extends TrainingSequence<?>> train = StringInput.prepareData("010101\nATAGAG\n101010\nAGAGAG\n");
048                    String[] tests = new String[] {
049                                    "101010101010101010101010\nATGACAGTAGACATGACAGTAGAC\n010101010101010101010101\nAATGACTGACACAATGACTGACAC\n",
050                                    "101010101010\nATGACAGTAGAC\n010101010101\nAATGACTGACAC\n",
051                                    "10\nAT\n01\nAA\n",
052                                    "1\nA\n0\nT\n"
053                    };
054                    sumLikelihoodSets("test/input/validPathsModel.xml", train, tests);
055            }
056    
057            /** Tests that the sum of the likelihoods of the valid paths is 1, uses a model that allows 0-0 */
058            public void testSumPathsMoreComplicated() throws Exception {
059                    List<? extends TrainingSequence<?>> train = StringInput.prepareData("000001\nATAGAG\n101010\nAGAGAG\n");
060                    
061                    String[] tests = new String[] {
062                                    "1\nA\n0\nT\n",
063                                    "10\nAT\n01\nAA\n00\nTA\n",
064                                    "000\nATG\n001\nAAC\n010\nTAG\n100\nGAG\n101\nTGG\n"
065                    };
066                    sumLikelihoodSets("test/input/validPathsModelMoreComplicated.xml", train, tests);
067            }
068    
069            /* If we have an invalid path in training data we need to reject */
070            public void testRejectInvalidTrainingPath() throws Exception {
071                    boolean failure = false;
072                    List<? extends TrainingSequence<?>> train = StringInput.prepareData("010101\nATAGAG\n011101\nAGAGAG\n");
073                    Conrad r = new Conrad("test/input/validPathsModel.xml");
074                    try {
075                            r.train(train);
076                    } catch (Exception e) {
077                            log.info(e);
078                            failure = true;
079                    }
080                    assertTrue(failure);
081            }
082    
083            void sumLikelihoodSets(String runner, List<? extends TrainingSequence<?>> train, String[] tests) throws Exception {
084                    Conrad r = new Conrad(runner);
085                    // Test weights of 1
086                    r.trainFeatures(train);
087                    double[] weights = new double[r.getModel().getNumFeatures()];
088                    Arrays.fill(weights, 1.0);
089                    sumLikelihood(r, weights, tests);                       
090                    
091                    // Test random weights
092                    for(int i =0; i<weights.length; ++i) {
093                            weights[i] = Math.random() * 10;
094                    }
095                    sumLikelihood(r, weights, tests);                       
096            }
097            
098            void sumLikelihood(Conrad r, double[] weights, String[] tests) throws Exception {
099                    log.debug("Sumlikelihoods test");
100                    for(String testString: tests) {
101                            // Sum a bunch of likelihoods and make sure they add up to 1.
102                            List<? extends TrainingSequence<?>> test = StringInput.prepareData(testString);
103                            double likelihood = 0.0;
104                            for(TrainingSequence<?> t : test) {
105                                    CRFObjectiveFunctionGradient obj = ((StandardOptimizer)r.getOptimizer()).getObjectiveFunction();
106                                    obj.setTrainingData(r.getModel(), Collections.singletonList(t));
107                                    double[] dummy = new double[weights.length];
108                                    double ll = obj.apply(weights, dummy);
109                                    double l = Math.exp(ll*t.length());
110                                    log.debug("Likelihood: "+l+" for "+t);
111                                    likelihood += l;
112                            }
113                            assertEquals(1.0,likelihood,0.000001);
114                    }
115            }
116            
117    
118    }