001 package calhoun.analysis.crf.features.supporting.phylogenetic;
002
003 import java.io.Serializable;
004 import java.util.Arrays;
005
006 import org.apache.commons.logging.Log;
007 import org.apache.commons.logging.LogFactory;
008
009 import calhoun.analysis.crf.io.MultipleAlignmentInputSequence.MultipleAlignmentColumn;
010 import calhoun.seq.KmerHasher;
011 import calhoun.util.Assert;
012 import cern.colt.matrix.DoubleMatrix2D;
013
014 public class EvolutionaryModel implements Serializable {
015 private static final long serialVersionUID = -4531626481183209673L;
016
017 private static final Log log = LogFactory.getLog(EvolutionaryModel.class);
018
019 // Fundamental information:
020 PhylogeneticTreeFelsensteinOrder T;
021 double[] pi;
022 NucleotideSubstitutionModel R;
023 int numSpecies;
024
025
026 // Derived information, or precomputed/reserved for efficiency:
027 int[] ileft,iright; // indices of left and right child nodes
028 double[][] Tleft; //Transition matrices for the branches going left
029 double[][] Tright; // transition matrices for branches going right
030 double[][] P; // space in which Felsenstein algorithm recursions will be performed.
031
032 static KmerHasher hforward = new KmerHasher(KmerHasher.ACGTother,1);
033 static KmerHasher hbackward = new KmerHasher(KmerHasher.ACGTotherRC,1);
034
035 public EvolutionaryModel(PhylogeneticTreeFelsensteinOrder T,
036 double[] pi, NucleotideSubstitutionModel R ) {
037 this.T = T;
038 this.pi = pi;
039 this.R = R;
040 numSpecies = T.numSpecies();
041 setup();
042 }
043
044 private void setup() {
045 ileft = T.getileft();
046 iright = T.getiright();
047 double[] bleft = T.getbleft();
048 double[] bright = T.getbright();
049
050 Tleft = new double[T.nSteps][];
051 Tright = new double[T.nSteps][];
052 for (int j=0; j<T.nSteps; j++) {
053 Tleft[j] = createArrayFromTransitionMatrix(R.transitionMatrix(bleft[j]));
054 Tright[j] = createArrayFromTransitionMatrix(R.transitionMatrix(bright[j]));
055 }
056 Assert.a(ileft.length == T.nSteps);
057 Assert.a(iright.length == T.nSteps);
058
059 P = new double[T.nNodes][4];
060 }
061
062
063 public double logprobRC(MultipleAlignmentColumn col, boolean conditionref) {
064 return logprob(col,conditionref,hbackward);
065 }
066
067 public double logprob(MultipleAlignmentColumn col, boolean conditionref) {
068 return logprob(col,conditionref,hforward);
069 }
070
071
072 private double logprob(MultipleAlignmentColumn C,boolean conditionref, KmerHasher h) {
073
074 if ( C.numSpecies() != numSpecies ) {
075 Assert.a(false,"C.numspecies is " + C.numSpecies() + " and numSpecies is " + numSpecies);
076 }
077
078 for (int i=0; i<numSpecies; i++) {
079 int x = h.hash(C.nucleotide(i));
080 for (int j=0; j<4; j++) {
081 P[i][j] = x >= 4 || x == j ? 1.0 : 0.0;
082 }
083 }
084
085 for (int step=0; step<T.nSteps; step++) {
086 int node = step + numSpecies;
087 felsenstein(P[ileft[step]],Tleft[step],P[iright[step]],Tright[step],P[node]);
088 }
089 double prob = 0;
090 for (int i=0; i<4; i++) {
091 prob += pi[i] * P[T.numNodes()-1][i];
092 }
093
094 if (conditionref) {
095 for (int i=1; i<T.numSpecies(); i++) {
096 Arrays.fill(P[i], 1.0);
097 }
098 for (int step=0; step<T.nSteps; step++) {
099 int node = step + numSpecies;
100 felsenstein(P[ileft[step]],Tleft[step],P[iright[step]],Tright[step],P[node]);
101 }
102 double denom = 0;
103 for (int i=0; i<4; i++) {
104 denom += pi[i] * P[T.numNodes()-1][i];
105 }
106 if ( !(prob/denom < 1.00000001) ) {
107 Assert.a(false , "prob=" + prob + " denom="+denom);
108 }
109
110 prob = prob/denom;
111 }
112
113 if (!(prob > 0)) {
114 Assert.a(false,"prob="+prob);
115 }
116 if ( !(prob < 1.00000001) ) {
117 Assert.a(false , "prob=" + prob );
118 }
119 return Math.log(prob);
120 }
121
122 private static double[] createArrayFromTransitionMatrix(DoubleMatrix2D R) {
123 double[] ret = new double[16];
124 for(int i = 0; i<4; ++i) {
125 for(int j = 0; j<4; ++j) {
126 ret[i*4+j] = R.getQuick(i,j);
127 }
128 }
129 return ret;
130 }
131
132 private static void felsenstein(double[] lp, double[] lT,double[] rp, double[] rT,double[] pp) {
133 for (int i=0; i<4; i++) {
134 double leftprob=0.0, rightprob=0.0;
135 for (int j=0; j<4; j++) {
136 leftprob += lT[i*4 + j]*lp[j];
137 rightprob += rT[i*4 + j]*rp[j];
138 }
139
140 pp[i] = leftprob*rightprob;
141 }
142 return;
143 }
144
145 public void summarize() {
146 log.debug("Evolutionary model, initial probabilities: pi = " + pi[0] + "\t" + pi[1] + "\t" + pi[2] + "\t" + pi[3]);
147 R.summarize();
148 }
149
150 }