#!/usr/bin/env python

import sys
import ete2
import copy
import argparse

parser = argparse.ArgumentParser(description= "arraySynteny: A program to find tandem arrays, establish synteny between tandem array loci and use synteny \
                        and phylostratigraphic analysis to date the tandem arrays",
                        epilog = 'To replicate McKenzie and Kronauer 2018, use program as: python arraySynteny.py --ortholog_file orthologs.txt \
                                 --gene_id_key Or --exclude_gene_ids Orc1,Orc4,Org-1 --tre SyntenyTree.tre --outgroup NvitOr1 --gffs Nvit.SyntenyAnalysis.gff \
                                 Amel.SyntenyAnalysis.gff Sinv.SyntenyAnalysis.gff Obir.OGS4.0.gff --species_list Nvit,Amel,Sinv,Obir')

parser.add_argument('--ortholog_file', help = 'orthologs file output from orthomcl')
parser.add_argument('--gene_id_key', help = 'text that appears in the IDs of the focal gene family but not in other gene families')
parser.add_argument('--exclude_gene_ids', help = 'comma seperated list of gene ids that would match the gene_id_key but should not be considered part of the focal gene family')
parser.add_argument('--tree', help = 'phylogenetic tree of genes of interest. terminal names should be preeceded by a four \
                    letter species code as used in gff file names and --species_list')
parser.add_argument('--outgroup', default = False, help = 'name of tree terminal to root tree (if not already correctly rooted)')
parser.add_argument('--gffs', nargs='*', help = 'gff files for each species. File name should be formated as "Species_code.X.gff"')
parser.add_argument('--species_list', help = 'comma seperated list of species codes, ordered oldest diverging to youngest. Note that \
                    this program can currently only handle cases of (Sp1,(Sp2,(Sp3,Sp4)))')


args = parser.parse_args()
splist = args.species_list.split(',')
#setup orthodict
orthologs_file = args.ortholog_file
orthodict = {}
for line in open(orthologs_file):
    genes = line.split('\t')
    if genes[0][5:] in orthodict:
        orthodict[genes[0][5:]].append(genes[1][5:])
    else:
        orthodict[genes[0][5:]] = [genes[1][5:]]
    if genes[1][5:] in orthodict:
        orthodict[genes[1][5:]].append(genes[0][5:])
    else:
        orthodict[genes[1][5:]] = [genes[0][5:]]


my_tree = ete2.Tree(open(args.tree).read())
if args.outgroup:
    tree_outgroup = my_tree.get_leaves_by_name(args.outgroup)[0]
    my_tree.set_outgroup(tree_outgroup)


gff_files = args.gffs

genes = {}

#process gff
for gff in gff_files:
    spname = gff.split('.')[0]
    genes[spname] = {}
    gff_file = open(gff)
    for line in gff_file:
        if '\tgene\t' in line or '\tmatch\t' in line:
            fields = line.split('\t')
            seqid = fields[0].split()[0]
            gene_ID = fields[8].split(';')[0].split('=')[1].replace('\n','').replace('\r','').split('-against-')[0]
            coords = (int(fields[3]),int(fields[4]))
            if not seqid in genes[spname]:
                genes[spname][seqid] = {}
            genes[spname][seqid][coords] = gene_ID


gene_lists = []
arrays = {}
nodes_in_arrays = {}
array_ends = {}

#debug
errorlist = []

for sp in genes:
    arrays[sp] = []
    nodes_in_arrays[sp] = []
    array_ends[sp] = []
    for seqid in genes[sp]:
        sorted_coords = list(genes[sp][seqid])
        sorted_coords.sort()
        isOA = False
        gene_lists.append('chromeBreak')
        notORcount = 0
        for coords in sorted_coords:
            name = genes[sp][seqid][coords]
            gene_lists.append(name)
            if args.gene_id_key in name and not name in args.exclude_gene_ids.split(','):
                #is an OR
                #if notORcount > 0:
                #    print gene_lists[-2]
                notORcount = 0
                if not isOA:
                    isOA = True
                    arrays[sp].append([name])
                    try:
                        nodes_in_arrays[sp].append([my_tree.get_leaves_by_name(name)[0]])
                    except:
                        errorlist.append(name)
                        nodes_in_arrays[sp].append([])
                    array_ends[sp].append([len(gene_lists) - 1])
                else:
                    arrays[sp][-1].append(name)
                    try:
                        nodes_in_arrays[sp][-1].append(my_tree.get_leaves_by_name(name)[0])
                    except:
                        errorlist.append(name)
            elif isOA:
                notORcount = notORcount + 1
                if notORcount > 1:
                    isOA = False
                    array_ends[sp][-1].append(len(gene_lists) - 3)
        if isOA:
            array_ends[sp][-1].append(len(gene_lists) - 1)
        gene_lists.append('chromeBreak')
    gene_lists.append('spBreak')

count = 0

array_list = []
for sp in arrays:  
    for array_index in range(len(arrays[sp])):
        astart = array_ends[sp][array_index][0]
        astop = array_ends[sp][array_index][1]
        lower5 = gene_lists[astart - 5:astart]
        upper5 = gene_lists[astop + 1: astop + 6]
        lower_ortho_list = []
        for gene in lower5:
            try:
                lower_ortho_list = lower_ortho_list + orthodict[gene]
            except:
                pass
        upper_ortho_list = []
        for gene in upper5:
            try:
                upper_ortho_list = upper_ortho_list + orthodict[gene]
            except:
                pass
        #print lower_ortho_list
        #print arrays[array_index]
        #print upper_ortho_list
        #print ""
        lflank = []
        uflank = []
        for sp2 in arrays:
            if sp2 != sp:
                for not_this_array in range(len(arrays[sp2])):                
                    lfind = False
                    ufind = False
                    lmaybe = False
                    umaybe = False
                    a2start = array_ends[sp2][not_this_array][0]
                    a2stop = array_ends[sp2][not_this_array][1]
                    lower5_2 = gene_lists[a2start - 5:a2start]
                    upper5_2 = gene_lists[a2stop + 1: a2stop + 6]
                    if 'chromeBreak' in lower5_2:
                        lower5_2 = lower5_2[lower5_2.index('chromeBreak') + 1:]
                    if 'chromeBreak' in upper5_2:
                        upper5_2 = upper5_2[:upper5_2.index('chromeBreak')]
                    if len(set(lower5_2) & set(lower_ortho_list)) > 2 or len(set(lower5_2) & set(upper_ortho_list)) > 2:
                        lflank.append(sp2 + str(not_this_array + 1))
                        lfind = True
                    elif len(set(lower5_2) & set(lower_ortho_list)) > 0 or len(set(lower5_2) & set(upper_ortho_list)) > 0:
                        lmaybe = True
                    if len(set(upper5_2) & set(lower_ortho_list)) > 2 or len(set(upper5_2) & set(upper_ortho_list)) > 2:
                        uflank.append(sp2 + str(not_this_array + 1))
                        ufind = True
                    elif len(set(upper5_2) & set(lower_ortho_list)) > 0 or len(set(upper5_2) & set(upper_ortho_list)) > 0:
                        umaybe = True
                    if ufind and not lfind:
                        if lmaybe:
                            lflank.append(sp2 + str(not_this_array + 1))
                    elif lfind and not ufind:
                        if umaybe:
                            uflank.append(sp2 + str(not_this_array + 1))
                    elif umaybe and lmaybe:
                        lflank.append(sp2 + str(not_this_array + 1))
                        uflank.append(sp2 + str(not_this_array + 1))
        homologous_arrays = False
        if len(uflank) > 0:
                homologous_arrays = True
        if len(lflank) > 0:
            homologous_arrays = True
        if homologous_arrays:
            count = count + 1
        if nodes_in_arrays[sp][array_index] != []:
            genes_in_clade = ",".join(my_tree.get_common_ancestor(nodes_in_arrays[sp][array_index]).get_leaf_names())
        else:
            clade_age = "PseudosOnly"
        max_age = 0
        if len(arrays[sp][array_index]) == 1:
            clade_age = "singleton"
            max_age = 0
        elif splist[0] in genes_in_clade:
            if splist[1] in genes_in_clade or splist[2] in genes_in_clade or splist[3] in genes_in_clade:
                clade_age = "3"
                max_age = 3
            else:
                clade_age = "species"
                max_age = 0
        elif splist[1] in genes_in_clade:
            if splist[2] in genes_in_clade or splist[3] in genes_in_clade:
                clade_age = "2"
                max_age = 2
            else:
                clade_age = "species"
                max_age = 0
        elif splist[2] in genes_in_clade and splist[3] in genes_in_clade:
            clade_age = "3"
            max_age = 1
        else:
            clade_age = "species"
            max_age = 0
        synteny_age = 0
        synteny_string = ",".join(lflank) + "," + ",".join(uflank) + "," + sp
        if splist[0] in synteny_string:
            if splist[1] in synteny_string or splist[2] in synteny_string or splist[3] in synteny_string:
                synteny_age = 3
            else:
                synteny_age = 0
        elif splist[1] in synteny_string:
            if splist[2] in synteny_string or splist[3] in synteny_string:
                synteny_age = 2
            else:
                synteny_age = 0
        elif splist[2] in synteny_string and splist[3] in synteny_string:
            synteny_age = 1
        else:
            synteny_age = 0
        if max_age < synteny_age:
            max_age = synteny_age
        array_list.append([max_age,sp + '\t' + str(array_index + 1) + '\t' + clade_age  + '\t' + ','.join(lflank) + '\t' + ','.join(uflank),synteny_age])

#transitive ager
ortho_array_list = []
for array in array_list:
    array_fields = array[1].split('\t')
    max_age = array[0]
    syn_age = array[2]
    arrays_in_group_list = [array_fields[0] + array_fields[1]]
    if array_fields[3] != "":
        arrays_in_group_list = arrays_in_group_list + array_fields[3].split(",")
    if array_fields[4] != "":
        arrays_in_group_list = arrays_in_group_list + array_fields[4].split(",")
    arrays_in_group = set(arrays_in_group_list)
    array_found = False
    for ortho_array_index in range(len(ortho_array_list)):
        ortho_array = ortho_array_list[ortho_array_index]
        if len(ortho_array[0] & arrays_in_group) > 0:
            ortho_array_list[ortho_array_index][0] = ortho_array[0] | arrays_in_group
            array_found = True
            if max_age > ortho_array[1]:
                ortho_array_list[ortho_array_index][1] = max_age
            if syn_age > ortho_array[2]:
                ortho_array_list[ortho_array_index][2] = syn_age
    if not array_found:
        ortho_array_list.append([copy.deepcopy(arrays_in_group),max_age,syn_age])

age_dict = {}
for ortho_array in ortho_array_list:
    for array in ortho_array[0]:
        age_dict[array] = [ortho_array[1],ortho_array[2]]

for array_index in range(len(array_list)):
    array = array_list[array_index]
    name = "".join(array[1].split('\t')[:2])
    if age_dict[name][0] > array[0]:
        array_list[array_index][0] = age_dict[name][0]
    if age_dict[name][1] > array[2]:
        array_list[array_index][2] = age_dict[name][1]


    

##Old, less efficient transitive ager
# num_updates = 1
# while num_updates > 0:
#     num_updates = 0
#     array_age_dict = {}
#     for array_index in range(len(array_list)):
#         array = array_list[array_index]
#         array_fields = array[1].split('\t')
#         array_name = "".join(array_fields[:2])
#         max_age = array[0]
#         syn_age = array[2]
#         flank_list = array_fields[3].split(',') + array_fields[4].split(',')
#         for flanking_array in flank_list:
#             if flanking_array in array_age_dict:
#                 if array_age_dict[flanking_array][0] > max_age:
#                     max_age = array_age_dict[flanking_array][0]
#                     array_list[array_index][0] = max_age
#                     num_updates += 1
#                 # if array_age_dict[flanking_array][1] > syn_age:
#                 #     syn_age = array_age_dict[flanking_array][1]
#                 #     array_list[array_index][2] = syn_age
#                 #     num_updates += 1
#         array_age_dict[array_name] = [max_age,syn_age]
#     for array_index in range(len(array_list)):
#         array = array_list[array_index]
#         array_fields = array[1].split('\t')
#         array_name = "".join(array_fields[:2])
#         max_age = array[0]
#         flank_list = array_fields[3].split(',') + array_fields[4].split(',')
#         for flanking_array in flank_list:
#             if flanking_array in array_age_dict:
#                 if array_age_dict[flanking_array][0] > max_age:
#                     max_age = array_age_dict[flanking_array][0]
#                     array_list[array_index][0] = max_age
#                     num_updates += 1
#                 # if array_age_dict[flanking_array][1] > syn_age:
#                 #     syn_age = array_age_dict[flanking_array][1]
#                 #     array_list[array_index][2] = syn_age
#                 #     num_updates += 1

print "#Array ages and synteny information"
print "Species\tArray_number\tpylostratigraphy_age\tarrays_with_synteny_to_upstream_flank\tarrays_with_synteny_to_downstream_flank\
\tall_evidence_age\tsynteny_evidence_age"
for array in array_list:
    print array[1] + '\t' + str(array[0]) + '\t' + str(array[2])

print "#Gene array assignments"
for sp in arrays:
    for array in arrays[sp]:
        for gene in array:
            print gene + '\t' + str(arrays[sp].index(array) + 1)
        
    
        

    

 
    


    


    


