#!/usr/bin/env python
import pandas as pd
import numpy as np
import argparse
from GGR import utils

import matplotlib
matplotlib.use('Agg')
font = {'size' : 10}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt
from GGR import utils
import scikits.bootstrap
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

heatmap_median_logFC.py

Given a dataframe, 

""")
##################################################
parser.add_argument("--dfs", nargs='+',help="required, dataframe(s)", dest="dfs", required=True)
parser.add_argument("--names", nargs='+',help="required, names(s) corresponding to each dataframe", required=True)
parser.add_argument("--indices", nargs='+',help="optional, indices with which to split dataframe(s)")
parser.add_argument("--indices_to_compare", nargs='+',help="optional, indices with which to split dataframe(s)")
parser.add_argument("--vmax", type=float, dest="vmax", action='store')
parser.add_argument("--vmin", type=float, dest="vmin", action='store')
parser.add_argument("--colorbar", dest="colorbar", action='store_true')
parser.add_argument("--outplots",nargs='+', help="required, output plot", dest="outplots", action='store', required=True)

# parser.add_argument("--names", type=str, nargs='+', help="optional, names of groups", dest="names", action='store', default="gray")
# parser.add_argument("--colors", type=str, nargs='+', help="optional, colors of box(es)", dest="colors", action='store', default="gray")
# parser.add_argument("--ylabel", type=str, default="Log10 norm. counts", dest="ylabel", action='store')
# parser.add_argument("--xlabel", type=str, default="Time (hrs) of dex exposure", dest="xlabel", action='store')
# parser.add_argument("--not_timepoints", dest="not_timepoints", action='store_true')

# parser.add_argument("--index", type=str, help="""required, indices where each line is indexes a 
# value in the first (index) column of the dataframe
# """,dest="index", action='store', required=True)

args = parser.parse_args()

##################################################
# TEST BLOCK
###################################### 

# args = type("args", (object,), {})()
# args.dfs = ["/data/reddylab/projects/GGR/results/integrative/mapping_samples_across_factors/iter0/differential_binding/DNase.in.EP300.logFC.vs.t00.txt"]
# args.names = ["DNase"]
# args.indices_to_compare = \
# ["FOSL2.in.enhancers.strong.txt",
# "FOSL2.in.enhancers.none.txt"]
# args.outplots = ['test.heatmap.png']
# args.vmin=None
# args.vmax=None
# args.colorbar=True

###################################### 

def make_colormap(seq):
    """Return a LinearSegmentedColormap
    seq: a sequence of floats and RGB-tuples. The floats should be increasing
    and in the interval (0,1).
    """
    seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]
    cdict = {'red': [], 'green': [], 'blue': []}
    for i, item in enumerate(seq):
        if isinstance(item, float):
            r1, g1, b1 = seq[i - 1]
            r2, g2, b2 = seq[i + 1]
            cdict['red'].append([item, r1, r2])
            cdict['green'].append([item, g1, g2])
            cdict['blue'].append([item, b1, b2])
    return mcolors.LinearSegmentedColormap('CustomMap', cdict)

c = mcolors.ColorConverter().to_rgb
BlOr = make_colormap([c('blue'), c('white'), 0.5, c('white'), c('orange')])
# BlOr = make_colormap([c('white'), c('orange')])

###################################### 


X = []
if args.indices:
    for indices in args.indices:
        indices = utils.read_file_return_set(indices) 
        for df, name in zip(args.dfs, args.names):
            X.append(np.median(np.array(pd.read_csv(df, index_col=0, sep="\t").ix[indices]), axis=0))

if args.indices_to_compare:
    indices_change = utils.read_file_return_set(args.indices_to_compare[0]) 
    indices_base = utils.read_file_return_set(args.indices_to_compare[1])
    for df, name in zip(args.dfs, args.names):
        X_change = np.nanmedian(np.array(pd.read_csv(df, index_col=0, sep="\t").ix[indices_change]), axis=0)
        X_base =   np.nanmedian(np.array(pd.read_csv(df, index_col=0, sep="\t").ix[indices_base]), axis=0)
#         print name, X_change - X_base
        X.append(X_change - X_base)

X = np.vstack(X)

########################################

print args.vmin

if not args.vmin and args.vmin != 0 :
    vmin = -max(abs(np.percentile(X.flatten(),5)), np.percentile(X.flatten(),95))
else:
    vmin = args.vmin

if not args.vmax:
    vmax = max(abs(np.percentile(X.flatten(),5)), np.percentile(X.flatten(),95))
else:
    vmax = args.vmax

    
X = X[::-1]

print vmin

fig = plt.figure(figsize=(3,3))

if args.indices:
    gs = gridspec.GridSpec(len(args.indices), 1, wspace=0.0, hspace=0.0)
#     ax = plt.subplot(gs[0])
    axes = []
    for i in range(len(args.indices)):
        axes.append(plt.subplot(gs[i]))
else:
    fig,ax = plt.subplots(figsize=(3,3))

if args.indices:
    for i in range(len(args.indices)):
        X_sub = X[len(X) - len(args.names)*(i+1):len(X) - len(args.names)*i,:]
        im  = axes[i].matshow(X_sub, 
                    aspect='auto',
                    origin='lower', 
                    cmap=BlOr, 
                    vmin=vmin, 
                    vmax=vmax)
        
        axes[i].set_xticks([])
        axes[i].xaxis.set_ticks_position('none')
        axes[i].set_yticks(np.arange(len(X_sub)))
        axes[i].set_yticklabels(args.names[::-1] * (len(X_sub) / len(args.names)))
        axes[i].yaxis.set_ticks_position('none')
        axes[i].set_aspect("equal")
else:
    im = ax.matshow(X, 
                    aspect='auto',
                    origin='lower', 
                    cmap=BlOr, 
                    vmin=vmin, 
                    vmax=vmax)
    
    ax.set_xticks([])
    ax.xaxis.set_ticks_position('none')
    ax.set_yticks(np.arange(len(X)))
    ax.set_yticklabels(args.names[::-1])
#     ax.set_yticklabels(args.names[::-1] * (len(X) / len(args.names)))
    ax.yaxis.set_ticks_position('none')
    ax.set_aspect("equal")

if args.colorbar:
    fig.subplots_adjust(right=0.85, bottom=0.15)
    cbar_ax = fig.add_axes([0.86, 0.15, 0.05, 0.75])
    cb = fig.colorbar(im, cax=cbar_ax)
    cb.ax.tick_params(labelsize=10)

for outplot in args.outplots:
    plt.savefig(outplot)
