"""
"""

import sys
import argparse

import pickle

import scipy.stats as ss
import pingouin as pg


def check_overlap(a, b):
    overlap = False
    if max(0, min(a[1], b[1]) - max(a[0], b[0])):
        overlap = True
    return overlap


def merge_overlaping_intervals(intervals):
    '''
    intervals should be sorted
    '''
    merged = []
    for higher in intervals:
        if not merged:
            merged.append(higher)
        else:
            lower = merged[-1]
            if higher[0] <= lower[1]:
                upper_bound = max(lower[1], higher[1])
                merged[-1] = (lower[0], upper_bound)
            else:
                merged.append(higher)
    return merged


def load_domains(input_file):
    d = {}
    with open(input_file, 'r') as infile:
        for line in infile:
            chrom, s, e, gene  = line.strip().split("\t")
            d[chrom] = d.get(chrom, {})
            d[chrom][(int(s), int(e))] = d[chrom].get((int(s), int(e)), [])
            d[chrom][(int(s), int(e))].append(gene)
    return d


def regions_to_genes(input_regfile, domains, transform_to_go=None):
    d = {}
    dg = {}
    with open(input_regfile, 'r') as infile:
        for line in infile:
            chrom, s, e, reg  = line.strip().split("\t")
            if chrom in domains:
                intervals = list(domains[chrom].keys())
                for interval in intervals:
                    if check_overlap((int(s), int(e)), interval):
                        d[reg] = d.get(reg, set())
                        for g in domains[chrom][interval]:

                            if transform_to_go is not None:
                                if g not in transform_to_go:
                                    continue

                                for go_tmp1 in transform_to_go[g]:
                                    d[reg].add(go_tmp1)
                                    dg[go_tmp1] = dg.get(go_tmp1, set())
                                    dg[go_tmp1].add(reg)

                            else:
                                d[reg].add(g)
                                dg[g] = dg.get(g, set())
                                dg[g].add(reg)

    print(len(d))
    print(len(dg))
    return len(d), dg


def genomic_background(domains, transform_to_go):

    dall = {}

    #load genomic background
    for chrom in domains:
        intervals = list(domains[chrom].keys())
        for interval in intervals:
            for g in domains[chrom][interval]:
                if g not in transform_to_go:
                    continue

                for go_tmp1 in transform_to_go[g]:
                    dall[go_tmp1] = dall.get(go_tmp1, set())
                    dall[go_tmp1].add((chrom, interval))

    dcov = {}
    all_intervals_wit_go = {}
    for g in dall:
        tot = 0

        temp = {}

        for (chrom, interval) in dall[g]:

            temp[chrom] = temp.get(chrom, [])
            temp[chrom].append(interval)

            all_intervals_wit_go[chrom] = all_intervals_wit_go.get(chrom, set())
            all_intervals_wit_go[chrom].add(interval)

        for chrom in temp:

            intervals = [(int(s), int(e)) for (s, e) in temp[chrom]]
            intervals = merge_overlaping_intervals(sorted(intervals))
            for (s, e) in intervals:
                tot += e - s
        dcov[g] = tot

    tot = 0
    for chrom in all_intervals_wit_go:
        intervals = [(int(s), int(e)) for (s, e) in all_intervals_wit_go[chrom]]
        intervals = merge_overlaping_intervals(sorted(intervals))
        for (s, e) in intervals:
            tot += e - s
    return dcov, tot


def write_res(res, output_file, gene_names=None, go_names=None, test_type="genes_hypergeom",
              go_to_genes=None, foreground_genes=None, freg=None):

    with open(output_file, 'w') as out:

        if test_type == "genes_hypergeom":

            res_type = "Gene"
            out.write('\t'.join([res_type, "Name", "Enrichment", "# foreground reg", "# background reg", "BH p-value"])+'\n')


        elif test_type == "go_over_regions_hypergeom":
            res_type = "GO"

            out.write('\t'.join([res_type, "Name", "Enrichment", "# foreground reg", "# background reg", "BH p-value", "Genes", "Gene_IDS", "Regions"])+'\n')

        elif test_type == "go_over_genes_hypergeom":

            res_type = "GO"

            out.write('\t'.join([res_type, "Name", "Enrichment", "# foreground genes", "# expected genes", "BH p-value", "Genes", "Gene_IDS", "Regions"])+'\n')

        elif test_type == "go_over_regions_binom":

            res_type = "GO"

            out.write('\t'.join([res_type, "Name", "Enrichment", "# foreground regions", "# expected regions", "BH p-value", "Genes", "Gene_IDS", "Regions"])+'\n')

        else:
            sys.stderr.write("Error: Unsupported test type.")
            sys.exit(9)

        for entry in res:

            myid, bh_pval, ratio, nb, NB = entry
            name = 'N.A'
            if res_type == "Gene" and gene_names is not None and myid in gene_names:
                name = gene_names[myid]

            if res_type == "Gene" and myid == "ENSHGLG00100016419":
                name = "ENSHGLG00100016419 (Nebl)"

            if res_type == "GO":

                name = myid
                if go_names is not None and myid in go_names:
                    name = go_names[myid]

                tmp_genes = []
                gene_ids = []
                tmp_regions = []
                if freg:
                    tmp_regions = freg[myid]
                for g in go_to_genes[myid]:
                    if foreground_genes and g not in foreground_genes:
                        continue

                    gene_ids.append(g)

                    if gene_names is not None and g in gene_names:
                        g = gene_names[g]
                    tmp_genes.append(g)

                genes = ','.join(sorted(tmp_genes))
                regions = ','.join(sorted(tmp_regions))
                gene_ids = ','.join(sorted(gene_ids))
                out.write('\t'.join([myid, name, ratio, nb, NB, "{:.2e}".format(bh_pval), genes, gene_ids, regions])+'\n')


            else:
                out.write('\t'.join([myid, name, ratio, nb, NB, "{:.2e}".format(bh_pval)])+'\n')


def genes_with_go(dict_go2genes, foreground_go, foreground_genes=None):

    dgo = {}
    tot_genes = set()
    for go in foreground_go:

        for gene in dict_go2genes[go]:
            if foreground_genes is not None and gene not in foreground_genes:
                continue
            dgo[go] = dgo.get(go, set())
            dgo[go].add(gene)
            tot_genes.add(gene)
    return len(tot_genes), dgo


def filter_redundant(go2genes_foreground, go_infos):

    seen = set()
    go_to_keep = set()

    for go_id in go2genes_foreground:
        if go_id in seen:
            continue
        genes = go2genes_foreground[go_id]
        redundant_go = {i for i in go2genes_foreground if go2genes_foreground[i] == genes}
        seen.update(redundant_go)

        redundant_go.add(go_id)

        best_level = 0
        best_depth = 0

        for go in redundant_go:

            _, level, depth = go_infos[go]

            if level > best_level:
                best_level = level

            if depth > best_depth:
                best_depth = depth

        go_to_keep.update({i for i in redundant_go if go_infos[i][1] == best_level and go_infos[i][2] == best_depth})
    return go_to_keep


def hypergeom(d_foreground, d_background, n, N, store_expected=False, fdr=0.05, keep_all=False):

    genes = list(d_foreground.keys())
    pvals = []
    ratios = []
    nb = []
    NB = []
    for gene in genes:

        k = len(d_foreground[gene])
        K = len(d_background[gene])

        pval = ss.hypergeom.sf(k-1, N, K, n)

        pvals.append(pval)
        ratios.append(str((k/n)/(K/N)))
        nb.append(str(k))

        if not store_expected:
            NB.append(str(K))
        else:
            NB.append(str(K/N * n))

    reject, pvals_corr = pg.multicomp(pvals, method='fdr_bh', alpha=fdr)
    enriched = []
    for i, dec in enumerate(reject):
        if dec or keep_all:
            enriched.append((genes[i], pvals_corr[i], ratios[i], nb[i], NB[i]))

    #sort by p-val
    enriched = sorted(enriched, key=lambda x: x[1])
    return enriched


def binom(nreg_foreground, d_foreground, d_genome, fdr=0.05, keep_all=False):

    go_ids = list(d_foreground.keys())
    pvals = []
    ratios = []
    nb = []
    NB = []
    for go_id in go_ids:

        k = len(d_foreground[go_id])
        p = d_genome[go_id]

        pval = ss.binom_test(k, nreg_foreground, p)

        pvals.append(pval)
        ratio = (k/nreg_foreground) / p
        ratios.append(ratio)
        nb.append(str(k))

        NB.append(str(p * nreg_foreground))


        # print(k, go_id)
        # print(pval, nreg_foreground, p, ratio)

    reject, pvals_corr = pg.multicomp(pvals, method='fdr_bh', alpha=fdr)
    enriched = []
    for i, decision in enumerate(reject):
        if (decision and ratios[i] > 1) or keep_all:
            enriched.append((go_ids[i], pvals_corr[i], str(ratios[i]), nb[i], NB[i]))

    #sort by p-val
    enriched = sorted(enriched, key=lambda x: x[1])
    return enriched


if __name__ == '__main__':
    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-f', '--foreground', type=str, required=True)

    PARSER.add_argument('-g2', '--go_infos', type=str, required=False, default="") #go_infos.pkl, info about go_name, level and depth


    PARSER.add_argument('-d', '--domains', type=str, required=False, default='reg_domains_hgla.tsv')

    PARSER.add_argument('-b', '--background', type=str, required=False, default=None)

    PARSER.add_argument('-g1', '--go_annotations', type=str, required=False, default="nmr_gene_go_full.pkl")

    PARSER.add_argument('-g3', '--gene_names', type=str, required=False, default="nmr_gene_names.pkl") 

    PARSER.add_argument('-o', '--output_genes', type=str, required=False, default="")

    PARSER.add_argument('-o1', '--output_go1', type=str, required=False, default="")

    PARSER.add_argument('-o2', '--output_go2', type=str, required=False, default="")

    PARSER.add_argument('-o3', '--output_go3', type=str, required=False, default="")

    PARSER.add_argument('-t', '--type', type=str, required=False, default="")

    PARSER.add_argument('-ot', '--output_genetable', type=str, required=False, default="")

    PARSER.add_argument('-oft', '--output_foreground_genetable', type=str, required=False, default="")

    PARSER.add_argument('--fdr', type=float, required=False, default=0.05)

    PARSER.add_argument('--filter', type=str, required=False, default="")

    PARSER.add_argument('--keep_all', action='store_true')

    # do not test overly general terms (no more than 1000 genes annotated to it)
    PARSER.add_argument('--max_size', type=int, required=False, default=1000)

    PARSER.add_argument('--min_size', type=int, required=False, default=2)

    ARGS = vars(PARSER.parse_args())

    if ARGS["type"]:
        for outputs in ['output_go1', 'output_go2', 'output_go3']:
            ARGS[outputs] = ARGS[outputs].replace("c2pathways", ARGS["type"])


    #Load gene-go
    with open(ARGS["go_annotations"], 'rb') as infile:
        GO2GENES = pickle.load(infile)

    #Load gos infos
    GO_INFOS = None
    if ARGS["go_infos"]:
        with open(ARGS["go_infos"], 'rb') as infile:
            GO_INFOS = pickle.load(infile)

    #Load gene names
    with open(ARGS["gene_names"], 'rb') as infile:
        GENE_NAMES = pickle.load(infile)

    #Filter out GO terms that are too general
    GO2GENES = {k:v for (k, v) in GO2GENES.items() if len(v) < ARGS["max_size"] and len(v) > ARGS["min_size"]}

    #Select only GO:pathway containing the requested string
    if ARGS["filter"]:
        GO2GENES = {k:v for (k, v) in GO2GENES.items() if ARGS["filter"] in k}


    GO_NAMES = None

    if GO_INFOS:

        #save go names in simple dict
        GO_NAMES = {k:v[0] for (k, v) in GO_INFOS.items()}

    #Load genes regulatory domains
    DOM = load_domains(ARGS["domains"])


    ##################################################################
    ## Test 1: Hypergeometric test for enriched genes vs background ##
    ##################################################################


    #Get genes associated with foreground regions
    n, gdf = regions_to_genes(ARGS["foreground"], DOM)

    if ARGS["output_foreground_genetable"]:
        with open(ARGS["output_foreground_genetable"], 'w') as out:
            out.write("nmr_gene\tmouse_gene_name\tregions\n")
            for gene in gdf:
                regions = ",".join(gdf[gene])
                gene_name = GENE_NAMES.get(gene, "N.A")
                out.write('\t'.join([gene, gene_name, regions])+'\n')


    if ARGS["background"] and ARGS["output_genes"]:
        sys.stderr.write('Hypergeometric test for enriched genes vs background...')
        sys.stdout.flush()

        #Get genes associated with background regions
        N, gdb = regions_to_genes(ARGS["background"], DOM)

        with open(ARGS["output_genetable"], 'w') as out:
            out.write("nmr_gene\tmouse_gene_name\tregions\n")
            for gene in gdb:
                regions = ",".join(gdb[gene])
                gene_name = GENE_NAMES.get(gene, "N.A")
                out.write('\t'.join([gene, gene_name, regions])+'\n')


        res_genes = hypergeom(gdf, gdb, n, N, fdr=ARGS["fdr"], keep_all=ARGS["keep_all"])

        #Write tables with all info to test for enriched genes
        write_res(res_genes, ARGS["output_genes"], GENE_NAMES)

        sys.stderr.write('OK\n')
        sys.stdout.flush()

    # else:
    #     open(ARGS["output_genes"], 'w').close()

    ############################################################################
    ## Test 2: Hypergeometric test over regions for enriched go vs background ##
    ############################################################################



    GENES2GO = {}
    for go in GO2GENES:
        for gene in GO2GENES[go]:
            GENES2GO[gene] = GENES2GO.get(gene, set())
            GENES2GO[gene].add(go)

    #Get GOs associated with foreground regions
    n, df = regions_to_genes(ARGS["foreground"], DOM, transform_to_go=GENES2GO)

    if GO_INFOS:
        #Filter out redundant gos (remove from df and GENES2GO)
        go_to_keep = filter_redundant(df, GO_INFOS)
        GO2GENES = {k:v for (k, v) in GO2GENES.items() if k in go_to_keep}
        df = {k:v for (k, v) in df.items() if k in go_to_keep}

        #Invert dict to get genes to go correspondance
        GENES2GO = {}
        for go in GO2GENES:
            for gene in GO2GENES[go]:
                GENES2GO[gene] = GENES2GO.get(gene, set())
                GENES2GO[gene].add(go)


    if ARGS["background"] and ARGS["output_go1"]:

        sys.stderr.write('Hypergeometric test over regions for enriched GO vs background...')
        sys.stdout.flush()

        #Get number of regions associated with these GOs in background
        N, db = regions_to_genes(ARGS["background"], DOM, transform_to_go=GENES2GO)

        res_go_hypergeom_over_regions = hypergeom(df, db, n, N, fdr=ARGS["fdr"], keep_all=ARGS["keep_all"])
        write_res(res_go_hypergeom_over_regions, ARGS["output_go1"], gene_names=GENE_NAMES,
                    go_names=GO_NAMES, test_type="go_over_regions_hypergeom",
                    go_to_genes=GO2GENES, foreground_genes=list(gdf.keys()), freg=df)


        sys.stderr.write('OK\n')
        sys.stdout.flush()

    # else:
    #     open(ARGS["output_go1"], 'w').close()

    ##########################################################################
    ## Test 3: Hypergeometric test over genes for enriched go vs genome ##
    ##########################################################################

    if ARGS["output_go2"]:

        sys.stderr.write('Hypergeometric test over genes for enriched GO vs genome...')
        sys.stdout.flush()

        #Get forgreound genes associated to gos
        ng, dfg = genes_with_go(GO2GENES, list(df.keys()), foreground_genes=list(gdf.keys()))


        #Get total genes associated to gos
        Ng, dbg = genes_with_go(GO2GENES, list(df.keys()))

        res_go_hypergeom_over_genes = hypergeom(dfg, dbg, ng, Ng, store_expected=True, fdr=ARGS["fdr"], keep_all=ARGS["keep_all"])

        write_res(res_go_hypergeom_over_genes, ARGS["output_go2"], gene_names=GENE_NAMES,
                    go_names=GO_NAMES, test_type="go_over_genes_hypergeom",
                    go_to_genes=GO2GENES, foreground_genes=list(gdf.keys()), freg=df)


        sys.stderr.write('OK\n')
        sys.stdout.flush()

    ##################################################################
    ## Test 4: binomial test over regions for enriched go vs genome ##
    ##################################################################


    if ARGS["output_go3"]:

        sys.stderr.write('Binomial test over regions for enriched GO vs genome...')
        sys.stdout.flush()

        dg, TOT = genomic_background(DOM, transform_to_go=GENES2GO)

        for go in dg:
            dg[go] = dg[go] / TOT

        res_go_binom = binom(n, df, dg, fdr=ARGS["fdr"], keep_all=ARGS["keep_all"])

        write_res(res_go_binom, ARGS["output_go3"], gene_names=GENE_NAMES,
                  go_names=GO_NAMES, test_type="go_over_regions_binom",
                  go_to_genes=GO2GENES, foreground_genes=list(gdf.keys()), freg=df)

        sys.stderr.write('OK\n')
        sys.stdout.flush()
