001 package calhoun.analysis.crf.test;
002
003 import org.apache.commons.logging.Log;
004 import org.apache.commons.logging.LogFactory;
005
006 import calhoun.analysis.crf.statistics.BasicStats;
007 import calhoun.analysis.crf.statistics.GammaDistribution;
008 import calhoun.analysis.crf.statistics.MixtureOfGammas;
009 import calhoun.util.AbstractTestCase;
010 import calhoun.util.Assert;
011 import calhoun.util.FileUtil;
012
013 public class GammaDistributionsTest extends AbstractTestCase {
014 private static final Log log = LogFactory.getLog(CRFIOTest.class);
015 boolean debug = log.isDebugEnabled();
016
017
018
019 public void testGammaDistribution() throws Exception {
020 String fileName ="test/input/testIntronLengths.txt";
021 double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
022 int nIntrons = intronlengths.length;
023
024 assertEquals(nIntrons,5863);
025 assertEquals(intronlengths[0],59,0.00001);
026 assertEquals(intronlengths[1],99,0.00001);
027
028 double mean = 0.0;
029 double meanlog = 0.0;
030 for (int j=0; j<nIntrons; j++) {
031 Assert.a(intronlengths[j] > 0);
032 mean += intronlengths[j];
033 meanlog += Math.log(intronlengths[j]);
034 }
035 mean /= nIntrons;
036 meanlog /= nIntrons;
037 System.out.println("Mean = " + mean);
038 System.out.println("Meanlog = " + meanlog);
039
040 double[] plam = GammaDistribution.mleg(mean,meanlog);
041 double p = plam[0];
042 double lam = plam[1];
043
044 assertEquals(mean,p/lam);
045 System.out.println("Shape parameter p is " +p);
046 System.out.println("Rate parameter lambda is " + lam);
047 }
048
049
050 public void testGammaFunction() throws Exception {
051 // for integers p, gamma(p) = (p-1)!
052 assertEquals(GammaDistribution.gamma(1),1,0.000001);
053 assertEquals(GammaDistribution.gamma(2),1,0.000001);
054 assertEquals(GammaDistribution.gamma(3),2,0.000001);
055 assertEquals(GammaDistribution.gamma(4),6,0.000001);
056 assertEquals(GammaDistribution.gamma(5),24,0.000001);
057 // Another identity is gamma(0.5) = sqrt(pi)
058 assertEquals(GammaDistribution.gamma(0.5),Math.sqrt(Math.PI),0.000001);
059 }
060
061
062
063
064 public void testMixtureGammaModel() throws Exception {
065 String fileName ="test/input/testIntronLengths.txt";
066 double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
067 int nIntrons = intronlengths.length;
068
069 /* Think through in english and equations before trying to write out code
070 *
071 * The hidden variables Y will be which of the two mixtures a particular point comes from.
072 *
073 * The observed value X will be the observed intron length.
074 *
075 * The parameters will be p (the probability of the biased coin saying distribution 1),
076 * the gamma parameters p1 and lambda1 for distribution 1, and the parameters p2 and
077 * lambda2 for other distribution. Aggregating, let theta = (p,p1,lambda1,p2,lambda2).
078 *
079 * The goal is to choose the parameters theta that maximize Pr(X|theta) = sum_Y ( Pr(Y,X|theta) ).
080 *
081 * We do this using E-M algorithm, which iteratively increases Pr(X|theta) and will converge to
082 * a local but not necessarily global maximum. Given theta_0 we need to be able to find
083 * theta_1 so that Pr(X|theta1) > Pr(X|theta0).
084 *
085 * log[Pr(X|theta1)/Pr(X|theta0)] = log( sum_Y [ Pr(Y|X,theta0) * Pr(X,Y|theta1) / Pr(Y,X|theta0) ] )
086 * [log is concave] >= sum_Y [ Pr(Y|X,theta0) * log [ Pr(X,Y|theta1) / Pr(Y,X|theta0) ] ]
087 * = G(theta1) - G(theta0),
088 * where G(theta) = sum_Y [ Pr(Y|X,theta0) * log(Pr(X,Y|theta) ] is the auxillary function.
089 *
090 * The EXPECTATION STEP is to form the function G in a way that can be evaluated as a function of theta.
091 *
092 * The MAXIMIZATION STEP is to choose the value theta1 which maximizes G. It must then be the
093 * case that G(theta1)-G(theta0) >= 0, so that Pr(X|theta1) >= Pr(X|theta0).
094 *
095 * Need to introduce fake data, two points per distribution, to prevent ML convergence to prevent
096 * having one of the distributions converge on a single point, which is the non-useful trivial
097 * solution to the ML problem.
098 */
099
100 // Introduce the fake data to prevent convergence to a spike; essentially a prior
101 int len = nIntrons+4;
102 double[] x = new double[len];
103 double[] post = new double[len]; // posterior probability of coming from distribution 1 not 2
104 for (int j=0; j<nIntrons; j++) { x[j] = intronlengths[j]; }
105 x[nIntrons] = 50; post[nIntrons] = 1; // fake data for dist 1
106 x[nIntrons+1] = 60; post[nIntrons+1] = 1; // fake data for dist 1
107 x[nIntrons+2] = 60; post[nIntrons+2] = 0; // fake data for dist 2
108 x[nIntrons+3] = 90; post[nIntrons+3] = 0; // fake data for dist 2
109
110
111 // Start with initial guesses for p,p1,lambda1,p2,lambda2
112 double p = 0.5;
113 double shape1 = 15;
114 double lambda1 = 0.25;
115 double shape2 = 5;
116 double lambda2 = 0.05;
117
118
119 for (int iteration=0; iteration<20; iteration++) {
120
121 System.out.println("Iteration " + iteration + "\tp=" + p + "\tshape1=" + shape1 + "\tlambda1=" + lambda1 + "\tmean1=" + shape1/lambda1 + "\tshape2=" + shape2 + "\tlambda2=" + lambda2 + "\tmean2=" + shape2/lambda2);
122
123
124 // Do a single EM loop.
125 // First the E-step:
126 for (int j=0; j<nIntrons; j++) {
127 double p1 = p*GammaDistribution.gamma(shape1,lambda1,x[j]);
128 double p2 = (1-p)*GammaDistribution.gamma(shape2,lambda2,x[j]);
129 Assert.a(p1+p2>0);
130 post[j] = p1/(p1+p2);
131 }
132
133 // And now the M-step.
134
135 double mean1=0, mean2=0;
136 double meanlog1=0, meanlog2=0;
137
138 for (int j=0; j<len; j++) {
139 Assert.a(x[j] > 0);
140 mean1 += post[j]*x[j];
141 mean2 += (1-post[j])*x[j];
142 meanlog1 += post[j]*Math.log(x[j]);
143 meanlog2 += (1-post[j])*Math.log(x[j]);
144 }
145 p = BasicStats.meanDoubleArray(post);
146 mean1 /= p*len;
147 meanlog1 /= p*len;
148 mean2 /= (1-p)*len;
149 meanlog2 /= (1-p)*len;
150
151
152 double[] plam1 = GammaDistribution.mleg(mean1,meanlog1);
153 double[] plam2 = GammaDistribution.mleg(mean2,meanlog2);
154 shape1 = plam1[0];
155 lambda1 = plam1[1];
156 shape2 = plam2[0];
157 lambda2 = plam2[1];
158
159 }
160 }
161
162 public void testMixtureGammaModelForceExponential() throws Exception {
163 String fileName ="test/input/testIntronLengths.txt";
164 double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
165 new MixtureOfGammas(intronlengths); // tests that this runs to completion
166
167 MixtureOfGammas mg = new MixtureOfGammas(intronlengths,true); // tests that runs to completion and is exponential
168
169 for (int j=2; j<200; j++) {
170 System.out.println("j="+j+" " + Math.log(mg.evaluate((double) j)/mg.evaluate((double) (j-1))));
171 }
172
173 assertEquals(Math.log(mg.evaluate(2)/mg.evaluate(1)),Math.log(mg.evaluate(102)/mg.evaluate(101)),0.0001);
174 }
175
176
177 public void testMixtureOfGammasClass() throws Exception {
178 String fileName ="test/input/testIntronLengths.txt";
179 double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
180
181 MixtureOfGammas MG = new MixtureOfGammas(intronlengths);
182 MG.summarize(System.out);
183 assertEquals(MG.getMix(),0.859,0.01);
184
185 MixtureOfGammas MG2 = new MixtureOfGammas(0.86,71,1.27,4.1,0.041);
186 assertEquals(MG2.evaluate(55.0),0.053107,0.01);
187 assertEquals(MG2.logEvaluate(55.0),-2.935446,0.01);
188
189 // MG3 is exponential decay with mean 1
190 MixtureOfGammas MG3 = new MixtureOfGammas(1.0, 1.0, 1.0, 1.0, 2.0);
191 assertEquals(MG3.evaluate(1.0),Math.exp(-1.0),0.01);
192 assertEquals(MG3.logEvaluate(1.0),-1.0,0.01);
193
194 // MG4 is exponential decay with mean 0.5
195 MixtureOfGammas MG4 = new MixtureOfGammas(0, 1.0, 1.0, 1.0, 2.0);
196 assertEquals(MG4.evaluate(1.0),2.0*Math.exp(-2.0),0.01);
197 assertEquals(MG4.logEvaluate(1.0),Math.log(2.0)-2.0,0.01);
198
199 // MG5 is equal mix of above two distributions
200 MixtureOfGammas MG5 = new MixtureOfGammas(0.5,1.0,1.0,1.0,2.0);
201 assertEquals(MG5.evaluate(1.0), 0.5*Math.exp(-1.0) + 0.5*2.0*Math.exp(-2.0),0.01);
202 assertEquals(MG5.logEvaluate(1.0), Math.log(0.5*Math.exp(-1.0) + 0.5*2.0*Math.exp(-2.0)),0.01);
203
204 // MG6 should be an exponential distribution with mean 3.5
205 MixtureOfGammas MG6 = new MixtureOfGammas(new double[]{1,2,3,4,5,6});
206 MG6.summarize(System.out);
207 assertEquals(MG6.getMean(),3.5,0.001);
208 }
209
210 }