001    package calhoun.analysis.crf.features.supporting.phylogenetic;
002    
003    import java.io.Serializable;
004    import java.util.ArrayList;
005    import java.util.Arrays;
006    import java.util.HashMap;
007    import java.util.List;
008    import java.util.Map;
009    
010    import org.apache.commons.logging.Log;
011    import org.apache.commons.logging.LogFactory;
012    
013    import calhoun.analysis.crf.ModelManager;
014    import calhoun.analysis.crf.io.InputSequence;
015    import calhoun.analysis.crf.io.TrainingSequence;
016    import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
017    import calhoun.seq.KmerHasher;
018    import calhoun.util.Assert;
019    import flanagan.math.Minimisation;
020    import flanagan.math.MinimisationFunction;
021    
022    public class ColumnConditionalLogProbability implements Serializable {
023            private static final long serialVersionUID = 5837909206967310115L;
024    
025            private static final Log log = LogFactory.getLog(ColumnConditionalLogProbability.class);
026    
027            InputSequence<MultipleAlignmentColumn> lastSeq; 
028            int lastPos;
029            double[] vals;  
030            
031            List<int[]> clusters;             // will be an input
032            int[] state2cluster;              // will be derived
033            List<EvolutionaryModel> emodels;  // this gets trained
034            private ModelManager model;
035            static KmerHasher h = new KmerHasher(KmerHasher.ACGTother,1);
036            Map<Integer, double[]> columnCache = new HashMap<Integer, double[]>();
037            private int eModelNum;  // 0 = Kimura80Model
038                                // 1 = HKY85Model
039                    
040            public ColumnConditionalLogProbability(List<int[]> clusters, int eModelNum) {
041                    this.clusters = clusters;
042                    this.eModelNum = eModelNum;
043            }
044            
045            public int numClusters() {
046                    return clusters.size();
047            }
048    
049            public int state2cluster(int state) {
050                    return state2cluster[state];
051            }
052    
053            public double condLogProb(InputSequence<? extends MultipleAlignmentColumn> seq, int pos, int state) {
054                    Assert.a(state < model.getNumStates());
055    
056                    MultipleAlignmentColumn col = seq.getX(pos);
057                    return emodels.get(state2cluster[state]).logprob(col,true);
058                    
059                    /*
060                    if(seq != lastSeq || pos != lastPos) {
061                            lastSeq = seq;
062                            lastPos = pos;
063    
064                            MultipleAlignmentColumn col = seq.getX(pos);
065                            Integer hash = col.getUniqueHash();
066                            if(columnCache.containsKey(hash)) {
067                                    vals = columnCache.get(hash); 
068                            }
069                            else {
070                                    for (int i=0; i<clusters.size(); i++) {
071                                            vals = new double[clusters.size()];
072                                            vals[i] = emodels.get(i).logprob(col,true);
073                                    }
074                                    columnCache.put(hash, vals);
075                            }
076                    }
077                    return vals[state2cluster[state]];
078                    */
079            }
080            
081            public void train( ModelManager modelInfo, final List<? extends TrainingSequence<? extends MultipleAlignmentColumn>> data) {
082                    model = modelInfo;
083                    vals = new double[clusters.size()];
084                    final PhylogeneticTreeFelsensteinOrder felsOrder = data.get(0).getX(0).getMultipleAlignment().getFelsensteinOrder();
085    
086                    
087                    // Step 1: combinatorially invert clusters to get state2cluster
088                    int nStates = model.getNumStates();
089                    state2cluster = new int[nStates];
090                    for (int i=0; i<nStates; i++) { state2cluster[i] = -1; }
091    
092                    for (int j=0; j<clusters.size(); j++) {
093                            int[] C = clusters.get(j);
094                            for (int k=0; k<C.length; k++) {
095                                    state2cluster[C[k]] = j;
096                            }
097                    }
098                    
099                    for (int i=0; i<nStates; i++) {
100                            Assert.a(state2cluster[i] >= 0);
101                    }       
102                    
103                    
104                    // Step 2: Train the evolutionary models
105                    emodels = new ArrayList<EvolutionaryModel>();
106                    for (int j=0; j<clusters.size(); j++) {
107                            final int cluster = j;
108    
109                            for (int k=0; k<clusters.get(j).length; k++) {
110                                    log.debug("Training evolutionary model for:  " + model.getStateName(clusters.get(j)[k]));
111                            }
112                            
113                            // RootedBinaryPhylogeneticTree rt = 
114                                                    
115                            //final PhylogeneticTreeFelsensteinOrder T = new PhylogeneticTreeFelsensteinOrder();
116                            
117                            // Estimate pi based on the nucleotide frequencies in the reference sequence
118                            final double[] pi = new double[]{1.0,1.0,1.0,1.0};
119                            for(TrainingSequence<? extends MultipleAlignmentColumn> aln : data) {
120                                    int len = aln.length();
121                                            
122                                    for (int ix=0; ix<len; ix++) {
123                                            int y = aln.getY(ix);
124                                            
125                                            if (state2cluster[y] == cluster) {
126                                                    int x = h.hash(aln.getX(ix).nucleotide(0));
127                                                    if (x<4) { pi[x] += 1.0; }                                                           
128                                            }
129                                    }
130                            }
131                            double total = pi[0] + pi[1] + pi[2] + pi[3];
132                            pi[0]/=total; pi[1]/=total; pi[2]/=total; pi[3]/=total;
133                            
134                            if (eModelNum==0) { // Kimura80Model
135                                    MinimisationFunction mFunc = new MinimisationFunction() {
136                                            public double function(double[] d) {
137                                                    double[] ed = new double[2];
138                                                    ed[0] = Math.exp(d[0]);
139                                                    ed[1] = Math.exp(d[1]);
140                                                    
141                                                    Kimura80Model R = new Kimura80Model(ed);
142                                                    EvolutionaryModel M = new EvolutionaryModel(felsOrder,pi,R);
143                                                    
144                                                    double ret = 0;
145                                                    for(TrainingSequence<? extends MultipleAlignmentColumn> aln : data) {
146                                                            int len = aln.length();                                                 
147                                                            for (int ix=0; ix<len; ix++) {
148                                                                    int y = aln.getY(ix);
149                                                                    if (state2cluster[y] == cluster) {
150                                                                            ret += M.logprob(aln.getX(ix),false);                                                                   
151                                                                    }
152                                                            }
153                                                    }                                       
154                                                    return -ret;
155                                            }
156                                    };                                              
157                                    
158                                    // The standard mantra for minimizing the function mFunc defined above 
159                                    int maxIter = 50;
160                                    final int nParm   = 2;
161                                    Minimisation m = new Minimisation();
162                                    m.setNmax(maxIter);
163                                    double[] starts = new double[nParm];
164                                    Arrays.fill(starts, 0.1);
165                                    double[] steps = new double[nParm];
166                                    Arrays.fill(steps, 0.1);                
167                                    m.nelderMead(mFunc, starts, steps);
168                                    if(!m.getConvStatus()) {
169                                            log.warn("WARNING - Nelder-Mead routine says convergence was not reached");
170                                            // throw new ErrorException("Convergence not reached.");
171                                    }
172                                    double[] results = m.getParamValues();
173                                    double[] eresults = new double[]{Math.exp(results[0]),Math.exp(results[1])};
174                                    
175                                    emodels.add(new EvolutionaryModel(felsOrder,pi,new Kimura80Model(eresults)));           
176                                    emodels.get(cluster).summarize();
177                            } else if (eModelNum==1) { // HKY85Model
178                                    MinimisationFunction mFunc = new MinimisationFunction() {
179                                            public double function(double[] d) {
180                                                    double[] ed = new double[5];
181                                                    ed[0] = Math.exp(d[0]);
182                                                    ed[1] = Math.exp(d[1]);
183                                                    ed[2] = pi[0];
184                                                    ed[3] = pi[1];
185                                                    ed[4] = pi[2];
186                                                    
187                                                    HKY85Model R = new HKY85Model(ed);
188                                                    EvolutionaryModel M = new EvolutionaryModel(felsOrder,pi,R);
189                                                    
190                                                    double ret = 0;
191                                                    for(TrainingSequence<? extends MultipleAlignmentColumn> aln : data) {
192                                                            int len = aln.length();                                                 
193                                                            for (int ix=0; ix<len; ix++) {
194                                                                    int y = aln.getY(ix);
195                                                                    if (state2cluster[y] == cluster) {
196                                                                            ret += M.logprob(aln.getX(ix),false);                                                                   
197                                                                    }
198                                                            }
199                                                    }                                       
200                                                    return -ret;
201                                            }
202                                    };                                              
203                                    
204                                    // The standard mantra for minimizing the function mFunc defined above 
205                                    int maxIter = 50;
206                                    final int nParm   = 2;
207                                    Minimisation m = new Minimisation();
208                                    m.setNmax(maxIter);
209                                    double[] starts = new double[nParm];
210                                    Arrays.fill(starts, 0.1);
211                                    double[] steps = new double[nParm];
212                                    Arrays.fill(steps, 0.1);                
213                                    m.nelderMead(mFunc, starts, steps);
214                                    if(!m.getConvStatus()) {
215                                            log.warn("WARNING - Nelder-Mead routine says convergence was not reached");
216                                            // throw new ErrorException("Convergence not reached.");
217                                    }
218                                    double[] results = m.getParamValues();
219                                    double[] eresults = new double[]{Math.exp(results[0]),Math.exp(results[1]),pi[0],pi[1],pi[2]};
220                                    
221                                    emodels.add(new EvolutionaryModel(felsOrder,pi,new HKY85Model(eresults)));              
222                                    emodels.get(cluster).summarize();
223                            } else { Assert.a(false); }
224                    }
225                    
226                    Assert.a(emodels.size() == clusters.size());
227                    log.debug("Just trained the Felsenstein Features");
228            }
229            
230    }