001 package calhoun.analysis.crf.test;
002
003 import java.io.BufferedWriter;
004 import java.io.FileWriter;
005 import java.io.IOException;
006 import java.io.Writer;
007 import java.util.List;
008
009 import org.apache.commons.lang.StringUtils;
010 import org.apache.commons.logging.Log;
011 import org.apache.commons.logging.LogFactory;
012
013 import calhoun.analysis.crf.Conrad;
014 import calhoun.analysis.crf.ModelManager;
015 import calhoun.analysis.crf.features.tricycle13.KmerFeatures;
016 import calhoun.analysis.crf.io.InputHandlerInterleaved;
017 import calhoun.analysis.crf.io.InputSequence;
018 import calhoun.analysis.crf.io.StringInput;
019 import calhoun.analysis.crf.io.TrainingSequence;
020 import calhoun.analysis.crf.solver.check.ArrayFeatureList;
021 import calhoun.util.AbstractTestCase;
022 import calhoun.util.Assert;
023
024 public class FeatureTest extends AbstractTestCase {
025 private static final Log log = LogFactory.getLog(FeatureTest.class);
026 boolean debug = log.isDebugEnabled();
027
028 public void testKmerFeatures() throws Exception {
029 List<TrainingSequence<Character>> data = (List<TrainingSequence<Character>>) new InputHandlerInterleaved(new StringInput()).readTrainingData("test/input/zeroOrderTest.txt");
030 KmerFeatures kf = new KmerFeatures();
031 kf.setRareThreshold(0);
032 kf.train(0, new ZeroOrderManager(), data);
033 assertEquals(2, kf.getNumFeatures());
034 assertEquals(expectedProb(53, 30), kf.getKmerProb(0, "A", 0), .01);
035 assertEquals(expectedProb(30, 53), kf.getKmerProb(0, "A", 1), .01);
036 assertEquals(expectedProb(55, 38), kf.getKmerProb(0, "T", 0), .01);
037 assertEquals(expectedProb(38, 55), kf.getKmerProb(0, "T", 1), .01);
038 assertEquals(expectedProb(21, 91), kf.getKmerProb(0, "C", 0), .01);
039 assertEquals(expectedProb(91, 21), kf.getKmerProb(0, "C", 1), .01);
040 assertEquals(expectedProb(23, 82), kf.getKmerProb(0, "G", 0), .01);
041 assertEquals(expectedProb(82, 23), kf.getKmerProb(0, "G", 1), .01);
042 }
043
044 double expectedProb(int a, int b) {
045 return Math.log((a+1)/(float)(a+b+2));
046 }
047
048 public void testFeatureNames() throws Exception {
049 List<? extends TrainingSequence<?>> data = new InputHandlerInterleaved(new StringInput()).readTrainingData("test/input/zeroOrderTest.txt");
050 Conrad runner = ZeroOrderManager.getCRF();
051 runner.trainFeatures(data);
052 assertEquals("Start.lowGC", runner.getModel().getFeatureName(0));
053 assertEquals("End.highGC", runner.getModel().getFeatureName(3));
054 assertEquals("Kmer.lowGC.0", runner.getModel().getFeatureName(4));
055 assertEquals("Kmer.highGC.0", runner.getModel().getFeatureName(5));
056 assertEquals("Edge.lowGC-lowGC", runner.getModel().getFeatureName(6));
057 assertEquals(10, runner.getModel().getNumFeatures());
058 }
059
060 public void testWriteFeatures() throws Exception {
061 List<? extends TrainingSequence<?>> data = new InputHandlerInterleaved(new StringInput()).readTrainingData("test/input/zeroOrderTest.txt");
062 Conrad runner = ZeroOrderManager.getCRF();
063 runner.trainFeatures(data);
064 writeFeatures(runner, "test/working/zeroOrderFeatures.csv", data.get(0));
065 assertFilesMatch("test/output/zeroOrderFeatures.csv", "test/working/zeroOrderFeatures.csv");
066 }
067
068 public void testComponentFeatures() throws Exception {
069 // Test that component features work correctly
070 Conrad runner = new Conrad("test/input/componentFeatures.xml");
071 List<? extends TrainingSequence<?>> data = runner.getInputHandler().readTrainingData("test/input/testTrain.txt");
072 runner.trainFeatures(data);
073 ArrayFeatureList result = new ArrayFeatureList(runner.getModel());
074 result.evaluateNode(data.get(0), 0, 0);
075 }
076
077 public void writeFeatures(Conrad runner, String file, InputSequence data) throws IOException {
078 ModelManager model = runner.getModel();
079 boolean training = TrainingSequence.class.isInstance(data);
080 int numFixed = 2 + (training ? 1 : 0);
081 int totalFeatures = model.getNumFeatures();
082 String[] line = new String[numFixed + totalFeatures];
083
084 Writer fw = new BufferedWriter(new FileWriter(file));
085 // Write the header line
086 line[0] = "pos";
087 line[1] = "x";
088 if(training) {
089 line[2] = "y";
090 }
091 for (int i = 0; i < totalFeatures; ++i) {
092 line[i + numFixed] = model.getFeatureName(i);
093 }
094 fw.write('#' + StringUtils.join(line, "\t") + '\n');
095
096 for (int i = 0; i < data.length(); ++i) {
097 line[0] = Integer.toString(i);
098 line[1] = data.getX(i).toString();
099 if(training) {
100 line[2] = Integer.toString(((TrainingSequence) data).getY(i));
101 }
102 for(int j = 0; j<totalFeatures; ++j) {
103 line[numFixed+j] = "0";
104 }
105 int nStates = model.getNumStates();
106 ArrayFeatureList result = new ArrayFeatureList(model);
107 for(int state = 0; state < nStates; ++state) {
108 result.evaluateNode(data, i, state);
109 updateFeatureValues(result, line, numFixed);
110 if(i > 0) {
111 for(int prevState = 0; prevState < nStates; ++prevState) {
112 result.evaluateEdge(data, i, prevState, state);
113 updateFeatureValues(result, line, numFixed);
114 }
115 }
116 }
117 fw.write(StringUtils.join(line, "\t") + '\n');
118 }
119 fw.close();
120 }
121
122 void updateFeatureValues(ArrayFeatureList result, String[] line, int numFixed) {
123 int[] indices = result.getIndices();
124 double[] vals = result.getValues();
125 int size = result.size();
126 for(int i = 0; i<size; ++i) {
127 Assert.a(line[numFixed + indices[i]].equals("0"), "Feature had a previous value: "+line[numFixed + indices[i]]);
128 line[numFixed + indices[i]] = Float.toString((float)vals[i]);
129 }
130 }
131 }