import sys, re, argparse
import random
from Bio import Phylo
from cStringIO import StringIO

def lookup_by_names(tree):
    names = {}
    for clade in tree.find_clades():
        if clade.name:
            if clade.name in names:
                raise ValueError("Duplicate key: %s" % clade.name)
            names[clade.name] = clade
    return names

def TransformTree(treedata, numberOfInternalNode):
    """
    Give a number to each internalNode
    """
    mytree=treedata
    for i in range(1, numberOfInternalNode+1):
        mytree=re.sub('\(', '#'+str(i)+'#', mytree, count=1)
    for j in reversed(range(1, numberOfInternalNode+1)):
        mytree=re.sub('(#'+str(j)+'#.*?)\)', '\g<1>#'+str(j)+'#', mytree)
    
    return mytree

def GetLeafBase(input):
    pass

def GetChildren(tree, numberOfInternalNode):
    mytree=tree
    node2children={}
    # get node2children
    pattern=re.compile('#(\d+)#')
    for i in reversed(range(1, numberOfInternalNode+1)):
        k=str(i)
        node2children[k]=[]
        arr=mytree.split('#'+k+'#')
        children=arr[1]
        parts=children.split(',')
        for j in parts:
            m=pattern.search(j)
            if m:
                node2children[k].append(str(m.group(1)))
            else:
                node2children[k].append((j.split(':'))[0])
        mytree=re.sub('#'+k+'#'+children+'#'+k+'#', '#'+k+'#', mytree)
    
    return node2children

def GetLeaves(tree, numberOfInternalNode, leaves):
    mytree=tree
    node2leaves={}
    # get node2leaves
    for i in reversed(range(1, numberOfInternalNode+1)):
        k=str(i)
        node2leaves[k]=[]
        arr=mytree.split('#'+k+'#')
        children=arr[1]
        for l in leaves:
            if re.search(l, children):
                node2leaves[k].append(l)
    return node2leaves

def GetNodeBase(node, node2base, leaf2base, leaves):
    if node in leaves:
        try:
            return leaf2base[node]
        except:
            print >> sys.stderr, 'Leaves don\'t match genotype, quit..'
            print >> sys.stderr, node
            sys.exit()
    else:
        return node2base[node]

def ForwardPass(internode2base, internode2children, leaf2base, numberOfInternalNode, leaves):
    z=0
    for i in reversed(range(1, numberOfInternalNode+1)):
        k=str(i)
        try: 
            u, w = internode2children[k]
        except:
            print >> sys.stderr, 'Please use bifurcate tree, quit..'
            sys.exit()
        su, sw = '',''
        su = GetNodeBase(u, internode2base, leaf2base, leaves)
        sw = GetNodeBase(w, internode2base, leaf2base, leaves)
        if su == ['N'] and sw == ['N']: 
            internode2base[k] = ['N']
            continue
        elif su == ['N'] and sw != ['N']: 
            internode2base[k] = sw
            continue
        elif su != ['N'] and sw == ['N']: 
            internode2base[k] = su
            continue
        if list(set(su) & set(sw)) == []:
            internode2base[k] = list(set(su + sw))
            z+=1
        else:
            internode2base[k] = list(set(su) & set(sw))
    return internode2base.copy(), z    
        
def BackwardPass(internode2base, internode2children, leaf2base, numberOfInternalNode, leaves):
    # Randomly choose a base for the root node
    if 'N' in internode2base['1']:
        print >> sys.stderr, "Lines with all missing data detected, skip..."
    else:
        internode2base['1'] = random.choice(internode2base['1'])

    # iterate each internal node
    for i in range(1, numberOfInternalNode+1):
        k=str(i)
        pbase = internode2base[k]
        if pbase == ['N']: del internode2base[k] # if internode is 'N', skip whole branch
        u, w = internode2children[k]
        su, sw = '',''
        if u in leaves: su = leaf2base[u]
        else: su = internode2base[u]
        if w in leaves: sw = leaf2base[w]
        else: sw = internode2base[w]

        if su != ['N'] and u not in leaves:
            if pbase in su: internode2base[u] = pbase
            else: internode2base[u] = internode2base[u][0]
            
        if sw != ['N'] and w not in leaves:
            if pbase in sw: internode2base[w] = pbase
            else: internode2base[w] = internode2base[w][0]

def CountMutation(internode2base, internode2nmut, internode2children, leaf2base, numberOfInternalNode):
    for i in range(1, numberOfInternalNode+1):
        k=str(i)
        u, w = internode2children[k]
        if u in internode2base and internode2base[k] != internode2base[u]: internode2nmut[u]+=1 
        if w in internode2base and internode2base[k] != internode2base[w]: internode2nmut[w]+=1
        if u in leaf2base and leaf2base[u] != ['N'] and internode2base[k] != leaf2base[u][0]: internode2nmut[u]+=1
        if w in leaf2base and leaf2base[w] != ['N'] and internode2base[k] != leaf2base[w][0]: internode2nmut[w]+=1
        

def OutputDebug(tree, internode2children, internode2leaves, internode2baseAfterForwardPass, internode2base, internode2nmut, score):
    print >> sys.stderr, '******Tree Info******'
    print >> sys.stderr, tree
    print >> sys.stderr, internode2children
    print >> sys.stderr, internode2leaves
    print >> sys.stderr, internode2baseAfterForwardPass
    print >> sys.stderr, internode2base
    print >> sys.stderr, internode2nmut
    print >> sys.stderr, score
    print >> sys.stderr, '******Tree Info******'
    
def OutputResults(internode2nmut,args):
    fw=open(args.outputFile[0], 'w')
    for k, v in sorted(internode2nmut.items(), key=lambda x: x[1], reverse=True):
        fw.write(k+'\t'+str(v)+'\n')
    fw.close()

def OutputTree(tree,internode2nmut, numberOfInternalNode, leaves, args):
    mytree=tree
    for i in range(1, numberOfInternalNode+1):
        k=str(i)
        pattern='#'+k+'#'
        mytree=re.sub(pattern, '(', mytree, count=1)
        pattern2='#'+k+'#'+'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?'
        mytree=re.sub(pattern2, ')'+str(internode2nmut[k]), mytree, count=1)
    
    for k, v in internode2nmut.items():
        if k not in leaves: continue
        mytree = re.sub(k, k+'  '+str(v), mytree, count=1)
    fw=open(args.outputFile[0], 'w')
    fw.write(mytree)
    fw.close()
    

def main(argv):
    treeFile=''
    idFile=''
    genotypeFile=''
    outFile=''
    zScore=0
    parser = argparse.ArgumentParser(prog='SNPtrack.py', description='Track the origin of mutations on the phylogenetic tree.')
    parser.add_argument('-z', nargs='?', default='0', type=int, help='Number of mutations on the tree.')
    parser.add_argument('newickTree', nargs=1, help='Tree file in Newick format.')
    parser.add_argument('genotypeFile', nargs=1, help='File with the genotype of each sample on the tree.')
    parser.add_argument('sampleID', nargs=1, help='File with the IDs of the samples on the tree.')
    parser.add_argument('outputFile', nargs=1, help='File for output')
    args=parser.parse_args(argv)

    fp = open(args.newickTree[0]) # tree.codeONLY.nwk
    treedata = fp.readline().strip()
    fp.close()
    handle = StringIO(treedata)
    tree = Phylo.read(handle, "newick")
    names = lookup_by_names(tree)
    leaves = names.keys()
    numberOfInternalNode=treedata.count('(')
    tree=TransformTree(treedata, numberOfInternalNode)
    
    internode2nmut={}
    for i in range(1, numberOfInternalNode+1): internode2nmut[str(i)]=0
    for i in leaves: internode2nmut[i]=0
    internode2children={}
    internode2children=GetChildren(tree, numberOfInternalNode)
    internode2leaves={}
    internode2leaves=GetLeaves(tree, numberOfInternalNode, leaves)
    
    sample2index={}
    samples=[]
    fp=open(args.sampleID[0]) # Sample_id.txt
    c=5
    for line in fp:
        c+=1
        sid = line.strip()
        samples.append(sid)
        sample2index[sid]=c
    fp.close()
    
    leaf2base={}
    fp=open(args.genotypeFile[0]) # multiple_sample_SNPs_2ndversion_MAF0.01_striking.txt
    for line in fp:
        arr=line.strip().split('\t')
        for sid in samples:
            gt=arr[sample2index[sid]]
            if len(gt) > 1:
                leaf2base[sid] = ['N']
            else:
                leaf2base[sid] = [gt]
        # Fitch algorithm - Forward Pass
        internode2base={}
        internode2baseAfterForwardPass, score = ForwardPass(internode2base, internode2children, leaf2base, numberOfInternalNode, leaves)
        if args.z != 0 and score > args.z: continue # output given z-score
        # Fitch algorithm - Backward Pass
        BackwardPass(internode2base, internode2children, leaf2base, numberOfInternalNode, leaves)
        CountMutation(internode2base, internode2nmut, internode2children, leaf2base, numberOfInternalNode)
        #OutputDebug(tree, internode2children, internode2leaves, internode2baseAfterForwardPass, internode2base, internode2nmut, score)
    #OutputResults(internode2nmut,args)
    OutputTree(tree, internode2nmut, numberOfInternalNode, leaves, args)
if __name__ == "__main__":
    main(sys.argv[1:])
