import sys
import argparse

import pickle

from goatools.obo_parser import GODag
from goatools.gosubdag.gosubdag import GoSubDag


def load_go_annot(input_file, gene_header_name, go_header_name, gene_name_header):

    """

    """

    go = {}
    gene_names = {}
    seen = set()
    with open(input_file, 'r') as infile:
        for i, line in enumerate(infile):

            line = line[:-1].split("\t")

            if i == 0:
                idx_gene = line.index(gene_header_name)
                idx_goid = line.index(go_header_name)
                idx_symbol = line.index(gene_name_header)

            else:
                gene = line[idx_gene]
                go_id = line[idx_goid]

                if go_id == '' or gene == '':
                    continue

                gene_name = line[idx_symbol]
                if gene_name == "":
                    gene_name_save = 'N.A'
                elif gene_name in seen:
                    gene_name_save = gene_name + "'"
                else:
                    gene_name_save = gene_name

                if gene_name_save != 'N.A':
                    seen.add(gene_name)

                go[go_id] = go.get(go_id, set())
                go[go_id].add(gene)

                gene_names[gene] = gene_name

    return go, gene_names


def get_infos_go(go_id, dag, subdag, infos_dict):

    """
    """

    if go_id in infos_dict:
        return

    prtfmt = '{GO_name}'

    level = dag[go_id].level
    depth = dag[go_id].depth

    ntgo = subdag.go2nt[go_id]
    name = prtfmt.format(**ntgo._asdict())

    infos_dict[go_id] = (name, level, depth)



def get_all_ancestors(go_dict, godag, domain="BP",
                      optional_relationships = {'regulates','negatively_regulates',
                                                'positively_regulates'}):

    """
    """

    all_go = {}
    go_infos = {}

    tot = len(go_dict)
    sys.stderr.write(f"Browsing the ontology to extract ancestor terms for {tot} GO IDs...\n")
    for i, go_id in enumerate(go_dict):

        if i and i % 100 == 0:
            sys.stderr.write(f"Processed {i} GO IDs ({round(i/tot * 100, 2)}%)\n")

        gosubdag = GoSubDag([go_id], godag, relationships=optional_relationships, prt=None)
        if gosubdag.go2nt[go_id].NS == domain:
            all_go[go_id] = all_go.get(go_id, set())
            all_go[go_id].update(go_dict[go_id])

            get_infos_go(go_id, godag, gosubdag, go_infos)

            ancestors = gosubdag.rcntobj.go2ancestors[go_id]
            for anc in ancestors:
                all_go[anc] = all_go.get(anc, set())
                all_go[anc].update(go_dict[go_id])

                get_infos_go(anc, godag, gosubdag, go_infos)

    return all_go, go_infos


if __name__ == '__main__':

    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)


    PARSER.add_argument('-g', '--go_annotations', type=str, required=False, default="data/Hgla_GO.tsv")

    PARSER.add_argument('--gene_header', type=str, required=False, default="Gene stable ID")

    PARSER.add_argument('--go_header', type=str, required=False, default='GO ID')

    PARSER.add_argument('--name_header', type=str, required=False, default='Symbol')


    PARSER.add_argument('-gd', '--go_dag', type=str, required=False, default="data/go-basic.obo")

    PARSER.add_argument('-o1', '--output1', type=str, required=False, default="nmr_gene_go_full.pkl")

    PARSER.add_argument('-o2', '--output2', type=str, required=False, default="go_infos.pkl")

    PARSER.add_argument('-o3', '--output3', type=str, required=False, default="nmr_gene_names.pkl")

    PARSER.add_argument('--go_domain', type=str, required=False, default='BP')

    ARGS = vars(PARSER.parse_args())

    GO_DICT, NAMES = load_go_annot(ARGS["go_annotations"], ARGS["gene_header"], ARGS["go_header"],
                                   ARGS["name_header"])

    with open(ARGS["output3"], 'wb') as out:
        pickle.dump(NAMES, out)

    GODAG = GODag(ARGS["go_dag"], optional_attrs={'relationship'})

    ALL_GO, GO_INFO = get_all_ancestors(GO_DICT, GODAG, domain=ARGS["go_domain"])

    with open(ARGS["output1"], 'wb') as out1:
        pickle.dump(ALL_GO, out1)

    with open(ARGS["output2"], 'wb') as out2:
        pickle.dump(GO_INFO, out2)