#!/usr/bin/python

from argparse import ArgumentParser

def read_taxid_file(taxid_file):
    accessions = {}
    with open(taxid_file, 'rt') as file:
        for line in file:
            lsplit = line.strip().split("\t")
            accessions[lsplit[0]] = lsplit[1]

    return accessions

# low_level = 1 => TP
# low_level = 2 => FP
# low_level = 3 => FN
def count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, cur_level_index, low_level = 3):

    subs = ["strain", "species group", "subspecies", "species subgroup", "subgenus", "subfamily", "subclass", "suborder", "no rank", "clade", "tribe"]
    cur_level = levels[cur_level_index]

    if not exp_taxid in nodes:
        return
    
    if nodes[exp_taxid]["rank"] in subs:
        exp_taxid = nodes[exp_taxid]["parent"]
    
    if nodes[true_taxid]["rank"] in subs:
        true_taxid = nodes[true_taxid]["parent"]

    # expected taxid is on the same level than true taxid
    if nodes[exp_taxid]["rank"] == cur_level and nodes[true_taxid]["rank"] == cur_level:
        # correct taxon assignment
        new_low_level = 1
        if true_taxid == exp_taxid:
            level_mat[cur_level]["TP"] += 1
        # synonymuous taxon assignment
        else:
            fp = True
            # similar s_aureus
            if true_taxid == "2807625":
                if exp_taxid in ["1280","2799681","2807624"]:
                    level_mat[cur_level]["TP"] += 1
                    fp = False
            #similar p aeruginosa
            elif true_taxid == "287":
                if exp_taxid in ["2944252","2545800","2866282"]:
                    level_mat[cur_level]["TP"] += 1
                    fp = False
            # similar s_enterica
            elif true_taxid == "28901":
                if exp_taxid in ["2878546","2816951","2878546","2979871","2500542","2833771","2833772","2204208","2816952","2816953"]:
                    level_mat[cur_level]["TP"] += 1
                    fp = False
            # similar b_subtilis
            elif true_taxid == "96241":
                if exp_taxid == "1963032":
                    level_mat[cur_level]["TP"] += 1
                    fp = False
            # similar e_coli
            elif true_taxid == "2857061" or true_taxid == "562":
                if exp_taxid in ["2860337", "2857061", "562"]:
                    level_mat[cur_level]["TP"] += 1
                    fp = False
            # similar f nucleatum
            elif true_taxid == "851":
                if exp_taxid == "671211":
                    level_mat[cur_level]["TP"] += 1
                    fp = False
            # false taxon assignment on current level
            if fp:
                #if cur_level == "species" and true_taxid == "562":# and low_level == 1:
                #    print(str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
                #    return 1
                level_mat[cur_level]["FP"] += 1
                new_low_level = 2
        # now count on next higher level using the same higher rank taxid 
        if cur_level_index + 1 < len(levels):
            return count_on_level(nodes, nodes[true_taxid]["parent"], nodes[exp_taxid]["parent"], level_mat, levels, cur_level_index + 1, new_low_level)
    # expected taxid is on higher rank than true taxid
    else:
        if not nodes[exp_taxid]["rank"] == cur_level:
            #print(str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
            while not nodes[exp_taxid]["rank"] in levels:
                #print(str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
                if nodes[exp_taxid]["rank"] in ["phylum", "subphylum", "kingdom", "superkingdom"] or exp_taxid == "1":
                    exp_taxid = "1"
                    break
                exp_taxid = nodes[exp_taxid]["parent"]
            
            if exp_taxid == "1":
                new_level_index = len(levels) - 1
            else:
                new_level_index = 0
                for idx, l in enumerate(levels):
                    if nodes[exp_taxid]["rank"] == l:
                        new_level_index = idx
                if cur_level_index == new_level_index:
                    #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
                    return count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, cur_level_index)

            #if levels[new_level_index] == "class":

            #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
            while cur_level_index < new_level_index:
                if low_level == 1:
                    level_mat[cur_level]["TP"] += 1
                elif low_level == 2:
                    level_mat[cur_level]["FP"] += 1
                elif low_level == 3:
                    level_mat[cur_level]["FN"] += 1
                cur_level_index += 1
                cur_level = levels[cur_level_index]
            
            #if cur_level == "class" and not nodes[exp_taxid]["rank"] == "class":
            #    print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])

            if exp_taxid == "1":
                #if cur_level == "class":
                #    print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
                #level_mat[cur_level]["FN"] += 1
                if low_level == 1:
                    level_mat[cur_level]["TP"] += 1
                elif low_level == 2:
                    level_mat[cur_level]["FP"] += 1
                elif low_level == 3:
                    level_mat[cur_level]["FN"] += 1
                return
            #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
            #return 1
            while not nodes[true_taxid]["rank"] == nodes[exp_taxid]["rank"]:
                true_taxid = nodes[true_taxid]["parent"]
                #print(str(true_taxid) + ":" + nodes[true_taxid]["rank"])
                if str(true_taxid) == "1":
                    return 1

            #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
            
            return count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, cur_level_index, low_level)

        elif not nodes[true_taxid]["rank"] == cur_level:
            #print(str(true_taxid) + ":" + nodes[true_taxid]["rank"])
            while not nodes[true_taxid]["rank"] in levels:
                #print(str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
                if nodes[true_taxid]["rank"] in ["phylum", "subphylum", "kingdom", "superkingdom"] or true_taxid == "1":
                    true_taxid = "1"
                    break
                true_taxid = nodes[true_taxid]["parent"]
            
            if true_taxid == "1":
                new_level_index = len(levels) - 1
            else:
                new_level_index = 0
                for idx, l in enumerate(levels):
                    if nodes[true_taxid]["rank"] == l:
                        new_level_index = idx
                if cur_level_index == new_level_index:
                    #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
                    return count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, cur_level_index)

            #if levels[new_level_index] == "class":


            while cur_level_index < new_level_index:
                if low_level == 1:
                    level_mat[cur_level]["TP"] += 1
                elif low_level == 2:
                    level_mat[cur_level]["FP"] += 1
                elif low_level == 3:
                    level_mat[cur_level]["FN"] += 1
                cur_level_index += 1
                cur_level = levels[cur_level_index]
            
            #if cur_level == "class" and not nodes[true_taxid]["rank"] == "class":
            #    print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])

            if true_taxid == "1":
                if low_level == 1:
                    level_mat[cur_level]["TP"] += 1
                elif low_level == 2:
                    level_mat[cur_level]["FP"] += 1
                elif low_level == 3:
                    level_mat[cur_level]["FN"] += 1
                return
            #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
            #return 1
            while not nodes[true_taxid]["rank"] == nodes[exp_taxid]["rank"]:
                exp_taxid = nodes[exp_taxid]["parent"]
                #print(str(true_taxid) + ":" + nodes[true_taxid]["rank"])
                if str(true_taxid) == "1":
                    return 1

            
            
            return count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, cur_level_index, low_level)

        #print(cur_level + "\t" + str(true_taxid) + ":" + nodes[true_taxid]["rank"] + "\t" + str(exp_taxid) + ":" + nodes[exp_taxid]["rank"])
        #return 1
        # results in false negative on current level
        level_mat[cur_level]["FN"] += 1
        # now count on next higher level using the same expected taxid 
        if cur_level_index + 1 < len(levels):
            return count_on_level(nodes, nodes[true_taxid]["parent"], exp_taxid, level_mat, levels, cur_level_index + 1, 3)


def find_lca(taxids, nodes, levels):

    for lev in levels:
        cur_level_ids = []
        for id in taxids:
            cur_id = id

            if not cur_id in nodes.keys():
                    return "1"

            while not nodes[cur_id]["rank"] == lev:
                cur_id = nodes[cur_id]["parent"]
                if not cur_id in nodes.keys():
                    return "1"
                if nodes[cur_id]["rank"] in ["phylum", "subphylum", "kingdom", "superkingdom"] or cur_id == "1":
                    break

            cur_level_ids.append(cur_id)

        all_equal = True
        for id in cur_level_ids:
            if not id == cur_level_ids[0]:
                all_equal = False

        if all_equal:
            return cur_level_ids[0]
        
    return "1"


def confusion_matrix(taxid_file, binning_file, nodes):

    level_mat = {}
    levels = ["species", "genus", "family", "order", "class"]
    for lev in levels:
        level_mat[lev] = {}
        level_mat[lev]["TP"] = 0
        level_mat[lev]["FP"] = 0
        level_mat[lev]["FN"] = 0
    acc = read_taxid_file(taxid_file)
    #fp = 0
    #tp = 0
    unclassfied = 0
    count = 0
    #level = "species"
    last_read = ""
    exp_ids = []
    read_ids = {}
    with open(binning_file, 'rt') as file:
        for line in file:
            if line.startswith("@") or line.startswith("#") or line.startswith("readID"):
                continue
            count += 1
            lsplit = line.strip().split("\t")

            read_ids[lsplit[0]] = 1 #lsplit[7]
            # read has no assignment in Taxor
            if lsplit[1] == "-":
                unclassfied += 1
                continue
            
            # read has no assignment in Centrifuge
            if lsplit[1] == "unclassified":
                unclassfied += 1
                if len(exp_ids) > 0:
                    last_taxid = find_lca(exp_ids, nodes, levels)
                    count_on_level(nodes, true_taxid, last_taxid, level_mat, levels, 0)
                    exp_ids.clear()
                last_read = lsplit[0]
                continue
            
            # read has no assignment in MetaMaps
            if lsplit[1] == "0":
                unclassfied += 1
                continue

            # read has no assignment in Kraken2
            if lsplit[0] == "U":
                unclassfied += 1
                continue
            
            # read classified in Kraken2
            if lsplit[0] == "C":
                # for simulated reads
                if lsplit[1].startswith("GCF_"):
                    refsplit = lsplit[1].split("_")
                    ref = refsplit[0] + "_" + refsplit[1]
                    true_taxid = acc[ref]
                    exp_taxid = lsplit[2]
                    # fill count matrix
                    count_on_level(nodes,true_taxid, exp_taxid, level_mat, levels, 0)
                # for real mock reads
                else:
                    #print(line)
                    readsplit = lsplit[1].split("|")
                    true_taxid = readsplit[1]
                    exp_taxid = lsplit[2]
                    # fill count matrix
                    count_on_level(nodes,true_taxid, exp_taxid, level_mat, levels, 0)
                    
            # all others
            else:
                # for simulated reads
                if lsplit[0].startswith("GCF_"):
                    refsplit = lsplit[0].split("_")
                    ref = refsplit[0] + "_" + refsplit[1]
                    true_taxid = acc[ref]
                    exp_taxid = lsplit[1]
                    # Centrifuge output
                    if len(lsplit) >= 7:
                        if lsplit[7] == "1":
                            exp_taxid = lsplit[2]
                            if len(exp_ids) > 0:
                                last_taxid = find_lca(exp_ids, nodes, levels)
                                last_split = last_read.split("_")
                                last_true_taxid = acc[last_split[0] + "_" + last_split[1]]
                                count_on_level(nodes, last_true_taxid, last_taxid, level_mat, levels, 0)
                                exp_ids.clear()
                        else:
                            #print(str(len(exp_ids)) + "\t" + lsplit[0] + "\t" + last_read)
                            if len(exp_ids) == 0 or lsplit[0] == last_read:
                                exp_ids.append(lsplit[2])
                                last_read = lsplit[0]
                                continue
                            else:
                                last_taxid = find_lca(exp_ids, nodes, levels)
                                last_split = last_read.split("_")
                                last_true_taxid = acc[last_split[0] + "_" + last_split[1]]
                                count_on_level(nodes, last_true_taxid, last_taxid, level_mat, levels, 0)
                                exp_ids.clear()
                                exp_ids.append(lsplit[2])
                                last_read = lsplit[0]
                                continue
                        last_read = lsplit[0]
                    
                    if exp_taxid == "0":
                        continue
                    
                    err = count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, 0)
                    if err == 1:
                        print(line)
                        break
                   
                # for real mock reads
                else:
                    readsplit = lsplit[0].split("|")
                    true_taxid = readsplit[1]
                    exp_taxid = lsplit[1]
                    # Centrifuge output
                    if len(lsplit) >= 7:
                        if lsplit[7] == "1":
                            exp_taxid = lsplit[2]
                            if len(exp_ids) > 0:
                                last_taxid = find_lca(exp_ids, nodes, levels)
                                last_true_taxid = last_read.split("|")[1]
                                #print(last_true_taxid + "\t" + last_taxid)
                                count_on_level(nodes, last_true_taxid, last_taxid, level_mat, levels, 0)
                                exp_ids.clear()
                        else:
                            #print(str(len(exp_ids)) + "\t" + lsplit[0] + "\t" + last_read)
                            if len(exp_ids) == 0 or lsplit[0] == last_read:
                                exp_ids.append(lsplit[2])
                                last_read = lsplit[0]
                                continue
                            else:
                                last_taxid = find_lca(exp_ids, nodes, levels)
                                last_true_taxid = last_read.split("|")[1]
                                count_on_level(nodes, last_true_taxid, last_taxid, level_mat, levels, 0)
                                exp_ids.clear()
                                exp_ids.append(lsplit[2])
                                last_read = lsplit[0]
                                continue
                        last_read = lsplit[0]
                    
                    if exp_taxid == "0":
                        print(line)
                        unclassfied += 1
                        continue

                    err = count_on_level(nodes, true_taxid, exp_taxid, level_mat, levels, 0)
                    if err == 1:
                        print(line)
                        break
                

    #print(len(read_ids.keys()))
    return level_mat, unclassfied



def parse_node_dump_file(node_dmp):
    nodes = {}
    with open(node_dmp, 'rt') as file:
        for line in file:
            lsplit = line.strip().split("|")
            taxid = lsplit[0].strip()
            nodes[taxid] = {}
            nodes[taxid]["rank"] = lsplit[2].strip()
            nodes[taxid]["parent"] = lsplit[1].strip()

    return nodes

def main():

    parser = ArgumentParser("")
    parser.add_argument('-t', '--taxid_map', required=True,
                        help='file mapping ref seq accessions to taxids')
    parser.add_argument('-b', '--bin_file', required=True,
                        help="binning file with classification to taxids")
    parser.add_argument('-n', '--nodes_dump', required=True,
                        help="ncbi taxonomy nodes dump file")                    
    args = parser.parse_args()

    nodes = parse_node_dump_file(args.nodes_dump)
    
    level_mat, unclassified = confusion_matrix(args.taxid_map, args.bin_file, nodes)
    #print(str(tp) + "\t" + str(fp) + "\t" + str(unclassfied) + "\n")
    for level in ["species", "genus", "family", "order", "class"]:
        #print(level)
        print(level + "\tTP:" + str(level_mat[level]["TP"]) + "\tFP:" + str(level_mat[level]["FP"]) +
               "\tFN:" + str(level_mat[level]["FN"] + unclassified) + "\t" + 
               str(level_mat[level]["TP"] + level_mat[level]["FP"] + level_mat[level]["FN"] + unclassified))
        #precision = (float(level_mat[level]["TP"])/(float(level_mat[level]["TP"])+float(level_mat[level]["FP"]))) * 100
        #recall = (float(tp+fp)/float(fp+tp+unclassfied)) * 100
        #recall = (float(level_mat[level]["TP"])/float(level_mat[level]["TP"]+level_mat[level]["FN"]+unclassified)) * 100
        #print(precision)
        #print(recall)
    #print("FP: " + str(fp))
    #print("TP: " + str(tp))
    #print("Unclassified: " + str(unclassfied))


if __name__ == "__main__":
    main()