001    package calhoun.analysis.crf.features.tricycle13;
002    
003    import java.io.Serializable;
004    import java.util.ArrayList;
005    import java.util.List;
006    
007    import org.apache.commons.logging.Log;
008    import org.apache.commons.logging.LogFactory;
009    
010    import calhoun.analysis.crf.AbstractFeatureManager;
011    import calhoun.analysis.crf.CacheStrategySpec;
012    import calhoun.analysis.crf.FeatureList;
013    import calhoun.analysis.crf.FeatureManagerNode;
014    import calhoun.analysis.crf.ModelManager;
015    import calhoun.analysis.crf.BeanModel.Node;
016    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
017    import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
018    import calhoun.analysis.crf.io.InputSequence;
019    import calhoun.analysis.crf.io.TrainingSequence;
020    
021    public class EmissionMarkovFeature extends AbstractFeatureManager<Character> implements FeatureManagerNode<Character> {
022            private static final long serialVersionUID = -7659288739348604129L;
023            private static final Log log = LogFactory.getLog(EmissionMarkovFeature.class);
024            boolean debug = log.isDebugEnabled();
025            
026            int startIx;  // The index of the first feature managed by this FeatureManager
027            ModelManager model;
028            
029            // Following block are things that depend explicitly on and are calculated directly
030            //  from geometry, included only for convenience. 
031            MarkovPredictorLogprob predictorlp;
032            boolean tieFlag = false;
033    
034            public static class MarkovHistory implements Serializable {
035                    private static final long serialVersionUID = 8731309130784681776L;
036    
037                    List<List<Node>> states;
038    
039                    public List<List<Node>> getStates() {
040                            return states;
041                    }
042    
043                    public void setStates(List<List<Node>> history) {
044                            this.states = history;
045                    }
046                    
047                    public List<int[]> convert() {
048                            List<int[]> historyArray= new ArrayList<int[]>();
049                            for(List<Node> nodeList : states) {
050                                    int[] historyEntry = new int[nodeList.size()];
051                                    historyArray.add(historyEntry);
052                                    for(int i=0; i<historyEntry.length; ++i) {
053                                            historyEntry[i] = nodeList.get(i).getIndex();
054                                    }
055                            }
056                            return historyArray;
057                    }
058            }
059            
060            public EmissionMarkovFeature() {
061            }
062    
063            public void setHistory(MarkovHistory markovHistory) {
064                    this.predictorlp = new MarkovPredictorLogprob(markovHistory.convert());
065            }
066            
067            public EmissionMarkovFeature(List<int[]> history) {
068                    this.predictorlp = new MarkovPredictorLogprob(history);
069            }
070    
071            public EmissionMarkovFeature(List<int[]> history, List<int[]> flags) {
072                    this.predictorlp = new MarkovPredictorLogprob(history);
073                    tieFlag = true;
074            }
075    
076            public int getNumFeatures() {
077                    return tieFlag ? 1 : model.getNumStates();
078            }       
079            
080            public String getFeatureName(int featureIndex) {
081                    if (tieFlag) { 
082                            return "TiedEmissionMarkovFeature";
083                    } else {
084                            int raw = featureIndex - startIx;               
085                            return "EmissionMarkov.span" + model.getStateName(raw);
086                    }
087            }
088            
089            public void evaluateNode(InputSequence<? extends Character> seq, int pos, int state, FeatureList result) {
090                    int index = startIx + (tieFlag ? 0 : state);
091                    result.addFeature(index, predictorlp.logprob(state,seq,pos));
092            }
093            
094            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
095                    startIx = startingIndex;
096                    model = modelInfo;
097                    
098                    predictorlp.train(data);
099            }
100    
101            @Override
102            public CacheStrategySpec getCacheStrategy() {
103                    return new CacheStrategySpec(CacheStrategy.DENSE);
104            }
105    }
106    
107