001 package calhoun.analysis.crf.features.generic;
002 import java.util.ArrayList;
003 import java.util.List;
004
005 import org.apache.commons.logging.Log;
006 import org.apache.commons.logging.LogFactory;
007
008 import calhoun.analysis.crf.AbstractFeatureManager;
009 import calhoun.analysis.crf.FeatureList;
010 import calhoun.analysis.crf.FeatureManagerEdge;
011 import calhoun.analysis.crf.ModelManager;
012 import calhoun.analysis.crf.io.InputSequence;
013 import calhoun.util.Assert;
014 import calhoun.util.DenseBooleanMatrix2D;
015 import calhoun.util.DenseIntMatrix2D;
016
017 /** a set of indicator features for the valid transitions in the model.
018 * <p>
019 * <b>Notes:</b>
020 * <ul>
021 * <li> Each indicator returns 1 when evaluated on it's transition, otherwise 0.
022 * <li> Exactly one feature will return a non-zero value for every valid edge transition.
023 * </ul>
024 */
025 public class IndicatorEdges extends AbstractFeatureManager implements FeatureManagerEdge {
026 private static final long serialVersionUID = -2633500053477439285L;
027 private static final Log log = LogFactory.getLog(IndicatorEdges.class);
028 boolean debug = log.isDebugEnabled();
029
030 int startIx;
031 DenseIntMatrix2D transitions;
032 List<String> names;
033
034 public String getFeatureName(int featureIndex) {
035 Assert.a(featureIndex - startIx < names.size(), "Invalid feature index");
036 return names.get(featureIndex - startIx);
037 }
038
039 public int getNumFeatures() {
040 return names.size();
041 }
042
043 public void evaluateEdge(InputSequence seq, int pos, int prevState, int state, FeatureList result) {
044 int index = transitions.getQuick(prevState, state);
045 if(index != -1) {
046 result.addFeature(index, 1);
047 }
048 }
049
050 /** Edge features don't train based on the data. Just set up based on the model. */
051 public void train(int startingIndex, ModelManager modelInfo, List data) {
052 startIx = startingIndex;
053 int nStates = modelInfo.getNumStates();
054 transitions = new DenseIntMatrix2D(nStates, nStates);
055 transitions.assign(-1);
056 names = new ArrayList<String>();
057
058 int n = 0;
059 DenseBooleanMatrix2D trans = modelInfo.getLegalTransitions();
060 for(int i=0; i<nStates; ++i) {
061 for(int j=0; j<nStates; ++j) {
062 if(trans.getQuick(i, j)) {
063 transitions.setQuick(i, j, startIx + n);
064 names.add("Edge."+modelInfo.getStateName(i)+"-"+modelInfo.getStateName(j));
065 ++n;
066 }
067 }
068 }
069 }
070 }