001 package calhoun.analysis.crf.features.generic;
002 import java.util.List;
003
004 import org.apache.commons.logging.Log;
005 import org.apache.commons.logging.LogFactory;
006
007 import calhoun.analysis.crf.AbstractFeatureManager;
008 import calhoun.analysis.crf.CacheStrategySpec;
009 import calhoun.analysis.crf.FeatureList;
010 import calhoun.analysis.crf.FeatureManagerEdge;
011 import calhoun.analysis.crf.ModelManager;
012 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
013 import calhoun.analysis.crf.io.InputSequence;
014 import calhoun.analysis.crf.io.TrainingSequence;
015 import calhoun.util.Assert;
016
017 /** learns transition probabilities from the data and then creates a single feature for all edge transitions out of a given state.
018 * <p>
019 * <b>Notes:</b>
020 * <ul>
021 * <li> This differs from WeightedEdgeFeatures because transitions that stay in the same state are not included. This is useful for explicit length models when
022 * transitions that stay in the same state are modeled separately.
023 * <li> Returns the log probability of the transitions.
024 * <li> Uses a CONSTANT cache strategy since the values returns are independent of position.
025 * <li> When learning, initializes each edge count with a pseudocount of 1.
026 * </ul>
027 */
028 public class WeightedStateChanges extends AbstractFeatureManager<Object> implements FeatureManagerEdge<Object> {
029 private static final long serialVersionUID = 8477631359065280630L;
030 private static final Log log = LogFactory.getLog(WeightedStateChanges.class);
031 boolean debug = log.isDebugEnabled();
032
033 int startIx;
034 ModelManager manager;
035 float[][] transitions;
036
037 @Override
038 public CacheStrategySpec getCacheStrategy() {
039 return new CacheStrategySpec(CacheStrategy.CONSTANT);
040 }
041
042 public String getFeatureName(int featureIndex) {
043 Assert.a(featureIndex == startIx, "Invalid feature index: ", featureIndex, ". Must be ", startIx);
044 return "WeightedEdges";
045 }
046
047 public int getNumFeatures() {
048 // The is a single feature that is the log transition probabilities from the data.
049 return 1;
050 }
051
052 public void evaluateEdge(InputSequence<?> seq, int pos, int prevState, int state, FeatureList result) {
053 if (prevState != state) {
054 result.addFeature(startIx, transitions[prevState][state]);
055 }
056 }
057
058 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<?>> data) {
059 startIx = startingIndex;
060 manager = modelInfo;
061 int nStates = manager.getNumStates();
062
063 // Count transitions from the training data
064 transitions = new float[nStates][nStates];
065 for (int j=0; j<nStates; j++) {
066 for (int k=0; k<nStates; k++) {
067 if (j!=k) {
068 transitions[j][k] = (float) 1.0; // pseudocounts
069 } else {
070 transitions[j][k] = (float) 0.0;
071 }
072 }
073 }
074
075 //DoubleMatrix2D transitions = new DenseDoubleMatrix2D(nStates, nStates);
076 for(TrainingSequence<?> seq : data) {
077 // Start at 1 because there is no transition for the first element of the sequence.
078 for(int pos = 1; pos < seq.length(); ++pos) {
079 int start = seq.getY(pos-1);
080 int end = seq.getY(pos);
081 if (start != end) {
082 transitions[start][end] += (float) 1.0;
083 }
084 }
085 }
086
087 for (int j=0; j<nStates; j++) {
088 float rowtotal = (float) 0.0;
089 for (int k=0; k<nStates; k++) {
090 rowtotal += transitions[j][k];
091 }
092 for (int k=0; k<nStates; k++) {
093 if (j!= k) {
094 transitions[j][k] = (float) Math.log(transitions[j][k] / rowtotal);
095 }
096 }
097 }
098 }
099 }