#!/usr/bin/env python
import numpy as np
import pandas as pd
import argparse

import matplotlib
matplotlib.use('Agg')
font = {'size' : 8}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt

import matplotlib.patches as mpatches 
import matplotlib.colors as mcolors
from scipy.stats import sem 
# import brewer2mpl
import palettable.colorbrewer.sequential
from collections import defaultdict
from sys import argv

matrices = argv[1].split(',')
colors = argv[2].split(',')
gene = argv[3]
outplot = argv[4]

# matrices = ["GR.in.EP300_up_FDR_0.1.norm.log_geom_mean.npy",
#             "GR.in.EP300_static_FDR_0.2.norm.log_geom_mean.npy",
#             "GR.in.EP300_down_FDR_0.1.norm.log_geom_mean.npy"]
# colors = ['red','gray','green']
# gene = "ENSG00000113580.13"
# outplot = 'GR.in.EP300_up_down_static.norm.log_geom_mean.png'

expression = pd.read_csv("/data/reddylab/projects/GGR/data/rna_seq/quantified_read_counts/iter0_union/featurecounts.genes.TPM.selected_reps.protein_coding.txt", index_col=0, sep="\t")

means, stds = [],[]
for t in ['t00', 't05', 't1', 't2','t3','t4','t5','t6','t7','t8','t10','t12']:
    means.append(expression[[c for c in expression.columns if c.split('_')[0] == t]].ix[gene].mean())
    stds.append(expression[[c for c in expression.columns if c.split('_')[0] == t]].ix[gene].std())

timepoints = [ 0, 0.5, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12]

cs = {}
cs['blue'] = palettable.colorbrewer.sequential.Blues_9.mpl_colors[5]
cs['gray'] = palettable.colorbrewer.sequential.Greys_9.mpl_colors[5]
cs['green'] = palettable.colorbrewer.sequential.Greens_9.mpl_colors[5]
cs['orange'] = palettable.colorbrewer.sequential.Oranges_9.mpl_colors[5]
cs['purple'] = palettable.colorbrewer.sequential.Purples_9.mpl_colors[5]
cs['red'] = palettable.colorbrewer.sequential.Reds_9.mpl_colors[5]
         
try:
    colors = [cs[color] for color in colors]
except KeyError:
    pass

Ms = []
for M in matrices:
    M = np.load(M)
    cols_per_row = M.shape[1] / len(timepoints)
    M = [M[:,i:i+cols_per_row] for i in np.arange(0, M.shape[1], cols_per_row)]
    Ms.append(M)
    del M

x = np.arange(cols_per_row)

def plot_aggregate_means_and_stderrs(axes, M, color, ymin, ymax):
    for i, (M_sub, ax) in enumerate(zip(M, axes)):
        mean = M_sub.mean(axis=0)
        mean_sem = sem(M_sub)
        ax.plot(x, mean, color = color)
#         ax.fill_between(x, mean - mean_sem, mean + mean_sem, \
#                         facecolor=color, edgecolor=color, \
#                         alpha=0.5, lw=0.01)
        ax.set_xticks(([]))
        ax.set_yticks([])
        
        ax.spines['right'].set_color('none')
        ax.spines['top'].set_color('none')
        ax.spines['bottom'].set_color('none')
        
        ymax = ax.get_ylim()[1] if ax.get_ylim()[1] > ymax else ymax 
        ymin = ax.get_ylim()[0] if ax.get_ylim()[0] < ymin else ymin 
            
        #ax.set_ylim((ymin, ymax))
        
    return( (ymin, ymax ) )

ymin = np.inf
ymax = -np.inf

fig, axes = plt.subplots(ncols=12, figsize=(3,2))
        
# scale_factors = np.array([1 / max([y.mean(axis=0).max() for y in M]) for M in Ms])
# ymin = min(scale_factors * np.array([min([y.mean(axis=0).min() for y in M]) for M in Ms]))

for M, color in zip(Ms, colors): 
    ymin, ymax = plot_aggregate_means_and_stderrs(axes.flatten(), M, color, ymin, ymax)

ymin=15; ymax=120
for i, ax in enumerate(axes):
    ax.set_ylim((ymin, ymax))
    print ymin, ymax
    for y in np.arange(ymin, ymax + (ymax-ymin)/4, (ymax-ymin)/4):
        ax.axhline(y=y, c='black', alpha=0.25, zorder=999)
        ax.set_xticks([1000])
        ax.set_xticklabels([])
        ax.xaxis.set_tick_params(width=1)
        ax.xaxis.set_ticks_position('bottom')
    if i == 0:
        ax.set_yticks((ymin, ymax))
        ax.set_yticklabels([])
        ax.spines['left'].set_linewidth(2)
        ax.yaxis.set_ticks_position('left')
        ax.yaxis.set_tick_params(width=2)
        ax.spines['left'].set_position(('outward', 10)) 
    else:
        ax.spines['left'].set_visible(False)

ymin_,ymax_ = np.inf,-np.inf

axes_ = []
for i, ax in enumerate(axes):
    ax_ = ax.twinx()
    axes_.append(ax_)

for i, ax_ in enumerate(axes_):
    # plot TPM expression, mean and std. dev.
    ax_.scatter(1500, means[i], color='blue', s=10)
    (_, caps, _) = ax_.errorbar(1500, means[i], yerr=stds[i], color='blue', linewidth=2)
    # keep track of axis limits
    thisymin, thisymax = ax_.get_ylim()
    ymin_ = ymin_ if ymin_ < thisymin else thisymin
    ymax_ = ymax_ if ymax_ > thisymax else thisymax
    ax_.set_ylim(0,300)
    # change cap width
    for cap in caps:
        cap.set_markeredgewidth(1)
    
    # x-ticks:
    ax.set_xticks([1000])
    ax.set_xticklabels([])
    ax.xaxis.set_tick_params(width=1)
    ax.xaxis.set_ticks_position('bottom')
    # remove most spines
    ax_.spines['top'].set_visible(False)
    ax_.spines['bottom'].set_visible(False)
    ax_.spines['left'].set_visible(False)
    ax.spines['left'].set_color('none')
    ax.spines['top'].set_color('none')
    ax.spines['bottom'].set_color('none')
    if i != len(axes) - 1: # if not rightmost ax
        ax_.set_yticks([])
        ax_.set_yticklabels([])
        ax_.spines['right'].set_visible(False)
        ax_.yaxis.set_ticks_position('none')
    else:
        ax_.yaxis.set_ticks_position('right')

ax_.spines['right'].set_linewidth(2)
ax_.spines['right'].set_color('blue')
ax_.yaxis.set_ticks_position('right')
ax_.yaxis.set_tick_params(width=2, color='blue')
ax_.spines['right'].set_position(('outward', 10)) 

ax_.set_yticks([0,100,200,300])

plt.savefig(outplot)
plt.savefig(outplot.replace('.png','.pdf'))
