001    package calhoun.analysis.crf;
002    
003    
004    import java.util.ArrayList;
005    import java.util.Arrays;
006    import java.util.List;
007    
008    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
009    import calhoun.analysis.crf.io.InputSequence;
010    import calhoun.analysis.crf.io.TrainingSequence;
011    import calhoun.util.Assert;
012    
013    /** a feature manager that combines feature types together. */
014    public class CompositeFeatureManager extends AbstractFeatureManager implements FeatureManagerEdge, FeatureManagerNode, FeatureManagerEdgeExplicitLength, FeatureManagerNodeExplicitLength {
015            private static final long serialVersionUID = 5061912595256694049L;
016    
017            protected List<FeatureManager> allFeatureTypes = new ArrayList<FeatureManager>();
018    
019            protected List<FeatureManagerNode> nodeFeatureTypes = new ArrayList<FeatureManagerNode>();
020            protected List<FeatureManagerEdge> edgeFeatureTypes = new ArrayList<FeatureManagerEdge>();
021            protected List<FeatureManagerNodeExplicitLength> explicitLengthNodeFeatureTypes = new ArrayList<FeatureManagerNodeExplicitLength>();
022            protected List<FeatureManagerEdgeExplicitLength> explicitLengthEdgeFeatureTypes = new ArrayList<FeatureManagerEdgeExplicitLength>();
023            int[] startIndexes = null;
024            
025            int startIx;
026            int totalFeatures;
027    
028            public  List<FeatureManager> getComponentFeatures() {
029                    return allFeatureTypes;
030            }
031            
032            public  void setComponentFeatures(List<FeatureManager> components) {
033                    for ( FeatureManager component : components ) {
034                            addFeatureManager(component);
035                    }
036            }
037            
038            public void addFeatureManager(FeatureManager fm) {
039                    addFeatureManager(null, null, fm);
040            }
041            
042            public void addFeatureManager(String name, String inputParams, FeatureManager fm) {
043                    Assert.a(startIndexes == null, "Attempted to add a new FeatureManager after training.");
044    
045                    if(name != null)
046                            fm.setInputComponent(name);
047                    allFeatureTypes.add(fm);
048    
049                    // Add each feature type into the right list for evalution
050                    if(fm instanceof FeatureManagerNode) {
051                            nodeFeatureTypes.add((FeatureManagerNode) fm);
052                    }
053                    if(fm instanceof FeatureManagerEdge) {
054                            edgeFeatureTypes.add((FeatureManagerEdge) fm);
055                    }
056                    if(fm instanceof FeatureManagerNodeExplicitLength) {
057                            explicitLengthNodeFeatureTypes.add((FeatureManagerNodeExplicitLength) fm);
058                    }
059                    if(fm instanceof FeatureManagerEdgeExplicitLength) {
060                            explicitLengthEdgeFeatureTypes.add((FeatureManagerEdgeExplicitLength) fm);
061                    }
062            }
063    
064            public int getNumFeatures() {
065                    Assert.a(startIndexes != null, "Attempted to get number of features before training.");
066                    return totalFeatures;
067            }
068            
069            public String getFeatureName(int featureIndex) {
070                    int index = Arrays.binarySearch(startIndexes, featureIndex);
071                    if(index < 0) {
072                            index = -index-2;
073                    }
074                    // Feature managers may have no features (just constraints).
075                    while(allFeatureTypes.get(index).getNumFeatures() == 0) {
076                            index += 1;
077                    }
078                    return allFeatureTypes.get(index).getFeatureName(featureIndex);
079            }
080    
081            public int getFeatureOffset(int featureIndex)
082            {
083                    int index = Arrays.binarySearch(startIndexes, featureIndex);
084                    if (index < 0) {
085                            index = -index-2;
086                    }
087                    // Feature managers may have no features (just constraints).
088                    while(allFeatureTypes.get(index).getNumFeatures() == 0) {
089                            index += 1;
090                    }
091                    return (startIndexes[index]);
092            }
093            
094            protected static class ComponentList extends ArrayList<InputSequence<?>> {
095                    private static final long serialVersionUID = 2954775229926328434L;
096                    String name;
097                    public ComponentList(List<TrainingSequence<?>> start, String name) {
098                            super(start);
099                            this.name = name;
100                    }
101                    @Override
102                    public InputSequence<?> get(int i) {
103                            return transform((TrainingSequence<?>) super.get(i));
104                    }
105                    public InputSequence<?> transform(TrainingSequence<?> in) {
106                            return in.getComponent(name);
107                    }
108            }
109    
110            public void train(int startingIndex, ModelManager modelInfo, List data) {
111                    Assert.a(allFeatureTypes.size() > 0, "No features types have been assigned.");
112                    Assert.a(startIndexes == null, "FeatureManager has already been trained.");
113                    startIx = startingIndex;
114    
115                    // Train each of the individual FeatureManagers and calculate offsets
116                    startIndexes = new int[allFeatureTypes.size()];
117                    totalFeatures = 0;
118                    for(int i = 0; i<startIndexes.length; ++i) {
119                            startIndexes[i] = totalFeatures + startIx;
120                            FeatureManager fm = allFeatureTypes.get(i);
121                            List compData = fm.getInputComponent() == null ? data : new ComponentList(data, fm.getInputComponent());
122                            fm.train(totalFeatures, modelInfo, compData);
123                            totalFeatures += fm.getNumFeatures();
124                    }
125            }
126    
127            public void evaluateNode(InputSequence seq, int pos, int state, FeatureList result) {
128                    for(FeatureManagerNode fm : nodeFeatureTypes) {
129                            InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
130                            fm.evaluateNode(componentSeq, pos, state, result);
131                            if(!result.isValid())
132                                    break;
133                    }
134            }
135    
136            public void evaluateEdge(InputSequence seq, int pos, int prevState, int state, FeatureList result) {
137                    for(FeatureManagerEdge fm : edgeFeatureTypes) {
138                            InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
139                            fm.evaluateEdge(componentSeq, pos, prevState, state, result);
140                            if(!result.isValid()) {
141                                    break;
142                            }
143                    }
144            }
145    
146            public void evaluateNodeLength(InputSequence seq, int pos, int length, int state, FeatureList result) {
147                    Assert.a(length>0);
148                    for(FeatureManagerNodeExplicitLength fm : explicitLengthNodeFeatureTypes) {
149                            InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
150                            fm.evaluateNodeLength(componentSeq, pos, length, state, result);
151                            if(!result.isValid())
152                                    break;
153                    }
154            }
155    
156            public void evaluateEdgeLength(InputSequence seq, int pos, int length, int prevState, int state, FeatureList result) {
157                    Assert.a(length>0);
158                    for(FeatureManagerEdgeExplicitLength fm : explicitLengthEdgeFeatureTypes) {
159                            InputSequence componentSeq = fm.getInputComponent() == null ? seq : seq.getComponent(fm.getInputComponent());
160                            fm.evaluateEdgeLength(componentSeq, pos, length, prevState, state, result);
161                            if(!result.isValid())
162                                    break;
163                    }
164            }
165    
166            @Override
167            public CacheStrategySpec getCacheStrategy() {
168                    return new CacheStrategySpec(CacheStrategy.COMPOSITE);
169            }
170    }