001    package calhoun.analysis.crf.test;
002    
003    import org.apache.commons.logging.Log;
004    import org.apache.commons.logging.LogFactory;
005    
006    import calhoun.analysis.crf.statistics.BasicStats;
007    import calhoun.analysis.crf.statistics.GammaDistribution;
008    import calhoun.analysis.crf.statistics.MixtureOfGammas;
009    import calhoun.util.AbstractTestCase;
010    import calhoun.util.Assert;
011    import calhoun.util.FileUtil;
012    
013    public class GammaDistributionsTest extends AbstractTestCase {
014            private static final Log log = LogFactory.getLog(CRFIOTest.class);
015            boolean debug = log.isDebugEnabled();
016    
017            
018            
019            public void testGammaDistribution() throws Exception {
020                    String fileName ="test/input/testIntronLengths.txt";
021                    double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
022                    int nIntrons = intronlengths.length;
023                    
024                    assertEquals(nIntrons,5863);
025                    assertEquals(intronlengths[0],59,0.00001);
026                    assertEquals(intronlengths[1],99,0.00001);
027                    
028                    double mean = 0.0;
029                    double meanlog = 0.0;
030                    for (int j=0; j<nIntrons; j++) {
031                            Assert.a(intronlengths[j] > 0);      
032                            mean += intronlengths[j];
033                            meanlog += Math.log(intronlengths[j]);
034                    }
035                    mean /= nIntrons;
036                    meanlog /= nIntrons;            
037                    System.out.println("Mean = " + mean);
038                    System.out.println("Meanlog = " + meanlog);
039                            
040                    double[] plam = GammaDistribution.mleg(mean,meanlog);
041                    double p = plam[0];
042                    double lam = plam[1];
043                    
044                    assertEquals(mean,p/lam);       
045                    System.out.println("Shape parameter p is "  +p);
046                    System.out.println("Rate parameter lambda is " + lam);
047            }
048            
049            
050            public void testGammaFunction() throws Exception {
051                    // for integers p, gamma(p) = (p-1)!
052                    assertEquals(GammaDistribution.gamma(1),1,0.000001);
053                    assertEquals(GammaDistribution.gamma(2),1,0.000001);
054                    assertEquals(GammaDistribution.gamma(3),2,0.000001);
055                    assertEquals(GammaDistribution.gamma(4),6,0.000001);
056                    assertEquals(GammaDistribution.gamma(5),24,0.000001);
057                    // Another identity is gamma(0.5) = sqrt(pi)
058                    assertEquals(GammaDistribution.gamma(0.5),Math.sqrt(Math.PI),0.000001);
059            }
060            
061    
062            
063            
064            public void testMixtureGammaModel() throws Exception {
065                    String fileName ="test/input/testIntronLengths.txt";
066                    double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
067                    int nIntrons = intronlengths.length;            
068                    
069                    /* Think through in english and equations before trying to write out code
070                     * 
071                     * The hidden variables Y will be which of the two mixtures a particular point comes from.
072                     * 
073                     * The observed value X will be the observed intron length.
074                     * 
075                     * The parameters will be p (the probability of the biased coin saying distribution 1),
076                     * the gamma parameters p1 and lambda1 for distribution 1, and the parameters p2 and
077                     * lambda2 for other distribution.  Aggregating, let theta = (p,p1,lambda1,p2,lambda2).
078                     * 
079                     * The goal is to choose the parameters theta that maximize Pr(X|theta) = sum_Y ( Pr(Y,X|theta) ).
080                     * 
081                     * We do this using E-M algorithm, which iteratively increases Pr(X|theta) and will converge to
082                     * a local but not necessarily global maximum.  Given theta_0 we need to be able to find
083                     * theta_1 so that Pr(X|theta1) > Pr(X|theta0).
084                     * 
085                     * log[Pr(X|theta1)/Pr(X|theta0)] = log( sum_Y [ Pr(Y|X,theta0) * Pr(X,Y|theta1) / Pr(Y,X|theta0) ] )
086                     *      [log is concave]    >=  sum_Y [ Pr(Y|X,theta0) * log [ Pr(X,Y|theta1) /  Pr(Y,X|theta0) ] ]
087                     *                          = G(theta1) - G(theta0),
088                     * where G(theta) = sum_Y [ Pr(Y|X,theta0) * log(Pr(X,Y|theta) ] is the auxillary function.
089                     *      
090                     * The EXPECTATION STEP is to form the function G in a way that can be evaluated as a function of theta.
091                     *      
092                     * The MAXIMIZATION STEP is to choose the value theta1 which maximizes G.  It must then be the
093                     *      case that G(theta1)-G(theta0) >= 0, so that Pr(X|theta1) >= Pr(X|theta0). 
094                     * 
095                     * Need to introduce fake data, two points per distribution, to prevent ML convergence to prevent
096                     * having one of the distributions converge on a single point, which is the non-useful trivial
097                     * solution to the ML problem.
098                     */             
099    
100                    // Introduce the fake data to prevent convergence to a spike; essentially a prior
101                    int len = nIntrons+4;
102                    double[] x = new double[len];
103                    double[] post = new double[len];  // posterior probability of coming from distribution 1 not 2
104                    for (int j=0; j<nIntrons; j++) { x[j] = intronlengths[j]; }
105                    x[nIntrons]   = 50; post[nIntrons]   = 1;  // fake data for dist 1
106                    x[nIntrons+1] = 60; post[nIntrons+1] = 1;  // fake data for dist 1
107                    x[nIntrons+2] = 60; post[nIntrons+2] = 0;  // fake data for dist 2
108                    x[nIntrons+3] = 90; post[nIntrons+3] = 0;  // fake data for dist 2
109                                            
110                    
111                    // Start with initial guesses for p,p1,lambda1,p2,lambda2
112                    double p = 0.5;
113                    double shape1 = 15;
114                    double lambda1 = 0.25;
115                    double shape2 = 5;
116                    double lambda2 = 0.05;
117                    
118    
119                    for (int iteration=0; iteration<20; iteration++) {
120                    
121                            System.out.println("Iteration " + iteration + "\tp=" + p + "\tshape1=" + shape1 + "\tlambda1=" + lambda1 + "\tmean1=" + shape1/lambda1 + "\tshape2=" + shape2 + "\tlambda2=" + lambda2 + "\tmean2=" + shape2/lambda2);
122                            
123                            
124                    // Do a single EM loop.
125                    // First the E-step:
126                    for (int j=0; j<nIntrons; j++) {
127                            double p1 = p*GammaDistribution.gamma(shape1,lambda1,x[j]);
128                            double p2 = (1-p)*GammaDistribution.gamma(shape2,lambda2,x[j]);
129                            Assert.a(p1+p2>0);
130                            post[j] = p1/(p1+p2); 
131                    }
132                    
133                    // And now the M-step.
134                    
135                    double mean1=0, mean2=0;
136                    double meanlog1=0, meanlog2=0;
137                    
138                    for (int j=0; j<len; j++) {
139                            Assert.a(x[j] > 0);  
140                            mean1 += post[j]*x[j];
141                            mean2 += (1-post[j])*x[j];
142                            meanlog1 += post[j]*Math.log(x[j]);
143                            meanlog2 += (1-post[j])*Math.log(x[j]);
144                    }
145                    p = BasicStats.meanDoubleArray(post);
146                    mean1 /= p*len;
147                    meanlog1 /= p*len;              
148                    mean2 /= (1-p)*len;
149                    meanlog2 /= (1-p)*len;
150                            
151                            
152                    double[] plam1 = GammaDistribution.mleg(mean1,meanlog1);
153                    double[] plam2 = GammaDistribution.mleg(mean2,meanlog2);
154                    shape1 = plam1[0];
155                    lambda1 = plam1[1];
156                    shape2 = plam2[0];
157                    lambda2 = plam2[1];
158                    
159                    }       
160            }
161            
162            public void testMixtureGammaModelForceExponential() throws Exception {
163                    String fileName ="test/input/testIntronLengths.txt";
164                    double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
165                    new MixtureOfGammas(intronlengths); // tests that this runs to completion
166                    
167                    MixtureOfGammas mg = new MixtureOfGammas(intronlengths,true); // tests that runs to completion and is exponential
168                    
169                    for (int j=2; j<200; j++) {
170                            System.out.println("j="+j+"   " + Math.log(mg.evaluate((double) j)/mg.evaluate((double) (j-1))));
171                    }
172            
173                    assertEquals(Math.log(mg.evaluate(2)/mg.evaluate(1)),Math.log(mg.evaluate(102)/mg.evaluate(101)),0.0001);
174            }
175            
176            
177            public void testMixtureOfGammasClass() throws Exception {
178                    String fileName ="test/input/testIntronLengths.txt";
179                    double[] intronlengths = FileUtil.readDoublesFromSingleTabbedLine(fileName);
180            
181                    MixtureOfGammas MG = new MixtureOfGammas(intronlengths);
182                    MG.summarize(System.out);
183                    assertEquals(MG.getMix(),0.859,0.01);
184                    
185                    MixtureOfGammas MG2 = new MixtureOfGammas(0.86,71,1.27,4.1,0.041);              
186                    assertEquals(MG2.evaluate(55.0),0.053107,0.01);
187                    assertEquals(MG2.logEvaluate(55.0),-2.935446,0.01);
188                    
189                    // MG3 is exponential decay with mean 1
190                    MixtureOfGammas MG3 = new MixtureOfGammas(1.0, 1.0, 1.0, 1.0, 2.0);                             
191                    assertEquals(MG3.evaluate(1.0),Math.exp(-1.0),0.01);
192                    assertEquals(MG3.logEvaluate(1.0),-1.0,0.01);
193                    
194                    // MG4 is exponential decay with mean 0.5
195                    MixtureOfGammas MG4 = new MixtureOfGammas(0, 1.0, 1.0, 1.0, 2.0);               
196                    assertEquals(MG4.evaluate(1.0),2.0*Math.exp(-2.0),0.01);
197                    assertEquals(MG4.logEvaluate(1.0),Math.log(2.0)-2.0,0.01);
198                    
199                    // MG5 is equal mix of above two distributions
200                    MixtureOfGammas MG5 = new MixtureOfGammas(0.5,1.0,1.0,1.0,2.0);         
201                    assertEquals(MG5.evaluate(1.0), 0.5*Math.exp(-1.0) + 0.5*2.0*Math.exp(-2.0),0.01);
202                    assertEquals(MG5.logEvaluate(1.0), Math.log(0.5*Math.exp(-1.0) + 0.5*2.0*Math.exp(-2.0)),0.01);
203                    
204                    // MG6 should be an exponential distribution with mean 3.5
205                    MixtureOfGammas MG6 = new MixtureOfGammas(new double[]{1,2,3,4,5,6});
206                    MG6.summarize(System.out);
207                    assertEquals(MG6.getMean(),3.5,0.001);
208            }
209    
210    }