001 package calhoun.analysis.crf.features.generic;
002 import java.util.List;
003
004 import org.apache.commons.logging.Log;
005 import org.apache.commons.logging.LogFactory;
006
007 import calhoun.analysis.crf.AbstractFeatureManager;
008 import calhoun.analysis.crf.CacheStrategySpec;
009 import calhoun.analysis.crf.FeatureList;
010 import calhoun.analysis.crf.FeatureManagerEdge;
011 import calhoun.analysis.crf.ModelManager;
012 import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
013 import calhoun.analysis.crf.io.InputSequence;
014 import calhoun.analysis.crf.io.TrainingSequence;
015 import calhoun.util.Assert;
016
017 /** learns transition probabilities from the data and then creates a single feature for all edge transitions.
018 * <p>
019 * <b>Notes:</b>
020 * <ul>
021 * <li> Returns the log probability of the transitions.
022 * <li> Uses a CONSTANT cache strategy since the values returns are independent of position.
023 * <li> When learning, initializes each edge count with a pseudocount of 1.
024 * </ul>
025 */
026 public class WeightedEdges extends AbstractFeatureManager<Object> implements FeatureManagerEdge<Object> {
027 private static final long serialVersionUID = 8477631359065280630L;
028 private static final Log log = LogFactory.getLog(WeightedEdges.class);
029
030 int startIx;
031 ModelManager manager;
032 float[][] transitions;
033
034 @Override
035 public CacheStrategySpec getCacheStrategy() {
036 return new CacheStrategySpec(CacheStrategy.CONSTANT);
037 }
038
039 public String getFeatureName(int featureIndex) {
040 Assert.a(featureIndex == startIx, "Invalid feature index: ", featureIndex, ". Must be ", startIx);
041 return "WeightedEdges";
042 }
043
044 public int getNumFeatures() {
045 // The is a single feature that is the log transition probabilities from the data.
046 return 1;
047 }
048
049 public void evaluateEdge(InputSequence<?> seq, int pos, int prevState, int state, FeatureList result) {
050 result.addFeature(startIx, transitions[prevState][state]);
051 }
052
053 public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<?>> data) {
054 log.debug("Training edges");
055 startIx = startingIndex;
056 manager = modelInfo;
057 int nStates = manager.getNumStates();
058
059 // Count transitions from the training data
060 transitions = new float[nStates][nStates];
061 for (int j=0; j<nStates; j++) {
062 for (int k=0; k<nStates; k++) {
063 transitions[j][k] = (float) 1.0; // pseudocounts
064 }
065 }
066
067 //DoubleMatrix2D transitions = new DenseDoubleMatrix2D(nStates, nStates);
068 for(TrainingSequence<?> seq : data) {
069 // Start at 1 because there is no transition for the first element of the sequence.
070 for(int pos = 1; pos < seq.length(); ++pos) {
071 int start = seq.getY(pos-1);
072 int end = seq.getY(pos);
073 transitions[start][end] += (float) 1.0;
074 }
075 }
076
077 log.debug("The transition logprobabilities are as follows (row is the FROM state and column is the TO state");
078 for (int j=0; j<nStates; j++) {
079 String s = modelInfo.getStateName(j);
080 float rowtotal = (float) 0.0;
081 for (int k=0; k<nStates; k++) {
082 rowtotal += transitions[j][k];
083 }
084 for (int k=0; k<nStates; k++) {
085 transitions[j][k] = (float) Math.log(transitions[j][k] / rowtotal);
086 s = s + "\t" + transitions[j][k];
087 }
088 log.debug(s);
089 }
090 }
091 }