001    package calhoun.analysis.crf.features.generic;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import calhoun.analysis.crf.AbstractFeatureManager;
007    import calhoun.analysis.crf.FeatureList;
008    import calhoun.analysis.crf.FeatureManagerNode;
009    import calhoun.analysis.crf.ModelManager;
010    import calhoun.analysis.crf.io.InputSequence;
011    import calhoun.util.Assert;
012    
013    /** indicator functions that evaluate to true for a selected set of start states at the first position in the sequence. */
014    public class StartFeatures extends AbstractFeatureManager<Object> implements FeatureManagerNode<Object> {
015            private static final long serialVersionUID = 2578820778300251051L;
016            int startIx;
017            int[] startStates;
018            String[] names;
019    
020            int[] configStates = new int[] {0, 1};
021    
022            public void setStates(int[] config) {
023                    configStates = config;
024            }
025            
026            public String getFeatureName(int featureIndex) {
027                    int raw = featureIndex - startIx;
028                    Assert.a(raw >= 0 && raw < names.length, "Invalid feature index");
029                    return names[raw];
030            }
031    
032            public int getNumFeatures() {
033                    return names.length;
034            }
035    
036            public void evaluateNode(InputSequence<?> seq, int pos, int state, FeatureList result) {
037                    if(pos == 0) {
038                            int index = startStates[state];
039                            if(index != -1) {
040                                    result.addFeature(index, 1);
041                            }
042                    }
043            }
044    
045            /** Start features don't train based on the data.  Just set up based on the model. */
046            public void train(int startingIndex, ModelManager modelInfo, List data) {
047                    startIx = startingIndex;
048                    startStates = new int[modelInfo.getNumStates()];
049                    Arrays.fill(startStates, -1);
050                    names = new String[configStates.length];
051                    for(int i=0; i<configStates.length; ++i) {
052                            startStates[configStates[i]] = startIx+i;
053                            names[i] = "Start."+modelInfo.getStateName(configStates[i]);
054                    }
055            }
056    }