#!/usr/bin/python3

from genericpath import exists
import shutil
import subprocess
import glob
import os
import performance_metrics as pm
from argparse import ArgumentParser

tools = ["Taxor", "Centrifuge", "Ganon", "KMCP", "Kraken2", "MetaMaps"]
tax_abund_tools = ["Taxor", "Ganon", "KMCP", "Sourmash"]
seq_abund_tools = ["Taxor", "Centrifuge", "Kraken2", "MetaMaps"]
#tools = ["MetaMaps"]
datasets = ["refseq-abfv-sim100","ZymoR103-groundTruth", "ZymoQ20-groundTruth", "HiFi_D6331-groundTruth","Illumina_D6300-groundTruth"]
#datasets = ["ZymoR103-groundTruth"]

def get_fn_fp_tp(results_dir, taxid_map_file, nodes_dmp_file):

    ds_metrics_mat = {}
    for ds in datasets:
        ds_metrics_mat[ds] = {}
        for t in tools:
            if ds == "Illumina_D6300-groundTruth" and t == "MetaMaps":
                continue
            ds_metrics_mat[ds][t] = {}
            binning_file = ds
            if t == "Centrifuge":
                binning_file += ".centrifuge.search.txt"
            elif t == "Ganon":
                binning_file += ".search.all"
            elif t == "Kraken2":
                binning_file += ".output.txt"
            elif t == "MetaMaps":
                binning_file += ".map.txt.EM.reads2Taxon"
            else:
                binning_file += ".binning"
            binning_file = results_dir + "/" + t + "/" + binning_file

            nodes = pm.parse_node_dump_file(nodes_dmp_file)
    
            level_mat, unclassified = pm.confusion_matrix(taxid_map_file, binning_file, nodes)

            for level in level_mat.keys():
                ds_metrics_mat[ds][t][level] = level_mat[level]
                ds_metrics_mat[ds][t][level]["FN"] += unclassified
                if t == "Taxor":
                    ds_metrics_mat[ds]["reads"] = ds_metrics_mat[ds][t][level]["TP"] + ds_metrics_mat[ds][t][level]["FN"] + ds_metrics_mat[ds][t][level]["FP"]
                else:
                    all_reads = ds_metrics_mat[ds][t][level]["TP"] + ds_metrics_mat[ds][t][level]["FN"] + ds_metrics_mat[ds][t][level]["FP"]
                    if all_reads < ds_metrics_mat[ds]["reads"]:
                        ds_metrics_mat[ds][t][level]["FN"] += ds_metrics_mat[ds]["reads"] - all_reads
            
            print(ds + " : " + t + " ... finished")
            #return ds_metrics_mat

    return ds_metrics_mat

def calculate_metrics(ds_mat):

    for ds in ds_mat:
        for t in tools:
            if ds == "Illumina_D6300-groundTruth" and t == "MetaMaps":
                continue
            ds_t = ds_mat[ds][t]
            util_cum = 0.0
            for level in ["species", "genus", "family", "order", "class"]:
                ds_mat[ds][t][level]["Utilization"] = (float(ds_mat[ds][t][level]["TP"] + ds_mat[ds][t][level]["FP"])) / float(ds_mat[ds]["reads"]) - util_cum
                util_cum += ds_mat[ds][t][level]["Utilization"]
                ds_mat[ds][t][level]["Precision"] = float(ds_mat[ds][t][level]["TP"]) / (float(ds_mat[ds][t][level]["TP"] + ds_mat[ds][t][level]["FP"]))
                ds_mat[ds][t][level]["Recall"] = float(ds_mat[ds][t][level]["TP"]) / (float(ds_mat[ds][t][level]["TP"] + ds_mat[ds][t][level]["FN"]))
                ds_mat[ds][t][level]["F1-Score"] = (2 * ds_mat[ds][t][level]["Precision"] * ds_mat[ds][t][level]["Recall"]) / \
                                                    (ds_mat[ds][t][level]["Precision"] + ds_mat[ds][t][level]["Recall"])
                ds_mat[ds][t][level]["F0.5-Score"] = ((1.0 + pow(0.5,2)) * ds_mat[ds][t][level]["Precision"] * ds_mat[ds][t][level]["Recall"]) / \
                                                    ((pow(0.5,2) * ds_mat[ds][t][level]["Precision"]) + ds_mat[ds][t][level]["Recall"])  
    return ds_mat

def write_output_tsv(ds_mat, output_file):
    with open(output_file, 'wt') as outfile:
        outfile.write("Dataset\tTool\tLevel\tMetric\tValue\n")
        for ds in ds_mat:
            for t in tools:
                if ds == "Illumina_D6300-groundTruth" and t == "MetaMaps":
                    continue
                ds_t = ds_mat[ds][t]
                for level in ds_t:
                    for met in ["Utilization","Precision", "Recall", "F1-Score", "F0.5-Score"]:
                        outfile.write(ds + "\t" + t + "\t" + level + "\t" + met + "\t" + str(ds_mat[ds][t][level][met]) + "\n")
#                                  str(ds_mat[ds][t][level]["Recall"]) + "\t" + str(ds_mat[ds][t][level]["F1-Score"]) + "\t" + 
#                                  str(ds_mat[ds][t][level]["F0.5-Score"]) + "\n")

def write_output_excel_csv(ds_mat, output_file):
    with open(output_file, 'wt') as outfile:
        outfile.write("Dataset\tTool\tLevel\tUtilization\tPrecision\tRecall\tF1-Score\tF0.5-Score\n")
        for ds in ds_mat:
            for t in tools:
                if ds == "Illumina_D6300-groundTruth" and t == "MetaMaps":
                    continue
                ds_t = ds_mat[ds][t]
                for level in ds_t:
                    outfile.write(ds + "\t" + t + "\t" + level + "\t" + str(ds_mat[ds][t][level]["Utilization"]) + "\t" + \
                                  str(ds_mat[ds][t][level]["Precision"]) + "\t" + \
                                  str(ds_mat[ds][t][level]["Recall"]) + "\t" + str(ds_mat[ds][t][level]["F1-Score"]) + "\t" + \
                                  str(ds_mat[ds][t][level]["F0.5-Score"]) + "\n")


def parse_cami_file(file, tool):
    abund_mat = {}
    with open(file, 'rt') as in_file:
        for line in in_file:
            if line.startswith("@"):
                continue

            lsplit = line.strip().split("\t")
            if tool == "Ganon" and not lsplit[0] == "species":
                continue

            if not tool == "Ganon" and not lsplit[1] == "species":
                continue
            
            if float(lsplit[-1]) >= 0.01:
                if tool == "Ganon":
                    abund_mat[lsplit[1]] = lsplit[-1]
                else:
                    abund_mat[lsplit[0]] = lsplit[-1]

    return abund_mat


def parse_kraken_report(file):
    abund_mat = {}
    with open(file, 'rt') as in_file:
        for line in in_file:
            lsplit = line.strip().split("\t")
            if not lsplit[3] == "S":
                continue

            if float(lsplit[0]) >= 0.01:
                abund_mat[lsplit[4]] = lsplit[0]
    
    return abund_mat

def parse_WIMP_file(file):
    abund_mat = {}
    with open(file, 'rt') as in_file:
        for line in in_file:
            if line.startswith("AnalysisLevel"):
                continue

            lsplit = line.strip().split("\t")
            if not lsplit[0] == "species":
                continue

            if float(lsplit[-1]) >= 0.0001:
                abund_mat[lsplit[1]] = float(lsplit[-1]) * 100

    return abund_mat 

def init_ds_taxonomic_abundances():
    abund_mat = {}
    for ds in datasets:

        if ds == "refseq-abfv-sim100":
            continue

        abund_mat[ds] = {}
        abund_mat[ds]["theoretical"] = {}
        if ds.startswith("Zymo") or ds.startswith("Illumina"):
            abund_mat[ds]["theoretical"]["p_aeruginosa"] = 6.1
            abund_mat[ds]["theoretical"]["e_coli"] = 8.5
            abund_mat[ds]["theoretical"]["s_enterica"] = 8.7
            abund_mat[ds]["theoretical"]["l_fermentum"] = 21.6
            abund_mat[ds]["theoretical"]["e_faecalis"] = 14.6
            abund_mat[ds]["theoretical"]["s_aureus"] = 15.2
            abund_mat[ds]["theoretical"]["l_monocytogenes"] = 13.9
            abund_mat[ds]["theoretical"]["b_subtilis"] = 10.3
            abund_mat[ds]["theoretical"]["s_cerevisiae"] = 0.57
            abund_mat[ds]["theoretical"]["c_neoformans"] = 0.37
        else:
            abund_mat[ds]["theoretical"]["f_prausnitzii"] = 14.77
            abund_mat[ds]["theoretical"]["e_coli"] = 8.71
            abund_mat[ds]["theoretical"]["s_enterica"] = 0.007
            abund_mat[ds]["theoretical"]["l_fermentum"] = 9.68
            abund_mat[ds]["theoretical"]["e_faecalis"] = 0.0011
            abund_mat[ds]["theoretical"]["v_rogosae"] = 19.94
            abund_mat[ds]["theoretical"]["r_hominis"] = 12.43
            abund_mat[ds]["theoretical"]["b_fragilis"] = 8.33
            abund_mat[ds]["theoretical"]["s_cerevisiae"] = 0.32
            abund_mat[ds]["theoretical"]["p_corporis"] = 6.26
            abund_mat[ds]["theoretical"]["b_adolescentis"] = 8.83
            abund_mat[ds]["theoretical"]["f_nucleatum"] = 7.53
            abund_mat[ds]["theoretical"]["c_difficile"] = 1.10
            abund_mat[ds]["theoretical"]["a_muciniphila"] = 1.62
            abund_mat[ds]["theoretical"]["m_smithii"] = 0.17
            abund_mat[ds]["theoretical"]["c_perfringens"] = 0.00009
            abund_mat[ds]["theoretical"]["c_albicans"] = 0.31
        for t in tax_abund_tools:
            abund_mat[ds][t] = {}
            for species in abund_mat[ds]["theoretical"].keys():
                abund_mat[ds][t][species] = 0.0
            abund_mat[ds][t]["other"] = 0.0

    return abund_mat

def init_ds_sequence_abundances():
    abund_mat = {}
    for ds in datasets:

        if ds == "refseq-abfv-sim100":
            continue

        abund_mat[ds] = {}
        abund_mat[ds]["theoretical"] = {}
        if ds.startswith("Zymo") or ds.startswith("Illumina"):
            abund_mat[ds]["theoretical"]["p_aeruginosa"] = 12.0
            abund_mat[ds]["theoretical"]["e_coli"] = 12.0
            abund_mat[ds]["theoretical"]["s_enterica"] = 12.0
            abund_mat[ds]["theoretical"]["l_fermentum"] = 12.0
            abund_mat[ds]["theoretical"]["e_faecalis"] = 12.0
            abund_mat[ds]["theoretical"]["s_aureus"] = 12.0
            abund_mat[ds]["theoretical"]["l_monocytogenes"] = 12.0
            abund_mat[ds]["theoretical"]["b_subtilis"] = 12.0
            abund_mat[ds]["theoretical"]["s_cerevisiae"] = 2.0
            abund_mat[ds]["theoretical"]["c_neoformans"] = 2.0
        else:
            abund_mat[ds]["theoretical"]["f_prausnitzii"] = 14.0
            abund_mat[ds]["theoretical"]["e_coli"] = 14.0
            abund_mat[ds]["theoretical"]["s_enterica"] = 0.01
            abund_mat[ds]["theoretical"]["l_fermentum"] = 6.0
            abund_mat[ds]["theoretical"]["e_faecalis"] = 0.001
            abund_mat[ds]["theoretical"]["v_rogosae"] = 14.0
            abund_mat[ds]["theoretical"]["r_hominis"] = 14.0
            abund_mat[ds]["theoretical"]["b_fragilis"] = 14.0
            abund_mat[ds]["theoretical"]["s_cerevisiae"] = 1.4
            abund_mat[ds]["theoretical"]["p_corporis"] = 6.0
            abund_mat[ds]["theoretical"]["b_adolescentis"] = 6.0
            abund_mat[ds]["theoretical"]["f_nucleatum"] = 6.0
            abund_mat[ds]["theoretical"]["c_difficile"] = 1.5
            abund_mat[ds]["theoretical"]["a_muciniphila"] = 1.5
            abund_mat[ds]["theoretical"]["m_smithii"] = 0.1
            abund_mat[ds]["theoretical"]["c_perfringens"] = 0.0001
            abund_mat[ds]["theoretical"]["c_albicans"] = 1.5
        for t in seq_abund_tools:
            abund_mat[ds][t] = {}
            for species in abund_mat[ds]["theoretical"].keys():
                abund_mat[ds][t][species] = 0.0
            abund_mat[ds][t]["other"] = 0.0

    return abund_mat

def summarize_abundances(abundance_matrix, species_abundances, dataset: str, t : str):

    for taxon in species_abundances[dataset][t].keys():
        if dataset.startswith("Zymo") or dataset.startswith("Illumina"):
            # similar s_aureus
            if taxon in ["1280","2799681","2807624","2807625"]:
                abundance_matrix[dataset][t]["s_aureus"] += float(species_abundances[dataset][t][taxon])
                continue
            #similar p aeruginosa
            elif taxon in ["287"]:
                abundance_matrix[dataset][t]["p_aeruginosa"] += float(species_abundances[dataset][t][taxon])
                continue
            # similar b_subtilis
            elif taxon in ["96241", "1423"]:
                abundance_matrix[dataset][t]["b_subtilis"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "2678528":
                abundance_matrix[dataset][t]["l_monocytogenes"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "5207":
                abundance_matrix[dataset][t]["c_neoformans"] += float(species_abundances[dataset][t][taxon])
                continue
        else:
            # similar f nucleatum
            if taxon in ["851", "671211"]:
                abundance_matrix[dataset][t]["f_nucleatum"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "853":
                abundance_matrix[dataset][t]["f_prausnitzii"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "423477":
                abundance_matrix[dataset][t]["v_rogosae"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "301301":
                abundance_matrix[dataset][t]["r_hominis"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "817":
                abundance_matrix[dataset][t]["b_fragilis"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "28128":
                abundance_matrix[dataset][t]["p_corporis"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon in ["1680", "2985571"]:
                abundance_matrix[dataset][t]["b_adolescentis"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "1496":
                abundance_matrix[dataset][t]["c_difficile"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "239935":
                abundance_matrix[dataset][t]["a_muciniphila"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "2173":
                abundance_matrix[dataset][t]["m_smithii"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "1502":
                abundance_matrix[dataset][t]["c_perfringens"] += float(species_abundances[dataset][t][taxon])
                continue
            elif taxon == "5476":
                abundance_matrix[dataset][t]["c_albicans"] += float(species_abundances[dataset][t][taxon])
                continue


        # similar s_enterica
        if taxon in ["28901", "2878546","2816951","2878546","2979871","2500542","2833771","2833772","2204208","2816952","2816953", "2991700", \
                     "2874627", "2876620", "2745892"]:
            abundance_matrix[dataset][t]["s_enterica"] += float(species_abundances[dataset][t][taxon])
            continue  
        # similar e_coli
        elif taxon in ["2860337", "2857061", "562", "386585", "2725997"]:
            abundance_matrix[dataset][t]["e_coli"] += float(species_abundances[dataset][t][taxon])
            continue
        elif taxon == "1613":
            abundance_matrix[dataset][t]["l_fermentum"] += float(species_abundances[dataset][t][taxon])
            continue
        elif taxon in ["1351", "2920934"]:
            abundance_matrix[dataset][t]["e_faecalis"] += float(species_abundances[dataset][t][taxon])
            continue
        elif taxon == "4932":
            abundance_matrix[dataset][t]["s_cerevisiae"] += float(species_abundances[dataset][t][taxon])
            continue
        else:
            abundance_matrix[dataset][t]["other"] += float(species_abundances[dataset][t][taxon])
    return abundance_matrix


def get_sequence_abundance(results_dir):
    species_abundances = {}
    abundance_matrix = init_ds_sequence_abundances()
    for ds in datasets:
        # only real mock datasets
        if ds == "refseq-abfv-sim100":
            continue

        species_abundances[ds] = {}
        for t in seq_abund_tools:

            if ds == "Illumina_D6300-groundTruth" and t == "MetaMaps":
                continue

            abundance_file = results_dir + "/" + t + "/" + ds
            if t == "Centrifuge":
                abundance_file += ".centrifuge.kreport.txt"
                species_abundances[ds][t] = parse_kraken_report(abundance_file)
            elif t == "Kraken2":
                abundance_file += ".report"
                species_abundances[ds][t] = parse_kraken_report(abundance_file)
            elif t == "MetaMaps":
                abundance_file += ".map.txt.EM.WIMP"
                species_abundances[ds][t] = parse_WIMP_file(abundance_file)
            elif t == "Taxor":
#                abundance_file += ".report"
                abundance_file += ".abundance"
                species_abundances[ds][t] = parse_cami_file(abundance_file, t)
        
            abundance_matrix = summarize_abundances(abundance_matrix, species_abundances, ds, t)

    return abundance_matrix

def get_taxonomic_abundance(results_dir):
    species_abundances = {}
    abundance_matrix = init_ds_taxonomic_abundances()
    for ds in datasets:
        # only real mock datasets
        if ds == "refseq-abfv-sim100":
            continue

        species_abundances[ds] = {}
        for t in tax_abund_tools:
            abundance_file = results_dir + "/" + t + "/" + ds
            if t == "Ganon":
                abundance_file += ".search.tre"
                species_abundances[ds][t] = parse_cami_file(abundance_file, t)
            elif t == "Taxor":
                abundance_file += ".report"
#                abundance_file += ".abundance"
                species_abundances[ds][t] = parse_cami_file(abundance_file, t)
            elif t == "Sourmash":
                abundance_file += ".kreport.txt"
                species_abundances[ds][t] = parse_kraken_report(abundance_file)
            else:
                abundance_file += ".cami.profile"
                species_abundances[ds][t] = parse_cami_file(abundance_file, t)
        
            abundance_matrix = summarize_abundances(abundance_matrix, species_abundances, ds, t)

    return abundance_matrix

def write_abundance_csv(abundance_matrix, output_file):
    with open(output_file, 'wt') as outfile:
        outfile.write("Dataset\tTool\tSpecies\tAbundance\n")
        for ds in abundance_matrix:
            for t in abundance_matrix[ds].keys():
                ds_t = abundance_matrix[ds][t]
                for species in ds_t.keys():
                    outfile.write(ds + "\t" + t + "\t" + species + "\t" + str(ds_t[species]) + "\n")

def main():

    parser = ArgumentParser("")
    parser.add_argument('-r', '--results_dir', required=True,
                        help='directory path to results of all tools')
    parser.add_argument('-o', '--output_dir', required=True,
                        help="directory for saving output files")
    parser.add_argument('-t', '--taxid_map', required=True,
                        help="taxid map file")
    parser.add_argument('-n', '--nodes_dmp', required=True,
                        help="nodes map file from ncbi taxonomy")
                       
    args = parser.parse_args()

    if not os.path.exists(args.results_dir):
        print("[Error] The given directory containing results from all tools does not exist: " + args.results_dir)
        return

    if not os.path.exists(args.taxid_map):
        print("[Error] The given taxid_map file does not exist: " + args.taxid_map)
        return
    
    if not os.path.exists(args.nodes_dmp):
        print("[Error] The given nodes.dmp file does not exist: " + args.nodes_dmp)
        return

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)


    ds_mat = get_fn_fp_tp(args.results_dir, args.taxid_map, args.nodes_dmp)

    ds_mat = calculate_metrics(ds_mat)

    write_output_tsv(ds_mat, args.output_dir + "/results.tsv")
    write_output_excel_csv(ds_mat, args.output_dir + "/level_metrics.tsv")

    seq_abundance_matrix = get_sequence_abundance(args.results_dir)
    write_abundance_csv(seq_abundance_matrix, args.output_dir + "/sequence_abundances.tsv")

    tax_abundance_matrix = get_taxonomic_abundance(args.results_dir)
    write_abundance_csv(tax_abundance_matrix, args.output_dir + "/taxonomic_abundances.tsv")




if __name__ == "__main__":
    main()