001 package calhoun.analysis.crf.features.tricycle13;
002
003 import java.io.Serializable;
004 import java.util.ArrayList;
005 import java.util.List;
006
007 import org.apache.commons.logging.Log;
008 import org.apache.commons.logging.LogFactory;
009
010 import calhoun.analysis.crf.AbstractFeatureManager;
011 import calhoun.analysis.crf.CacheStrategySpec;
012 import calhoun.analysis.crf.FeatureList;
013 import calhoun.analysis.crf.FeatureManagerNode;
014 import calhoun.analysis.crf.ModelManager;
015 import calhoun.analysis.crf.BeanModel.Node;
016 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
017 import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
018 import calhoun.analysis.crf.io.InputSequence;
019 import calhoun.analysis.crf.io.TrainingSequence;
020
021 public class EmissionMarkovFeature extends AbstractFeatureManager<Character> implements FeatureManagerNode<Character> {
022 private static final long serialVersionUID = -7659288739348604129L;
023 private static final Log log = LogFactory.getLog(EmissionMarkovFeature.class);
024 boolean debug = log.isDebugEnabled();
025
026 int startIx; // The index of the first feature managed by this FeatureManager
027 ModelManager model;
028
029 // Following block are things that depend explicitly on and are calculated directly
030 // from geometry, included only for convenience.
031 MarkovPredictorLogprob predictorlp;
032 boolean tieFlag = false;
033
034 public static class MarkovHistory implements Serializable {
035 private static final long serialVersionUID = 8731309130784681776L;
036
037 List<List<Node>> states;
038
039 public List<List<Node>> getStates() {
040 return states;
041 }
042
043 public void setStates(List<List<Node>> history) {
044 this.states = history;
045 }
046
047 public List<int[]> convert() {
048 List<int[]> historyArray= new ArrayList<int[]>();
049 for(List<Node> nodeList : states) {
050 int[] historyEntry = new int[nodeList.size()];
051 historyArray.add(historyEntry);
052 for(int i=0; i<historyEntry.length; ++i) {
053 historyEntry[i] = nodeList.get(i).getIndex();
054 }
055 }
056 return historyArray;
057 }
058 }
059
060 public EmissionMarkovFeature() {
061 }
062
063 public void setHistory(MarkovHistory markovHistory) {
064 this.predictorlp = new MarkovPredictorLogprob(markovHistory.convert());
065 }
066
067 public EmissionMarkovFeature(List<int[]> history) {
068 this.predictorlp = new MarkovPredictorLogprob(history);
069 }
070
071 public EmissionMarkovFeature(List<int[]> history, List<int[]> flags) {
072 this.predictorlp = new MarkovPredictorLogprob(history);
073 tieFlag = true;
074 }
075
076 public int getNumFeatures() {
077 return tieFlag ? 1 : model.getNumStates();
078 }
079
080 public String getFeatureName(int featureIndex) {
081 if (tieFlag) {
082 return "TiedEmissionMarkovFeature";
083 } else {
084 int raw = featureIndex - startIx;
085 return "EmissionMarkov.span" + model.getStateName(raw);
086 }
087 }
088
089 public void evaluateNode(InputSequence<? extends Character> seq, int pos, int state, FeatureList result) {
090 int index = startIx + (tieFlag ? 0 : state);
091 result.addFeature(index, predictorlp.logprob(state,seq,pos));
092 }
093
094 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
095 startIx = startingIndex;
096 model = modelInfo;
097
098 predictorlp.train(data);
099 }
100
101 @Override
102 public CacheStrategySpec getCacheStrategy() {
103 return new CacheStrategySpec(CacheStrategy.DENSE);
104 }
105 }
106
107