# -*- coding: utf-8 -*-
from cmath import nan
from signal import pthread_kill
from ete3 import Tree, TreeStyle, TextFace,  CircleFace, NodeStyle,BarChartFace,RectFace
import pandas as pd
import numpy as np
import sys
import random

date="20220902"

attributes_table = "data/metadata/attributes_Seq_dataonline_20220630.tsv"
#env_table = "../Pluvio_phylogeography/df_taxonomy_and_env_dataset_species_with_flagged_suspicious_coor.tsv"
#env_table = "data/df_taxonomy_and_env_dataset_species_with_flagged_suspicious_coor.tsv"
fam_color = "data/col_taxonomy.csv"
coldata_expr = "data/metadata/coldata_localandonline_info20220630.tsv"


df_attribut_tsv = pd.read_csv(attributes_table, sep="\t")
df_col_fam = pd.read_csv(fam_color, sep=",")


col_anc="#A2A2A2"   # "black"   #"#A2A2A2" (gris thèse)
col_conv= "#1A1A1A" #  "orange"    #"#058C45" (vert thèse)
col_trans="#1A1A1A"  # "red"
width_branch = 6
size_leaves = 20


bio17_yes = True

col_grad1 =  "#DCDCDC" # "#F6F2C8"
col_grad2 =  "#DCDCDC" # "#D3FFBA"
col_grad3 =  "#DCDCDC" # "#B4D7FF"
col_grad1 =  "#FFA500" # "#F6F2C8"
col_grad2 =  "#5EA336" # "#D3FFBA"
col_grad3 =  "#B4D7FF"

legend = True
add_fam_col = True

tree_file_name="data/tree/local_online_species_tree_rep_8_20200605.raxml.bestTree"

discarded_species = []

for dataset_type in ["sequences","expression"]: # ["all","expression"]:
    dataset_x = "localonline"
    for dataset_name in ["total", "murinae", "recent_trans", "ancient_trans"]: # " "nov2020_restricted_old-4sp","nov2020_restricted_recent-4sp-spermo","nov2020_restricted_recent-4sp-macedo"]:
        path_figure=date+"_figures_output_"+dataset_x+"_species/"
        print("tree_file_name")
        print(tree_file_name)

        print("dataset_x")
        print(dataset_x)

        df_coldata_expr = pd.read_csv(coldata_expr, sep="\t")
        df_coldata_expr = df_coldata_expr[df_coldata_expr["exp_"+dataset_name] == "yes"]
        df_coldata_expr["Species_ok"] = df_coldata_expr["ID"].str.replace('[a-z]*_[a-z]*_[0-9A-Z]*_', '', n=1,regex=True)
        df_coldata_expr["Species_ok"] = df_coldata_expr["Species_ok"].str.replace('2$', '', n=1,regex=True)

        df_coldata_expr_total = pd.read_csv(coldata_expr, sep="\t")
        df_coldata_expr_total = df_coldata_expr_total[df_coldata_expr_total["to_be_used"] == "yes"]
        df_coldata_expr_total["Species_ok"] = df_coldata_expr_total["ID"].str.replace('[a-z]*_[a-z]*_[0-9A-Z]*_', '', n=1,regex=True)
        df_coldata_expr_total["Species_ok"] = df_coldata_expr_total["Species_ok"].str.replace('2$', '', n=1,regex=True)


        t=Tree(tree_file_name)

        # prune l arbre pour le dataset X

        df_attribut_tsv_dataset_x = df_attribut_tsv.loc[(df_attribut_tsv["total"] == "yes" ) | (df_attribut_tsv["total"] == "only_s")]
        
        
        sp_leaves = [n.name for n in t.get_leaves()]
        
        print("sp_leaves")
        print(sp_leaves)

        sp_att = list(df_attribut_tsv_dataset_x["Species"].values)
        sp_att=[x.replace(" ","_") for x in sp_att]

        for node in t.traverse():
            if node.is_leaf():
                node.raw_name = node.name

       
        if not set(sp_leaves) == set(sp_att):
            dict_df_coor_names = df_attribut_tsv.to_dict('records')
            
            for rows in dict_df_coor_names:
                sp_t_row   = rows["leaf_name"]
                sp_att_row = rows["Species"].replace(" ","_")
                #print(sp_t_row)
                #print(sp_att_row)
                try:
                    node_row   = t.search_nodes(name=sp_t_row)[0]
                except:
                    node_row   = t.search_nodes(name=sp_att_row)[0]
                
                node_row.name = sp_att_row

        #if dataset_name  != "total":
        sp_restricted = df_attribut_tsv_dataset_x.loc [df_attribut_tsv_dataset_x[dataset_name].isin(["yes","only_s"]), ["Species"]].values
        sp_restricted = [item for sublist in sp_restricted for item in sublist]
        sp_leaves2 = [n.name for n in t.get_leaves()]
        sp_sp_restricted2 = [sp for sp in sp_restricted if sp.replace(" ","_") in sp_leaves2]
        print("sp_restricted")
        t.prune([n.replace(" ","_") for n in sp_sp_restricted2], preserve_branch_length = True)
        
        #print(t)
        sp_expr = df_attribut_tsv_dataset_x.loc [df_attribut_tsv_dataset_x[dataset_name] == "yes", ["Species"]].values
        sp_expr = [item for sublist in sp_expr for item in sublist]
        sp_leaves2 = [n.name for n in t.get_leaves()]
        sp_expr2 = [sp for sp in sp_expr if sp.replace(" ","_") in sp_leaves2]



        if dataset_type  == "expression":
            #print("sp_expr")
            t.prune([n.replace(" ","_") for n in sp_expr2], preserve_branch_length = True)

        # change name:
        all_sp = [n.name for n in t.get_leaves()]
        #print ("tree")
        #print (t)



        def get_FabreName(sp, replace=False):
           if replace:
               sp = sp.replace("_"," ")
           res = df_attribut_tsv.Fabre_leaf_name[df_attribut_tsv.Species == sp].values[0]
           return res

        def get_bio17(sp):
            res="0"
            if sp != "":
                val = df_attribut_tsv["Precipitation_of_Driest_Quarter"][df_attribut_tsv.Species == sp].values[0]
                print(val)
                if np.isnan(val):
                    res = "NA"
                else:
                    res = str(int(val))
            return res


        def get_analysis_node(sp):
            res=""
            if sp != "":
                df_sp = df_attribut_tsv.loc[df_attribut_tsv["Species"] == sp.replace("_"," ")]
                if not df_sp.empty:
                    E = df_sp ["total"].values[0] == "yes"
                    C = True
                    if E and C :
                        res = "CE"
                    elif C:
                        res = "C"
            return res
        
        def get_nb_samples_expr(sp):
            res = ""
            df_coldata_sp = df_coldata_expr_total[df_coldata_expr_total["Species_ok"] == sp]
            df_coldata_sp = df_coldata_sp[["Species_ok","Bioproject"]]

            nb_local = df_coldata_sp[df_coldata_sp["Bioproject"] == "local"].shape[0]
            nb_public = df_coldata_sp.shape[0] - nb_local
            return([nb_local,nb_public])
        
        def get_val_from_coldata_expr(sp,column):
            res = ""
            if dataset_name + column in df_coldata_expr.columns:
                df_coldata_sp = df_coldata_expr[df_coldata_expr["Species_ok"] == sp]
                df_coldata_sp = df_coldata_sp[["Species_ok",dataset_name + column]] 

                group_name = df_coldata_sp[dataset_name + column].unique()
                if len(group_name) > 0  and isinstance(group_name[0], float):
                    group_name = [ str(int(x)) if str(x) != 'nan' else "" for x in group_name]

                res = "".join(group_name)

            return(res)
        
        
        def is_a_sp_in_a_dataset(genus_name, dataset_to_check):
            df_tmp = df_attribut_tsv_dataset_x[df_attribut_tsv_dataset_x["Species"] == genus_name]
            res = df_tmp[dataset_to_check].values[0]
            if res == "yes":
                res = "Y"
            elif res == "only_s":
                res = "y"
            else:
                res = "N"
            return (res)


        def rect_bio17_node(pr, value=False):
            pr_str_int = pr.replace("mm","").replace("<", "").replace(">", "").replace("≤", "").replace("=","")
            if value:
                int_pr = value
            elif pr_str_int == "NA":
                int_pr = -1
            else:
                int_pr = int(pr_str_int)
            if int_pr < 0:
                col = "#FFFFFF"
            elif 0 <= int_pr <= 40:
                col = col_grad1
            elif 40 < int_pr <= 1500:
                col = col_grad2
            else:
                col = col_grad3
            if not "mm" in pr:
                nb_space = 3-len(pr)
                pr = max(0,nb_space) * " " + pr

            label = {"text" : pr, "fontsize" : 12, "color": "black" }
            PrecRect_style = RectFace(6+12*len(pr),25,"black",col,label)
            #PrecRect_style.background.color = col
            #PrecRect_style.border.width = 1
            PrecRect_style.margin_bottom = 5
            PrecRect_style.margin_top = 0
            PrecRect_style.margin_right = 5
            PrecRect_style.margin_left = 5
            PrecRect_style.background.color = "white"
            #PrecRect_style.opacity = 1
            return(PrecRect_style)

        # Type of analyses possible rectangle definition
        def rect_analysis_node(value, text = ""):
            font = "Black"
            if text != "":
                an_str = value
                an = text
            elif value == "CDS":
                an_str = "CDS"
                an = an_str
            elif value == "Exp":
                an_str = 'Exp'
                an = an_str
            elif value == "CDS":
                an_str = 'CDS'
                an = an_str
            else:
                an_str = ''
                an = an_str

            if an_str == 'Exp':
                col = "#D4D4D4"
            else:
                col = "White"

            label = {"text" : an, "fontsize" : 12, "color": "black" }
            AnaRect_style = RectFace(8*len(an),25,"black",col,label)
            #AnaRect_style.background.color = col
            #AnaRect_style.border.width = 1
            AnaRect_style.margin_bottom = 5
            AnaRect_style.margin_top = 5
            AnaRect_style.margin_right = 5
            AnaRect_style.margin_left = 5
            #AnaRect_style.opacity = 1
            return(AnaRect_style)

        def bns_style(col="black", bg="", width= 5):
            ns = NodeStyle()
            ns["size"]  = 0
            ns["fgcolor"] = col
            if bg:
               ns["bgcolor"] = bg
            ns["hz_line_width"] = width
            ns["vt_line_width"] = width
            ns["vt_line_color"] = col
            ns["hz_line_color"] = col
            return(ns)

        def get_lns(sp, bg, col, fgcol,width,size):
            ns = NodeStyle()
            ns["size"]  = size
            ns["shape"] = "circle"
            ns["fgcolor"] = fgcol
            if bg:
               ns["bgcolor"] = bg
            ns["hz_line_width"] = width
            ns["vt_line_width"] = width
            ns["vt_line_color"] = col
            ns["hz_line_color"] = col
            return(ns)


        def get_col_fam(n):
            desc_l = n.get_leaves()
            fam_l = []
            for desc in desc_l:
                if desc.name:
                    sp = desc.name.replace("_", " ")
                    fam_l.append(df_attribut_tsv["Family"][df_attribut_tsv.Species == sp].values[0])
            set_fam = list(set(fam_l))
            if len(set_fam) == 1:
                return(df_col_fam.fam_col[df_col_fam["family"] == set_fam[0]].values[0])
            else:
                return()
        
        #automated annotation
        ## add bio17 value in tag
        print("add tag")
        i_new = -1
        i_old = 0
        i = 0
        while (i_new != i_old or i < 10):
            i+=1
            i_old = i_new
            i_new = 0
            for node in t.traverse(strategy="postorder"):
                if node.is_leaf():
                    sp = node.name.replace("_", " ")
                    node.bio17 = get_bio17(sp)
                    if int(node.bio17) <= 40 :
                        node.Condition = "1"
                        i_new +=1
                    else:
                        node.Condition = "0"
                else:
                    children_l, children_r = node.get_children()
                    if children_l.Condition == "1" and children_r.Condition == "1":
                        node.Condition = "1"
                        i_new +=1
                    elif children_l.Condition == "1" and children_r.Condition == "0":
                        node.Condition = "0"
                        children_l.Transition = "1"
                        i_new +=1
                    elif children_l.Condition == "0" and children_r.Condition == "1":
                        node.Condition = "0"
                        children_r.Transition = "1"
                        i_new +=1
                    else:
                        node.Condition = "0"

        # Basic tree style
        ts = TreeStyle()
        ts.show_leaf_name = False
        ts.show_scale = True
        ts.draw_guiding_lines = True
        ts.scale = 4000

        for n in t.traverse():
            if add_fam_col:
                col_fam = get_col_fam(n)
            else:
                col_fam = ""
            width = width_branch
            col = col_anc
            if hasattr(n, "Condition") and n.Condition == "1":
                width = width_branch
                col = col_conv
            if hasattr(n, "Transition") and n.Transition == "1":
                width = width_branch
                col = col_trans
            if n.is_leaf():
                bio17 = get_bio17(n.name.replace("_", " "))
                if int(bio17) <= 40:
                    fgcol = col_grad1
                else:
                    fgcol = col_grad2
                n.set_style(get_lns(n.name, bg = col_fam, col = col, fgcol=fgcol, width = width, size = size_leaves))
            else:
                n.set_style(bns_style(col=col, bg = col_fam,width = width))


        for node in t.traverse():

            if node.is_leaf():
                nb_col = 0
                
                genus_name = node.name
                node.name = node.name.replace("_", " ")
                node.add_face(TextFace(" ", fsize=8), column=nb_col, position = "aligned")
                
                if bio17_yes :
                    bio17 = get_bio17(node.name)
                    nb_col +=1
                    node.add_face(rect_bio17_node(bio17), column=nb_col, position = "aligned")

                analysis_node = get_analysis_node(node.name)
                nb_col +=1

                #############################

                empty_face = TextFace("  ", fsize=12, fstyle='italic')
                empty_face.background.color = "white"
                node.add_face(empty_face, column=nb_col, position = "aligned")
                nb_col +=1
                
                #############################

                nb_inds_exp = get_nb_samples_expr(genus_name)

                empty_face = TextFace(" ", fsize=6, fstyle='italic')

                col_bg_C = "#eaeaea"
                col_bg_CE = "#FFFFFF"

                col_Y_CE = "#cb997e"
                col_Y_C  = "#d5bdaf"
                col_N_CE = "#b7b7a4"
                col_N_C  = "#cccccc"

                col_cercle_New = "#777777"
                col_cercle_Pub = "#c3c9cf"


                empty_face.background.color = col_bg_C

                if analysis_node == "CE" or True :
                    
                    label_N = {"text" : str(nb_inds_exp[0]), "fontsize" : 11, "color": "black" }
                    face_N = CircleFace(12,col_cercle_New,label=label_N)
                    label_P = {"text" : str(nb_inds_exp[1]), "fontsize" : 11, "color": "black" }
                    face_P = CircleFace(12,col_cercle_Pub,label=label_P)

                    face_N.background.color = col_bg_CE
                    face_P.background.color = col_bg_CE
                    empty_face.background.color = col_bg_C


                    if nb_inds_exp[0] > 0 :
                        node.add_face(face_N, column=nb_col, position = "aligned")
                    #else:
                    #    node.add_face(empty_face, column=nb_col, position = "aligned")
                    elif nb_inds_exp[1] > 0 :
                        node.add_face(face_P, column=nb_col, position = "aligned")
                    else:
                        node.add_face(empty_face, column=nb_col, position = "aligned")

                elif analysis_node == "C":
                    node.add_face(empty_face, column=nb_col, position = "aligned")
                nb_col +=1

                #################################

                name_face = TextFace("  " + node.name, fsize=14, fstyle='italic', bold = True)
                name_face.background.color = "white"
                node.add_face(name_face,  column=nb_col, position = "aligned")
                nb_col +=1

                ##################################

                empty_face = TextFace("  ", fsize=12, fstyle='italic')
                empty_face.background.color = "white"
                node.add_face(empty_face, column=nb_col, position = "aligned")
                nb_col +=1
                
                if dataset_type == "expression":
                    group_name = get_val_from_coldata_expr(genus_name,"_permutation")
                    print(group_name)
                    group_face = TextFace("  " + group_name, fsize=14)
                    group_face.background.color = "white"
                    node.add_face(group_face,  column=nb_col, position = "aligned")
                    nb_col +=1

                    deseq2fam_name = get_val_from_coldata_expr(genus_name,"_familydeseq2")
                    fam_face = TextFace("  " + deseq2fam_name, fsize=14)
                    fam_face.background.color = "white"
                    node.add_face(fam_face,  column=nb_col, position = "aligned")
                    nb_col +=1
                
                if dataset_name == "total" and dataset_type == "sequences":
                    for dataset_to_check in ["total", "murinae", "recent_trans", "ancient_trans"]:
                        is_in = is_a_sp_in_a_dataset(node.name, dataset_to_check)

                        
                        label = {"text" : is_in, "fontsize" : 12, "color": "black" }

                        print(is_in)
                        if analysis_node == "CE":
                            if is_in == "Y":
                                col_dataset = col_Y_CE
                            elif is_in == "y":
                                col_dataset = col_Y_C
                            else :
                                col_dataset = col_N_CE
                        else:
                            if is_in == "y":
                                col_dataset = col_Y_C
                            else:
                                col_dataset = col_N_C
                        
                        DatasetCircle_face = CircleFace(12,col_dataset,label=label)
                        if analysis_node == "CE":
                            DatasetCircle_face.background.color = col_bg_CE
                        else:
                            DatasetCircle_face.background.color = col_bg_C
                        
                        
                        node.add_face(DatasetCircle_face, column=nb_col, position = "aligned")
                        nb_col +=1

        

        if legend :
            ts.legend_position = 1
            
            nb_col_l = 0
            ts.legend.add_face(TextFace("Legend:  ", fsize=12), column=nb_col_l)

            nb_col_l+=1
            ts.legend.add_face(TextFace("  ", fsize=8), column=nb_col_l)

            if add_fam_col and True:
                nb_col_l+=1
                ts.legend.add_face(TextFace("  ", fsize=12), column=nb_col_l)
                nb_col_l += 1
                #ts.legend.add_face(RectFace(10,10,"black",fam_col), column=nb_col_l)
                ts.legend.add_face(TextFace("Family:", fsize=12), column=nb_col_l)
                ts.legend.add_face(TextFace("  ", fsize=8), column=nb_col_l + 1)
                order_d = {}
                for sp in all_sp:
                    sp_ok = sp.replace("_"," ")
                    fam = df_attribut_tsv["Family"][df_attribut_tsv.Species == sp_ok].values[0]
                    order = df_attribut_tsv["Order"][df_attribut_tsv.Species == sp_ok].values[0]
                    if order in order_d:
                        order_d[order].append(fam)
                    else:
                        order_d[order]=[fam]
                for order in order_d:
                    set_fam = list(set(order_d[order]))
                    for fam in set_fam:
                        fam_col = df_col_fam.fam_col[df_col_fam["family"] == fam].values[0]
                        ts.legend.add_face(RectFace(20,20,"black",fam_col), column=nb_col_l)
                        ts.legend.add_face(TextFace(" "+fam, fsize=10), column=nb_col_l + 1)
                nb_col_l+=1

            if bio17_yes:
                nb_col_l+=1
                ts.legend.add_face(TextFace(" Precipitation of Driest Quarter"+":", fsize=12), column=nb_col_l)
                ts.legend.add_face(rect_bio17_node("≤ 40mm", value=10), column=nb_col_l)
                ts.legend.add_face(rect_bio17_node("> 40mm", value=51), column=nb_col_l)


            ts.legend.add_face(TextFace("   ", fsize=8), column=nb_col_l)
            


            nb_col_l += 1
            ts.legend.add_face(TextFace("      ", fsize=8), column=nb_col_l)
            

            
            nb_col_l+=1
            #ts.legend.add_face(rect_analysis_node(text="CDS: Cds",value="C"), column=nb_col_l)
            #nb_col_l+=1
            ts.legend.add_face(rect_analysis_node(text="Expression data:",value="E"), column=nb_col_l)
            nb_col_l += 1

            label_N = {"text" : "New", "fontsize" : 10, "color": "black" }
            face_N = CircleFace(15,col_cercle_New,label=label_N)
            label_P = {"text" : "Pub", "fontsize" : 10, "color": "black" }
            face_P = CircleFace(15,col_cercle_Pub,label=label_P)

            face_N.background.color = col_bg_CE
            face_P.background.color = col_bg_CE


            ts.legend.add_face(face_N, column=nb_col_l)
            ts.legend.add_face(face_P, column=nb_col_l+1)
            nb_col_l += 2

            
            ts.legend.add_face(TextFace("Species name", fsize=12), column=nb_col_l)
            nb_col_l += 1

            ts.legend.add_face(TextFace("      ", fsize=8), column=nb_col_l)
            nb_col_l += 1

            if dataset_name == "total" and dataset_type == "sequences":

                ts.legend.add_face(TextFace("     ", fsize=12), column=nb_col_l)
                nb_col_l += 1

                ts.legend.add_face(rect_analysis_node(text="Present in:",value="E"), column=nb_col_l)
                ts.legend.add_face(TextFace("Sub dataset", fsize=12), column=nb_col_l+1)


                for is_in in ["Y","y","N"]:
                    label = {"text" : is_in, "fontsize" : 12, "color": "black" }
                    if is_in == "Y":
                        col_dataset = col_Y_CE
                        text = "Expr+CDS"
                    elif is_in == "y":
                        col_dataset = col_Y_C
                        text = "only CDS"
                    else :
                        col_dataset = col_N_CE
                        text = "no"
                        
                    DatasetCircl_face = CircleFace(12,col_dataset,label=label)
                    DatasetCircl_face.background.color = "white"
                    ts.legend.add_face(DatasetCircl_face, column=nb_col_l)

                    ts.legend.add_face(TextFace(text, fsize=12), column=nb_col_l+1)

                nb_col_l += 2
                ts.legend.add_face(TextFace("     ", fsize=12), column=nb_col_l)
                nb_col_l += 1

                ts.legend.add_face(rect_analysis_node(text="Sub dataset order:",value="E"), column=nb_col_l) 
                for dataset_to_check in ["total", "murinae", "recent_trans", "ancient_trans"]:
                    ts.legend.add_face(TextFace(dataset_to_check, fsize=12), column=nb_col_l)
                
                nb_col_l += 1
            





        
        
            ts.title.add_face(TextFace("Dataset: "+ dataset_name + " (" + dataset_type + " sp)", fsize=14), column=0)

        t.render(path_figure+"/"+dataset_type+"/svg/tree_dataset_"+dataset_name+"_"+dataset_type+"_sp_"+date+".svg", tree_style=ts)
        t.render(path_figure+"/"+dataset_type+"/pdf/tree_dataset_"+dataset_name+"_"+dataset_type+"_sp_"+date+".pdf", tree_style=ts)
        t.render(path_figure+"/"+dataset_type+"/png/tree_dataset_"+dataset_name+"_"+dataset_type+"_sp_"+date+".png", tree_style=ts)
        
        t.write(outfile=path_figure+"/"+dataset_type+"/nw/tree_dataset_"+dataset_name+"_"+dataset_type+"_sp.nw",format=1)

        for node in t.traverse():
            if node.is_leaf():
                node.name = node.raw_name
        
        t.write(outfile=path_figure+"/"+dataset_type+"/raw_nw/tree_dataset_"+dataset_name+"_"+dataset_type+"_raw_sp.nw",format=1)
