import sys, re
from Bio import Phylo
from StringIO import StringIO

id2pind={'ptr':0,'gma':1,'ath':4,'gra':5,'vvi':7,'sly':10,'ugi':11,'nnu':13,'bdi':16,'tur':17,'osa':19,'zma':22,'sbi':23,'sit':25,
         'mac':27,'pab':29,'smo':31,'ppa':33,
          1:2,5:3,3:6,7:8,11:9,9:12,13:14,17:18,19:20,27:15,23:24,25:21,21:26,15:28,29:30,31:32}
parentalNode={'ptr':1,'gma':1,'ath':5,'gra':5,'vvi':7,'sly':11,'ugi':11,'nnu':13,'bdi':17,'tur':17,'osa':19,'zma':23,'sbi':23,'sit':25,
              'mac':27,'pab':29,'smo':31,'ppa':33,
               1:3,5:3,3:7,7:9,11:9,9:13,13:15,17:19,19:21,27:15,23:25,25:21,21:27,15:29,29:31,31:33}
cafeTree='(((((((((ptr:0,gma:2):1,(ath:4,gra:6):5):3,vvi:8):7,(sly:10,ugi:12):11):9,nnu:14):13,((((bdi:16,tur:18):17,osa:20):19,((zma:22,sbi:24):23,sit:26):25):21,mac:28):27):15,pab:30):29,smo:32):31,ppa:34):33'
branchLength={}
expansion={'ptr':0,'gma':0,'ath':0,'gra':0,'vvi':0,'sly':0,'ugi':0,'nnu':0,'bdi':0,'tur':0,'osa':0,'zma':0,'sbi':0,'sit':0,
         'mac':0,'pab':0,'smo':0,'ppa':0,
          1:0,5:0,3:0,7:0,11:0,9:0,13:0,17:0,19:0,27:0,23:0,25:0,21:0,15:0,29:0,31:0}
contraction={'ptr':0,'gma':0,'ath':0,'gra':0,'vvi':0,'sly':0,'ugi':0,'nnu':0,'bdi':0,'tur':0,'osa':0,'zma':0,'sbi':0,'sit':0,
         'mac':0,'pab':0,'smo':0,'ppa':0,
          1:0,5:0,3:0,7:0,11:0,9:0,13:0,17:0,19:0,27:0,23:0,25:0,21:0,15:0,29:0,31:0}
expansion_clusterID={'ptr':[],'gma':[],'ath':[],'gra':[],'vvi':[],'sly':[],'ugi':[],'nnu':[],'bdi':[],'tur':[],'osa':[],'zma':[],'sbi':[],'sit':[],
         'mac':[],'pab':[],'smo':[],'ppa':[],
          1:[],5:[],3:[],7:[],11:[],9:[],13:[],17:[],19:[],27:[],23:[],25:[],21:[],15:[],29:[],31:[]}
contraction_clusterID={'ptr':[],'gma':[],'ath':[],'gra':[],'vvi':[],'sly':[],'ugi':[],'nnu':[],'bdi':[],'tur':[],'osa':[],'zma':[],'sbi':[],'sit':[],
         'mac':[],'pab':[],'smo':[],'ppa':[],
          1:[],5:[],3:[],7:[],11:[],9:[],13:[],17:[],19:[],27:[],23:[],25:[],21:[],15:[],29:[],31:[]}

# get innerNodeId
innerNodeId=[]
handle = StringIO(cafeTree)
cafeTree = Phylo.read(handle, "newick")
for node in cafeTree.get_nonterminals():
    innerNodeId.append(int(node.branch_length))


fp=open('cafe_18_monocods_output.cafe')
#fp=open('cafe_18g_mcl_output.cafe')
myfile=fp.read().splitlines()
for i in range(11): del myfile[0] # remove header
fp.close()
for line in myfile:
    arr=line.split('\t')
    tree=arr[1]
    tree=re.sub(':\d+','',tree)
    tree=tree.replace('_',':')
    handle = StringIO(tree)
    tree = Phylo.read(handle, "newick")
    p=re.sub('[\(\)]','',arr[3])
    pvalues=p.split(',')
    pvalues=[float(i) for i in pvalues]
    
    #get branch length
    c=-1
    for leaf in tree.get_terminals():
        branchLength[leaf.name]=leaf.branch_length
    for node in tree.get_nonterminals():
        c+=1
        branchLength[innerNodeId[c]]=node.branch_length
        
    
    # get results
    c=-1
    for leaf in tree.get_terminals():
        if pvalues[id2pind[leaf.name]] > 0.01: continue
        pid=parentalNode[leaf.name]
        if branchLength[leaf.name] > branchLength[pid]:
            expansion[leaf.name]+=1
            expansion_clusterID[leaf.name].append(arr[0])
        else:
            contraction[leaf.name]+=1
            contraction_clusterID[leaf.name].append(arr[0])
    for node in tree.get_nonterminals():
        c+=1
        if c==0: continue
        if pvalues[id2pind[innerNodeId[c]]] > 0.01: continue
        pid=parentalNode[innerNodeId[c]]
        if branchLength[innerNodeId[c]] > branchLength[pid]:
            expansion[innerNodeId[c]]+=1
            expansion_clusterID[innerNodeId[c]].append(arr[0])
        else:
            contraction[innerNodeId[c]]+=1
            contraction_clusterID[innerNodeId[c]].append(arr[0])   
    #sys.exit()
    
for bid, expdNum in expansion.items():
    #print str(bid)+'\t'+str(expdNum)+'/'+str(contraction[bid])
    #print 'ExpansionClusters:\t'+'\t'.join(expansion_clusterID[bid])
    #print 'ContractionClusters:\t'+'\t'.join(contraction_clusterID[bid])
    if str(bid) =='osa':
        print str(expdNum)
        print len(expansion_clusterID[bid])
        print 'ExpansionClusters:\t'+'\t'.join(expansion_clusterID[bid])