001    package calhoun.analysis.crf.features.generic;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import calhoun.analysis.crf.AbstractFeatureManager;
007    import calhoun.analysis.crf.FeatureList;
008    import calhoun.analysis.crf.FeatureManagerNode;
009    import calhoun.analysis.crf.ModelManager;
010    import calhoun.analysis.crf.io.InputSequence;
011    import calhoun.util.Assert;
012    
013    /** indicator functions that evaluate to true for a selected set of start states at the first position in the sequence. */
014    public class EndFeatures extends AbstractFeatureManager<Object> implements FeatureManagerNode<Object> {
015            private static final long serialVersionUID = 5206582215565363768L;
016            int startIx;
017            int[] endStates;
018            String[] names;
019            
020            int[] configStates = new int[] {0, 1};
021    
022            public void setStates(int[] config) {
023                    configStates = config;
024            }
025            
026            public String getFeatureName(int featureIndex) {
027                    Assert.a(featureIndex - startIx < names.length, "Invalid feature index");
028                    return names[featureIndex - startIx];
029            }
030    
031            public int getNumFeatures() {
032                    return names.length;
033            }
034    
035            public void evaluateNode(InputSequence<?> seq, int pos, int state, FeatureList result) {
036                    if(pos == seq.length()-1) {
037                            int index = endStates[state];
038                            if(index != -1) {
039                                    result.addFeature(index, 1);
040                            }
041                    }
042            }
043    
044            /** Edge features don't train based on the data.  Just set up based on the model. */
045            public void train(int startingIndex, ModelManager modelInfo, List data) {
046                    startIx = startingIndex;
047                    endStates = new int[modelInfo.getNumStates()];
048                    Arrays.fill(endStates, -1);
049                    names = new String[configStates.length];
050                    for(int i=0; i<configStates.length; ++i) {
051                            endStates[configStates[i]] = startIx+i;
052                            names[i] = "End."+modelInfo.getStateName(configStates[i]);
053                    }
054            }
055    }