001    package calhoun.analysis.crf.features.generic;
002    import java.util.List;
003    
004    import org.apache.commons.logging.Log;
005    import org.apache.commons.logging.LogFactory;
006    
007    import calhoun.analysis.crf.AbstractFeatureManager;
008    import calhoun.analysis.crf.CacheStrategySpec;
009    import calhoun.analysis.crf.FeatureList;
010    import calhoun.analysis.crf.FeatureManagerEdge;
011    import calhoun.analysis.crf.ModelManager;
012    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
013    import calhoun.analysis.crf.io.InputSequence;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.util.Assert;
016    
017    /** learns transition probabilities from the data and then creates a single feature for all edge transitions out of a given state.
018     * <p>
019     * <b>Notes:</b>
020     * <ul>
021     * <li> This differs from WeightedEdgeFeatures because transitions that stay in the same state are not included.  This is useful for explicit length models when
022     * transitions that stay in the same state are modeled separately.   
023     * <li> Returns the log probability of the transitions.  
024     * <li> Uses a CONSTANT cache strategy since the values returns are independent of position.
025     * <li> When learning, initializes each edge count with a pseudocount of 1.
026     * </ul>
027     */
028    public class WeightedStateChanges extends AbstractFeatureManager<Object> implements FeatureManagerEdge<Object> {
029            private static final long serialVersionUID = 8477631359065280630L;
030            private static final Log log = LogFactory.getLog(WeightedStateChanges.class);
031            boolean debug = log.isDebugEnabled();
032            
033            int startIx;
034            ModelManager manager;
035            float[][] transitions;
036    
037            @Override
038            public CacheStrategySpec getCacheStrategy() {
039                    return new CacheStrategySpec(CacheStrategy.CONSTANT);
040            }
041    
042            public String getFeatureName(int featureIndex) {
043                    Assert.a(featureIndex == startIx, "Invalid feature index: ", featureIndex, ". Must be ", startIx);
044                    return "WeightedEdges";
045            }
046    
047            public int getNumFeatures() {
048                    // The is a single feature that is the log transition probabilities from the data. 
049                    return 1;
050            }
051    
052            public void evaluateEdge(InputSequence<?> seq, int pos, int prevState, int state, FeatureList result) {
053                    if (prevState != state) {
054                            result.addFeature(startIx, transitions[prevState][state]);
055                    }
056            }
057    
058            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<?>> data) {
059                    startIx = startingIndex;
060                    manager = modelInfo;
061                    int nStates = manager.getNumStates();
062                    
063                    // Count transitions from the training data
064                    transitions = new float[nStates][nStates];
065                    for (int j=0; j<nStates; j++) {
066                            for (int k=0; k<nStates; k++) {
067                                    if (j!=k) {
068                                    transitions[j][k] = (float) 1.0; // pseudocounts
069                                    } else {
070                                            transitions[j][k] = (float) 0.0;
071                                    }
072                            }
073                    }
074                    
075                    //DoubleMatrix2D transitions  = new DenseDoubleMatrix2D(nStates, nStates);
076                    for(TrainingSequence<?> seq : data) {
077                            // Start at 1 because there is no transition for the first element of the sequence.
078                            for(int pos = 1; pos < seq.length(); ++pos) {
079                                    int start = seq.getY(pos-1);
080                                    int end = seq.getY(pos);
081                                    if (start != end) {
082                                            transitions[start][end] += (float) 1.0;
083                                    }
084                            }
085                    }
086    
087                    for (int j=0; j<nStates; j++) {
088                            float rowtotal = (float) 0.0;
089                            for (int k=0; k<nStates; k++) {
090                                    rowtotal += transitions[j][k];
091                            }
092                            for (int k=0; k<nStates; k++) {
093                                    if (j!= k) {
094                                            transitions[j][k] = (float) Math.log(transitions[j][k] / rowtotal);
095                                    }
096                            }
097                    }
098            }
099    }