001 package calhoun.analysis.crf.test;
002
003 import java.util.Arrays;
004 import java.util.Collections;
005 import java.util.List;
006
007 import org.apache.commons.logging.Log;
008 import org.apache.commons.logging.LogFactory;
009
010 import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
011 import calhoun.analysis.crf.Conrad;
012 import calhoun.analysis.crf.io.StringInput;
013 import calhoun.analysis.crf.io.TrainingSequence;
014 import calhoun.analysis.crf.solver.StandardOptimizer;
015 import calhoun.util.AbstractTestCase;
016
017 /** Tests that CRF is working with valid probabilities - the sum of all possible labelings is 1.
018 *
019 * Test that the code to walk through only the valid paths works correctly.
020 * Uses a two state model that disallows transitions to self. 010101... or 101010... are the only allowed paths. */
021 public class ValidPathsTest extends AbstractTestCase {
022 private static final Log log = LogFactory.getLog(ValidPathsTest.class);
023
024 static List<? extends TrainingSequence<?>> train1 = StringInput.prepareData("000001\nATAGAG\n101010\nAGAGAG\n");
025 static String[] test1 = new String[] {
026 "1\nA\n0\nT\n",
027 "10\nAT\n01\nAA\n00\nTA\n11\nAA\n"
028 };
029
030 static String[] test2 = new String[] {
031 "1\nA\n0\nT\n",
032 "10\nAT\n01\nAA\n"
033 };
034
035 /** Tests all 2-state combinations with a restricted model and onlyValidTransitions = false */
036 public void testSumPathsAllPaths() throws Exception {
037 sumLikelihoodSets("test/input/validPathsModelAllPathsBasic.xml", train1, test1);
038 }
039
040 /** Tests all 2-state combinations with a restricted model and onlyValidTransitions = false */
041 public void testSumPathsAllPathsCached() throws Exception {
042 sumLikelihoodSets("test/input/validPathsModelAllPaths.xml", train1, test1);
043 }
044
045 /** Tests that the sum of the likelihoods of the valid paths is 1 */
046 public void testSumPaths() throws Exception {
047 List<? extends TrainingSequence<?>> train = StringInput.prepareData("010101\nATAGAG\n101010\nAGAGAG\n");
048 String[] tests = new String[] {
049 "101010101010101010101010\nATGACAGTAGACATGACAGTAGAC\n010101010101010101010101\nAATGACTGACACAATGACTGACAC\n",
050 "101010101010\nATGACAGTAGAC\n010101010101\nAATGACTGACAC\n",
051 "10\nAT\n01\nAA\n",
052 "1\nA\n0\nT\n"
053 };
054 sumLikelihoodSets("test/input/validPathsModel.xml", train, tests);
055 }
056
057 /** Tests that the sum of the likelihoods of the valid paths is 1, uses a model that allows 0-0 */
058 public void testSumPathsMoreComplicated() throws Exception {
059 List<? extends TrainingSequence<?>> train = StringInput.prepareData("000001\nATAGAG\n101010\nAGAGAG\n");
060
061 String[] tests = new String[] {
062 "1\nA\n0\nT\n",
063 "10\nAT\n01\nAA\n00\nTA\n",
064 "000\nATG\n001\nAAC\n010\nTAG\n100\nGAG\n101\nTGG\n"
065 };
066 sumLikelihoodSets("test/input/validPathsModelMoreComplicated.xml", train, tests);
067 }
068
069 /* If we have an invalid path in training data we need to reject */
070 public void testRejectInvalidTrainingPath() throws Exception {
071 boolean failure = false;
072 List<? extends TrainingSequence<?>> train = StringInput.prepareData("010101\nATAGAG\n011101\nAGAGAG\n");
073 Conrad r = new Conrad("test/input/validPathsModel.xml");
074 try {
075 r.train(train);
076 } catch (Exception e) {
077 log.info(e);
078 failure = true;
079 }
080 assertTrue(failure);
081 }
082
083 void sumLikelihoodSets(String runner, List<? extends TrainingSequence<?>> train, String[] tests) throws Exception {
084 Conrad r = new Conrad(runner);
085 // Test weights of 1
086 r.trainFeatures(train);
087 double[] weights = new double[r.getModel().getNumFeatures()];
088 Arrays.fill(weights, 1.0);
089 sumLikelihood(r, weights, tests);
090
091 // Test random weights
092 for(int i =0; i<weights.length; ++i) {
093 weights[i] = Math.random() * 10;
094 }
095 sumLikelihood(r, weights, tests);
096 }
097
098 void sumLikelihood(Conrad r, double[] weights, String[] tests) throws Exception {
099 log.debug("Sumlikelihoods test");
100 for(String testString: tests) {
101 // Sum a bunch of likelihoods and make sure they add up to 1.
102 List<? extends TrainingSequence<?>> test = StringInput.prepareData(testString);
103 double likelihood = 0.0;
104 for(TrainingSequence<?> t : test) {
105 CRFObjectiveFunctionGradient obj = ((StandardOptimizer)r.getOptimizer()).getObjectiveFunction();
106 obj.setTrainingData(r.getModel(), Collections.singletonList(t));
107 double[] dummy = new double[weights.length];
108 double ll = obj.apply(weights, dummy);
109 double l = Math.exp(ll*t.length());
110 log.debug("Likelihood: "+l+" for "+t);
111 likelihood += l;
112 }
113 assertEquals(1.0,likelihood,0.000001);
114 }
115 }
116
117
118 }