001 package calhoun.analysis.crf.features.supporting.phylogenetic;
002
003 import java.io.Serializable;
004 import java.util.ArrayList;
005 import java.util.Arrays;
006 import java.util.HashMap;
007 import java.util.List;
008 import java.util.Map;
009
010 import org.apache.commons.logging.Log;
011 import org.apache.commons.logging.LogFactory;
012
013 import calhoun.analysis.crf.ModelManager;
014 import calhoun.analysis.crf.io.InputSequence;
015 import calhoun.analysis.crf.io.TrainingSequence;
016 import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
017 import calhoun.seq.KmerHasher;
018 import calhoun.util.Assert;
019 import flanagan.math.Minimisation;
020 import flanagan.math.MinimisationFunction;
021
022 public class ColumnConditionalLogProbability implements Serializable {
023 private static final long serialVersionUID = 5837909206967310115L;
024
025 private static final Log log = LogFactory.getLog(ColumnConditionalLogProbability.class);
026
027 InputSequence<MultipleAlignmentColumn> lastSeq;
028 int lastPos;
029 double[] vals;
030
031 List<int[]> clusters; // will be an input
032 int[] state2cluster; // will be derived
033 List<EvolutionaryModel> emodels; // this gets trained
034 private ModelManager model;
035 static KmerHasher h = new KmerHasher(KmerHasher.ACGTother,1);
036 Map<Integer, double[]> columnCache = new HashMap<Integer, double[]>();
037 private int eModelNum; // 0 = Kimura80Model
038 // 1 = HKY85Model
039
040 public ColumnConditionalLogProbability(List<int[]> clusters, int eModelNum) {
041 this.clusters = clusters;
042 this.eModelNum = eModelNum;
043 }
044
045 public int numClusters() {
046 return clusters.size();
047 }
048
049 public int state2cluster(int state) {
050 return state2cluster[state];
051 }
052
053 public double condLogProb(InputSequence<? extends MultipleAlignmentColumn> seq, int pos, int state) {
054 Assert.a(state < model.getNumStates());
055
056 MultipleAlignmentColumn col = seq.getX(pos);
057 return emodels.get(state2cluster[state]).logprob(col,true);
058
059 /*
060 if(seq != lastSeq || pos != lastPos) {
061 lastSeq = seq;
062 lastPos = pos;
063
064 MultipleAlignmentColumn col = seq.getX(pos);
065 Integer hash = col.getUniqueHash();
066 if(columnCache.containsKey(hash)) {
067 vals = columnCache.get(hash);
068 }
069 else {
070 for (int i=0; i<clusters.size(); i++) {
071 vals = new double[clusters.size()];
072 vals[i] = emodels.get(i).logprob(col,true);
073 }
074 columnCache.put(hash, vals);
075 }
076 }
077 return vals[state2cluster[state]];
078 */
079 }
080
081 public void train( ModelManager modelInfo, final List<? extends TrainingSequence<? extends MultipleAlignmentColumn>> data) {
082 model = modelInfo;
083 vals = new double[clusters.size()];
084 final PhylogeneticTreeFelsensteinOrder felsOrder = data.get(0).getX(0).getMultipleAlignment().getFelsensteinOrder();
085
086
087 // Step 1: combinatorially invert clusters to get state2cluster
088 int nStates = model.getNumStates();
089 state2cluster = new int[nStates];
090 for (int i=0; i<nStates; i++) { state2cluster[i] = -1; }
091
092 for (int j=0; j<clusters.size(); j++) {
093 int[] C = clusters.get(j);
094 for (int k=0; k<C.length; k++) {
095 state2cluster[C[k]] = j;
096 }
097 }
098
099 for (int i=0; i<nStates; i++) {
100 Assert.a(state2cluster[i] >= 0);
101 }
102
103
104 // Step 2: Train the evolutionary models
105 emodels = new ArrayList<EvolutionaryModel>();
106 for (int j=0; j<clusters.size(); j++) {
107 final int cluster = j;
108
109 for (int k=0; k<clusters.get(j).length; k++) {
110 log.debug("Training evolutionary model for: " + model.getStateName(clusters.get(j)[k]));
111 }
112
113 // RootedBinaryPhylogeneticTree rt =
114
115 //final PhylogeneticTreeFelsensteinOrder T = new PhylogeneticTreeFelsensteinOrder();
116
117 // Estimate pi based on the nucleotide frequencies in the reference sequence
118 final double[] pi = new double[]{1.0,1.0,1.0,1.0};
119 for(TrainingSequence<? extends MultipleAlignmentColumn> aln : data) {
120 int len = aln.length();
121
122 for (int ix=0; ix<len; ix++) {
123 int y = aln.getY(ix);
124
125 if (state2cluster[y] == cluster) {
126 int x = h.hash(aln.getX(ix).nucleotide(0));
127 if (x<4) { pi[x] += 1.0; }
128 }
129 }
130 }
131 double total = pi[0] + pi[1] + pi[2] + pi[3];
132 pi[0]/=total; pi[1]/=total; pi[2]/=total; pi[3]/=total;
133
134 if (eModelNum==0) { // Kimura80Model
135 MinimisationFunction mFunc = new MinimisationFunction() {
136 public double function(double[] d) {
137 double[] ed = new double[2];
138 ed[0] = Math.exp(d[0]);
139 ed[1] = Math.exp(d[1]);
140
141 Kimura80Model R = new Kimura80Model(ed);
142 EvolutionaryModel M = new EvolutionaryModel(felsOrder,pi,R);
143
144 double ret = 0;
145 for(TrainingSequence<? extends MultipleAlignmentColumn> aln : data) {
146 int len = aln.length();
147 for (int ix=0; ix<len; ix++) {
148 int y = aln.getY(ix);
149 if (state2cluster[y] == cluster) {
150 ret += M.logprob(aln.getX(ix),false);
151 }
152 }
153 }
154 return -ret;
155 }
156 };
157
158 // The standard mantra for minimizing the function mFunc defined above
159 int maxIter = 50;
160 final int nParm = 2;
161 Minimisation m = new Minimisation();
162 m.setNmax(maxIter);
163 double[] starts = new double[nParm];
164 Arrays.fill(starts, 0.1);
165 double[] steps = new double[nParm];
166 Arrays.fill(steps, 0.1);
167 m.nelderMead(mFunc, starts, steps);
168 if(!m.getConvStatus()) {
169 log.warn("WARNING - Nelder-Mead routine says convergence was not reached");
170 // throw new ErrorException("Convergence not reached.");
171 }
172 double[] results = m.getParamValues();
173 double[] eresults = new double[]{Math.exp(results[0]),Math.exp(results[1])};
174
175 emodels.add(new EvolutionaryModel(felsOrder,pi,new Kimura80Model(eresults)));
176 emodels.get(cluster).summarize();
177 } else if (eModelNum==1) { // HKY85Model
178 MinimisationFunction mFunc = new MinimisationFunction() {
179 public double function(double[] d) {
180 double[] ed = new double[5];
181 ed[0] = Math.exp(d[0]);
182 ed[1] = Math.exp(d[1]);
183 ed[2] = pi[0];
184 ed[3] = pi[1];
185 ed[4] = pi[2];
186
187 HKY85Model R = new HKY85Model(ed);
188 EvolutionaryModel M = new EvolutionaryModel(felsOrder,pi,R);
189
190 double ret = 0;
191 for(TrainingSequence<? extends MultipleAlignmentColumn> aln : data) {
192 int len = aln.length();
193 for (int ix=0; ix<len; ix++) {
194 int y = aln.getY(ix);
195 if (state2cluster[y] == cluster) {
196 ret += M.logprob(aln.getX(ix),false);
197 }
198 }
199 }
200 return -ret;
201 }
202 };
203
204 // The standard mantra for minimizing the function mFunc defined above
205 int maxIter = 50;
206 final int nParm = 2;
207 Minimisation m = new Minimisation();
208 m.setNmax(maxIter);
209 double[] starts = new double[nParm];
210 Arrays.fill(starts, 0.1);
211 double[] steps = new double[nParm];
212 Arrays.fill(steps, 0.1);
213 m.nelderMead(mFunc, starts, steps);
214 if(!m.getConvStatus()) {
215 log.warn("WARNING - Nelder-Mead routine says convergence was not reached");
216 // throw new ErrorException("Convergence not reached.");
217 }
218 double[] results = m.getParamValues();
219 double[] eresults = new double[]{Math.exp(results[0]),Math.exp(results[1]),pi[0],pi[1],pi[2]};
220
221 emodels.add(new EvolutionaryModel(felsOrder,pi,new HKY85Model(eresults)));
222 emodels.get(cluster).summarize();
223 } else { Assert.a(false); }
224 }
225
226 Assert.a(emodels.size() == clusters.size());
227 log.debug("Just trained the Felsenstein Features");
228 }
229
230 }