001    package calhoun.analysis.crf.features.tricycle13;
002    import java.util.ArrayList;
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.AbstractFeatureManager;
010    import calhoun.analysis.crf.FeatureList;
011    import calhoun.analysis.crf.FeatureManagerNodeExplicitLength;
012    import calhoun.analysis.crf.ModelManager;
013    import calhoun.analysis.crf.io.InputSequence;
014    import calhoun.analysis.crf.io.TrainingSequence;
015    import calhoun.analysis.crf.statistics.GammaDistribution;
016    import calhoun.util.Assert;
017    
018    /**  Learns from data an explicit length distribution for introns, and also subtracts out
019     * what it would have gotten from just the exponential distribution given by
020     * DenseWeightedEdgeFeatures.
021     */
022    public class IntronLengthFeature extends AbstractFeatureManager implements  FeatureManagerNodeExplicitLength {
023            private static final long serialVersionUID = 8477631359065280630L;
024            private static final Log log = LogFactory.getLog(IntronLengthFeature.class);
025            boolean debug = log.isDebugEnabled();
026            
027            // Standard boring member variables
028            int startIx;
029    
030            // Fundamental member variables
031            double[] logProbExtend;       // given that in state i, what is log of probability of remaining in state i for another base?
032            boolean[] explicitLengthFlag; // true if this is a state we wish to model explictly, otherwise false
033            private boolean[] isIntron;     
034                    
035            public String getFeatureName(int featureIndex) {
036                    Assert.a(featureIndex == startIx, "Invalid feature index");
037                    return "IntronLengthFeature";
038            }
039    
040            public int getNumFeatures() { 
041                    return 1;
042            }
043    
044            public void evaluateNodeLength(InputSequence seq, int pos, int length, int state, FeatureList result) {
045                    if(explicitLengthFlag[state]) {
046                            // Gaussian with mean mu=69 and stddev sigma=15 is (parms a guess from experience) is
047                            //  [1/(sigma*sqrt(2*pi))]  * exp[ - (x-mu)^2 / (2*sigma^2) ]           
048                            //double val = -3.6269887 - (length-69.0)*(length-69.0)/450.0 ;
049                            
050                            // Correct for what probably we would have gotten from the exponential state length.
051                            //val -= (length-1)*logProbExtend[state];                               
052                            
053                            // Mixture of two gamma distributions with the following parameters:
054                            // p = 0.86 = probability from distribution 1:
055                            // dist1 is gamma distribution with shape 71 and lambda=1.27, so that mean=56
056                            // dist2 is gamma distribution with shape 4.1 and lambda=0.041, so that mean=101
057                            // [eventually this will be trained, but for now hardcoded.
058                            double pdist1 = 0.86;
059                            double shape1 = 71;
060                            double lambda1 = 1.27;
061                            double shape2 = 4.1;
062                            double lambda2 = 0.041;
063                            double val = pdist1*(GammaDistribution.gamma(shape1,lambda1,length));
064                            val += (1-pdist1)*(GammaDistribution.gamma(shape2,lambda2,length));
065                            val = Math.log(val);
066                            
067                            // Correct for what probably we would have gotten from the exponential state length.
068                            val -= (length-1)*logProbExtend[state];                         
069    
070                            result.addFeature(startIx, val);
071                    }
072            }
073    
074            public void train(int startingIndex, ModelManager modelInfo, List data) {
075                    log.debug("Training the Feature for IntronLengths, explicitly modelled as a Gaussian");
076                    
077                    startIx = startingIndex;
078                    int nStates = modelInfo.getNumStates();
079                    
080                    // Say here which states modelled explicitly; the remainder are not
081                    explicitLengthFlag = new boolean[nStates];
082                    Arrays.fill(explicitLengthFlag, false);
083                    explicitLengthFlag[modelInfo.getStateIndex("intron1")]  = true;
084                    explicitLengthFlag[modelInfo.getStateIndex("intron2")]  = true;
085                    explicitLengthFlag[modelInfo.getStateIndex("intron3")]  = true;
086                    explicitLengthFlag[modelInfo.getStateIndex("intron1m")] = true;
087                    explicitLengthFlag[modelInfo.getStateIndex("intron2m")] = true;
088                    explicitLengthFlag[modelInfo.getStateIndex("intron3m")] = true;
089                    isIntron = explicitLengthFlag.clone();
090                    
091                    
092                    // Count transitions from the training data
093                    float[][] transitions = new float[nStates][nStates];
094                    for (int j=0; j<nStates; j++) {
095                            for (int k=0; k<nStates; k++) {
096                                    transitions[j][k] = (float) 1.0; // pseudocounts
097                            }
098                    }
099                                    
100                    //DoubleMatrix2D transitions  = new DenseDoubleMatrix2D(nStates, nStates);
101                    for(TrainingSequence seq : (List<TrainingSequence>) data) {
102                            // Start at 1 because there is no transition for the first element of the sequence.
103                            for(int pos = 1; pos < seq.length(); ++pos) { 
104                                    int start = seq.getY(pos-1);
105                                    int end = seq.getY(pos);
106                                    transitions[start][end] += (float) 1.0; 
107                            }
108                    }
109    
110                    logProbExtend = new double[nStates];
111                    for (int j=0; j<nStates; j++) {
112                            float rowtotal = (float) 0.0;
113                            for (int k=0; k<nStates; k++) {
114                                    rowtotal += transitions[j][k];
115                            }
116                            logProbExtend[j] = (float) Math.log(transitions[j][j] / rowtotal);
117                    }
118                    
119                    log.debug("logprobextend for the variuos states are:");
120                    for (int j=0; j<modelInfo.getNumStates(); j++) {
121                            log.debug("  " + modelInfo.getStateName(j) + "   " + logProbExtend[j]);
122                    }
123                    
124                    ArrayList intronLengths = new ArrayList();
125                    //= new ArrayList<integer>;
126                    for(TrainingSequence seq : (List<TrainingSequence>) data) {
127                            int lastIntronStart = -1;
128                            int y = seq.getY(0);
129                            for(int pos = 1; pos < seq.length(); ++pos) { 
130                                    int yprev = y;
131                                    y = seq.getY(pos);
132                                    
133                                    if (isIntron[yprev] && (!isIntron[y]) && (lastIntronStart>=-1)) {
134                                            intronLengths.add(pos-lastIntronStart);
135                                            lastIntronStart = -1;
136                                    }
137                                    if ((!isIntron[yprev]) && (isIntron[y])) {
138                                            lastIntronStart = pos;
139                                    }
140                            }
141                    }
142                    
143                    log.debug("The intron lengths are:");
144                    for (int j=0; j<intronLengths.size(); j++) {
145                            log.debug(intronLengths.get(j) + ",");
146                    }
147            }
148    }