001    package calhoun.analysis.crf.statistics;
002    
003    import java.io.PrintStream;
004    import java.io.Serializable;
005    import java.util.Arrays;
006    
007    import org.apache.commons.logging.Log;
008    import org.apache.commons.logging.LogFactory;
009    
010    import calhoun.util.Assert;
011    import calhoun.util.ColtUtil;
012    
013    public class MixtureOfGammas implements Serializable {
014            private static final long serialVersionUID = 6269582005717276381L;
015            private static final Log log = LogFactory.getLog(MixtureOfGammas.class);
016    
017            private double pdist1;
018            private double shape1;
019            private double lambda1;
020            private double shape2;
021            private double lambda2;
022    
023            public MixtureOfGammas(double pdist1, double shape1, double lambda1, double shape2, double lambda2) {
024                    Assert.a(pdist1>=0);
025                    Assert.a(pdist1<=1.0);
026                    Assert.a(shape1>0);
027                    Assert.a(lambda1>0);
028                    Assert.a(shape2>0);
029                    Assert.a(lambda2>0);
030                    
031                    this.pdist1 = pdist1;
032                    this.shape1 = shape1;
033                    this.lambda1 = lambda1;
034                    this.shape2 = shape2;
035                    this.lambda2 = lambda2;
036            }
037    
038    
039            
040            public MixtureOfGammas(double[] lengths) {
041                    setup(lengths,false);
042                    log.info(summary());
043            }
044    
045            public MixtureOfGammas(double[] lengths, boolean forceExponentialLength) {
046                    setup(lengths,forceExponentialLength);
047                    log.info(summary());
048            }
049            
050            public void setup(double[] lengths, boolean forceExponentialLength) {
051                    int nLengths = lengths.length;          
052                    for (int j=0; j<nLengths; j++) {
053                            Assert.a(lengths[j]>0);
054                    }
055    
056                    Assert.a(nLengths>=0);
057                    
058                    boolean exponentialDistribution = false;
059                    
060                    if (nLengths==0) {
061                            log.warn("Train mixture gamma called with no length inputs; returning an exponential distn with mean 100");
062                            double mean = 100;
063                            pdist1 = 1.0;
064                            shape1 = 1.0;
065                            lambda1 = 1/mean;
066                            shape2 = 1.0;
067                            lambda2 = 1.0;
068                            
069                            return;         
070                    }
071    
072                    double[] sortedLengths = lengths;
073                    Arrays.sort(sortedLengths);
074                    if(log.isDebugEnabled())
075                            log.debug("Mixture of Gammas trainer called to model these lengths: "+ColtUtil.format(sortedLengths));
076                    
077                    
078                    if (forceExponentialLength) {
079                            log.warn("you called a mixture of gammas model but set flag to force it to model as an exponential length distribution.");
080                            exponentialDistribution = true;                 
081                    }
082                    
083                    if (nLengths < 20) {
084                            log.warn("fewer than 20 lengths supplied for training; modeling with an exponential distribution instead of a mixture of Gammas");
085                            exponentialDistribution = true;
086                    }       
087                    
088                    boolean allSame = true;
089                    double firstLength = lengths[0];
090                    for (int j=0; j<lengths.length; j++) {
091                            if ((lengths[j]<0.99*firstLength) || (lengths[j]>1.01*firstLength)) {
092                                    allSame = false;
093                            }
094                    }
095                    if (allSame) {
096                            log.warn("All the lengths we're asked to model are extremely close to the same value, suggesting the length is artifically fixed for some reason.  This is probably a mistake and some othe model of length is appropriate.  Will model as exponential length distribution.");
097                            exponentialDistribution = true;
098                    }
099                    
100                    if (exponentialDistribution) {
101                            double mean = BasicStats.meanDoubleArray(lengths);
102                            pdist1 = 1.0;
103                            shape1 = 1.0;
104                            lambda1 = 1/mean;
105                            shape2 = 1.0;
106                            lambda2 = 1.0;
107                            
108                            return;                         
109                    }
110                    
111                    /* Think through in english and equations before trying to write out code
112                     * 
113                     * The hidden variables Y will be which of the two mixtures a particular point comes from.
114                     * 
115                     * The observed value X will be the observed intron length.
116                     * 
117                     * The parameters will be p (the probability of the biased coin saying distribution 1),
118                     * the gamma parameters p1 and lambda1 for distribution 1, and the parameters p2 and
119                     * lambda2 for other distribution.  Aggregating, let theta = (p,p1,lambda1,p2,lambda2).
120                     * 
121                     * The goal is to choose the parameters theta that maximize Pr(X|theta) = sum_Y ( Pr(Y,X|theta) ).
122                     * 
123                     * We do this using E-M algorithm, which iteratively increases Pr(X|theta) and will converge to
124                     * a local but not necessarily global maximum.  Given theta_0 we need to be able to find
125                     * theta_1 so that Pr(X|theta1) > Pr(X|theta0).
126                     * 
127                     * log[Pr(X|theta1)/Pr(X|theta0)] = log( sum_Y [ Pr(Y|X,theta0) * Pr(X,Y|theta1) / Pr(Y,X|theta0) ] )
128                     *      [log is concave]    >=  sum_Y [ Pr(Y|X,theta0) * log [ Pr(X,Y|theta1) /  Pr(Y,X|theta0) ] ]
129                     *                          = G(theta1) - G(theta0),
130                     * where G(theta) = sum_Y [ Pr(Y|X,theta0) * log(Pr(X,Y|theta) ] is the auxillary function.
131                     *      
132                     * The EXPECTATION STEP is to form the function G in a way that can be evaluated as a function of theta.
133                     *      
134                     * The MAXIMIZATION STEP is to choose the value theta1 which maximizes G.  It must then be the
135                     *      case that G(theta1)-G(theta0) >= 0, so that Pr(X|theta1) >= Pr(X|theta0). 
136                     * 
137                     * Need to introduce fake data, two points per distribution, to prevent ML convergence to prevent
138                     * having one of the distributions converge on a single point, which is the non-useful trivial
139                     * solution to the ML problem.  The fake data will be introduced as follows.
140                     * We sort the observed lengths, and take the median.  Then the fake data for distribution 1
141                     * is 0.90*median and 1.00*median; the fake data for distn 2 is 1.00*median and 1.10*median. 
142                     */             
143    
144                    // Introduce the fake data to prevent convergence to a spike; essentially a prior
145                    double median = BasicStats.medianDoubleArray(lengths);
146                    int len = nLengths+4;
147                    double[] x = new double[len];
148                    double[] post = new double[len];  // posterior probability of coming from distribution 1 not 2
149                    for (int j=0; j<nLengths; j++) { x[j] = lengths[j]; }
150                    x[nLengths]   = 0.90*median; post[nLengths]   = 1;  // fake data for dist 1
151                    x[nLengths+1] = 1.00*median; post[nLengths+1] = 1;  // fake data for dist 1
152                    x[nLengths+2] = 1.00*median; post[nLengths+2] = 0;  // fake data for dist 2
153                    x[nLengths+3] = 1.10*median; post[nLengths+3] = 0;  // fake data for dist 2
154                                            
155                    
156                    // Start with initial seed guesses for p,p1,lambda1,p2,lambda2
157                    pdist1 = 0.5;
158                    shape1 = 15;
159                    lambda1 = 0.25;
160                    shape2 = 5;
161                    lambda2 = 0.05;
162                    
163                    
164                    for (int iteration=0; iteration<40; iteration++) {
165                    
166                            //System.out.println("Iteration " + iteration + "\tp=" + pdist1 + "\tshape1=" + shape1 + "\tlambda1=" + lambda1 + "\tmean1=" + shape1/lambda1 + "\tshape2=" + shape2 + "\tlambda2=" + lambda2 + "\tmean2=" + shape2/lambda2);
167                            
168                            
169                            // Do a single EM loop.
170                            // First the E-step:
171                            for (int j=0; j<nLengths; j++) {
172                                    double p1 = pdist1*GammaDistribution.gamma(shape1,lambda1,x[j]);
173                                    double p2 = (1-pdist1)*GammaDistribution.gamma(shape2,lambda2,x[j]);
174                                    if (!(p1 + p2 > 0)) {
175                                            Assert.a(false,"x[j]=" + x[j] + "  p1="+p1 + "  p2="+p2+"  pdist1=" + pdist1 + "  shape1="+shape1+"  lambda1="+ lambda1+"  shape2="+shape2+"  lambda2="+lambda2);
176                                    }
177                                    post[j] = p1 / (p1 + p2);
178                            }
179                            
180                            // And now the M-step.
181                            
182                            double mean1 = 0, mean2 = 0;
183                            double meanlog1 = 0, meanlog2 = 0;
184                            
185                            for (int j = 0; j < len; j++) {
186                                    Assert.a(x[j] > 0);
187                                    mean1 += post[j] * x[j];
188                                    mean2 += (1 - post[j]) * x[j];
189                                    meanlog1 += post[j] * Math.log(x[j]);
190                                    meanlog2 += (1 - post[j]) * Math.log(x[j]);
191                            }
192                            pdist1 = BasicStats.meanDoubleArray(post);
193                            mean1 /= pdist1 * len;
194                            meanlog1 /= pdist1 * len;
195                            mean2 /= (1 - pdist1) * len;
196                            meanlog2 /= (1 - pdist1) * len;
197                            
198                            double[] plam1 = GammaDistribution.mleg(mean1, meanlog1);
199                            double[] plam2 = GammaDistribution.mleg(mean2, meanlog2);
200                            shape1 = plam1[0];
201                            lambda1 = plam1[1];
202                            shape2 = plam2[0];
203                            lambda2 = plam2[1];
204                            
205                    }
206                    
207            }
208    
209    
210            public double logEvaluate(double x) {
211                    double lret1 = GammaDistribution.lgamma(shape1,lambda1,x); // plus log(pdist1)
212                    double lret2 = GammaDistribution.lgamma(shape2,lambda2,x); // plus log(1-pdist1)
213                    
214                    if (pdist1>0.999) {
215                            return lret1;
216                    }
217                    if (pdist1<0.001) {
218                            return lret2;
219                    }
220                    
221                    double maxlog = Math.max(lret1,lret2);
222                    
223                    lret1 -= maxlog;
224                    lret2 -= maxlog;
225                    
226                    if (lret1<-100) {
227                            return maxlog + Math.log(1-pdist1) + lret2;
228                    }
229    
230                    if (lret2<-100) {
231                            return maxlog + Math.log(pdist1) + lret1;
232                    }               
233                    
234                    double ret = maxlog + Math.log((pdist1)*Math.exp(lret1) + (1-pdist1)*Math.exp(lret2));  
235                    
236                    Assert.a((ret != Double.NEGATIVE_INFINITY) && (ret != Double.POSITIVE_INFINITY) && (!Double.isNaN(ret)));
237                    return ret;
238            }
239            
240    
241            public double evaluate(double x) {
242                    double ret = pdist1*(GammaDistribution.gamma(shape1,lambda1,x));
243                    ret += (1-pdist1)*(GammaDistribution.gamma(shape2,lambda2,x));
244                    Assert.a((ret != Double.NEGATIVE_INFINITY) && (ret != Double.POSITIVE_INFINITY) && (!Double.isNaN(ret)));
245                    return ret;
246            }
247    
248    
249            public void summarize(PrintStream out) {
250                    out.println(summary());
251            
252            }
253    
254            private String summary() {
255                    String ret = "";
256                    ret = ret + "MIXTURE OF GAMMAS INFO: pr(dist1)=" + pdist1;
257                    ret = ret + "  shape1=" + shape1;
258                    ret = ret + "  rate1=" + lambda1;
259                    ret = ret + "  mean1=" + (shape1/lambda1);
260                    ret = ret + "  shape2=" + shape2;
261                    ret = ret + "  rate2=" + lambda2;
262                    ret = ret + "  mean2=" + (shape2/lambda2);                      
263                    return ret;
264            }
265            
266    
267            public double getMix() {
268                    return pdist1;
269            }
270    
271    
272            public double getMean() {
273                    double ret = pdist1*shape1/lambda1 + (1-pdist1)*shape2/lambda2;
274                    return ret;
275            }
276            
277            
278            
279    }
280    
281