001    package calhoun.analysis.crf.features.tricycle13;
002    
003    import java.util.List;
004    
005    import org.apache.commons.logging.Log;
006    import org.apache.commons.logging.LogFactory;
007    
008    import calhoun.analysis.crf.AbstractFeatureManager;
009    import calhoun.analysis.crf.CacheStrategySpec;
010    import calhoun.analysis.crf.FeatureList;
011    import calhoun.analysis.crf.FeatureManagerEdge;
012    import calhoun.analysis.crf.ModelManager;
013    import calhoun.analysis.crf.CacheStrategySpec.CacheStrategy;
014    import calhoun.analysis.crf.features.supporting.MarkovPredictorLogprob;
015    import calhoun.analysis.crf.features.tricycle13.EmissionMarkovFeature.MarkovHistory;
016    import calhoun.analysis.crf.io.InputSequence;
017    import calhoun.analysis.crf.io.TrainingSequence;
018    import calhoun.seq.KmerHasher;
019    import calhoun.util.Assert;
020    
021    public class CodingStopFeature extends AbstractFeatureManager<Character> implements FeatureManagerEdge<Character> {
022            private static final long serialVersionUID = -7659288739348604129L;
023            private static final Log log = LogFactory.getLog(CodingStopFeature.class);
024            boolean debug = log.isDebugEnabled();
025            
026            int startIx;  // The index of the first feature managed by this FeatureManager
027            ModelManager model;
028            
029            KmerHasher   h; // for a single letter
030            
031            float[][] pluslogprob; /* logprob[i][j] is log probability of base j at position i, i=0..(span-1), j=0..3. */
032            float[][] minuslogprob;
033    
034            int[][] pc; // for plus corrections.  Second dimension is of length 2; First entry is where to look relative to the transition; second entry is what the hidden state in that position should be.  
035            int[][] mc; // minus corrections
036    
037            int stateIntergenic;
038            int stateExon3;
039            int stateExon3m;
040    
041            MarkovPredictorLogprob predictorlp;
042            
043            ///////////////////////////////////// Class variables above, methods below //////////
044    
045            
046            public CodingStopFeature() {
047            }
048    
049            public void setHistory(MarkovHistory markovHistory) {
050                    this.predictorlp = new MarkovPredictorLogprob(markovHistory.convert());
051            }
052            
053            public CodingStopFeature( List<int[]> markovhistory ) {
054                    this.predictorlp = new MarkovPredictorLogprob(markovhistory);   
055            }
056    
057    
058            public int getNumFeatures() {
059                    return 1;
060            }       
061            
062            
063            public String getFeatureName(int featureIndex) {
064                    Assert.a(featureIndex == startIx);
065                    return "CodingStopFeature";
066            }
067    
068            
069            
070            public void evaluateEdge(InputSequence<? extends Character> seq, int ix, int prevState, int state, FeatureList result) {
071                    int len = seq.length();
072                    if(prevState == stateExon3 && state == stateIntergenic) {
073                            // first the positive strand
074                            float pval = 0;
075                            if ((ix>=3) && (ix<len-5) ) {
076    
077                                    char a = seq.getX(ix-3);
078                                    char b = seq.getX(ix-2);
079                                    char c = seq.getX(ix-1);
080                                    if ( (a=='T') && (b=='A') && (c=='G') ) {       pval = (float) Math.log(0.33333333); }
081                                    else if ( (a=='T') && (b=='G') && (c=='A') ) { pval = (float) Math.log(0.33333333); }
082                                    else if ( (a=='T') && (b=='A') && (c=='A') ) { pval = (float) Math.log(0.33333333); }
083                                    else { pval = (float) -50.0; }
084                            
085                                    for (int i=0; i<6; i++) {
086                                            char d = seq.getX(ix+i);
087                                            pval += pluslogprob[i][h.hash(d)];
088                                    }
089                                    
090                                    for (int i=0; i<9; i++) {
091                                            pval -= predictorlp.logprob(pc[i][1],seq,ix-pc[i][0]);                                  
092                                    }
093                            }
094                            result.addFeature(startIx, pval);
095                    }
096                    else if(state == stateExon3m && prevState == stateIntergenic) {
097                            // now the negative strand;
098                            float mval = 0;
099                            if ((ix>=6) && (ix<len-2) ) {
100    
101                                    char a = seq.getX(ix+0);
102                                    char b = seq.getX(ix+1);
103                                    char c = seq.getX(ix+2);
104                                    if ( (a=='T') && (b=='T') && (c=='A') ) {       mval = (float) Math.log(0.33333333); }
105                                    else if ( (a=='C') && (b=='T') && (c=='A') ) { mval = (float) Math.log(0.33333333); }
106                                    else if ( (a=='T') && (b=='C') && (c=='A') ) { mval = (float) Math.log(0.33333333); }
107                                    else { mval = (float) -50.0; }
108                            
109                                    for (int i=0; i<6; i++) {
110                                            char d = seq.getX(ix+i-6);
111                                            mval += minuslogprob[i][h.hash(d)];
112                                    }
113                                    
114                                    for (int i=0; i<9; i++) {
115                                            mval -= predictorlp.logprob(mc[i][1],seq,ix-mc[i][0]);                                  
116                                    }
117                            }
118                            result.addFeature(startIx, mval);
119                    }
120            }
121    
122            
123            public void train(int startingIndex, ModelManager modelInfo, List<? extends TrainingSequence<? extends Character>> data) {
124    
125                    startIx = startingIndex;
126                    model = modelInfo;
127    
128                    stateIntergenic = model.getStateIndex("intergenic");
129                    stateExon3 =  model.getStateIndex("exon3");
130                    stateExon3m =  model.getStateIndex("exon3m");
131    
132                    predictorlp.train(data);
133                    h = new KmerHasher(KmerHasher.ACGTN, 1);                
134                    
135                    
136                    /*
137                    # stop signal positive strand TAG|xxxxxx   
138                    L 9 3 exon3 intergenic
139                    # stop signal negative strand xxxxxx|CTA
140                    L 9 6 intergenic exon3m
141    
142                    # stop signal positive strand TAG|xxxxxx  
143                    L exon1 exon2 exon3 intergenic intergenic intergenic intergenic intergenic intergenic
144                    # stop signal negative strand xxxxxx|CTA
145                    L intergenic intergenic intergenic intergenic intergenic intergenic exon3m exon2m exon1m
146                    */
147                    
148                    pc = new int[9][2];
149                    pc[0][0] = -3;  // pc[0][1] = model.getStateIndex("exon1");
150                    pc[1][0] = -2;   pc[1][1] = model.getStateIndex("exon2");
151                    pc[2][0] = -1;   pc[2][1] = model.getStateIndex("exon3");
152                    pc[3][0] =  0;   pc[3][1] = model.getStateIndex("intergenic");
153                    pc[4][0] =  1;   pc[4][1] = model.getStateIndex("intergenic");
154                    pc[5][0] =  2;   pc[5][1] = model.getStateIndex("intergenic");
155                    pc[6][0] =  3;   pc[6][1] = model.getStateIndex("intergenic");
156                    pc[7][0] =  4;   pc[7][1] = model.getStateIndex("intergenic");
157                    pc[8][0] =  5;   pc[8][1] = model.getStateIndex("intergenic");
158    
159                    mc = new int[9][2];
160                    mc[0][0] = -6;   mc[0][1] = model.getStateIndex("intergenic");
161                    mc[1][0] = -5;   mc[1][1] = model.getStateIndex("intergenic");
162                    mc[2][0] = -4;   mc[2][1] = model.getStateIndex("intergenic");
163                    mc[3][0] = -3;   mc[3][1] = model.getStateIndex("intergenic");
164                    mc[4][0] = -2;   mc[4][1] = model.getStateIndex("intergenic");
165                    mc[5][0] = -1;   mc[5][1] = model.getStateIndex("intergenic");
166                    mc[6][0] =  0;   mc[6][1] = model.getStateIndex("exon3m");
167                    mc[7][0] =  1;   mc[7][1] = model.getStateIndex("exon2m");
168                    mc[8][0] =  2;   mc[8][1] = model.getStateIndex("exon1m");
169                            
170                    pluslogprob = new float[6][h.range()];
171                    minuslogprob = new float[6][h.range()];         
172                    
173                    for (int i=0; i<6; i++) {
174                            for (int j=0; j<h.range(); j++) {
175                                    pluslogprob[i][j] = (float) 1.0;
176                                    minuslogprob[i][j] = (float) 1.0;                               
177                            }
178                    }
179                    
180                    
181    
182                    // In English, what I want to do is this.  Loop through all of the training data, once for each Feature.
183                    // While so doing, look for any positions where one of the allowed transitions for that feature occurs.
184                    // At such positions, increment the counts for logprob.
185                    for(TrainingSequence<? extends Character> seq : data) {
186                            int len = seq.length();         
187                            
188                            for (int ix=3; ix<(len-5); ix++) {
189                                    int yprev = seq.getY(ix-1);
190                                    int y = seq.getY(ix);
191                                    if (  (yprev == model.getStateIndex("exon3") ) && (y == model.getStateIndex("intergenic") )  ) {        
192                                            for (int k=0; k<6; k++) {
193                                                    char c = seq.getX(ix + k);
194                                                    pluslogprob[k][h.hash(c)] += 1.0;
195                                            }
196                                    }
197                            }
198                            
199                            for (int ix=6; ix<(len-3); ix++) {
200                                    int yprev = seq.getY(ix-1);
201                                    int y = seq.getY(ix);
202                                    if (  (yprev == model.getStateIndex("intergenic") ) && (y == model.getStateIndex("exon3m") )  ) {       
203                                            for (int k=0; k<6; k++) {
204                                                    char c = seq.getX(ix + k - 6);
205                                                    minuslogprob[k][h.hash(c)] += 1.0;
206                                            }
207                                    }
208                            }       
209                                    
210                    }
211    
212                    for (int k=0; k<6; k++) {
213                            float totalp = 0;
214                            for (int j=0; j<h.range(); j++) { totalp += pluslogprob[k][j]; }
215                            Assert.a(totalp > 0);
216                            for (int j=0; j<h.range(); j++) { 
217                                    pluslogprob[k][j] = (float) (Math.log(pluslogprob[k][j]) - Math.log(totalp));
218                            }
219                    }
220                    
221                    for (int k=0; k<6; k++) {                    
222                            float totalm = 0;
223                            for (int j=0; j<h.range(); j++) { totalm += minuslogprob[k][j]; }
224                            Assert.a(totalm > 0);
225                            for (int j=0; j<h.range(); j++) { 
226                                    minuslogprob[k][j] = (float) (Math.log(minuslogprob[k][j]) - Math.log(totalm));
227                            }       
228                    }               
229            }
230            @Override
231            public CacheStrategySpec getCacheStrategy() {
232                    return new CacheStrategySpec(CacheStrategy.UNSPECIFIED);
233            }
234            
235    }
236