001    package calhoun.analysis.crf.test;
002    
003    import java.io.BufferedWriter;
004    import java.io.FileWriter;
005    import java.io.IOException;
006    import java.io.Writer;
007    import java.util.List;
008    
009    import org.apache.commons.lang.StringUtils;
010    import org.apache.commons.logging.Log;
011    import org.apache.commons.logging.LogFactory;
012    
013    import calhoun.analysis.crf.Conrad;
014    import calhoun.analysis.crf.ModelManager;
015    import calhoun.analysis.crf.features.tricycle13.KmerFeatures;
016    import calhoun.analysis.crf.io.InputHandlerInterleaved;
017    import calhoun.analysis.crf.io.InputSequence;
018    import calhoun.analysis.crf.io.StringInput;
019    import calhoun.analysis.crf.io.TrainingSequence;
020    import calhoun.analysis.crf.solver.check.ArrayFeatureList;
021    import calhoun.util.AbstractTestCase;
022    import calhoun.util.Assert;
023    
024    public class FeatureTest extends AbstractTestCase {
025            private static final Log log = LogFactory.getLog(FeatureTest.class);
026            boolean debug = log.isDebugEnabled();
027            
028            public void testKmerFeatures() throws Exception {
029                    List<TrainingSequence<Character>> data = (List<TrainingSequence<Character>>) new InputHandlerInterleaved(new StringInput()).readTrainingData("test/input/zeroOrderTest.txt");
030                    KmerFeatures kf = new KmerFeatures();
031                    kf.setRareThreshold(0);
032                    kf.train(0, new ZeroOrderManager(), data);
033                    assertEquals(2, kf.getNumFeatures());
034                    assertEquals(expectedProb(53, 30), kf.getKmerProb(0, "A", 0), .01);
035                    assertEquals(expectedProb(30, 53), kf.getKmerProb(0, "A", 1), .01);
036                    assertEquals(expectedProb(55, 38), kf.getKmerProb(0, "T", 0), .01);
037                    assertEquals(expectedProb(38, 55), kf.getKmerProb(0, "T", 1), .01);
038                    assertEquals(expectedProb(21, 91), kf.getKmerProb(0, "C", 0), .01);
039                    assertEquals(expectedProb(91, 21), kf.getKmerProb(0, "C", 1), .01);
040                    assertEquals(expectedProb(23, 82), kf.getKmerProb(0, "G", 0), .01);
041                    assertEquals(expectedProb(82, 23), kf.getKmerProb(0, "G", 1), .01);
042            }
043    
044            double expectedProb(int a, int b) {
045                    return Math.log((a+1)/(float)(a+b+2));
046            }
047            
048            public void testFeatureNames() throws Exception {
049                    List<? extends TrainingSequence<?>> data = new InputHandlerInterleaved(new StringInput()).readTrainingData("test/input/zeroOrderTest.txt");
050                    Conrad runner = ZeroOrderManager.getCRF();
051                    runner.trainFeatures(data);
052                    assertEquals("Start.lowGC", runner.getModel().getFeatureName(0));
053                    assertEquals("End.highGC", runner.getModel().getFeatureName(3));
054                    assertEquals("Kmer.lowGC.0", runner.getModel().getFeatureName(4));
055                    assertEquals("Kmer.highGC.0", runner.getModel().getFeatureName(5));
056                    assertEquals("Edge.lowGC-lowGC", runner.getModel().getFeatureName(6));
057                    assertEquals(10, runner.getModel().getNumFeatures());
058            }
059    
060            public void testWriteFeatures() throws Exception {
061                    List<? extends TrainingSequence<?>> data = new InputHandlerInterleaved(new StringInput()).readTrainingData("test/input/zeroOrderTest.txt");
062                    Conrad runner = ZeroOrderManager.getCRF();
063                    runner.trainFeatures(data);
064                    writeFeatures(runner, "test/working/zeroOrderFeatures.csv", data.get(0));
065                    assertFilesMatch("test/output/zeroOrderFeatures.csv", "test/working/zeroOrderFeatures.csv");
066            }
067    
068            public void testComponentFeatures() throws Exception {
069                    // Test that component features work correctly
070                    Conrad runner = new Conrad("test/input/componentFeatures.xml");
071                    List<? extends TrainingSequence<?>> data = runner.getInputHandler().readTrainingData("test/input/testTrain.txt");
072                    runner.trainFeatures(data);
073                    ArrayFeatureList result = new ArrayFeatureList(runner.getModel()); 
074                    result.evaluateNode(data.get(0), 0, 0);
075            }
076    
077            public void writeFeatures(Conrad runner, String file, InputSequence data) throws IOException {
078                    ModelManager model = runner.getModel();
079                    boolean training = TrainingSequence.class.isInstance(data);
080                    int numFixed = 2 + (training ? 1 : 0);
081                    int totalFeatures = model.getNumFeatures();
082                    String[] line = new String[numFixed + totalFeatures];
083                    
084                    Writer fw = new BufferedWriter(new FileWriter(file));
085                    // Write the header line
086                    line[0] = "pos";
087                    line[1] = "x";
088                    if(training) {
089                            line[2] = "y";
090                    }
091                    for (int i = 0; i < totalFeatures; ++i) {
092                            line[i + numFixed] = model.getFeatureName(i);
093                    }
094                    fw.write('#' + StringUtils.join(line, "\t") + '\n');
095                    
096                    for (int i = 0; i < data.length(); ++i) {
097                            line[0] = Integer.toString(i);
098                            line[1] = data.getX(i).toString();
099                            if(training) {
100                                    line[2] = Integer.toString(((TrainingSequence) data).getY(i));
101                            }
102                            for(int j = 0; j<totalFeatures; ++j) {
103                                    line[numFixed+j] = "0";
104                            }
105                            int nStates = model.getNumStates();
106                            ArrayFeatureList result = new ArrayFeatureList(model);
107                            for(int state = 0; state < nStates; ++state) {
108                                    result.evaluateNode(data, i, state);
109                                    updateFeatureValues(result, line, numFixed);
110                                    if(i > 0) {
111                                            for(int prevState = 0; prevState < nStates; ++prevState) {
112                                                    result.evaluateEdge(data, i, prevState, state);
113                                                    updateFeatureValues(result, line, numFixed);
114                                            }
115                                    }
116                            }
117                            fw.write(StringUtils.join(line, "\t") + '\n');
118                    }
119                    fw.close();
120            }
121    
122            void updateFeatureValues(ArrayFeatureList result, String[] line, int numFixed) {
123                    int[] indices = result.getIndices();
124                    double[] vals = result.getValues();
125                    int size = result.size();
126                    for(int i = 0; i<size; ++i) {
127                            Assert.a(line[numFixed + indices[i]].equals("0"), "Feature had a previous value: "+line[numFixed + indices[i]]);
128                            line[numFixed + indices[i]] = Float.toString((float)vals[i]);
129                    }
130            }
131    }