001    package calhoun.analysis.crf.test;
002    
003    import java.util.List;
004    
005    import calhoun.analysis.crf.AbstractFeatureManager;
006    import calhoun.analysis.crf.CacheStrategySpec;
007    import calhoun.analysis.crf.FeatureList;
008    import calhoun.analysis.crf.FeatureManager;
009    import calhoun.analysis.crf.FeatureManagerEdge;
010    import calhoun.analysis.crf.FeatureManagerNode;
011    import calhoun.analysis.crf.FeatureManagerNodeExplicitLength;
012    import calhoun.analysis.crf.ModelManager;
013    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
014    import calhoun.analysis.crf.io.InputSequence;
015    import calhoun.analysis.crf.io.TrainingSequence;
016    
017    public class TestFeatures {
018            public static abstract class TestFeature extends AbstractFeatureManager<Character> implements FeatureManager<Character> {
019    
020                    int startIx;
021                    
022                    public String getFeatureName(int featureIndex) {
023                            return getClass().getName();
024                    }
025    
026                    public int getNumFeatures() {
027                            return 1;
028                    }
029    
030                    public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
031                            startIx = startingIndex;
032                    }
033            }
034    
035            /* Returns log(.8) for G/C in state 1 or A/T in state 0, log(.2) otherwise */
036            public static class EmissionFeature extends TestFeature implements FeatureManagerNode<Character> {
037                    private static final long serialVersionUID = 8039479168741225007L;
038    
039                    public void evaluateNode(InputSequence<? extends Character> seq, int pos, int state, FeatureList result) {
040                            Character c = seq.getX(pos);
041                            boolean match = state == 0;
042                            if(c.charValue() == 'G' || c.charValue() == 'C') {
043                                    match = state == 1;
044                            }
045                            result.addFeature(startIx, match ? Math.log(.8) : Math.log(.2));
046                    }
047                    @Override
048                    public CacheStrategySpec getCacheStrategy() {
049                            return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
050                    }
051            }
052    
053            /* Returns 0 for state 0 transitions.  Returns guassian with mena of 6 and std dev of 1. for state 1 */
054            public static class GaussianLengthFeature extends TestFeature implements FeatureManagerNodeExplicitLength<Character> {
055                    private static final long serialVersionUID = 6050417482057409153L;
056    
057                    public void evaluateNodeLength(InputSequence<? extends Character> seq, int pos, int length, int state, FeatureList result) {
058                            double val = 0.0;
059                            if(state == 1) {
060                                    val = Math.pow(length - 6, 2);
061                                    result.addFeature(startIx, val);
062                            }
063                    }
064                    @Override
065                    public CacheStrategySpec getCacheStrategy() {
066                            return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
067                    }
068            }
069    
070            /* Has a value of 1/2 the edge feature value for staying the same transition. */
071            public static class ExplicitHalfExponentialLengthFeature extends TestFeature implements FeatureManagerNodeExplicitLength<Character>, FeatureManagerEdge<Character> {
072                    private static final long serialVersionUID = 6050417482057409153L;
073    
074                    public void evaluateNodeLength(InputSequence<? extends Character> seq, int pos, int length, int state, FeatureList result) {
075                            if(state == 1) {
076                                    float val = (float)((length-1)*Math.log(.8)/2.0f);
077                                    result.addFeature(startIx, val);
078                            }
079                    }
080    
081                    float[] vals = new float[] {.8f, .2f, .2f, (float) Math.sqrt(.8)};
082                    
083                    public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int prevState, int state, FeatureList result) {
084                            result.addFeature(startIx, Math.log(vals[prevState + 2*state]));
085                    }
086                    @Override
087                    public CacheStrategySpec getCacheStrategy() {
088                            return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
089                    }
090            }
091    
092            public static class ExponentialLengthFeature extends TestFeature implements FeatureManagerEdge<Character> {
093                    private static final long serialVersionUID = 290546936098017942L;
094    
095                    float[] vals = new float[] {.8f, .2f, .2f, .8f};
096                    
097                    public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int prevState, int state, FeatureList result) {
098                            result.addFeature(startIx, Math.log(vals[prevState + 2*state]));
099                    }
100                    @Override
101                    public CacheStrategySpec getCacheStrategy() {
102                            return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
103                    }
104            }
105    
106            public static class HalfExponentialLengthFeature extends TestFeature implements FeatureManagerEdge<Character> {
107                    private static final long serialVersionUID = 290546936098017942L;
108    
109                    float[] vals = new float[] {.8f, .2f, .2f, (float) Math.sqrt(.8)};
110                    
111                    public void evaluateEdge(InputSequence<? extends Character> seq, int pos, int prevState, int state, FeatureList result) {
112                            result.addFeature(startIx, Math.log(vals[prevState + 2*state]));
113                    }
114                    @Override
115                    public CacheStrategySpec getCacheStrategy() {
116                            return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
117                    }
118            }
119    }