#!/usr/bin/env python
# -*- coding: utf-8 -*-

from pybedtools import BedTool
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns


def unmerge(repname):
    """Starting from a repeat name, get the name of the ere family"""
    if type(repname) == str:
        l = repname.split('|')
        l_int = [u for u in l if '-int' in u]
        if len(set(l_int)) == 1:
            return l_int[0]
        c = Counter(l)
        mc = c.most_common(1)
        if len(mc) > 1:
            mc = [m for m in mc if '-int' in mc]
        if len(mc) == 0:
            return c.most_common(1)[0][0]
        else:
            return  mc[0][0]
    else:
        return [unmerge(elem) for elem in repname]


def get_hypresults(bedtss, bedexon, bedtes, ere, outfolder, name=""):
    """ Overlap chim transcripts with merge ere and make hypergeom test for
    enrichement"""
    enri = {}
    scores = {}
    all_ere = BedTool(ere)

    # tss
    ere_tss = all_ere.intersect(bedtss, u=True)
    ere_tss.saveas('/tmp/ere_at_tss_{}.bed'.format(name))
    ere_tss.saveas('../results/{}/ere_at_tss_tc.bed'.format(outfolder))
    ere_tss = ere_tss.to_dataframe()
    ere_tss.columns = ["chrom", "start", "end", "name", "score", "strand",
                       "class.fam"]
    ere_tss['unmerged'] = common_fun.unmerge(ere_tss['name'])

    ere_all_df = all_ere.to_dataframe()
    ere_all_df.columns = ["chrom", "start", "end", "name", "score", "strand",
                       "class.fam"]
    ere_all_df['unmerged'] = common_fun.unmerge(ere_all_df['name'])
    hyp_tss = common_fun.hypertest(ere_tss, ere_all_df, 'unmerged')
    hyp_tss['padj'] = common_fun.pval_adjust(hyp_tss['pval'].values)
    hyp_tss.to_csv('../results/{}/hyp_tss.xls'.format(outfolder),sep='\t')
    scores['tss'] = hyp_tss['padj']
    enri['tss'] = hyp_tss['prop'] / hyp_tss['prop_tot']

    #  exon
    ere_exons_100 = all_ere.intersect(bedexon, u=True, f=1.0)
    ere_exons_any = all_ere.intersect(exons_bed, u=True)
    # remove the non-fully overlapping bed1
    ere_exons = ere_exons_any.intersect(ere_exons_100, v=True)
    ere_exons.saveas('/tmp/ere_at_exons_{}.bed'.format(name))
    ere_exons.saveas('../results/{}/ere_at_exons_tc.bed'.format(outfolder))
    ere_exons = ere_exons.to_dataframe()
    ere_exons.columns = ["chrom", "start", "exons", "name", "score", "strand",
                       "class.fam"]
    ere_exons['unmerged'] = common_fun.unmerge(ere_exons['name'])
    hyp_exons = common_fun.hypertest(ere_exons, ere_all_df, 'unmerged')
    hyp_exons['padj'] = common_fun.pval_adjust(hyp_exons['pval'].values)
    hyp_exons.to_csv('../results/{}/hyp_exons.xls'.format(outfolder),sep='\t')
    scores['exons'] = hyp_exons['padj']
    enri['exons'] = hyp_exons['prop'] / hyp_exons['prop_tot']

    # tse
    ere_tse = all_ere.intersect(bedtes, u=True)
    ere_tse.saveas('/tmp/ere_at_tse_{}.bed'.format(name))
    ere_tse.saveas('../results/{}/ere_at_tse_tc.bed'.format(outfolder))
    ere_tse = ere_tse.to_dataframe()
    ere_tse.columns = ["chrom", "start", "end", "name", "score", "strand",
                       "class.fam"]
    ere_tse['unmerged'] = common_fun.unmerge(ere_tse['name'])

    ere_all_df = all_ere.to_dataframe()
    ere_all_df.columns = ["chrom", "start", "end", "name", "score", "strand",
                       "class.fam"]
    ere_all_df['unmerged'] = common_fun.unmerge(ere_all_df['name'])
    hyp_tse = common_fun.hypertest(ere_tse, ere_all_df, 'unmerged')
    hyp_tse['padj'] = common_fun.pval_adjust(hyp_tse['pval'].values)
    hyp_tse.to_csv('../results/{}/hyp_tse.xls'.format(outfolder),sep='\t')
    scores['tse'] = hyp_tse['padj']
    enri['tse'] = hyp_tse['prop'] / hyp_tse['prop_tot']


    hyps = {'tss': hyp_tss,
            'exons': hyp_exons,
            'tse': hyp_tse}

    return hyps, scores, enri


hyps_p, scores_p, enri_p = get_hypresults(['../results/P_ere_overlaps/all_chim_tss.bed', 
                                     '../results/P_ere_overlaps/all_chim_exons.bed'],
                                          "P_ere_overlaps", "primed")


def make_heatmap(hyps):
    ### Plotting
    founds_df_p = pd.DataFrame({'tss_p': hyps_p['tss']['found'], 
                              'exons_p': hyps_p['exons']['found']})
    founds_df_n = pd.DataFrame({'tss_n': hyps_n['tss']['found'], 
                              'exons_n': hyps_n['exons']['found']})
    founds_df = pd.concat([founds_df_n, founds_df_p], axis=1)
    scores_df_n = pd.DataFrame.from_dict(scores_n)
    col_n = [i + '_n' for i in scores_df_n.columns]
    scores_df_n.columns = col_n
    scores_df_p = pd.DataFrame.from_dict(scores_p)
    col_p = [i + '_p' for i in scores_df_p.columns]
    scores_df_p.columns = col_p
    scores_df = pd.concat([scores_df_n, scores_df_p], axis=1)

    enri_df_n = pd.DataFrame.from_dict(enri_n)
    enri_df_n.columns = [i + '_n' for i in enri_df_n.columns]
    enri_df_p = pd.DataFrame.from_dict(enri_p)
    enri_df_p.columns = [i + '_p' for i in enri_df_p.columns]
    enri_df = pd.concat([enri_df_n, enri_df_p], axis=1)

    col = ['tss_n','tss_p', 'exons_n', 'exons_p']

    scores_df = scores_df[col]
    enri_df = enri_df[col]
    founds_df = founds_df[col]
    scores_df = -np.log10(scores_df)

    assert all(scores_df.index == enri_df.index)

    sel = (scores_df > -np.log10(0.05)).any(axis=1)
    sig = scores_df.loc[sel, :]
    enri_df_s = enri_df.loc[sel, :]
    founds_df_s = founds_df.loc[sel, :].fillna(0)

    plt.close('all')
    fig, ax = plt.subplots(figsize=(8,16))
    idx = np.argsort(founds_df_s[['tss_n', 'tss_p']].sum(axis=1).values)[::-1]
    cmap = matplotlib.cm.coolwarm
    cmap.set_bad('lightgrey',0.5)
    hm = sns.heatmap(sig.iloc[idx, :], annot=founds_df_s.iloc[idx, :],
                    cmap=cmap, fmt='.0f', vmax=6)
    # a = sns.heatmap(sig, annot=enri_df_s,
    #                 cmap=cmap, fmt='.0f', vmax=6, vmin=0)
    #p = hm.get_children()[1]
    #p.set_visible(False)
    plt.yticks(rotation=0)
    plt.xticks(rotation=45)
    plt.subplots_adjust(left=0.2)
    plt.axvline(2, color='darkgrey')
    plt.axvline(4, color='darkgrey')
    plt.savefig('../results/figs/heatmap_NP_sep.pdf')

    ### plot figure paper
    # fc = founds_df_s.iloc[:, [0,2,4]].values / founds_df_s.iloc[:, [1,3,5]].values
    # fc_df = pd.DataFrame(fc, index=founds_df_s.index)
    plt.close('all')
    fig, ax = plt.subplots(figsize=(8,16))
    idx = np.argsort(founds_df_s[['tss_n', 'tss_p']].sum(axis=1).values)[::-1]
    cmap = matplotlib.cm.coolwarm
    cmap.set_bad('lightgrey',0.5)
    hm = sns.heatmap(sig.iloc[idx, :], annot=founds_df_s.iloc[idx, :],
                    cmap=cmap, fmt='.0f', vmax=6)
    # a = sns.heatmap(sig, annot=enri_df_s,
    #                 cmap=cmap, fmt='.0f', vmax=6, vmin=0)
    #p = hm.get_children()[1]
    #p.set_visible(False)
    plt.yticks(rotation=0)
    plt.xticks(rotation=45)
    plt.subplots_adjust(left=0.2)
    plt.axvline(2, color='darkgrey')
    plt.axvline(4, color='darkgrey')
    plt.savefig('../results/figs/heatmap_NP_sep.pdf')


def run_enrich(args):
    get_hypresults(f"{args.outdir}/all_chim_tss", f"{args.outdir}/all_chim_exons")
