001 package calhoun.analysis.crf.test;
002
003 import java.util.Arrays;
004 import java.util.List;
005
006 import org.apache.commons.logging.Log;
007 import org.apache.commons.logging.LogFactory;
008
009 import calhoun.analysis.crf.Conrad;
010 import calhoun.analysis.crf.LocalPathSimilarityScore;
011 import calhoun.analysis.crf.ModelManager;
012 import calhoun.analysis.crf.SemiMarkovSetup;
013 import calhoun.analysis.crf.io.IntInput;
014 import calhoun.analysis.crf.io.TrainingSequence;
015 import calhoun.analysis.crf.scoring.SimScoreMaxStateAgreement;
016 import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
017 import calhoun.analysis.crf.solver.MaximumExpectedAccuracySemiMarkovGradient;
018 import calhoun.analysis.crf.solver.StandardOptimizer;
019 import calhoun.analysis.crf.solver.check.CachedAOFGradient;
020 import calhoun.util.AbstractTestCase;
021 import calhoun.util.ColtUtil;
022
023 public class MaximumExpectedAccuracyTest extends AbstractTestCase {
024 private static final Log log = LogFactory.getLog(MaximumExpectedAccuracyTest.class);
025
026 public void testLittle() throws Exception {
027
028 String config = "test/input/aofreg_test/delta_conservation_aofreg.xml";
029 String input = "test/input/aofreg_test/testSeq.txt";
030
031 Conrad conrad = new Conrad(config);
032 conrad.train(input);
033 double[] regularWeights = conrad.getWeights();
034 conrad.test(input);
035
036 config = "test/input/aofreg_test/delta_conservation_aofreg_semi.xml";
037 conrad = new Conrad(config);
038 conrad.train(input);
039 double[] semiMarkovWeights = conrad.getWeights();
040 conrad.test(input);
041 assertArrayEquals(regularWeights, semiMarkovWeights, 0.001);
042 }
043
044 public void testGradEvals() throws Exception {
045 doFuncEvalsSkewed(0, IntInput.prepareData("0\n0"));
046 doFuncEvalsSkewed(0, IntInput.prepareData("00\n00"));
047 doFuncEvalsSkewed(0, IntInput.prepareData("000\n000"));
048 doFuncEvalsSkewed(1, IntInput.prepareData("1\n1"));
049 doFuncEvalsSkewed(3, IntInput.prepareData("11\n11"));
050 doFuncEvalsSkewed(1, IntInput.prepareData("10\n10"));
051 doFuncEvalsSkewed(1, IntInput.prepareData("111\n111"));
052 doFuncEvalsSkewed(3, IntInput.prepareData("101\n101"));
053 doFuncEvalsSkewed(3, IntInput.prepareData("011\n011"));
054 doFuncEvalsSkewed(1, IntInput.prepareData("1111\n1111"));
055 doFuncEvalsSkewed(1, IntInput.prepareData("001111\n001111"));
056 doFuncEvalsSkewed(2, IntInput.prepareData("00\n00"));
057 doFuncEvalsSkewed(2, IntInput.prepareData("00\n00\n00\n00"));
058 doFuncEvalsSkewed(1, IntInput.prepareData("00\n00\n0\n0\n"));
059 doFuncEvalsSkewed(1, IntInput.prepareData("0\n0\n1111\n1111\n"));
060 doFuncEvalsSkewed(2, IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"));
061 }
062
063 void doFuncEvalsSkewed(int mmNum, List<? extends TrainingSequence<?>> data) throws Exception {
064 //doFuncEvals(mmNum, false, false, data);
065 //doFuncEvals(mmNum, true, false, data);
066
067 if(mmNum == 1 || mmNum == 3) {
068 //doFuncEvals(mmNum, false, true, data);
069 doFuncEvals(mmNum, true, true, data);
070 }
071 }
072
073 void doFuncEvals(int mmNum, boolean skewedWeights, boolean fm3, List<? extends TrainingSequence<?>> data) throws Exception {
074 double[] weights = skewedWeights ? new double[] { 2, 0.5, 1 } : new double[] { 1, 1, 1 };
075
076 MaximumExpectedAccuracySemiMarkovGradient gradFunc;
077 CacheProcessorDeluxe cacheProcessor;
078 short[] max = new short[2];
079 Arrays.fill(max, (short) 1);
080
081 ModelManager m = fm3 ? new TestFeatureManager3(mmNum, false) : new TestFeatureManager2(mmNum, false);
082 m.train(0, m, data);
083
084 // Semi CRF no lookback
085 gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
086 cacheProcessor = new CacheProcessorDeluxe();
087 cacheProcessor.setAllPaths(false);
088 cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
089 ((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
090 gradFunc.setTrainingData(m, data);
091 double[] grad1 = new double[m.getNumFeatures()];
092 double val1 = gradFunc.apply(weights, grad1);
093 log.info("Grad1: " + ColtUtil.format(grad1));
094
095 // Semi CRF lookback
096 Arrays.fill(max, (short) 20);
097 m = fm3 ? new TestFeatureManager3(mmNum, false) : new TestFeatureManager2(mmNum, false);
098 m.train(0, m, data);
099 gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
100 cacheProcessor = new CacheProcessorDeluxe();
101 cacheProcessor.setAllPaths(false);
102 cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
103 ((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
104 //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setBetaLengthFile("modeMarginal.txt");
105 //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setNodeMarginalFile("nodeMarginal.txt");
106 //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setExpectedProductFile("expectedProduct.txt");
107 //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setScoreAlphaFile("scoreAlpha.txt");
108 gradFunc.setTrainingData(m, data);
109 double[] grad2 = new double[m.getNumFeatures()];
110 double val2 = gradFunc.apply(weights, grad2);
111 log.info("Grad2: " + ColtUtil.format(grad2));
112
113 assertEquals(val1, val2, 0.001);
114 assertArrayEquals(grad1, grad2, 0.001);
115
116 // Semi CRF lookback with length features
117 m = fm3 ? new TestFeatureManager3(mmNum, true) : new TestFeatureManager2(mmNum, true);
118 m.train(0, m, data);
119 gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
120 cacheProcessor = new CacheProcessorDeluxe();
121 cacheProcessor.setAllPaths(false);
122 cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
123 ((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setCacheProcessor(cacheProcessor);
124 //((MaximumExpectedAccuracySemiMarkovGradient) gradFunc).setExpectedProductFile("expectedProductLen.txt");
125 gradFunc.setTrainingData(m, data);
126 double[] grad3 = new double[m.getNumFeatures()];
127 double val3 = gradFunc.apply(weights, grad3);
128 log.info("Grad3: " + ColtUtil.format(grad3));
129
130 assertEquals(val1, val3, 0.001);
131 assertArrayEquals(grad1, grad3, 0.001);
132 }
133
134 public void testGeneCallerLocalScore() throws Exception {
135 Conrad nodeOnly = new Conrad("test/input/geneCallerLocal/baseline_aof.xml");
136 Conrad length = new Conrad("test/input/geneCallerLocal/baseline_aof_length.xml");
137 nodeOnly.train("test/input/geneCallerLocal");
138 length.train("test/input/geneCallerLocal");
139 assertArrayEquals(nodeOnly.getWeights(), length.getWeights(), 0.0001);
140 }
141
142 public void testAlternateObjectiveFunction() throws Exception {
143 LocalPathSimilarityScore s = new SimScoreMaxStateAgreement();
144 // s measures the number of agreeing nucleotides between a hidden path
145 // and the actual hidden path, not counting position zero.
146
147 // This is a two feature example I worked out by hand
148 ModelManager m2 = new TestFeatureManager(2);
149 List<? extends TrainingSequence<?>> data = IntInput.prepareData("00\n00");
150 doAlternateObjectiveFunctionTest( m2, data, s, 0.2857142/2.0,-0.1838962/2.0);
151
152 // below should be exactly double what is above (the
153 // observations are not used by this feature manager)
154 List<? extends TrainingSequence<?>> data2 = IntInput.prepareData("00\n11\n00\n00");
155 doAlternateObjectiveFunctionTest( m2, data2, s,0.57142853/4.0,-0.367792/4.0);
156
157 // the following should grow linearly with length-1
158 // The manager has just one node feature. At weight 1.0, the probability of a zero,
159 // and hence the expected value of S, is 1/3.
160 ModelManager m0 = new TestFeatureManager(0); // this has just node features
161 List<? extends TrainingSequence<?>> data3 = IntInput.prepareData("00\n00");
162 double t3 = -0.1540327; // I didn't check the gradient by hand but copied it from output
163 doAlternateObjectiveFunctionTest( m0, data3, s,0.33333333/2.0,t3/2.0);
164
165 // since this grows linearly, it should be exactly three times the above
166 List<? extends TrainingSequence<?>> data4 = IntInput.prepareData("0000\n0000");
167 doAlternateObjectiveFunctionTest( m0, data4, s, 1.0/4.0,3*t3/4.0);
168
169 // Expected value of S should be 0.5; gradient should be zero
170 ModelManager m1 = new TestFeatureManager(1); // this has just edge features
171 doAlternateObjectiveFunctionTest( m1, data3, s, 0.5/2.0,0.0/2.0);
172
173 // Expected value of S should be 1.5; gradient should be 0
174 doAlternateObjectiveFunctionTest( m1, data4, s,1.5/4.0,0.0/4.0);
175 }
176
177 // Similar logic to above bug on a slightly bigger example.
178 // If LBGFS can converge given my gradient/value function, then my function is
179 // probably self-consistent.
180 public void testSemiRealExample() throws Exception {
181 String input = "test/input/cryptoAOFUnittest/Tiny_1_Train_Test.txt";
182 String config1 = "test/input/cryptoAOFUnittest/delta_aof_Model.xml";
183 String config2 = "test/input/cryptoAOFUnittest/delta_aof_Model_semi_nolen.xml";
184 String config3 = "test/input/cryptoAOFUnittest/delta_aof_Model_semi.xml";
185 doSameWeightsTest(config1, config2, config3, input);
186 }
187
188 void doSameWeightsTest(String config1, String config2, String config3, String input) throws Exception {
189 double[] semiMarkovWeights;
190 Conrad conrad;
191 conrad = new Conrad(config1);
192 conrad.train(input);
193 double[] regularWeights = conrad.getWeights();
194 conrad.test(input);
195
196 conrad = new Conrad(config2);
197 conrad.train(input);
198 double[] noLenWeights = conrad.getWeights();
199 conrad.test(input);
200 assertArrayEquals(regularWeights, noLenWeights, 0.001);
201
202 if(config3 != null) {
203 conrad = new Conrad(config3);
204 conrad.train(input);
205 semiMarkovWeights = conrad.getWeights();
206 conrad.test(input);
207 assertArrayEquals(noLenWeights, semiMarkovWeights, 0.001);
208 }
209 }
210
211 // The idea of the test below is that of requiring LBGFS to converge to a point with zero gradient.
212 // Doesn't this require that the gradient equal zero at a point where the function is maximized.
213 // Isn't this very unlikely to succeed if either the function value or its
214 // gradient were computed incorrcetly? I think actually this is a strong test;
215 // the gradient defines a direction for a line search, and the function value determines
216 // where on that line search to stop next. It's hard to image this process
217 // converging in 2 or more dimensions unless the function and gradient are matched.
218 public void testCachedAOFGradient() throws Exception {
219 List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111\n001111\n001111\n001111\n001111\n001111\n001111\n");
220 ModelManager m = new TestFeatureManager(2);
221
222 StandardOptimizer opt = new StandardOptimizer();
223 opt.setStarts(new double[] {0.1, 0.2});
224 opt.setRequireConvergence(true);
225 opt.setEpsForConvergence(0.0000005);
226 opt.setObjectiveFunction(new CachedAOFGradient());
227 opt.optimize(m, data);
228
229 MaximumExpectedAccuracySemiMarkovGradient semiAof = new MaximumExpectedAccuracySemiMarkovGradient();
230 semiAof.setCacheProcessor(new CacheProcessorDeluxe());
231 opt.setObjectiveFunction(semiAof);
232 opt.optimize(m, data);
233 }
234
235 void doAlternateObjectiveFunctionTest(ModelManager m, List<? extends TrainingSequence<?>> data, LocalPathSimilarityScore s, double hand_val, double hand_grad0 ) {
236
237 double[] weights = new double[m.getNumFeatures()];
238 Arrays.fill(weights,1.0);
239
240 CachedAOFGradient gradFunc = new CachedAOFGradient();
241 gradFunc.setAllPaths(true);
242 gradFunc.setScoreAlphaFile("scoreAlphaOld.txt");
243 gradFunc.setExpectedProductFile("expectedProductOld.txt");
244 gradFunc.setLocalPathSimilarityScore(s);
245 gradFunc.setTrainingData(m, data);
246
247 double val;
248 double[] grad = new double[m.getNumFeatures()];
249
250 val = gradFunc.apply(weights, grad);
251 log.info(" val = " + val + " grad[0] = " + grad[0]) ;
252 log.info("Grad(Cache,Valid Paths): " + ColtUtil.format(grad));
253 assertEquals(hand_val, val, 0.001);
254 assertEquals(hand_grad0, grad[0], 0.001);
255
256 MaximumExpectedAccuracySemiMarkovGradient semiAof = new MaximumExpectedAccuracySemiMarkovGradient();
257 semiAof.setCacheProcessor(new CacheProcessorDeluxe());
258 semiAof.setTrainingData(m, data);
259 semiAof.setMarginalsFile("marginals.txt");
260 semiAof.setScoreAlphaFile("scoreAlpha.txt");
261 semiAof.setExpectedProductFile("expectedProduct.txt");
262 val = semiAof.apply(weights, grad);
263 log.info(" val = " + val + " grad[0] = " + grad[0]) ;
264 log.info("Grad(Cache,Valid Paths): " + ColtUtil.format(grad));
265 assertEquals(hand_val, val, 0.001);
266 assertEquals(hand_grad0, grad[0], 0.001);
267 }
268 }