001 package calhoun.analysis.crf.test;
002
003 import java.util.List;
004
005 import calhoun.analysis.crf.AbstractFeatureManager;
006 import calhoun.analysis.crf.CacheStrategySpec;
007 import calhoun.analysis.crf.Conrad;
008 import calhoun.analysis.crf.FeatureList;
009 import calhoun.analysis.crf.FeatureManagerEdge;
010 import calhoun.analysis.crf.ModelManager;
011 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
012 import calhoun.analysis.crf.io.InputSequence;
013 import calhoun.util.AbstractTestCase;
014
015 public class ConstraintTest extends AbstractTestCase {
016 public void testViterbiConstraints() throws Exception {
017 Conrad r = new Conrad("test/input/constraintModel.xml");
018 r.trainFeatures("test/input/shortTrain.tricycle13.txt");
019 r.setWeights(new double[] {1.0});
020
021 // Test that the feature name prints out correctly
022 assertEquals("Edge Feature", r.getModel().getFeatureName(0));
023
024 // Test that the constraints are observed by checking against a manually validated set.
025 r.test("test/input/shortTrain.tricycle13.txt", "test/working/constraintTest.txt");
026 assertFilesMatch("test/output/constraintTest.txt", "test/working/constraintTest.txt");
027 }
028
029
030 public void testConstraintsDifferentCaches() throws Exception {
031 Conrad r = new Conrad("test/input/constraintModel.xml");
032 r.train("test/input/genesShorterTrain.txt");
033 double normWeight = r.getWeights()[0];
034
035 r = new Conrad("test/input/constraintModelCache.xml");
036 r.train("test/input/genesShorterTrain.txt");
037 assertEquals(normWeight, r.getWeights()[0], 0.0001);
038 }
039
040 /** Edge class used for testing that favors changing states whenever possible. Used to generate interesting paths for constraints */
041 public static class FixedEdges extends AbstractFeatureManager implements FeatureManagerEdge {
042 private static final long serialVersionUID = 5995552526733022868L;
043 int startIx;
044 int[] clusters = new int[] {0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4};
045 public void evaluateEdge(InputSequence seq, int pos, int prevState, int state, FeatureList result) {
046 result.addFeature(startIx, clusters[state] == clusters[prevState] ? Math.log(.3) : Math.log(.7));
047 }
048
049 public String getFeatureName(int featureIndex) {
050 return "Edge Feature";
051 }
052
053 public int getNumFeatures() {
054 return 1;
055 }
056
057 public void train(int startingIndex, ModelManager modelInfo, List data) {
058 startIx = startingIndex;
059 }
060
061 @Override
062 public CacheStrategySpec getCacheStrategy() {
063 return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
064 }
065
066 }
067
068
069 }