001 package calhoun.analysis.crf.test;
002
003 import java.util.List;
004
005 import org.apache.commons.logging.Log;
006 import org.apache.commons.logging.LogFactory;
007
008 import calhoun.analysis.crf.Conrad;
009 import calhoun.analysis.crf.ModelManager;
010 import calhoun.analysis.crf.SemiMarkovSetup;
011 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
012 import calhoun.analysis.crf.io.IntInput;
013 import calhoun.analysis.crf.io.StringInput;
014 import calhoun.analysis.crf.io.TrainingSequence;
015 import calhoun.analysis.crf.solver.CacheProcessor;
016 import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
017 import calhoun.analysis.crf.solver.MaximumLikelihoodSemiMarkovGradient;
018 import calhoun.analysis.crf.solver.NoCachingCacheProcessor;
019 import calhoun.analysis.crf.solver.StandardOptimizer;
020 import calhoun.analysis.crf.solver.CacheProcessor.FeatureEvaluation;
021 import calhoun.analysis.crf.solver.CacheProcessor.LengthFeatureEvaluation;
022 import calhoun.analysis.crf.solver.check.AllSparseLengthCacheProcessor;
023 import calhoun.util.AbstractTestCase;
024 import calhoun.util.Assert;
025
026 public class CacheProcessorTest extends AbstractTestCase {
027 private static final Log log = LogFactory.getLog(CacheProcessorTest.class);
028
029 public void testCPDRejectTrainingDataStatesTooLong() throws Exception {
030 checkFailure("test/input/interval13/config/shortIntergenicCPD.xml", "test/input/interval13/data/tooLong.txt");
031 }
032
033 public void testCPDRejectTrainingDataStatesTooShort() throws Exception {
034 checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShort.txt");
035 }
036
037 public void testAllSparseRejectTrainingDataStatesTooShort() throws Exception {
038 checkFailure("test/input/interval13/config/lengthDependentAllSparse.xml", "test/input/interval13/data/tooShort.txt");
039 }
040
041 public void testCPDRejectTrainingDataStatesTooShortStart() throws Exception {
042 checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShortStart.txt");
043 }
044
045 public void testCPDRejectTrainingDataStatesTooShortEnd() throws Exception {
046 checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/tooShortEnd.txt");
047 }
048
049 public void testCPDRejectTrainingDataStatesViolatesConstraints() throws Exception {
050 checkFailure("test/input/interval13/config/lengthDependentCPDuncommented.xml", "test/input/interval13/data/badConstraints.txt");
051 }
052
053 public void testCPDAllSequencesInvalid() throws Exception {
054 checkFailure("test/input/interval13/config/lengthDependentCPDuncommentedDiscard.xml", "test/input/interval13/data/tooShortEnd.txt");
055 }
056
057 public void testCPDAllDiscardInvalid() throws Exception {
058 Conrad conrad = new Conrad("test/input/interval13/config/lengthDependentCPDuncommentedDiscard.xml");
059 conrad.train("test/input/interval13/data/oneGoodOneBad.txt");
060 }
061
062 public void testCPDAllDiscardInvalidLocalScore() throws Exception {
063 Conrad conrad = new Conrad("test/input/interval13/config/lengthDependentCPDuncommentedDiscardLocalScore.xml");
064 conrad.train("test/input/interval13/data/oneGoodOneBad.txt");
065 }
066
067 void checkFailure(String configFile, String data) {
068 Conrad conrad = new Conrad(configFile);
069 boolean fail = false;
070 try {
071 conrad.trainFeatures(data);
072 conrad.trainWeights(conrad.getInputHandler().readTrainingData(data));
073 }
074 catch(Exception ex) {
075 log.warn(ex);
076 fail = true;
077 }
078 assertTrue(fail);
079 }
080
081 // Test edge features with 1 node, should be no evaluations
082 public void testEdgeTrivial() throws Exception {
083 int[][] indices = new int[][] { {-1}, {-1}, {-1}, {-1}, {-1}, {-1} };
084 float[][] vals = new float[0][0];
085 double[] featureSums = new double[] { 0.0 };
086 List<? extends TrainingSequence<?>> data = IntInput.prepareData("0\n0");
087 System.out.println("number of data sequences is " + data.size());
088 doTest(1, data, 0, 0, indices, vals, featureSums);
089 }
090
091 // Test edge features with 2 and 3 positions.
092 public void testEdgeShort() throws Exception {
093 int[][] indices = new int[][] { {-1}, {-1}, {0, -1}, {0, -1}, {0, -1}, {0, -1} };
094 float[][] vals = new float[][] { {}, {}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f} };
095 double[] featureSums = new double[] { -0.4054651f };
096 doTest(1, IntInput.prepareData("00\n00"), 0, 1, indices, vals, featureSums);
097 }
098
099 // Test node features
100 public void testNode() throws Exception {
101 int[][] indices = new int[][] { {0, -1}, {0, -1}, {-1}, {-1}, {-1}, {-1} };
102 float[][] vals1 = new float[][] { {-1.0986123f}, {-0.4054651f}, {}, {}, {}, {} };
103 double[] featureSums = new double[] { -1.0986123f*2 + -0.4054651f*4 };
104 doTest(0, IntInput.prepareData("001111\n001111"), 0, 0, indices, vals1, featureSums); // the DCP feature sums for this are wrong
105 float[][] vals2 = new float[][] { {-1.0986123f}, {-0.4054651f}, {}, {}, {}, {} };
106 doTest(0, IntInput.prepareData("001111\n001111"), 0, 4, indices, vals2, featureSums);
107 }
108
109 // Test 2 features
110 public void testTwoFeaturesTrivial() throws Exception {
111 int[][] indices1 = new int[][] { {0, -1}, {0, -1}, {-1}, {-1}, {-1}, {-1} };
112 float[][] vals1 = new float[][] { {-1.0986123f}, {-0.4054651f}, {}, {}, {}, {} };
113 double[] featureSums = new double[] { -1.0986123f*2, -0.4054651f };
114 doTest(2, IntInput.prepareData("00\n00"), 0, 0, indices1, vals1, featureSums);
115 int[][] indices2 = new int[][] { {0, -1}, {0, -1}, {1, -1}, {1, -1}, {1, -1}, {1, -1} };
116 float[][] vals2 = new float[][] { {-1.0986123f}, {-0.4054651f}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f} };
117 featureSums = new double[] { -1.0986123f*4, -0.4054651f*2 };
118 doTest(2, IntInput.prepareData("00\n00\n00\n00"), 1, 1, indices2, vals2, featureSums);
119 }
120
121 // Test 2 features non-trivial
122 public void testTwoFeaturesNonTrivial() throws Exception {
123 int[][] indices = new int[][] { {0, -1}, {0, -1}, {1, -1}, {1, -1}, {1, -1}, {1, -1} };
124 float[][] vals = new float[][] { {-1.0986123f}, {-0.4054651f}, {-0.4054651f}, {-1.0986123f}, {-1.0986123f}, {-0.4054651f} };
125 double[] featureSums = new double[] { -33.54728, -29.963765};
126 doTest(2, IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"), 1, 4, indices, vals, featureSums);
127 }
128
129 public void testLengthCacheDummy() throws Exception {
130 int[][] lookbacks = new int[][] { {0, 1, -1}, {0, 1, -1} };
131 int[][] nodeIndices = new int[][] { {0, 0, 0, -1}, {1, 0, 0, -1} };
132 float[][] nodeValues = new float[0][4];
133 ModelManager m = new TestFeatureManager(1);
134 doLengthTest(m, IntInput.prepareData("00\n00"), 0, 1, 2, lookbacks, nodeIndices, nodeValues);
135
136 m = new TestFeatureManager(2);
137 lookbacks = new int[][] { {0, 1, 2, 3, -1}, {0, 1, 2, 3, -1} };
138 nodeIndices = new int[][] { {0, 0, 0, -1}, {0, 1, 0, -1}, {0, 2, 0, -1}, {0, 3, 0, -1}, {1, 0, 0, -1} };
139 doLengthTest(m, IntInput.prepareData("0000\n0000"), 0, 3, 2, lookbacks, nodeIndices, nodeValues);
140 }
141
142 public void testLengthCache() throws Exception {
143 int[][] lookbacks = new int[][] { {0, 1, 2, 3, -1} };
144 int[][] nodeIndices = new int[0][4];
145 float[][] nodeValues = new float[][] { {0, 0, 0, -0.0f}, {0, 1, 0, -0.11157f} , {0, 2, 0, -0.11157f*2}, {0, 3, 0, -0.11157f*3}};
146 Conrad c = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
147 ModelManager m = c.getModel();
148 List<? extends TrainingSequence<?>> data = StringInput.prepareData("00110\nATGCA");
149 c.trainFeatures(data);
150 CacheProcessor cp = ((MaximumLikelihoodSemiMarkovGradient) ((StandardOptimizer)c.getOptimizer()).getObjectiveFunction()).getCacheProcessor();
151 cp.setTrainingData(m, data);
152 cp.evaluateSegmentsEndingAt(0, 3);
153 LengthFeatureEvaluation[][] lenEvals = cp.getLengthFeatureEvaluations();
154 checkLengthEvals(lenEvals, 1, lookbacks, nodeIndices, nodeValues);
155 }
156
157 void doLengthTest(ModelManager m, List<? extends TrainingSequence<?>> data, int seq, int pos, int nStates, int[][] lookback, int[][] nodeIndices, float[][] nodeValues) {
158 AllSparseLengthCacheProcessor cp = new AllSparseLengthCacheProcessor();
159 SemiMarkovSetup setup = new SemiMarkovSetup(new short[] {4, 4});
160 setup.setIgnoreSemiMarkovSelfTransitions(true);
161 cp.setSemiMarkovSetup(setup);
162 cp.setTrainingData(m, data);
163 cp.evaluateSegmentsEndingAt(seq, pos);
164 LengthFeatureEvaluation[][] lenEvals = cp.getLengthFeatureEvaluations();
165 checkLengthEvals(lenEvals, nStates, lookback, nodeIndices, nodeValues);
166 }
167
168 void checkLengthEvals(LengthFeatureEvaluation[][] lenEvals, int nStates, int[][] lookback, int[][] nodeIndices, float[][] nodeValues) {
169 assertEquals(nStates, lenEvals.length);
170 for(int i=0; i<lookback.length; ++i) {
171 for(int j=0; j<lookback[i].length; ++j) {
172 assertEquals(lookback[i][j], lenEvals[i][j].lookback);
173 }
174 }
175 for(int[] entry : nodeIndices) {
176 assertEquals(entry[3], lenEvals[entry[0]][entry[1]].nodeEval.index[entry[2]]);
177 }
178
179 for(float[] entry : nodeValues) {
180 assertEquals(entry[3], lenEvals[(int)entry[0]][(int)entry[1]].nodeEval.value[(int)entry[2]], 0.0001);
181 }
182 }
183
184 void doTest(int mmNum, List<? extends TrainingSequence<?>> data, int seq, int pos, int[][] indices, float[][] vals, double[] featureSums) {
185 ModelManager m = new TestFeatureManager(mmNum);
186
187 AllSparseLengthCacheProcessor cp = new AllSparseLengthCacheProcessor();
188 testOneCacheProcessor(cp,m,data,seq,pos,indices,vals,featureSums);
189
190 NoCachingCacheProcessor ncp = new NoCachingCacheProcessor();
191 testOneCacheProcessor(ncp,m,data,seq,pos,indices,vals,featureSums);
192
193 CacheProcessorDeluxe dcp = new CacheProcessorDeluxe();
194 //dcp.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
195 testOneCacheProcessor(dcp,m,data,seq,pos,indices,vals,featureSums);
196
197 CacheProcessorDeluxe dcp2 = new CacheProcessorDeluxe(CacheStrategy.CONSTANT);
198 //dcp2.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
199 testOneCacheProcessor(dcp2,m,data,seq,pos,indices,vals,featureSums);
200
201 CacheProcessorDeluxe dcp3 = new CacheProcessorDeluxe(CacheStrategy.DENSE);
202 //dcp3.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
203 testOneCacheProcessor(dcp3,m,data,seq,pos,indices,vals,featureSums);
204
205 CacheProcessorDeluxe dcp4 = new CacheProcessorDeluxe(CacheStrategy.SPARSE);
206 //dcp4.setSemiMarkovSetup(new SemiMarkovSetup(new short[]{1,1}, new short[]{50,50},true));
207 testOneCacheProcessor(dcp4,m,data,seq,pos,indices,vals,featureSums);
208 }
209
210 void testOneCacheProcessor(CacheProcessor dcp, ModelManager m, List<? extends TrainingSequence<?>> data, int seq, int pos, int[][] indices, float[][] vals, double[] featureSums) {
211 //dcp.setAllPaths(false);
212 dcp.setTrainingData(m, data);
213 dcp.evaluatePosition(seq, pos);
214 if (pos > 0) {
215 assertEvalEquals(dcp.getFeatureEvaluations(), indices, vals);
216 } else {
217 assertNonedgeEvalEquals(m,dcp.getFeatureEvaluations(), indices, vals);
218 }
219 assertArrayEquals(featureSums, dcp.getFeatureSums(), 0.00001);
220 }
221
222
223 private void assertNonedgeEvalEquals(ModelManager m, FeatureEvaluation[] evals, int[][] indices, float[][] vals) {
224 log.warn(evals);
225
226 Assert.a(indices.length >= m.getNumStates());
227
228 for(int i=0; i<m.getNumStates(); ++i) {
229 for(int j=0; j<indices[i].length; ++j) {
230 assertEquals(indices[i][j], evals[i].index[j]);
231 }
232 }
233 for(int i=0; i<vals.length; ++i) {
234 for(int j=0; j<vals[i].length; ++j) {
235 assertEquals(vals[i][j], evals[i].value[j], .00001);
236 }
237 }
238 }
239
240 private void assertEvalEquals(FeatureEvaluation[] evals, int[][] indices, float[][] vals) {
241 log.warn(evals);
242 for(int i=0; i<indices.length; ++i) {
243 for(int j=0; j<indices[i].length; ++j) {
244 assertEquals(indices[i][j], evals[i].index[j]);
245 }
246 }
247 for(int i=0; i<vals.length; ++i) {
248 for(int j=0; j<vals[i].length; ++j) {
249 assertEquals(vals[i][j], evals[i].value[j], .00001);
250 }
251 }
252 }
253 }