001 package calhoun.analysis.crf;
002
003
004 import java.util.ArrayList;
005 import java.util.Arrays;
006 import java.util.List;
007
008 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
009 import calhoun.analysis.crf.io.InputSequence;
010 import calhoun.analysis.crf.io.TrainingSequence;
011 import calhoun.util.Assert;
012
013 /** a feature manager that combines feature types together. */
014 public class CompositeFeatureManager extends AbstractFeatureManager implements FeatureManagerEdge, FeatureManagerNode, FeatureManagerEdgeExplicitLength, FeatureManagerNodeExplicitLength {
015 private static final long serialVersionUID = 5061912595256694049L;
016
017 protected List<FeatureManager> allFeatureTypes = new ArrayList<FeatureManager>();
018
019 protected List<FeatureManagerNode> nodeFeatureTypes = new ArrayList<FeatureManagerNode>();
020 protected List<FeatureManagerEdge> edgeFeatureTypes = new ArrayList<FeatureManagerEdge>();
021 protected List<FeatureManagerNodeExplicitLength> explicitLengthNodeFeatureTypes = new ArrayList<FeatureManagerNodeExplicitLength>();
022 protected List<FeatureManagerEdgeExplicitLength> explicitLengthEdgeFeatureTypes = new ArrayList<FeatureManagerEdgeExplicitLength>();
023 int[] startIndexes = null;
024
025 int startIx;
026 int totalFeatures;
027
028 public List<FeatureManager> getComponentFeatures() {
029 return allFeatureTypes;
030 }
031
032 public void setComponentFeatures(List<FeatureManager> components) {
033 for ( FeatureManager component : components ) {
034 addFeatureManager(component);
035 }
036 }
037
038 public void addFeatureManager(FeatureManager fm) {
039 addFeatureManager(null, null, fm);
040 }
041
042 public void addFeatureManager(String name, String inputParams, FeatureManager fm) {
043 Assert.a(startIndexes == null, "Attempted to add a new FeatureManager after training.");
044
045 if(name != null)
046 fm.setInputComponent(name);
047 allFeatureTypes.add(fm);
048
049 // Add each feature type into the right list for evalution
050 if(fm instanceof FeatureManagerNode) {
051 nodeFeatureTypes.add((FeatureManagerNode) fm);
052 }
053 if(fm instanceof FeatureManagerEdge) {
054 edgeFeatureTypes.add((FeatureManagerEdge) fm);
055 }
056 if(fm instanceof FeatureManagerNodeExplicitLength) {
057 explicitLengthNodeFeatureTypes.add((FeatureManagerNodeExplicitLength) fm);
058 }
059 if(fm instanceof FeatureManagerEdgeExplicitLength) {
060 explicitLengthEdgeFeatureTypes.add((FeatureManagerEdgeExplicitLength) fm);
061 }
062 }
063
064 public int getNumFeatures() {
065 Assert.a(startIndexes != null, "Attempted to get number of features before training.");
066 return totalFeatures;
067 }
068
069 public String getFeatureName(int featureIndex) {
070 int index = Arrays.binarySearch(startIndexes, featureIndex);
071 if(index < 0) {
072 index = -index-2;
073 }
074 // Feature managers may have no features (just constraints).
075 while(allFeatureTypes.get(index).getNumFeatures() == 0) {
076 index += 1;
077 }
078 return allFeatureTypes.get(index).getFeatureName(featureIndex);
079 }
080
081 public int getFeatureOffset(int featureIndex)
082 {
083 int index = Arrays.binarySearch(startIndexes, featureIndex);
084 if (index < 0) {
085 index = -index-2;
086 }
087 // Feature managers may have no features (just constraints).
088 while(allFeatureTypes.get(index).getNumFeatures() == 0) {
089 index += 1;
090 }
091 return (startIndexes[index]);
092 }
093
094 protected static class ComponentList extends ArrayList<InputSequence<?>> {
095 private static final long serialVersionUID = 2954775229926328434L;
096 String name;
097 public ComponentList(List<TrainingSequence<?>> start, String name) {
098 super(start);
099 this.name = name;
100 }
101 @Override
102 public InputSequence<?> get(int i) {
103 return transform((TrainingSequence<?>) super.get(i));
104 }
105 public InputSequence<?> transform(TrainingSequence<?> in) {
106 return in.getComponent(name);
107 }
108 }
109
110 public void train(int startingIndex, ModelManager modelInfo, List data) {
111 Assert.a(allFeatureTypes.size() > 0, "No features types have been assigned.");
112 Assert.a(startIndexes == null, "FeatureManager has already been trained.");
113 startIx = startingIndex;
114
115 // Train each of the individual FeatureManagers and calculate offsets
116 startIndexes = new int[allFeatureTypes.size()];
117 totalFeatures = 0;
118 for(int i = 0; i<startIndexes.length; ++i) {
119 startIndexes[i] = totalFeatures + startIx;
120 FeatureManager fm = allFeatureTypes.get(i);
121 List compData = fm.getInputComponent() == null ? data : new ComponentList(data, fm.getInputComponent());
122 fm.train(totalFeatures, modelInfo, compData);
123 totalFeatures += fm.getNumFeatures();
124 }
125 }
126
127 public void evaluateNode(InputSequence seq, int pos, int state, FeatureList result) {
128 for(FeatureManagerNode fm : nodeFeatureTypes) {
129 InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
130 fm.evaluateNode(componentSeq, pos, state, result);
131 if(!result.isValid())
132 break;
133 }
134 }
135
136 public void evaluateEdge(InputSequence seq, int pos, int prevState, int state, FeatureList result) {
137 for(FeatureManagerEdge fm : edgeFeatureTypes) {
138 InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
139 fm.evaluateEdge(componentSeq, pos, prevState, state, result);
140 if(!result.isValid()) {
141 break;
142 }
143 }
144 }
145
146 public void evaluateNodeLength(InputSequence seq, int pos, int length, int state, FeatureList result) {
147 Assert.a(length>0);
148 for(FeatureManagerNodeExplicitLength fm : explicitLengthNodeFeatureTypes) {
149 InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
150 fm.evaluateNodeLength(componentSeq, pos, length, state, result);
151 if(!result.isValid())
152 break;
153 }
154 }
155
156 public void evaluateEdgeLength(InputSequence seq, int pos, int length, int prevState, int state, FeatureList result) {
157 Assert.a(length>0);
158 for(FeatureManagerEdgeExplicitLength fm : explicitLengthEdgeFeatureTypes) {
159 InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
160 fm.evaluateEdgeLength(componentSeq, pos, length, prevState, state, result);
161 if(!result.isValid())
162 break;
163 }
164 }
165
166 @Override
167 public CacheStrategySpec getCacheStrategy() {
168 return new CacheStrategySpec(CacheStrategy.COMPOSITE);
169 }
170 }