001    package calhoun.analysis.crf.features.supporting.phylogenetic;
002    
003    import java.io.Serializable;
004    import java.util.Arrays;
005    
006    import org.apache.commons.logging.Log;
007    import org.apache.commons.logging.LogFactory;
008    
009    import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
010    import calhoun.seq.KmerHasher;
011    import calhoun.util.Assert;
012    import cern.colt.matrix.DoubleMatrix2D;
013    
014    public class EvolutionaryModel implements Serializable {
015            private static final long serialVersionUID = -4531626481183209673L;
016    
017            private static final Log log = LogFactory.getLog(EvolutionaryModel.class);
018    
019            // Fundamental information:
020            PhylogeneticTreeFelsensteinOrder             T;
021            double[]                     pi;
022            NucleotideSubstitutionModel  R;
023            int numSpecies;
024    
025            
026            // Derived information, or precomputed/reserved for efficiency:
027            int[] ileft,iright;  // indices of left and right child nodes
028            double[][] Tleft;  //Transition matrices for the branches going left
029            double[][] Tright;  // transition matrices for branches going right
030            double[][] P;   // space in which Felsenstein algorithm recursions will be performed.
031    
032            static KmerHasher hforward = new KmerHasher(KmerHasher.ACGTother,1);
033            static KmerHasher hbackward = new KmerHasher(KmerHasher.ACGTotherRC,1); 
034            
035            public EvolutionaryModel(PhylogeneticTreeFelsensteinOrder T,
036                            double[] pi, NucleotideSubstitutionModel  R ) {
037                    this.T = T;
038                    this.pi = pi;
039                    this.R = R;
040                    numSpecies = T.numSpecies();
041                    setup();
042            }
043    
044            private void setup() {
045                    ileft = T.getileft();
046                    iright = T.getiright();
047                    double[] bleft = T.getbleft();
048                    double[] bright = T.getbright();
049                    
050                    Tleft  = new double[T.nSteps][];
051                    Tright = new double[T.nSteps][];
052                    for (int j=0; j<T.nSteps; j++) {
053                            Tleft[j] = createArrayFromTransitionMatrix(R.transitionMatrix(bleft[j]));
054                            Tright[j] = createArrayFromTransitionMatrix(R.transitionMatrix(bright[j]));
055                    }
056                    Assert.a(ileft.length == T.nSteps);
057                    Assert.a(iright.length == T.nSteps);
058                    
059                    P = new double[T.nNodes][4];
060            }
061            
062            
063            public double logprobRC(MultipleAlignmentColumn col, boolean conditionref) {
064                    return logprob(col,conditionref,hbackward);
065            }
066    
067            public double logprob(MultipleAlignmentColumn col, boolean conditionref) {
068                    return logprob(col,conditionref,hforward);
069            }
070            
071            
072            private double logprob(MultipleAlignmentColumn C,boolean conditionref, KmerHasher h) {
073                            
074                    if ( C.numSpecies() != numSpecies ) {
075                            Assert.a(false,"C.numspecies is " + C.numSpecies() + "  and numSpecies is " + numSpecies);
076                    }
077                    
078                    for (int i=0; i<numSpecies; i++) {
079                            int x = h.hash(C.nucleotide(i));
080                            for (int j=0; j<4; j++) { 
081                                    P[i][j] = x >= 4 || x == j ? 1.0 : 0.0; 
082                            }
083                    }
084                    
085                    for (int step=0; step<T.nSteps; step++) {
086                            int node = step + numSpecies;                   
087                            felsenstein(P[ileft[step]],Tleft[step],P[iright[step]],Tright[step],P[node]);
088                    }
089                    double prob = 0;
090                    for (int i=0; i<4; i++) {
091                            prob += pi[i] * P[T.numNodes()-1][i];
092                    }
093                    
094                    if (conditionref) {
095                            for (int i=1; i<T.numSpecies(); i++) {
096                                    Arrays.fill(P[i], 1.0);
097                            }
098                            for (int step=0; step<T.nSteps; step++) {
099                                    int node = step + numSpecies;                   
100                                    felsenstein(P[ileft[step]],Tleft[step],P[iright[step]],Tright[step],P[node]);
101                            }
102                            double denom = 0;
103                            for (int i=0; i<4; i++) {
104                                    denom += pi[i] * P[T.numNodes()-1][i];
105                            }
106                            if ( !(prob/denom < 1.00000001) ) {
107                                    Assert.a(false , "prob=" + prob + "  denom="+denom);
108                            }
109                            
110                            prob = prob/denom;
111                    }
112                    
113                    if (!(prob > 0)) {
114                            Assert.a(false,"prob="+prob);
115                    }
116                    if ( !(prob < 1.00000001) ) {
117                            Assert.a(false , "prob=" + prob );
118                    }
119                    return Math.log(prob);  
120            }
121    
122            private static double[] createArrayFromTransitionMatrix(DoubleMatrix2D R) {
123                    double[] ret = new double[16];
124                    for(int i = 0; i<4; ++i) {
125                            for(int j = 0; j<4; ++j) {
126                                    ret[i*4+j] = R.getQuick(i,j);
127                            }                       
128                    }
129                    return ret;
130            }
131            
132            private static void felsenstein(double[] lp, double[] lT,double[] rp, double[] rT,double[] pp) {
133                    for (int i=0; i<4; i++) { 
134                            double leftprob=0.0,  rightprob=0.0;
135                            for (int j=0; j<4; j++) {
136                                    leftprob += lT[i*4 + j]*lp[j];
137                                    rightprob += rT[i*4 + j]*rp[j];
138                            }
139                            
140                            pp[i] = leftprob*rightprob;
141                    }
142                    return;
143            }
144    
145            public void summarize() {
146                    log.debug("Evolutionary model, initial probabilities:   pi = " + pi[0] + "\t" + pi[1] + "\t" + pi[2] + "\t" + pi[3]);
147                    R.summarize();
148            }
149    
150    }