001    package calhoun.analysis.crf.features.interval29;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.CacheStrategySpec;
010    import calhoun.analysis.crf.FeatureList;
011    import calhoun.analysis.crf.FeatureManagerNodeExplicitLength;
012    import calhoun.analysis.crf.ModelManager;
013    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
014    import calhoun.analysis.crf.io.InputSequence;
015    import calhoun.analysis.crf.io.SequenceConverter;
016    import calhoun.analysis.crf.io.TrainingSequence;
017    import calhoun.analysis.crf.statistics.MixtureOfGammas;
018    import calhoun.util.Assert;
019    
020    public class StateLengthLogprobInterval29 implements FeatureManagerNodeExplicitLength<Character> {        
021            private static final long serialVersionUID = 8685543199212865835L;
022    
023            /* Returns the log probability of a state having the specified duration
024             * returned value depends only on state and length
025             * 
026             * trained using mixture of gammas model, with a few tweaks:
027             *   a) if all values identical (eg always have exactly 200 intergenic bases in training examples), then model as exponential length
028             *   b) Otherwise introduce priors at 90 and 110% of median and at 95% and 105% of median, and train mixture of gammas using EM.
029             *      This prevents convergence to a case where one of the two components is a single data point and likelihood is infinite.
030             *     
031             * Some notes on normalization: (see also mixGamma):
032             *    If you have a given state and evaluate these probabilities (before taking logs) and add up over all possible lengths,
033             *    the result should be approximately one, but not exactly.  Two reasons for discrepancy:
034             *    a) mixGamma, integrated from to to infinity, should add up to 1.0.  But summing at 1,2,3,4,... is a discrete
035             *      approximation and might not agree exactly.
036             *    b) The summation isn't from 1,2,3,4,... but from minLength[state] to maxLength[state].  However, this feature does not
037             *      have access to that min/max length information, so normalization for this reason must happen downstream.
038             */
039            private static final Log log = LogFactory.getLog(StateLengthLogprobInterval29.class);
040            
041            private int startIx;
042            ModelManager mi;
043            private String inputComponentName;
044    
045            MixtureOfGammas intergenicMixGamma;
046            MixtureOfGammas exonMixGamma;
047            MixtureOfGammas intronMixGamma;
048            
049            private boolean forceExponential = false;
050            private boolean exonExponential = false;
051            private boolean multipleFeatures = true;
052    
053            public void setForceExponential(boolean forceExponential) {
054                    this.forceExponential = forceExponential;
055            }
056            
057            public void setExonExponential(boolean exonExponential) {
058                    this.exonExponential = exonExponential;
059            }
060            
061            public void evaluateNodeLength(InputSequence<? extends Character> seq, int pos, int length, int state, FeatureList result) {
062                    Assert.a(length>0);
063                    MixtureOfGammas mg = null;
064                    int indexOffset = Integer.MIN_VALUE;
065                    switch (state) {
066                    case(0):
067                            indexOffset = 0;
068                            mg = intergenicMixGamma;
069                    break;
070                    case(1):
071                    case(2):
072                    case(3):
073                    case(7):
074                    case(8):
075                    case(9):
076                            indexOffset = 1;
077                            mg = exonMixGamma;
078                    break;
079                    case(4):
080                    case(5):
081                    case(6):
082                    case(10):
083                    case(11):
084                    case(12):
085                            indexOffset = 2;
086                            mg = intronMixGamma;
087                    break;
088                    case(13):
089                    case(14):
090                    case(15):
091                    case(16):
092                    case(17):
093                    case(18):
094                    case(19):
095                    case(20):
096                    case(21):
097                    case(22):
098                    case(23):
099                    case(24):
100                    case(25):
101                    case(26):
102                    case(27):
103                    case(28):                       
104                            indexOffset = 3;
105                            break;
106                    default:
107                            Assert.a(false);
108                    }
109                    double val;
110                    if (mg != null) {
111                            val = mg.logEvaluate((double) length);
112                    } else {
113                            Assert.a(state >= 13 && state <= 28, "Trying to use exact state lengths for non-boundary state.");
114                            // XXX: Should we be doing this?
115                            if (length == 2) {
116                                    val = -0.5;
117                            } else {
118                                    val = -200;
119                            }
120                    }
121                    Assert.a((val != Double.NEGATIVE_INFINITY) && (val != Double.POSITIVE_INFINITY) && (!Double.isNaN(val)));
122                    if (val>0) {
123                            log.warn("About to return a state length logprob evaluation that is greater than zero, see notes in source code.");
124                    }
125                    result.addFeature(startIx + (multipleFeatures ? indexOffset : 0),val);
126            }
127    
128            public CacheStrategySpec getCacheStrategy() {
129                    return new CacheStrategySpec(CacheStrategy.LENGTHFUNCTION);
130            }
131    
132            public String getFeatureName(int featureIndex) {
133                    if(multipleFeatures) {
134                            String type = "";
135                            switch(featureIndex - startIx) {
136                            case 0:
137                                    type = "intergenic";
138                                    break;
139                            case 1:
140                                    type = "exon";
141                                    break;
142                            case 2:
143                                    type = "intron";
144                                    break;
145                            case 3:
146                                    type = "boundary";
147                                    break;                          
148                            }
149                            return "StateDurationLogProbForModelInterval21 "+type;
150                    }
151                    else {
152                            return "StateDurationLogProbForModelInterval21";
153                    }
154            }
155    
156            public String getInputComponent() {
157                    return inputComponentName;
158            }
159    
160            public void setInputComponent(String name) {
161                    inputComponentName = name;              
162            }
163    
164            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
165                    startIx = startingIndex;
166                    mi = modelInfo;
167                    Assert.a(mi.getNumStates()==29);
168                    
169                    ArrayList<ArrayList<Integer>> stateDurations;
170                    
171                    stateDurations = SequenceConverter.stateVector2StateLengths(data,mi.getNumStates());
172                    //stateDurations = SequenceConverter.addInInterval29Boundaries(stateDurations);
173                    List<Integer> exonLengths = new ArrayList<Integer>();
174                    List<Integer> intronLengths = new ArrayList<Integer>();
175                    List<Integer> intergenicLengths = new ArrayList<Integer>();
176                    
177                    intergenicLengths.addAll(stateDurations.get(0));
178                    
179                    exonLengths.addAll(stateDurations.get(1));
180                    exonLengths.addAll(stateDurations.get(2));
181                    exonLengths.addAll(stateDurations.get(3));
182                    exonLengths.addAll(stateDurations.get(7));
183                    exonLengths.addAll(stateDurations.get(8));
184                    exonLengths.addAll(stateDurations.get(9));
185                    
186                    intronLengths.addAll(stateDurations.get(4));
187                    intronLengths.addAll(stateDurations.get(5));
188                    intronLengths.addAll(stateDurations.get(6));
189                    intronLengths.addAll(stateDurations.get(10));
190                    intronLengths.addAll(stateDurations.get(11));
191                    intronLengths.addAll(stateDurations.get(12));
192                    
193                    double[] inter = new double[intergenicLengths.size()];
194                    for (int j=0; j<intergenicLengths.size(); j++) {
195                            inter[j] = (double) intergenicLengths.get(j);
196                    }
197                    
198                    double[] exon = new double[exonLengths.size()];
199                    for (int j=0; j<exonLengths.size(); j++) {
200                            exon[j] = (double) exonLengths.get(j);
201                    }
202                    
203                    double[] intron = new double[intronLengths.size()];
204                    for (int j=0; j<intronLengths.size(); j++) {
205                            intron[j] = (double) intronLengths.get(j);
206                    }
207    
208                    if (forceExponential) {
209                            intergenicMixGamma = new MixtureOfGammas(inter,true);
210                            exonMixGamma       = new MixtureOfGammas(exon,true);
211                            intronMixGamma     = new MixtureOfGammas(intron,true);
212                    } else if (exonExponential) {
213                            intergenicMixGamma = new MixtureOfGammas(inter,true);  
214                            exonMixGamma       = new MixtureOfGammas(exon,true);
215                            intronMixGamma     = new MixtureOfGammas(intron);                               
216                    } else {
217                            // by default, only intergenic regions modeled with exp length distributions
218                            intergenicMixGamma = new MixtureOfGammas(inter,true);  
219                            exonMixGamma       = new MixtureOfGammas(exon);
220                            intronMixGamma     = new MixtureOfGammas(intron);                       
221                    }
222            }
223    
224            public int getNumFeatures() {
225                    return multipleFeatures ? 4 : 1;
226            }
227    
228            /** 
229             * @return Returns the multipleFeatures.
230             */
231            public boolean isMultipleFeatures() {
232                    return multipleFeatures;
233            }
234    
235            /** set to true to indicate that intergenic, exonic, and intergenic lengths should each get a separate weight.
236             * @param multipleFeatures The multipleFeatures to set.
237             */
238            public void setMultipleFeatures(boolean multipleFeatures) {
239                    this.multipleFeatures = multipleFeatures;
240            }
241    }