#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import re
from glob import glob
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from pybedtools import BedTool
from pybedtools.helpers import chromsizes
from multiprocessing import Pool
from transpo import transcripts_counter
from transpo import util


def get_genes_bed(gtf):
    """Given GTF; extract genes and return a bedtool object of the coord"""
    gene_bed = ""
    current_gene = ""
    max_end = 0
    min_start = 9999999999
    for line in open(gtf):
        if line.startswith('#'):
            continue
        ch, way, feat, start, end, score, strand, dot, attr = line.split('\t')
        gene_id = util.GENE_ID_REG.match(attr).group(1)
        if gene_id != current_gene:
            if not current_gene == '':
                gene_bed += "\t".join([cur_ch, gene_start, gene_end, current_gene,
                                       score, cur_strand]) + '\n'
                max_end = 0
                min_start = 9999999999
            if int(start) < min_start:
                gene_start = start
                min_start = int(start)
            current_gene = gene_id
        elif gene_id == current_gene:
            if int(end) > max_end:
                gene_end = end
                max_end = int(end)
                cur_ch = ch
                cur_strand = strand
    return BedTool(gene_bed, from_string=True)


def get_transcripts_bed():
    """Given GTF; extract transcripts and return a bedtool object of the coord
    """
    transcript_bed = ""
    current_transcript = ""
    for line in open(gtf):
        if line.startswith('#'):
            continue
        ch, way, feat, start, end, score, strand, dot, attr = line.split('\t')
        transcript_id = TRANSCRIPT_ID_REG.search(attr).group(1)
        sample = SAMPLE_REG.match(transcript_id).group(1)
        if transcript_id != current_transcript:
            if not current_transcript == '':
                transcript_bed += "\t".join([ch, transcript_start,
                                       transcript_end, current_transcript,
                                       score, strand, sample]) + '\n'
            transcript_start = start
            current_transcript = transcript_id
        elif transcript_id == current_transcript:
            transcript_end = end
    return BedTool(transcript_bed, from_string=True)


def make_idx(df):
    """Given a dataframe, make an index uniq based on the position and tid"""
    myidx = (df.iloc[:, 0].map(str)
             + df.iloc[:, 1].map(str)
             + df.iloc[:, 2].map(str)
             + df.iloc[:, 3].map(str))
    df.index = myidx


def count_trans_chr(chromgtf):
    chrom = os.path.basename(chromgtf).split('_')[0]
    util.logger.info(("Working on {chrom}").format(chrom=chrom))
    res = transcripts_counter.get_transc_counts_from_gtf(chromgtf)
    util.logger.info(("Finished {chrom}").format(chrom=chrom))
    return res


def get_transc_df(gtf, args, logger):
    """Read the GTF, count transcripts and return a dataframe"""
    logger.info("Counting all transcripts occurences...")
    transcripts = []
    files = glob(os.path.join(args.outdir, 'chromcat/*.gtf'))
    files.sort()
    with Pool(args.threads) as p:  # run the multiprocessing
        transcripts = p.map(count_trans_chr, files)
    df = pd.DataFrame([t for trans in transcripts for t in trans])
    df.columns = ['chr', 'start', 'end', 'transcript_id', 'n_exons',
                  'strand', 'n_found', 'tr_list', 'found_in']
    df['sample'] = df['transcript_id'].str.extract(util.SAMPLE_REG, expand=False)
    return df


def p_adjust_bh(p):
    """Benjamini-Hochberg p-value correction for multiple hypothesis testing."""
    p = np.asfarray(p)
    by_descend = p.argsort()[::-1]
    by_orig = by_descend.argsort()
    steps = float(len(p)) / np.arange(len(p), 0, -1)
    q = np.minimum(1, np.minimum.accumulate(steps * p[by_descend]))
    return q[by_orig]


def anno_transc_df(df, logger, args):
    """Add the target overlapping the transcripts in the dataframe"""
    make_idx(df)
    abed = BedTool.from_dataframe(df.iloc[:, :6])
    this_dir, this_filename = os.path.split(__file__)
    logger.info("Annotating the results table with the genes for "
                "{}..".format(args.org))
    target_bed_path = os.path.join(this_dir, 'assets',
                                   '{}_ens_coding_genes_symbol.bed'.format(args.org))
    cols = (['chr', 'start', 'end', 'transcript_id', 'n_exons', 'strand']
            + ['chr_g', 'start_g', 'end_g', 'gene_id', 'n_transcripts',
               'strand_g', 'assoc_gene'])
    target_bed = BedTool(target_bed_path)
    tmp_df = abed.intersect(target_bed, wa=True, wb=True, s=True).to_dataframe(
        names=cols)
    g = tmp_df.groupby(['chr', 'start', 'end', 'transcript_id'])
    assoc_target = g[['assoc_gene', 'gene_id']].agg("|".join).reset_index()
    assoc_target.fillna('NA', inplace=True)

    # handle the transcripts overlapping multiple gene bodies (MGB)
    sel = assoc_target['assoc_gene'].str.contains('\|')
    assoc_target.loc[sel.values, 'transcript_id'].to_csv('/tmp/tcmgb.txt', index=False)
    util.call_w_check(("grep -wF -f /tmp/tcmgb.txt "
                       f"{args.outdir}/all_chim_trans_cat.gtf "
                       "> /tmp/tcmgb.gtf"))
    tcmgb_df = pd.read_csv("/tmp/tcmgb.gtf", sep='\t', header=None)
    tcmgb_df.columns = ['chr', 'method', 'feature', 'start', 'end',
                        'score', 'strand', 'dot', 'desc']
    tcmgb_df['tid'] = tcmgb_df['desc'].apply(lambda x:
                                             util.TRANSCRIPT_ID_REG.search(x).group(1))
    exons_mgb = tcmgb_df.loc[tcmgb_df['feature'] == 'exon', :]
    exons_mgb_bed = BedTool.from_dataframe(exons_mgb[['chr', 'start', 'end',
                                                      'tid', 'dot', 'strand']])
    exons_all_bed = BedTool(args.bed2)
    tmp_df = exons_mgb_bed.intersect(exons_all_bed, wa=True, wb=True,loj=True,
                                     s=True).to_dataframe(names=['chr', 'start', 'end', 'tid',
                                                                 'dot', 'strand', 'chr_g',
                                                                 'start_g', 'end_g', 'exon_gene_id',
                                                                 'score', 'strand_g',
                                                                 'exon_assoc_gene'])
    g = tmp_df.groupby('tid')
    assoc_target_mgb = g[['exon_gene_id', 'exon_assoc_gene']].agg(util.join_pipe).reset_index()
    assoc_target_mgb.index =  assoc_target_mgb.tid
    assoc_target.index = assoc_target['transcript_id']
    assoc_target.loc[assoc_target_mgb.tid, "gene_id"] = assoc_target_mgb.exon_gene_id.values
    assoc_target.loc[assoc_target_mgb.tid, "assoc_gene"] = assoc_target_mgb.exon_assoc_gene.values
    make_idx(assoc_target)
    anno_df = df.join(assoc_target[['assoc_gene', 'gene_id']])
    anno_df.index = anno_df.transcript_id
    anno_df['assoc_gene'].fillna('NA', inplace=True)
    anno_df.index.name = "idx"
    os.remove('/tmp/tcmgb.gtf')
    os.remove('/tmp/tcmgb.txt')
    return anno_df


def plot_pval(pvals, padj, group, args, logger):
    """Plot the density of the pval and padj"""
    logger.info("Computing and ploting the p-values density...")
    try:
        density_pval = stats.gaussian_kde(pvals)
        density_padj = stats.gaussian_kde(padj)
    except np.linalg.linalg.LinAlgError:
        logger.warning("All the pvalues are the same or they are too few, "
                       "skipping pvalue plot")
        return
    xs = np.linspace(0, 1, 100)
    density_pval.covariance_factor = lambda: .25
    density_padj._compute_covariance()
    fig, ax = plt.subplots()
    ax.set_title('P-values density')
    ax.plot(xs, density_pval(xs), color='blue')
    ax.plot(xs, density_padj(xs), color='red')
    ax.legend(['pval', 'padj'])
    fig.savefig('{dir}/figs/{name}'.format(dir=args.outdir,
                                           name=group + '_pval_density.pdf'))


def plot_against(data, g, score, prop_g, sel, args, logger):
    """Plot the counts of a group against the control"""

    logger.info("Computing and ploting the groups versus plots...")

    plt.close('all')
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.scatter(data[args.control], data[g],  marker='.', c=score, cmap='Oranges',
               vmin=-3, vmax=10)
    if sum(sel) == 0:
        logger.warning("No significant chimeric transcripts detected")
    else:
        ax.scatter(data.loc[sel, args.control], data.loc[sel, g],  s=35,
                   edgecolors='lightgrey', facecolors='none')
    ax.set_title('{s} vs {c} plot'.format(s=g, c=args.control))
    ax.set_xlabel('Number of time transcripts found in {}'.format(args.control))
    ax.set_ylabel('Number of time transcripts found in {}'.format(g))
    ax.axis('equal')
    mc = np.max(data[args.control])
    mg = np.max(data[g])
    ax.set_xlim((-10, mc))
    ax.set_ylim((-10, mg))
    ax.plot([0, mc], [0, mg], 'k-')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    ax.get_yaxis().tick_left()
    figname = '{g}_vs_{c}.pdf'.format(g=g, c=args.control)
    fig.savefig('{dir}/figs/{name}'.format(dir=args.outdir,
                                           name=figname))


def merge_gtf(gtf, outputdir, logger):
    """Use stringtie to merge the GTFs of the chimeric transcripts"""
    gtf_merge = "all_chim_merged.gtf"
    if sys.version_info >= (3, 5):
        chimeric_gtfs = glob('{}/*_chimeric.gtf'.format(outputdir),
                             recursive=True)
    else:
        chimeric_gtfs = util.myglob('{}/*_chimeric.gtf'.format(outputdir))
    cmd = ("stringtie --merge {files} -o {out}/{gtf_merge} -f 0.001".format(
        gtf_merge=gtf_merge, out=outputdir, files=" ".join(chimeric_gtfs)))
    logger.info("Merging all the chimeric GTFs files...")
    logger.debug(cmd)
    util.call_w_check(cmd)
    return os.path.join(outputdir, gtf_merge)


def myfisher(x, y, tot_x, tot_y):
    """Make a fisher test and retrun the pvalues"""
    return stats.fisher_exact([[x, y], [tot_x-x, tot_y-y]])[1]


def count_groups(gene_table, args, groups):
    """Counts the number of occurrences of the groups"""
    # Parse the metadata to get totals for the pval
    tot_groups = pd.read_csv(args.metadata, sep='\t',
                               header=0)['groups'].value_counts()
    groups_set = set(groups.values())
    if not args.counts_only:
        groups_set.remove(args.control)
    groups_col = gene_table['sample'].str.split('|').apply(
        lambda x: "|".join([groups[u] for u in x]))
    g_col = []
    split = groups_col.str.split('|')
    gene_table[args.control] = split.apply(lambda x: (x.count(args.control)))

    for g in groups_set:
        counts = split.apply(lambda x: (x.count(g)))
        gene_table[g] = counts
        if not args.counts_only:
            pvals = [myfisher(cg, cc, tot_groups[g], tot_groups[args.control])
                     for cg, cc in
                     zip(counts, gene_table[args.control])]
            pvals = np.array(pvals)
            padj = p_adjust_bh(pvals)
            score = -np.log10(padj)
            prop_g = counts / (counts + gene_table[args.control])
            psel = pvals if args.usepval else padj
            sel = ((psel < 0.05)
                   & ((prop_g > 0.75) | (1-prop_g > 0.75))
                   & ((counts > 0.25*np.max(counts)) | (gene_table[args.control] >
                      0.25*np.max(gene_table[args.control]))))
            gene_table[g + '_delta'] = gene_table[g] - gene_table[args.control]
            gene_table[g + '_pval'] = pvals
            gene_table[g + '_padj'] = padj
            gene_table[g + '_sig'] = 0
            gene_table.loc[sel, g + '_sig'] = 1
            g_col.append(g)
            g_col.append(g + '_delta')
            g_col.append(g + '_pval')
            g_col.append(g + '_padj')
            g_col.append(g + '_sig')
            plot_pval(pvals, padj, g, args, util.logger)
            plot_against(gene_table, g, score, prop_g, sel, args, util.logger)
        else:
            gene_table[g + '_perc'] = counts / tot_groups[g]

    gene_table['groups'] = groups_col
    old_col = [i for i in gene_table.columns if 'sample' not in i]
    col = (old_col + ["sample"])
    return gene_table[col]


def construct_gene_table(anno_df, gtf_cat, logger, outdir, groups):
    """Make the gene table for final export from the anno table"""
    c = anno_df.groupby('gene_id')
    counts = c['found_in'].agg("|".join).str.split("|").apply(
        lambda x: len(set(x))).reset_index()
    counts.columns = ['gene_id', 'n_samples']
    res = pd.merge(anno_df, counts, left_on='gene_id',
                   right_on='gene_id')
    res_s = res.sort_values('n_samples', ascending=False)
    res_s.to_csv('{}/all_transcripts_table.xls'.format(outdir), sep='\t')
    col = ['chr', 'start', 'end', 'gene_id', 'n_samples', 'strand', 'assoc_gene', 'sample']
    g = res[col].groupby('gene_id')

    gene_table = g.agg({'chr': util.get_first,
                        'start': np.median,
                        'end': np.median,
                        'strand': util.get_first,
                        'sample': util.join_pipe,
                        'n_samples': util.get_first,
                        'assoc_gene': util.join_pipe
                        })
    gene_table['gene_id'] = gene_table.index
    return gene_table


def construct_tss_table(outdir, logger):
    """Make the transcripts table per TSS for final export from the anno table"""
    df = pd.read_csv('{}/all_transcripts_table.xls'.format(outdir),
                       sep='\t')
    df.assoc_gene.fillna('---')
    tss_window = 100
    df['tss'] = df.start
    df.loc[df.strand == '-', 'tss'] = df.end
    g = df.groupby('gene_id')
    df_tss = pd.DataFrame()

    for name, grouped in g:
        tss_l = grouped.tss.unique()
        tss_l.sort()
        tmp_arr = [('chr1', i-tss_window, i+tss_window) for i in tss_l]
        if len(tmp_arr) == 1:
            tss_arr = np.array(tmp_arr[0][1:])
        else:
            b = BedTool.from_dataframe(pd.DataFrame(tmp_arr))
            b = b.merge(d=tss_window)
            tss_arr = b.to_dataframe().iloc[:, 1:].values.flatten()
        gg = grouped.groupby(pd.cut(grouped.tss, tss_arr, include_lowest=True,
                                    right=True))
        df_agg = gg.agg({'chr': util.get_first,
                         'start': np.median,
                         'end': np.median,
                         'transcript_id': util.join_pipe,
                         'strand': util.get_first,
                         'n_found': np.sum,
                         'found_in': util.join_pipe,
                         'sample': util.join_pipe,
                         'assoc_gene': util.join_pipe,
                         'gene_id': util.get_first,
                         'tss': lambda x: int(np.median(x))}).reset_index(drop=True).dropna()

        df_tss = df_tss.append(df_agg, ignore_index=True)
    col = ['chr', 'start', 'end', 'transcript_id', 'strand', 'n_found',
           'found_in', 'sample', 'assoc_gene', 'gene_id', 'tss']
    return df_tss[col]


def concat_gtf(outputdir, args, groups, logger):
    """Concat all the transpochimeric gtf into different file per chrom"""
    outname = "all_chim_trans_cat.gtf"
    if sys.version_info >= (3, 5):
        chimeric_gtfs = glob(f'{args.gtfdir}/**/*_chimeric.gtf.gz', recursive=True)
    else:
        chimeric_gtfs = util.myglob(f'{args.gtfdir}/*_chimeric.gtf.gz')
    # subset for the one in the metadata which is being processed
    # pattern = "|".join([f'{args.gtfdir}\/' + k for k in groups.keys()])
    chimeric_gtfs = [i for i in chimeric_gtfs if i.split('/')[-1].split('_chimeric')[0] in groups.keys()]
    logger.debug('Number of GTF to merge: {}'.format(len(chimeric_gtfs)))
    if len(chimeric_gtfs) == 0:
        logger.error(("Cannot find chimeric gtf, check if samples folder"
                     "exists in the results dir and the metadata match!"))
        sys.exit(1)
    cmd = ("zcat {files} > {out}/{name}").format(out=outputdir, name=outname,
                                                files=" ".join(chimeric_gtfs))
    logger.info("Concatenating all the GTF for the chimeric transcripts...")
    try:
        logger.debug(cmd)
        util.call_w_check(cmd)
    except OSError:
        logger.warning("many files; trying concat with python...")
        with open('{out}/{name}'.format(out=outputdir, name=outname), 'w') as outw:
            for g in chimeric_gtfs:
                with open(g) as gin:
                    for line in gin:
                        outw.write(line)
    # Split chromosomes
    logger.info("Splitting concat file per chromosomes...")
    chromlist = [_ for _ in chromsizes(args.org) if '_' not in _ and _ != 'chrM']

    chromdir = '{out}/chromcat/'.format(out=outputdir)

    if not os.path.exists(chromdir):
        os.mkdir(chromdir)

    chromgtf = "{chromdir}/{chrom}_chimeric_cat.gtf"

    chromfiles = {}
    for chrom in chromlist:
        chromfiles[chrom] = open(chromgtf.format(chromdir=chromdir,
                                                 chrom=chrom), 'w')

    with open("{out}/{name}".format(out=outputdir, name=outname)) as catgtf:
        for line in catgtf:
            chrom = line.split()[0]
            if '_' not in chrom and chrom != 'chrM':
                chromfiles[chrom].write(line)

    for f in chromfiles.values():
        f.close()
    return os.path.join(outputdir, outname)


def get_ere_overlaps(gtf, outdir):
    '''Exctract list of ERE on TSS, or 3' end and the exonised ones'''
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    transcripts, n = transcripts_counter.get_transcr_from_gtf(gtf)
    with open('{}/all_chim_tss.bed'.format(outdir), 'w') as tss_f, \
            open('{}/all_chim_exons.bed'.format(outdir), 'w') as exons_f:
        for tr in transcripts:
            chrom, start, end, name, n_exons, strand, _1, _2, _3 = tr.to_list()
            if strand == '+':
                tss_bed = "\t".join([chrom, str(start),
                                     str(start+50), name, str(n_exons), strand])
                tss_f.write(tss_bed + '\n')
                other = tr[1:]
                for o in other:
                    exons_f.write(str(o.to_bed()) + '\n')
            elif strand == '-':
                tss_bed = "\t".join([chrom, str(end-50),
                                     str(end), name, str(n_exons), strand])
                tss_f.write(tss_bed + '\n')
                other = tr[:-1]
                for o in other:
                    exons_f.write(str(o.to_bed()) + '\n')


def get_ere_at_tss(data, args, tss_analysis=False):
    """Give the set of data, anotate for the ERE @ the TSS"""
    tss = pd.DataFrame({'chr': [], 'start': [], 'end': [], 'name': []})
    tss['chr'] = data['chr']
    tss['end'] = data['start'] + 50
    tss['start'] = data['start'] - 50
    sel = data['strand'] == '-'
    tss.loc[sel, 'start'] = data.loc[sel, 'end'] - 50 
    tss.loc[sel, 'end'] = data.loc[sel, 'end'] + 50
    if tss_analysis:
        tss['name'] = data['transcript_id']
    else:
        tss['name'] = data['gene_id']

    tss['start'] = tss['start'].map(int)
    tss['end'] = tss['end'].map(int)
    tss_bed = BedTool.from_dataframe(tss[['chr', 'start', 'end', 'name']])
    target_bed = BedTool(args.bed1)
    ov = target_bed.intersect(tss_bed, wa=True, wb=True).to_dataframe()
    if tss_analysis:
        ov.columns = ['ere_chr', 'ere_start', 'ere_end', 'tss_ere', 'score', 'ere_strand', 'erefam',
                      'chr', 'start', 'end', 'transcript_id']
        g = ov.groupby('transcript_id')['tss_ere'].agg(util.join_pipe)
        col = list(data.columns)
        data.index.name = "idx"
        m = pd.merge(data, g.to_frame(), left_on='transcript_id', right_index=True, how='outer')
        col = col[:4] + ['tss_ere'] + col[4:]
    else:
        ov.columns = ['ere_chr', 'ere_start', 'ere_end', 'tss_ere', 'score', 'ere_strand', 'erefam',
                      'chr', 'start', 'end', 'gene_id']
        g = ov.groupby('gene_id')['tss_ere'].agg(util.join_pipe)
        col = list(data.columns)
        data.index.name = "idx"
        m = pd.merge(data, g.to_frame(), left_on='gene_id', right_index=True, how='outer')
        col = col[:4] + ['tss_ere'] + col[4:]
    return m[col]


def main(args, logger, groups):
    """Main function running the transcripts analysis"""
    logger.info("\n" + 53*"-")
    logger.info("### Analysing the discovered chimeric transcripts ###\n")
    logger.info(53*"-")
    # Getting the dataframe with annotation of genes overlapping
    picklef = os.path.join(args.outdir, 'all_transcripts_counts_df.pickle')
    if not os.path.exists(picklef) or args.force:
        gtf_cat = concat_gtf(args.outdir, args, groups, logger)
        df = get_transc_df(gtf_cat, args, logger)
        df = anno_transc_df(df, logger, args)
        logger.info("Pickling the dataframe...")
        df.to_pickle(picklef)
    else:
        logger.warning(("Pickle detected, reading from it (use force to force "
                       "re-counting of the chimeric transcripts)..."))
        gtf_cat = "{}/all_chim_trans_cat.gtf".format(args.outdir)
        df = pd.read_pickle(picklef)
        df.index.name = "idx"
    notfound = (df.gene_id == "")
    notfound_tr = "\n".join(df.loc[df.gene_id == "", 'transcript_id']) + '\n'
    logger.warning(f"Warning: {notfound.sum()} transcript have not been anotated:\n{notfound_tr}") 
    df.loc[notfound, 'gene_id'] = [f'unkown{i}' for i in range(notfound.sum())]
    logger.info("Extracting EREs under chimeric transcripts")
    logger.info("Finalizing dataframe and exporting results...")
    gene_table = construct_gene_table(df, gtf_cat, logger,
                                      args.outdir, groups)
    gene_table_export = count_groups(gene_table, args, groups)
    gene_table_export = get_ere_at_tss(gene_table_export, args)

    if args.counts_only is False:
        firstpval = np.argmax(gene_table_export.columns.str.contains('_padj'))
        firstdelta = np.argmax(gene_table_export.columns.str.contains('_delta'))
        gs = gene_table_export.sort_values([gene_table_export.columns[firstpval],
                                            gene_table_export.columns[firstdelta]],
                                           ascending=[True, False])
    else:
        gs = gene_table_export
    gs.to_csv('{}/chimeric_genes_table.xls'.format(args.outdir),
              sep='\t', index=False)

    if args.tss_analysis:
        logger.info("Computing the TSS table (slow)...")
        df_tss = construct_tss_table(args.outdir, logger)
        count_groups(df_tss, args, groups)
        df_tss = get_ere_at_tss(df_tss, args, tss_analysis=True)
        df_tss.to_csv('{}/chimeric_tss_table.xls'.format(args.outdir),
                      sep='\t', index=False)
