001 package calhoun.analysis.crf.statistics;
002
003 import java.io.PrintStream;
004 import java.io.Serializable;
005 import java.util.Arrays;
006
007 import org.apache.commons.logging.Log;
008 import org.apache.commons.logging.LogFactory;
009
010 import calhoun.util.Assert;
011 import calhoun.util.ColtUtil;
012
013 public class MixtureOfGammas implements Serializable {
014 private static final long serialVersionUID = 6269582005717276381L;
015 private static final Log log = LogFactory.getLog(MixtureOfGammas.class);
016
017 private double pdist1;
018 private double shape1;
019 private double lambda1;
020 private double shape2;
021 private double lambda2;
022
023 public MixtureOfGammas(double pdist1, double shape1, double lambda1, double shape2, double lambda2) {
024 Assert.a(pdist1>=0);
025 Assert.a(pdist1<=1.0);
026 Assert.a(shape1>0);
027 Assert.a(lambda1>0);
028 Assert.a(shape2>0);
029 Assert.a(lambda2>0);
030
031 this.pdist1 = pdist1;
032 this.shape1 = shape1;
033 this.lambda1 = lambda1;
034 this.shape2 = shape2;
035 this.lambda2 = lambda2;
036 }
037
038
039
040 public MixtureOfGammas(double[] lengths) {
041 setup(lengths,false);
042 log.info(summary());
043 }
044
045 public MixtureOfGammas(double[] lengths, boolean forceExponentialLength) {
046 setup(lengths,forceExponentialLength);
047 log.info(summary());
048 }
049
050 public void setup(double[] lengths, boolean forceExponentialLength) {
051 int nLengths = lengths.length;
052 for (int j=0; j<nLengths; j++) {
053 Assert.a(lengths[j]>0);
054 }
055
056 Assert.a(nLengths>=0);
057
058 boolean exponentialDistribution = false;
059
060 if (nLengths==0) {
061 log.warn("Train mixture gamma called with no length inputs; returning an exponential distn with mean 100");
062 double mean = 100;
063 pdist1 = 1.0;
064 shape1 = 1.0;
065 lambda1 = 1/mean;
066 shape2 = 1.0;
067 lambda2 = 1.0;
068
069 return;
070 }
071
072 double[] sortedLengths = lengths;
073 Arrays.sort(sortedLengths);
074 if(log.isDebugEnabled())
075 log.debug("Mixture of Gammas trainer called to model these lengths: "+ColtUtil.format(sortedLengths));
076
077
078 if (forceExponentialLength) {
079 log.warn("you called a mixture of gammas model but set flag to force it to model as an exponential length distribution.");
080 exponentialDistribution = true;
081 }
082
083 if (nLengths < 20) {
084 log.warn("fewer than 20 lengths supplied for training; modeling with an exponential distribution instead of a mixture of Gammas");
085 exponentialDistribution = true;
086 }
087
088 boolean allSame = true;
089 double firstLength = lengths[0];
090 for (int j=0; j<lengths.length; j++) {
091 if ((lengths[j]<0.99*firstLength) || (lengths[j]>1.01*firstLength)) {
092 allSame = false;
093 }
094 }
095 if (allSame) {
096 log.warn("All the lengths we're asked to model are extremely close to the same value, suggesting the length is artifically fixed for some reason. This is probably a mistake and some othe model of length is appropriate. Will model as exponential length distribution.");
097 exponentialDistribution = true;
098 }
099
100 if (exponentialDistribution) {
101 double mean = BasicStats.meanDoubleArray(lengths);
102 pdist1 = 1.0;
103 shape1 = 1.0;
104 lambda1 = 1/mean;
105 shape2 = 1.0;
106 lambda2 = 1.0;
107
108 return;
109 }
110
111 /* Think through in english and equations before trying to write out code
112 *
113 * The hidden variables Y will be which of the two mixtures a particular point comes from.
114 *
115 * The observed value X will be the observed intron length.
116 *
117 * The parameters will be p (the probability of the biased coin saying distribution 1),
118 * the gamma parameters p1 and lambda1 for distribution 1, and the parameters p2 and
119 * lambda2 for other distribution. Aggregating, let theta = (p,p1,lambda1,p2,lambda2).
120 *
121 * The goal is to choose the parameters theta that maximize Pr(X|theta) = sum_Y ( Pr(Y,X|theta) ).
122 *
123 * We do this using E-M algorithm, which iteratively increases Pr(X|theta) and will converge to
124 * a local but not necessarily global maximum. Given theta_0 we need to be able to find
125 * theta_1 so that Pr(X|theta1) > Pr(X|theta0).
126 *
127 * log[Pr(X|theta1)/Pr(X|theta0)] = log( sum_Y [ Pr(Y|X,theta0) * Pr(X,Y|theta1) / Pr(Y,X|theta0) ] )
128 * [log is concave] >= sum_Y [ Pr(Y|X,theta0) * log [ Pr(X,Y|theta1) / Pr(Y,X|theta0) ] ]
129 * = G(theta1) - G(theta0),
130 * where G(theta) = sum_Y [ Pr(Y|X,theta0) * log(Pr(X,Y|theta) ] is the auxillary function.
131 *
132 * The EXPECTATION STEP is to form the function G in a way that can be evaluated as a function of theta.
133 *
134 * The MAXIMIZATION STEP is to choose the value theta1 which maximizes G. It must then be the
135 * case that G(theta1)-G(theta0) >= 0, so that Pr(X|theta1) >= Pr(X|theta0).
136 *
137 * Need to introduce fake data, two points per distribution, to prevent ML convergence to prevent
138 * having one of the distributions converge on a single point, which is the non-useful trivial
139 * solution to the ML problem. The fake data will be introduced as follows.
140 * We sort the observed lengths, and take the median. Then the fake data for distribution 1
141 * is 0.90*median and 1.00*median; the fake data for distn 2 is 1.00*median and 1.10*median.
142 */
143
144 // Introduce the fake data to prevent convergence to a spike; essentially a prior
145 double median = BasicStats.medianDoubleArray(lengths);
146 int len = nLengths+4;
147 double[] x = new double[len];
148 double[] post = new double[len]; // posterior probability of coming from distribution 1 not 2
149 for (int j=0; j<nLengths; j++) { x[j] = lengths[j]; }
150 x[nLengths] = 0.90*median; post[nLengths] = 1; // fake data for dist 1
151 x[nLengths+1] = 1.00*median; post[nLengths+1] = 1; // fake data for dist 1
152 x[nLengths+2] = 1.00*median; post[nLengths+2] = 0; // fake data for dist 2
153 x[nLengths+3] = 1.10*median; post[nLengths+3] = 0; // fake data for dist 2
154
155
156 // Start with initial seed guesses for p,p1,lambda1,p2,lambda2
157 pdist1 = 0.5;
158 shape1 = 15;
159 lambda1 = 0.25;
160 shape2 = 5;
161 lambda2 = 0.05;
162
163
164 for (int iteration=0; iteration<40; iteration++) {
165
166 //System.out.println("Iteration " + iteration + "\tp=" + pdist1 + "\tshape1=" + shape1 + "\tlambda1=" + lambda1 + "\tmean1=" + shape1/lambda1 + "\tshape2=" + shape2 + "\tlambda2=" + lambda2 + "\tmean2=" + shape2/lambda2);
167
168
169 // Do a single EM loop.
170 // First the E-step:
171 for (int j=0; j<nLengths; j++) {
172 double p1 = pdist1*GammaDistribution.gamma(shape1,lambda1,x[j]);
173 double p2 = (1-pdist1)*GammaDistribution.gamma(shape2,lambda2,x[j]);
174 if (!(p1 + p2 > 0)) {
175 Assert.a(false,"x[j]=" + x[j] + " p1="+p1 + " p2="+p2+" pdist1=" + pdist1 + " shape1="+shape1+" lambda1="+ lambda1+" shape2="+shape2+" lambda2="+lambda2);
176 }
177 post[j] = p1 / (p1 + p2);
178 }
179
180 // And now the M-step.
181
182 double mean1 = 0, mean2 = 0;
183 double meanlog1 = 0, meanlog2 = 0;
184
185 for (int j = 0; j < len; j++) {
186 Assert.a(x[j] > 0);
187 mean1 += post[j] * x[j];
188 mean2 += (1 - post[j]) * x[j];
189 meanlog1 += post[j] * Math.log(x[j]);
190 meanlog2 += (1 - post[j]) * Math.log(x[j]);
191 }
192 pdist1 = BasicStats.meanDoubleArray(post);
193 mean1 /= pdist1 * len;
194 meanlog1 /= pdist1 * len;
195 mean2 /= (1 - pdist1) * len;
196 meanlog2 /= (1 - pdist1) * len;
197
198 double[] plam1 = GammaDistribution.mleg(mean1, meanlog1);
199 double[] plam2 = GammaDistribution.mleg(mean2, meanlog2);
200 shape1 = plam1[0];
201 lambda1 = plam1[1];
202 shape2 = plam2[0];
203 lambda2 = plam2[1];
204
205 }
206
207 }
208
209
210 public double logEvaluate(double x) {
211 double lret1 = GammaDistribution.lgamma(shape1,lambda1,x); // plus log(pdist1)
212 double lret2 = GammaDistribution.lgamma(shape2,lambda2,x); // plus log(1-pdist1)
213
214 if (pdist1>0.999) {
215 return lret1;
216 }
217 if (pdist1<0.001) {
218 return lret2;
219 }
220
221 double maxlog = Math.max(lret1,lret2);
222
223 lret1 -= maxlog;
224 lret2 -= maxlog;
225
226 if (lret1<-100) {
227 return maxlog + Math.log(1-pdist1) + lret2;
228 }
229
230 if (lret2<-100) {
231 return maxlog + Math.log(pdist1) + lret1;
232 }
233
234 double ret = maxlog + Math.log((pdist1)*Math.exp(lret1) + (1-pdist1)*Math.exp(lret2));
235
236 Assert.a((ret != Double.NEGATIVE_INFINITY) && (ret != Double.POSITIVE_INFINITY) && (!Double.isNaN(ret)));
237 return ret;
238 }
239
240
241 public double evaluate(double x) {
242 double ret = pdist1*(GammaDistribution.gamma(shape1,lambda1,x));
243 ret += (1-pdist1)*(GammaDistribution.gamma(shape2,lambda2,x));
244 Assert.a((ret != Double.NEGATIVE_INFINITY) && (ret != Double.POSITIVE_INFINITY) && (!Double.isNaN(ret)));
245 return ret;
246 }
247
248
249 public void summarize(PrintStream out) {
250 out.println(summary());
251
252 }
253
254 private String summary() {
255 String ret = "";
256 ret = ret + "MIXTURE OF GAMMAS INFO: pr(dist1)=" + pdist1;
257 ret = ret + " shape1=" + shape1;
258 ret = ret + " rate1=" + lambda1;
259 ret = ret + " mean1=" + (shape1/lambda1);
260 ret = ret + " shape2=" + shape2;
261 ret = ret + " rate2=" + lambda2;
262 ret = ret + " mean2=" + (shape2/lambda2);
263 return ret;
264 }
265
266
267 public double getMix() {
268 return pdist1;
269 }
270
271
272 public double getMean() {
273 double ret = pdist1*shape1/lambda1 + (1-pdist1)*shape2/lambda2;
274 return ret;
275 }
276
277
278
279 }
280
281