import sys, os

import pickle
import subprocess

from math import log
from time import sleep
from rpy2 import robjects
from random import sample
from itertools import repeat
from tempfile import NamedTemporaryFile
from statsmodels.sandbox.stats.multicomp import multipletests

# import plotting packages
import matplotlib as mpl
mpl.use('pdf')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

CONV_GENE_IDS = False
SPLIT_GENES = False

def calc_p_hypers(mer_dist, samp, expr_genes):
    mer_sums = mer_dist[expr_genes,:].sum(axis=0)
    num_mers = mer_sums.sum()
    num_mers_drawn = mer_dist[samp,:].sum()
    p_hypers = [
        robjects.r('phyper({:d}, {:d}, {:d}, {:d}, lower.tail=FALSE)'.format(
                mer_dist[samp, i].sum(), mer_sums[i], num_mers - mer_sums[i],
                num_mers_drawn))[0]
        for i in range(mer_dist.shape[1])]

    return p_hypers

def sim_and_plot(all_mer_dist, num_gene_inds, hyper_p_vals,
                 gene_set_fn, expr_gene_inds, pp, num_sims=100):
    if VERBOSE: sys.stderr.write('simulating gene sets\n')
    # simulate gene sets of similar sizes to gene set of interest
    sim_phypers = []
    for samp in [sample(range(all_mer_dist.shape[0]), num_gene_inds)
                 for foo in range(num_sims)]:
        # calculate hypergeometric p-values for random samples
        sim_phypers.append(calc_p_hypers(all_mer_dist, samp, expr_gene_inds))

    def plot_lines(iax, plot_legend=True):
        for sim_phyp in sim_phypers:
            sim_p, = iax.plot(range(len(hyper_p_vals)), sorted(sim_phyp),
                             alpha=0.2, color='black')

        # plot both the filtered and unfiltered p-value plots
        p1, = iax.plot(range(len(hyper_p_vals)),
                      sorted(hyper_p_vals.values()), 'green')
        if plot_legend:
            iax.legend(
                [p1, sim_p],
                ['Gene Set', 'Simulated Sets'],
                loc='upper left', fontsize='small')

        return

    if VERBOSE: sys.stderr.write('plotting\n')
    fig, ax = plt.subplots()
    plt.title(gene_set_fn)
    plot_lines(ax)

    # create inset to zoom in on low p-values
    axinset = zoomed_inset_axes(ax, 7, loc=4)
    plot_lines(axinset, False)
    axinset.set_xlim(0, 1000)
    axinset.set_ylim(0, 0.0625)
    plt.xticks(visible=False)
    plt.yticks(visible=False)
    mark_inset(ax, axinset, loc1=2, loc2=3, fc="none", ec="k")

    pp.savefig()
    plt.clf()

    return

# this max p-value is the 1-percentile of the beta(1, 4^7)
# which is the expected minimum under the null of random
# hypergeometric p-values
def produce_memes(hyper_p_vals, save_dir, max_p_val=6.134236*10**-7, 
                  min_seqs=10, score_diff=float(50), max_uniq_seqs=50,
                  num_seqs=5):
    seqs_to_write = sorted((p, m) for m, p in hyper_p_vals.iteritems()
                           if p < max_p_val)
    if len(seqs_to_write) < min_seqs:
        if VERBOSE: sys.stderr.write('Too few signif. sequences in ' +
                                     save_dir + ' so MEME will not be run.\n')
        return

    if len(seqs_to_write) > max_uniq_seqs:
        seqs_to_write = seqs_to_write[:max_uniq_seqs]
    min_w = float((-log(max(seqs_to_write[-1][0], 10**-200), 10)) ** 0.75)
    max_w = float((-log(max(seqs_to_write[0][0], 10**-200), 10)) ** 0.75)
    def transform_p_val(p_val):
        t1_pv = (-log(max(p_val, 10**-200), 10)) ** 0.75
        return (int((t1_pv - min_w) * (score_diff - 1) /
                   float(max_w - min_w)) + 1) / score_diff

    with NamedTemporaryFile(bufsize=0) as motifsFasta:
        weights_str = '>WEIGHTS ' + ' '.join(
            (x for p in zip(*seqs_to_write)[0]
             for x in repeat("{0:.2f}".format(
                        transform_p_val(p)), num_seqs)))
        motifsFasta.write(weights_str + '\n')
        
        for p_val, mer_seq in seqs_to_write:
            for i in range(num_seqs):
                motifsFasta.write(
                    '>' + mer_seq.replace('U', 'T') + str(i) + '\n' +
                    ''.join((sample(('A', 'C', 'G', 'T'), 1)[0]
                             for j in range(8))) + mer_seq.replace('U', 'T') +
                    ''.join((sample(('A', 'C', 'G', 'T'), 1)[0]
                             for j in range(8))) + '\n')

        if VERBOSE: sys.stderr.write(
            'Running MEME on ' + str(len(seqs_to_write)) +
            ' unique kmers (5x sequences per kmer). Storing output here: ' + 
            save_dir + '\n')
        # run meme from command line
        subprocess.call(
            ['meme', motifsFasta.name, '-dna', '-nmotifs', '5', '-minw', '3',
             '-maxw', '11', '-mod', 'oops', '-oc ', save_dir, '-nostatus'])

    if VERBOSE: sys.stderr.write('Finishded MEME for ' + save_dir + '.\n\n')

    return

def calc_enrichments(mer_dist_fn, gene_set_fns, expr_genes_fn,
                     run_meme=True, do_sim_and_plot=True):
    if VERBOSE: sys.stderr.write('loading mer dist matrix\n')
    with open(mer_dist_fn) as fp:
        all_mer_dist, gene_indicies, mer_indicies = pickle.load(fp)

    if VERBOSE: sys.stderr.write('loading expressed genes\n')
    with open(expr_genes_fn) as fp:
        if SPLIT_GENES:
            expr_genes = [gene_id for line in fp 
                          for gene_id in line.strip().split('+')]
        else:
            expr_genes = [line.strip() for line in fp]

    if CONV_GENE_IDS:
        expr_gene_inds = [gene_indicies[g_id.replace('_', ':')]
                          for g_id in expr_genes]
    else:
        expr_gene_inds = [gene_indicies[g_id] for g_id in expr_genes]

    if do_sim_and_plot: pp = PdfPages('mer_distributions.pdf')
    for gene_set_fn in gene_set_fns:
        if VERBOSE: sys.stderr.write('loading gene set: ' + gene_set_fn + '\n')
        with open(gene_set_fn) as gene_set_fp:
            gene_set = [gene_id for line in gene_set_fp 
                        for gene_id in line.strip().split('+')]

        if VERBOSE: sys.stderr.write('calculating gene set p-vals\n')
        # calculate hypergeometric p-values for all k-mers
        if CONV_GENE_IDS:
            gene_inds = [gene_indicies[g_id.replace('_', ':')]
                         for g_id in gene_set]
        else:
            gene_inds = [gene_indicies[g_id]
                         for g_id in gene_set]
        # order kmer seqs by index into numpy array (essentially invert dict)
        ordered_mer_seqs = zip(*sorted(
                (ind, mer_seq) for mer_seq, ind in 
                mer_indicies.iteritems()))[1]
        hyper_p_vals = dict(zip(
                ordered_mer_seqs, calc_p_hypers(
                    all_mer_dist, gene_inds, expr_gene_inds)))

        if VERBOSE: sys.stderr.write('calculating corrections\n')
        # calculate corrected p-values and save to file
        corr_p_vals = dict(zip(
                hyper_p_vals.keys(),
                multipletests(hyper_p_vals.values(), method='fdr_bh')[1]))
        with open(gene_set_fn + '.hyper_p_vals.txt', 'w') as hypers_fp:
            for mer_seq in hyper_p_vals.iterkeys():
                hypers_fp.write('\t'.join(map(str, (
                                mer_seq, hyper_p_vals[mer_seq],
                                corr_p_vals[mer_seq]))) + '\n')

        if run_meme:
            produce_memes(hyper_p_vals, gene_set_fn + '.meme_out')

        if do_sim_and_plot:
            sim_and_plot(all_mer_dist, len(gene_inds), hyper_p_vals, 
                         gene_set_fn, expr_gene_inds, pp)

    if do_sim_and_plot: pp.close()

    return

def parse_arguments():
    global VERBOSE
    import argparse
    parser = argparse.ArgumentParser(
        description='Run .' )
    parser.add_argument('--mer-dist',
                        help='Pickled mer distribution with no filtering.')

    parser.add_argument('--gene-sets', nargs='+',
                        help='Lists of genes with one gene name per line.')
    parser.add_argument('--expr-genes',
                        help='Lists of genes which are expressed ' +
                        'in this experiment.')

    parser.add_argument('--verbose', '-v', default=False, action='store_true',
                        help='Whether or not to print status information.')
    args = parser.parse_args()

    VERBOSE = args.verbose

    return args.mer_dist, args.gene_sets, args.expr_genes

def main():
    mer_dist_fn, gene_set_fns, expr_genes_fn = parse_arguments()

    calc_enrichments(mer_dist_fn, gene_set_fns, expr_genes_fn)

    return

if __name__ == '__main__':
    main()
