001    package calhoun.analysis.crf.features.supporting.phylogenetic;
002    
003    import java.io.Serializable;
004    import java.util.ArrayList;
005    import java.util.HashMap;
006    import java.util.HashSet;
007    import java.util.Iterator;
008    import java.util.Map;
009    import java.util.Set;
010    import java.util.Map.Entry;
011    
012    import org.apache.commons.logging.Log;
013    import org.apache.commons.logging.LogFactory;
014    
015    import calhoun.util.Assert;
016    
017    public class RootedBinaryPhylogeneticTree implements Serializable {
018            private static final long serialVersionUID = 4800435332663788163L;
019            private static final Log log = LogFactory.getLog(RootedBinaryPhylogeneticTree.class);
020    
021            // Member variables
022            public  ArrayList<BinaryTreeNode> T;
023            
024            // Constructors
025            public RootedBinaryPhylogeneticTree( String ss ) {
026                    ss = ss.trim();
027                    weaklyValidateNewickString(ss);         
028                    T = new ArrayList<BinaryTreeNode>();
029                    growBranch(T,-1,ss.substring(0,ss.length()-1)); 
030            }
031    
032            private RootedBinaryPhylogeneticTree( ArrayList<BinaryTreeNode> TT ) {      
033                    T = TT;
034            }
035            
036            //////////////////////////////////////////////////
037            // Public functions:
038            public String newick() {
039                    Assert.a( T.get(0).p == -1 );
040                    String s = newick_recursion( 0 );
041                    s = s + ";";
042                    return s;
043            }       
044    
045            public void summarize_tree() {
046                    log.debug("Now printing an extended summary of a RootedBiinaryPhylogeneticTree:");
047                    log.debug("Representation as a Newick String --->" + newick() );
048                    log.debug( "Total branch length (normalized): " + ( total_branch_length() ) );
049                    log.debug( "Longest branch length (normalized): " + ( longest_branch_length() ) );  
050                    for (int j=0; j<T.size(); j++) {
051                            log.debug( "" +  j + " -- " + T.get(j).toString() );
052                    }
053            }
054            
055            public double total_branch_length() {
056                    double dist = 0;
057                    for (int j=0; j<T.size(); j++) {
058                            if (T.get(j).p == -1)
059                                    Assert.a (T.get(j).d == 0);
060                            dist += T.get(j).d;
061                    }
062                    return dist;
063            }
064                    
065            public double longest_branch_length() {
066                    double longest = 0;
067                    for (int j=0; j<T.size(); j++) {
068                            if (T.get(j).p == -1) {
069                                    double x = T.get(T.get(j).l).d + T.get(T.get(j).r).d;
070                                    if (x>longest) longest = x;
071                            } else{
072                                    double x = T.get(j).d;
073                                    if (x>longest) longest = x;
074                            }
075                    }
076                    return longest;
077            }
078    
079            public int getNumSpecies() {
080                    return getSpeciesSet().size();
081            }
082            
083            /** Given an ordering of species in a multiple alignment file,
084             * determine an order for computations by which Felsenstein's algorithm can be performed, and the branch lengths involved at each step. 
085             * For example, for the phylogenetic tree</p>
086             * (((cnDT,cnDS),cnAB),(cnBB,cnBV))
087             * and the MSA in the order  (0=cnDT,1=cnDS,2=cnAB,3=cnBB,4=cnBV), can compute in this order:</p>
088             * 5 = combine(0,1)</p>
089             * 6 = combine(2,5)</p>
090             * 7 = combine(3,4)</p>
091             * 8 = combine(6,7)</p>
092             * and the final answer is at node 8.</p>
093             * 
094             * The computation for nodes 5-8 can be represented by</p>
095             *              ileft = new int[]{0,5,3,6};</p>
096             *              iright = new int[]{1,2,4,7};</p>
097             * 
098             * @param msaOrder the (n) species presented in a multiple alignment, in order of their listing in the file
099             * @return  an ordering in which the (n-1) recursive calculations for Felsenstein's algorithm can be performed.
100             */
101            public PhylogeneticTreeFelsensteinOrder getFelsensteinOrder(String[] msaOrder) {
102                    
103                    log.debug("Attempting to determine the order for a Felsenstein recursion given a phylogenetic tree");
104                    
105                    int nSpecies = msaOrder.length;
106                    int nSteps = nSpecies - 1;
107                                    
108    
109                    
110                    // Step 1: Make sure the species in msaOrder exactly correspond to the set of species in Phylogenetic Tree.
111                    Set<String> treeSpecies = getSpeciesSet();
112                    
113                    if (treeSpecies.size() != msaOrder.length) {
114                            Assert.a(false, "treeSpecies.size = " + treeSpecies.size() + "   and msaOrder.length = " + msaOrder.length);
115                    }
116                    for (int j=0; j<msaOrder.length; j++) {
117                            Assert.a(treeSpecies.contains(msaOrder[j]));
118                    }
119                    
120                    // Step 2: Initialize mappings from new to old and from old to new, such that for i=0..(nSpecies-1),
121                    //   T.get(new2old[i]).n = msaOrder[i]    and    old2new[new2old[i]] = i
122                    //   and also initilize the Boolean vector hasParent
123                    Map<Integer,Integer> new2old    = new HashMap<Integer,Integer>();
124                    Map<Integer,Integer> old2new    = new HashMap<Integer,Integer>();
125                    Map<Integer,Boolean> needParent = new HashMap<Integer,Boolean>();
126                    for (int j=0; j<nSpecies; j++) {
127                            int oldIndex = getSpeciesIndex(msaOrder[j]);
128                            log.debug("Species " + j + " is " + msaOrder[j]);
129                            old2new.put(oldIndex,j);
130                            new2old.put(j,oldIndex);
131                            needParent.put(j,true);
132                    }
133    
134                    
135                    for (int step=0; step<nSteps; step++) {
136                            // Step 3: Find a nextNode to add on.  It must be:
137                            //   a) the parent of one of the nodes already in the new list which needs a parent
138                            //   b) both childern of nextNode are already in new list and need a parent
139                            int nextNodeOldIndex = -1;
140                            for (int j=0; j<new2old.size(); j++) {
141                                    if (!needParent.get(j)) { continue; }
142                                    int thisOldIndex = new2old.get(j);
143                                    int parentOldIndex = T.get(new2old.get(j)).p;
144                                    int leftOldIndex = T.get(parentOldIndex).l;
145                                    int rightOldIndex = T.get(parentOldIndex).r;
146                                    Assert.a( (thisOldIndex==leftOldIndex) || (thisOldIndex==rightOldIndex) ); // We went up and then came down two different ways, one of which must have resulted in no net movement
147                                    if (!old2new.containsKey(leftOldIndex)) { continue; }
148                                    if (!old2new.containsKey(rightOldIndex)) { continue; }  // Are both children of the considered parent already on the new list?
149                                    nextNodeOldIndex = parentOldIndex;
150                                    break;
151                            }
152                            Assert.a(nextNodeOldIndex != -1); // The process above has to have resulted in the identification of a suitable parent.
153                            
154                            // Step 4: Add this node to the mappings
155                            new2old.put(nSpecies+step,nextNodeOldIndex);
156                            old2new.put(nextNodeOldIndex,nSpecies+step);
157                            needParent.put(nSpecies+step,true);
158                            needParent.put(old2new.get(T.get(nextNodeOldIndex).l),false);
159                            needParent.put(old2new.get(T.get(nextNodeOldIndex).r),false);
160                    }
161                    
162                    // Step 5: Now that we know the mapping from new nodes to old nodes, just write down the order of computations.
163                    int[] ileft = new int[nSteps];
164                    int[] iright = new int[nSteps];
165                    double[] bleft = new double[nSteps];
166                    double[] bright = new double[nSteps];           
167                    
168                    for (int step=0; step <nSteps; step++) {
169                            int oldNodeIndex = new2old.get(nSpecies+step);
170                            ileft[step] = old2new.get(T.get(oldNodeIndex).l);
171                            iright[step] = old2new.get(T.get(oldNodeIndex).r);
172                            bleft[step] = T.get(new2old.get(ileft[step])).d;
173                            bright[step] = T.get(new2old.get(iright[step])).d;                              
174                    }
175                    
176    //              for (int step=0; step <nSteps; step++) {
177    //                      int oldNodeIndex = old2new.get(nSpecies+step);
178    //                      ileft[step] = T.get(oldNodeIndex).l;
179    //                      iright[step] = T.get(oldNodeIndex).r;
180    //                      bleft[step] = T.get(old2new.get(ileft[step])).d;
181    //                      bright[step] = T.get(old2new.get(iright[step])).d;                              
182    //              }
183                    
184                    // Step 6: Construct the actual computation-order object, and return it.
185                    return new PhylogeneticTreeFelsensteinOrder( ileft, iright, bleft, bright);             
186            }
187    
188            
189            public RootedBinaryPhylogeneticTree subtree(String[] sn) {
190                    HashMap<String,Integer>  selected = new HashMap<String,Integer>();;
191                    
192                    for (int j=0; j<sn.length; j++) {
193                            selected.put(sn[j],0);
194                    }
195                    
196                    return subtree(selected);
197            }
198            
199            
200            private RootedBinaryPhylogeneticTree subtree( HashMap<String,Integer>  selected ) {
201                    ArrayList<BinaryTreeNode> oldT = T;
202                    
203                    for (int j=0; j<oldT.size(); j++) {
204                            if (oldT.get(j).l == -1) {
205                                    Assert.a( oldT.get(j).r == -1 );
206                                    Integer a = selected.get(oldT.get(j).n);
207                                    if (a != null) {
208                                            selected.put(oldT.get(j).n ,   selected.get(oldT.get(j).n).intValue() +1  );
209                                            oldT.get(j).lm=true;
210                                            oldT.get(j).rm=true;
211                                            int s = j;  int p = oldT.get(s).p;
212                                            while (p != -1) {
213                                                    Assert.a ( (oldT.get(p).l == s) || (oldT.get(p).r == s) );
214                                                    if (oldT.get(p).l == s) { oldT.get(p).lm = true; }
215                                                    if (oldT.get(p).r == s) { oldT.get(p).rm = true; }
216                                                    s = p; p = oldT.get(s).p;
217                                            }
218                                    }
219                            }
220                    }
221                    
222                    log.debug( "About to doublecheck that all selected species were found exactly once...");
223                    Iterator ii = selected.entrySet().iterator();
224                    
225                    while ( ii.hasNext() ) {
226                            Map.Entry<String,Integer> me = (Entry<String, Integer>) ii.next();
227                            if (me.getValue() != 1) {
228                                    log.debug("  NOT FOUND UNIQUELY: " + me.getKey() +  "   " + me.getValue() );
229                                    
230                            }
231                            
232                    }
233                    log.debug( "DONE" );            
234                    
235                    
236                    
237                    Map<Integer,Integer> old_to_new = new HashMap<Integer,Integer>();
238                    int numNew = 0;
239                    for (int j=0; j<oldT.size(); j++ ) {
240                            if (oldT.get(j).lm && oldT.get(j).rm) {
241                                    old_to_new.put(j,numNew);
242                                    numNew++;
243                            }
244                    }
245                    
246                    ArrayList<BinaryTreeNode> newT = new ArrayList<BinaryTreeNode>();
247                    for (int j=0; j<numNew; j++) {
248                            newT.add(new BinaryTreeNode());
249                    }
250                    
251                    for (int j=0; j<oldT.size(); j++) {
252                            if (!oldT.get(j).lm || !oldT.get(j).rm) continue;
253                            int s = old_to_new.get(j);
254                            newT.set(s , oldT.get(j) );
255                            
256                            int p = oldT.get(j).p;  double dist=oldT.get(j).d;
257                            while ( p!= -1 ) {
258                                    if (oldT.get(p).lm && oldT.get(p).rm) break;
259                                    dist += oldT.get(p).d;
260                                    p = oldT.get(p).p;
261                            }
262                            if (p==-1) {
263                                    newT.get(s).p=-1; newT.get(s).d=0;
264                            } else {
265                                    newT.get(s).p=old_to_new.get(p);  newT.get(s).d=dist;
266                            }
267                            
268                            
269                            int l = oldT.get(j).l;
270                            if (l == -1) {
271                                    newT.get(s).l = -1;
272                            } else {
273                                    while (!oldT.get(l).lm || !oldT.get(l).rm) {
274                                            Assert.a (oldT.get(l).lm || oldT.get(l).rm);
275                                            if (oldT.get(l).lm) { l = oldT.get(l).l; continue; }
276                                            if (oldT.get(l).rm) { l = oldT.get(l).r; continue; }
277                                    }
278                                    newT.get(s).l = old_to_new.get(l);
279                            }
280                            
281                            int r = oldT.get(j).r;
282                            if (r == -1) {
283                                    newT.get(s).r = -1;
284                            } else {
285                                    while (!oldT.get(r).lm || !oldT.get(r).rm) {
286                                            Assert.a (oldT.get(r).lm || oldT.get(r).rm);
287                                            if (oldT.get(r).lm) { r = oldT.get(r).l; continue; }
288                                            if (oldT.get(r).rm) { r = oldT.get(r).r; continue; }
289                                    }
290                                    newT.get(s).r = old_to_new.get(r);
291                            }
292                    }
293                    
294                    RootedBinaryPhylogeneticTree subRBPT = new RootedBinaryPhylogeneticTree(newT);
295                    return subRBPT;
296            }
297            
298            /////////////////////////////////////////////////////////
299            // Internal private functions:
300    
301    //      private Boolean containsSpecies(String name) {
302    //              Boolean ret = false;
303    //              for (int j=0; j<T.size(); j++) {
304    //                      if (T.get(j).n == name) {
305    //                              ret = true;
306    //                      }
307    //              }
308    //              return ret;
309    //      }
310    
311            private Integer getSpeciesIndex(String name) {
312                    Integer ret = -1;
313                    for (int j=0; j<T.size(); j++) {
314                            String temp =  T.get(j).n;
315                            //System.out.println("This name is " + temp);
316                            if ( name.equals(temp) ) {
317                                    ret = j;
318                                    //System.out.println("  Species " + j + " = " + name);
319                            } else {
320                                    //System.out.println("  --->" + name + "<--- and --->" + temp + "<--- are not equal");
321                            }
322                    }
323                    Assert.a(ret != -1," Could not find species " + name);
324                    return ret;
325            }
326            
327            public Set<String> getSpeciesSet() {
328                    Set<String> ret = new HashSet<String>();
329                    for (int j=0; j<T.size(); j++) {
330                            BinaryTreeNode btn = T.get(j);
331                            if (btn.l == -1) { // this node has no left child
332                                    Assert.a(btn.r == -1); // if no left child, then its a leaf and hence also no right child either.
333                                    String name = btn.n;
334                                    Assert.a(name != ""); // I require that each leaf node have a name.
335                                    Assert.a(!ret.contains(name)); // I want no repitition; if same species is listed in tree twice then fail here
336                                    ret.add(name);
337                            }
338                    }
339                    return ret;
340            }
341            
342            private String newick_recursion( int top ) {
343                    String s = "";  
344                    
345                    if (T.get(top).l == -1) {
346                            Assert.a( T.get(top).r == -1);
347                            s = s + T.get(top).n;
348                    } else {
349                            s = s + "(";
350                            s = s + newick_recursion( T.get(top).l );
351                            s = s + ",";
352                            s = s + newick_recursion( T.get(top).r );
353                            s = s + ")";
354                    }
355                    s = s + ":" + T.get(top).d;
356                    return s;
357            };
358            
359            private void growBranch( ArrayList<BinaryTreeNode> T1, int parent, String S ) {
360                    
361                    int currentNode = T1.size();
362                    
363                    int x = S.lastIndexOf(":");
364                    
365                    Assert.a( x != -1, "we require that all nodes in Newick format have a branch length" );  
366                    Assert.a( x != S.length()-1 );
367                    
368                    double dist = Double.parseDouble(S.substring(x+1));
369                    
370                    String S2 = S.substring(0,x);
371                    String S3;
372                    
373                    if (S2.charAt(0)=='(') {
374                            Assert.a (S2.charAt(S2.length()-1) == ')');
375                            S3 = S2.substring(1,S2.length()-1);
376                    } else {
377                            S3 = S2;
378                    }
379                    
380                    int depth=0;
381                    for (int j=0; j<(S3.length()-1); j++) {
382                            if (S3.charAt(j) == '(') depth++;
383                            if (S3.charAt(j) == ')') depth--;
384                            if ((depth == 0) && (S3.charAt(j)==',')) {
385                                    String leftString = S3.substring(0,j);
386                                    String rightString = S3.substring(j+1,S3.length());
387                                    BinaryTreeNode TN = new BinaryTreeNode( parent, -1,-1, dist, "");
388                                    T1.add(TN);
389                                    T1.get(currentNode).l = T1.size();
390                                    growBranch(T1,currentNode,leftString);
391                                    T1.get(currentNode).r = T1.size();
392                                    growBranch(T1,currentNode,rightString);
393                                    return;
394                            }
395                    }
396                    Assert.a(depth==0);
397                    
398                    BinaryTreeNode T2 = new BinaryTreeNode( parent,-1,-1,dist,S3);
399                    T1.add(T2);
400                    return;
401            }       
402            
403            private void weaklyValidateNewickString(String ss) {
404                    int depth=0;
405                    Assert.a( ss.charAt(ss.length()-1) == ';', "Invalid newick tree: "+ss);
406                    for (int j=0; j<(ss.length()-1); j++) {
407                            if (ss.charAt(j) == '(') depth++;
408                            if (ss.charAt(j) == ')') depth--;
409                            Assert.a(depth >= 0);
410                    }
411                    Assert.a(depth==0);
412            }
413    
414            public int nSpecies() {
415                    int m = T.size(); // should be m=2*n-1, where n is number of species
416                    Assert.a((m%2)==1);
417                    return m/2;
418            }
419    
420            
421    }