#!/usr/bin/env python
import numpy as np
import pandas as pd
import argparse

import matplotlib
matplotlib.use('Agg')
font = {'size' : 8}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt

# import matplotlib.patches as mpatches 
from scipy.stats import sem 
import palettable.colorbrewer.sequential
from collections import defaultdict

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""
<describe expression_prediction script>

""")

##################################################
# required args:
parser.add_argument("--hstacked_arrays_across_timepoints", nargs='+', type=str,
                    help="""required, numpy array(s) of  binned read counts, 
e.g. DNase.npy H3K27ac.npy ...
""", dest="hstacked_arrays_across_timepoints", action='store', required=True)
parser.add_argument("-o", '--outplots', dest='outplots', nargs='+', type=str,
                    help='required, outplot', action='store', required=True)
parser.add_argument("-t", '--timepoints', dest='timepoints', nargs='+', type=str,
                    help='required, outplot', action='store', required=True)

##################################################
# optional args:
parser.add_argument("-c", "--clustered_indices", type=str,
                    help="""of the format:
        cluster
7454    0
800     0
...     ...
450     10
300     10
1200    11

where cluster numbers split epigenetic profile into separate plots by cluster
and indices indicate sorted order of epigenetic profile for each plot.
(for example, index may be pre-sorted by sum of read counts 
or by strength of motif match)
    
""", dest='clustered_indices', action='store')

parser.add_argument("--names", nargs='+', type=str, action='store',
                    help="optional, names of --hstacked_arrays_across_timepoints in corresponding order, e.g. DNase H3K27ac ...")
parser.add_argument("--colors", nargs='+', type=str, action='store',
help="""optional, colors of the datasets, must be among the following choices:
purple purple-blue blue blue-green green orange orange-red red gray

""")
parser.add_argument("--colormap", type=str, action='store',
                    help="""optional, colormap for all datasets, must be among the following choices:
purple purple-blue blue blue-green green orange orange-red red gray""")
parser.add_argument("--delta_from_t00", action='store_true',
                    help="optional, ...")
parser.add_argument("--colorbar", action='store_true',
                    help="optional, ...")
parser.add_argument("-i", "--indices", type=str,
                    help="""of the format:
7454
800
...
450
300
1200

indices indicate sorted order of epigenetic profile for each plot.
(for example, index may be pre-sorted by sum of read counts 
or by strength of motif match)
    
""", dest='indices', action='store')

parser.add_argument("--sort_by_sum", 
                    help='sort sites by sum read count across all bins and all timepoints and all datasets', 
                    dest='sort_by_sum', action='store_true')
parser.add_argument("--save_indices_to_file", 
                    help='if --sort_by_sum, this will save the indices if needed for analysis or future plotting')

parser.add_argument("--vmaxs", nargs='+', type=float,
                    help="optional, vmaxs for imshow")
parser.add_argument("--vmins", nargs='+', type=float,
                    help="optional, vmins for imshow")
parser.add_argument("--ymax", type=float,
                    help="optional, ymax for aggregate profile plot")
parser.add_argument("--ymin", type=float,
                    help="optional, ymin for aggregate profile plot")

# parser.add_argument("--colors", nargs='+', type=str, help="", dest="colors", action='store')
# parser.add_argument("--cmaps", nargs='+', type=str, help="", dest="cmaps", action='store')

##################################################
# test block  

# args = type("args", (object,), {})()
# args.hstacked_arrays_across_timepoints = \
# ['/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/GR.in.EP300_down_FDR_0.1.rpm.delta_from_t00.npy',
#  '/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/EP300.in.EP300_down_FDR_0.1.rpm.delta_from_t00.npy',
#  '/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/DNase.in.EP300_down_FDR_0.1.rpm.delta_from_t00.npy', 
#  '/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/H3K27ac.in.EP300_down_FDR_0.1.rpm.delta_from_t00.npy',
#  '/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/H3K4me1.in.EP300_down_FDR_0.1.rpm.delta_from_t00.npy',
#  '/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/H3K4me2.in.EP300_down_FDR_0.1.rpm.delta_from_t00.npy']
# args.names = ['GR','p300','DNase','H3K27ac','H3K4me1','H3K4me2']
# args.colors = ['purple','green','blue','red','orange-red','orange']
# args.delta_from_t00 = True
# args.colorbar = True
# args.outplots = ['/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/epigenetic_read_counts.in.EP300_down_FDR_0.1.rpm.delta_from_t00.png',
# '/data/reddylab/projects/GGR/results/integrative/enhancers_by_EP300_analysis/iter0/tf_binding/epigenetic_read_counts.in.EP300_down_FDR_0.1.rpm.delta_from_t00.pdf']
# args.timepoints = [0.5, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12]
# args.sort_by_sum = True
# args.clustered_indices=False
# args.indices = False
# args.vmins = False
# args.vmaxs = False
# args.colormap = False

#################################################

args = parser.parse_args()

N_datasets = len(args.hstacked_arrays_across_timepoints)

##########
# set up color schemes
##########

# default cmaps, should be changed?
cmaps = {}
cmaps['purple'] = palettable.colorbrewer.sequential.Purples_9.mpl_colormap
cmaps['purple-blue'] = palettable.colorbrewer.sequential.PuBu_9.mpl_colormap
cmaps['blue'] = palettable.colorbrewer.sequential.Blues_9.mpl_colormap
cmaps['blue-green'] = palettable.colorbrewer.sequential.BuGn_9.mpl_colormap
cmaps['green'] = palettable.colorbrewer.sequential.Greens_9.mpl_colormap
cmaps['orange'] = palettable.colorbrewer.sequential.Oranges_9.mpl_colormap
cmaps['orange-red'] = palettable.colorbrewer.sequential.OrRd_9.mpl_colormap
cmaps['red'] = palettable.colorbrewer.sequential.Reds_9.mpl_colormap
cmaps['gray'] = palettable.colorbrewer.sequential.Greys_9.mpl_colormap

colors = {}
colors['purple'] = palettable.colorbrewer.sequential.Purples_9.mpl_colors[5]
colors['purple-blue'] = palettable.colorbrewer.sequential.PuBu_9.mpl_colors[5]
colors['blue'] = palettable.colorbrewer.sequential.Blues_9.mpl_colors[5]
colors['blue-green'] = palettable.colorbrewer.sequential.BuGn_9.mpl_colors[5]
colors['green'] = palettable.colorbrewer.sequential.Greens_9.mpl_colors[5]
colors['orange'] = palettable.colorbrewer.sequential.Oranges_9.mpl_colors[5]
colors['orange-red'] = palettable.colorbrewer.sequential.OrRd_9.mpl_colors[5]
colors['red'] = palettable.colorbrewer.sequential.Reds_9.mpl_colors[5]
colors['gray'] = palettable.colorbrewer.sequential.Greys_9.mpl_colors[5]

if args.colors:
    cmaps = [cmaps[color] for color in args.colors]
    colors = [colors[color] for color in args.colors]
else:
    # default colors
    cmaps = [cmaps['blue'], cmaps['green'], cmaps['orange'], cmaps['purple'], cmaps['red'], cmaps['gray'], cmaps['purple-blue'], cmaps['orange-red'], cmaps['blue-green']][:N_datasets]
    colors = [colors['blue'], colors['green'], colors['orange'], colors['purple'], colors['red'], colors['gray'], colors['purple-blue'], colors['orange-red'], colors['blue-green']][:N_datasets]

##########
# if delta_from_t00, use a default colormap
# which ensures "zero-balance"
##########

if args.delta_from_t00 or args.colormap == "BlOr":
    from palettable.cubehelix import Cubehelix
    import matplotlib.colors as mcolors
#     negative = Cubehelix.make(start=0.25, rotation=-0.25, reverse=True, max_light=0.75)
#     positive = Cubehelix.make(start=0.75, rotation=0.25, max_light=0.75)
    negative = Cubehelix.make(start=0.25, rotation=-0.25, sat=1.5, reverse=True, max_light=0.8)
    positive = Cubehelix.make(start=0.75, rotation=0.25, sat=1.5, max_light=0.8)    # sample the colormaps that you want to use. Use 128 from each so we get 256
    # colors in total
    negative_cols = negative.mpl_colormap(np.linspace(0., 1, 128))
    positive_cols = positive.mpl_colormap(np.linspace(0, 1, 128))
    # combine them and build a new colormap
    BlOr = mcolors.LinearSegmentedColormap.from_list('BlOr', np.vstack((negative_cols, positive_cols)))

##########
# read indices, if applicable
##########

if args.clustered_indices and args.indices:
    print "ERROR: please provide --clustered_indices OR --indices, not both"
    exit(1)
elif args.clustered_indices:
    # read clustered_index file
    clustered_indices = pd.read_csv(args.clustered_indices, 
                                    delim_whitespace=True, 
                                    index_col=0)
    # convert into dictionary
    indices_by_cluster = defaultdict(list)
    for index, cluster in zip(list(clustered_indices.index),
                              list(clustered_indices['cluster'])):
        indices_by_cluster[cluster].append(index)
elif args.indices:
    with open(args.indices, 'r') as f:
        indices =  list([int(line.strip()) for line in f])
else:
    pass

##########
# read each of the epigenetic read count profiles
# and split arrays by timepoint
##########

Ms = []
for M in args.hstacked_arrays_across_timepoints:
    M = np.load(M)
    cols_per_row = M.shape[1] / len(args.timepoints)
    M = [M[:,i:i+cols_per_row] for i in np.arange(0, M.shape[1], cols_per_row)]
    Ms.append(M)
    del M

##########
# split read count profiles by cluster and/or order by indices
##########

Ms_clustered = defaultdict(list)
if args.clustered_indices:
    for cluster, indices in indices_by_cluster.iteritems():
        for M in Ms:
            M_subs = [M_sub[np.array(indices),:] for M_sub in M]
            Ms_clustered[cluster].append(M_subs)
            del M_subs
elif args.indices:
    for M in Ms:
        M_subs = [M_sub[np.array(indices),:] for M_sub in M]
        Ms_clustered['all'].append(M_subs)
        del M_subs
else:
    Ms_clustered['all'] = Ms

del Ms

##########
# Plot epigenetic read count profiles
##########

# create x array for each ax
x = np.arange(cols_per_row)

def plot_hstacked_read_vector(axes, M, vmin, vmax, cmap, idx):
    
    for i, (ax, mat) in enumerate(zip(axes, M)):
        
        if args.sort_by_sum:
            mat = mat[idx,:]
        
        img = ax.imshow(mat,
                        aspect='auto',
                        interpolation=None,
                        origin='upper',
                        vmin=vmin,
                        vmax=vmax,
                        cmap=cmap,
                        extent=[-1000, 1000, mat.shape[0], 0])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['left'].set_color('none')
        ax.spines['right'].set_color('none')
        ax.spines['top'].set_color('none')
        ax.spines['bottom'].set_color('none')
        
    return(img)


def plot_aggregate_means_and_stderrs(axes, M, color, ymin, ymax):
    for i, (M_sub, ax) in enumerate(zip(M, axes)):
        mean = M_sub.mean(axis=0)
        mean_sem = sem(M_sub)
        ax.plot(x, mean, color = color)
        ax.fill_between(x, mean - mean_sem, mean + mean_sem, \
                        facecolor=color, edgecolor=color, \
                        alpha=0.5, lw=0.01)
        ax.set_xticks(([]))
        ax.set_yticks([])
#         buff = (ymax - ymin) / 25.
        buff = 0
        if args.delta_from_t00 or args.ymin:
            ax.set_ylim((ymin - buff, ymax + buff))
        else:
            ax.set_ylim((0, ymax + buff))
        
        ax.spines['right'].set_color('none')
        ax.spines['top'].set_color('none')
        ax.spines['bottom'].set_color('none')
        if i == 0:
            ax.set_ylabel('Norm. signal', fontsize=8)
            if args.delta_from_t00:
                ax.set_yticks([ymin - buff, 0,  ymax + buff])
            else:
                ax.set_yticks([0, ymax+buff])
        else:
            ax.spines['left'].set_color('none')

def plot_wrapper(Ms, N_datasets, timepoints, colors, cmaps, outplots):
    
    nrows = 1 + N_datasets # 1 extra for profile plot at top
    ncols = len(timepoints)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8,8))
    
    mean_axes = axes[0] if ncols > 1 else [axes[0]] # first row of axes is profile plot
    
    if not args.ymax:
        ymax = max(np.array([max([y.mean(axis=0).max() for y in M]) for M in Ms]))
    else:
        ymax = args.ymax
    
    if not args.ymin:
        ymin = min(np.array([min([y.mean(axis=0).min() for y in M]) for M in Ms]))
    else:
        ymin = args.ymin
        
    for M, color in zip(Ms, colors): 
        plot_aggregate_means_and_stderrs(mean_axes, M, color, ymin, ymax)
    
    vector_axes_multiple = axes[1:]
    
    if args.sort_by_sum:
        idx = np.hstack([np.hstack(M) / np.hstack(M).sum() for M in Ms]).sum(axis=1).argsort()[::-1]
    else:
        idx = np.arange(len(Ms[0][0]))
    
    if args.save_indices_to_file:
        with open(args.save_indices_to_file, "w") as f:
            f.write('\n'.join([str(i) for i in list(idx)]) + "\n")
    
    for i, (M, cmap, vector_axes) in enumerate(zip(Ms, cmaps, vector_axes_multiple)):
        
        vector_axes = np.array([vector_axes]) if ncols == 1 else vector_axes
        
        if args.vmins:
            vmin = args.vmins[i] 
        else:
            vmin = np.percentile(np.array(M).flatten(), 5)
        if args.vmaxs:
            vmax = args.vmaxs[i] 
        else:
            vmax = np.percentile(np.array(M).flatten(), 95)
        
        cbar_ticks = [vmin, vmax]
        if args.delta_from_t00:
            vmax = max(abs(vmin), abs(vmax))
            vmin = -vmax
            cbar_ticks = [vmin, 0, vmax]
            
        if args.colormap:
            cmap = args.colormap
            if cmap == "cubehelix":
                from palettable.cubehelix import jim_special_16
                cmap = jim_special_16.mpl_colormap
            elif cmap == "BlOr":
                cmap = BlOr
            
        if args.delta_from_t00:
            cmap = BlOr
            
        img = plot_hstacked_read_vector(vector_axes, M, vmin, vmax, cmap, idx)
        if args.colorbar:            
            fig.colorbar(img, ax=vector_axes.ravel().tolist(), ticks=cbar_ticks)
        
    if args.names:
        for i, name in zip(range(vector_axes_multiple.shape[0]), args.names):
            if ncols > 1:
                vector_axes_multiple[i,0].set_ylabel(name, fontsize=8)
            else:
                vector_axes_multiple[i].set_ylabel(name, fontsize=8)
        
    for ax, timepoint in zip(vector_axes, timepoints):
        ax.set_xlabel(timepoint, fontsize=8)
    
    fig.suptitle('Duration of dex exposure (hrs)', y=0.06, fontsize=8)
    plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1)
    for outplot in outplots:
        plt.savefig(outplot)

if args.clustered_indices:
    for cluster in indices_by_cluster.keys():
        for outplot in args.outplots:
            plot_format = outplot.split('.')[-1]
            plot_wrapper(Ms_clustered[cluster], N_datasets, args.timepoints, colors, cmaps, outplot.replace(plot_format, 'cluster_%s.%s'%(cluster,plot_format)))
else:
    plot_wrapper(Ms_clustered['all'], N_datasets, args.timepoints, colors, cmaps, args.outplots)