001    package calhoun.analysis.crf.features.generic;
002    import java.util.List;
003    
004    import org.apache.commons.logging.Log;
005    import org.apache.commons.logging.LogFactory;
006    
007    import calhoun.analysis.crf.AbstractFeatureManager;
008    import calhoun.analysis.crf.CacheStrategySpec;
009    import calhoun.analysis.crf.FeatureList;
010    import calhoun.analysis.crf.FeatureManagerEdge;
011    import calhoun.analysis.crf.ModelManager;
012    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
013    import calhoun.analysis.crf.io.InputSequence;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.util.Assert;
016    
017    /** learns transition probabilities from the data and then creates a single feature for all edge transitions.
018     * <p>
019     * <b>Notes:</b>
020     * <ul>
021     * <li> Returns the log probability of the transitions.  
022     * <li> Uses a CONSTANT cache strategy since the values returns are independent of position.
023     * <li> When learning, initializes each edge count with a pseudocount of 1.
024     * </ul>
025     */
026    public class WeightedEdges extends AbstractFeatureManager<Object> implements FeatureManagerEdge<Object> {
027            private static final long serialVersionUID = 8477631359065280630L;
028            private static final Log log = LogFactory.getLog(WeightedEdges.class);
029            
030            int startIx;
031            ModelManager manager;
032            float[][] transitions;
033    
034            @Override
035            public CacheStrategySpec getCacheStrategy() {
036                    return new CacheStrategySpec(CacheStrategy.CONSTANT);
037            }
038    
039            public String getFeatureName(int featureIndex) {
040                    Assert.a(featureIndex == startIx, "Invalid feature index: ", featureIndex, ". Must be ", startIx);
041                    return "WeightedEdges";
042            }
043    
044            public int getNumFeatures() {
045                    // The is a single feature that is the log transition probabilities from the data. 
046                    return 1;
047            }
048    
049            public void evaluateEdge(InputSequence<?> seq, int pos, int prevState, int state, FeatureList result) {
050                    result.addFeature(startIx, transitions[prevState][state]);
051            }
052    
053            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<?>> data) {
054                    log.debug("Training edges");
055                    startIx = startingIndex;
056                    manager = modelInfo;
057                    int nStates = manager.getNumStates();
058                    
059                    // Count transitions from the training data
060                    transitions = new float[nStates][nStates];
061                    for (int j=0; j<nStates; j++) {
062                            for (int k=0; k<nStates; k++) {
063                                    transitions[j][k] = (float) 1.0; // pseudocounts
064                            }
065                    }
066                    
067                    //DoubleMatrix2D transitions  = new DenseDoubleMatrix2D(nStates, nStates);
068                    for(TrainingSequence<?> seq : data) {
069                            // Start at 1 because there is no transition for the first element of the sequence.
070                            for(int pos = 1; pos < seq.length(); ++pos) {
071                                    int start = seq.getY(pos-1);
072                                    int end = seq.getY(pos);
073                                    transitions[start][end] += (float) 1.0; 
074                            }
075                    }
076    
077                    log.debug("The transition logprobabilities are as follows (row is the FROM state and column is the TO state");
078                    for (int j=0; j<nStates; j++) {
079                            String s = modelInfo.getStateName(j);
080                            float rowtotal = (float) 0.0;
081                            for (int k=0; k<nStates; k++) {
082                                    rowtotal += transitions[j][k];
083                            }
084                            for (int k=0; k<nStates; k++) {
085                                    transitions[j][k] = (float) Math.log(transitions[j][k] / rowtotal);
086                                    s = s + "\t" + transitions[j][k];
087                            }
088                            log.debug(s);
089                    }
090            }
091    }