#!/usr/bin/env python
import matplotlib
matplotlib.use('Agg')
font = {'size'   : 8}
matplotlib.rc('font', **font)
from GGR import utils

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import defaultdict
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

Plots expression of one or multiple genes over the time course across one or multiple expression matrices.

""")

##################################################
# required args:

parser.add_argument("-i", "--mats", nargs='+', type=str,
                    help="""required, file paths to expression matrices
""", required=True)
parser.add_argument("-g", "--genes", nargs='+', type=str,
                    help="""required, name of genes for which to plot expression
""", required=True)
parser.add_argument("-o", "--plot", type=str,
                    help="""required, file path to plot output
""", required=True)

##################################################
# optional args:
parser.add_argument("--gene_names", nargs='+', type=str,
                    help="""optional, name of expression matrices
""")
parser.add_argument("--colors", nargs='+', type=str,
                    help="""optional, list of colors to plot.
if not indicated, then "rainbow" cmap is discretized into colors
""")
parser.add_argument("--cmap", type=str, default="tableau",
                    help="""optional, colormap to discretize into colors (default: %(default)s)
""")
parser.add_argument("--mat_names", nargs='+', type=str,
                    help="""optional, name of expression matrices
""")
parser.add_argument("--unit", type=str,
                    help="""optional, unit of expression (default: %(default)s)
""", default="TPM")
parser.add_argument("--xlabel", type=str,
                    help="""optional, x-axis label (default: %(default)s)
""", default='Time (hr)')
parser.add_argument("--ymin", help="optional, if expression is logFC", type=float)
parser.add_argument("--ymax", help="optional, if expression is logFC", type=float)

parser.add_argument("--logFC", help="optional, if expression is logFC", action='store_true')
parser.add_argument("--legend", help="optional, if a legend is desired", action='store_true')
parser.add_argument("--same_plot", help="optional", action='store_true')

##################################################

args = parser.parse_args()

if args.same_plot:
    # if requesting to plot the expression of different genes on the same plot,
    # then only one expression matrix can be plotted, otherwise there would
    # be too much information on a singular plot
    assert(len(args.mats) == 1)

##################################################
# set up colors

# use color blind-friendly colors
orange = (230/255.0,159/255.0,0)
skyblue = (86/255.0,180/255.0,233/255.0)
bluishgreen = (0,158/255.0,115/255.0)
yellow = (240/255.0,228/255.0,66/255.0)
blue = (0,114/255.0,178/255.0)
vermillion = (213/255.0,94/255.0,0)
reddishpurple = (204/255.0,121/255.0,167/255.0)

tableau20 = [(31/255., 119/255., 180/255.), 
             (174/255., 199/255., 232/255.),
             (255/255., 127/255., 14/255.),
             (255/255., 187/255., 120/255.),    
             (44/255., 160/255., 44/255.),
             (152/255., 223/255., 138/255.),
             (214/255., 39/255., 40/255.),
             (255/255., 152/255., 150/255.),    
             (148/255., 103/255., 189/255.),
             (197/255., 176/255., 213/255.),
             (140/255., 86/255., 75/255.),
             (196/255., 156/255., 148/255.),    
             (227/255., 119/255., 194/255.),
             (247/255., 182/255., 210/255.),
             (127/255., 127/255., 127/255.),
             (199/255., 199/255., 199/255.),    
             (188/255., 189/255., 34/255.),
             (219/255., 219/255., 141/255.),
             (23/255., 190/255., 207/255.),
             (158/255., 218/255., 229/255.)]    
tableau10 = tableau20[::2]

color_convert= {'orange':orange,
                'skyblue':skyblue,
                'bluishgreen':bluishgreen,
                'yellow':yellow,
                'blue':blue,
                'vermillion':vermillion,
                'reddishpurple':reddishpurple}

def color_converter(color):
    if color in color_convert:
        return color_convert[color]
    else:
        return color

def timepoint_conversion_to_numeric(t):
    hrs = 0
    convert_dict = {"h":1,"m":1/60.,"s":1/3600.}
    for i in range(0,len(t), 3):
        hrs += float(t[i:i+2]) * convert_dict[t[i+2]]
    
    if hrs == int(hrs):
        hrs = int(hrs)
    
    return hrs    
    
if args.colors:
    colors = [color_converter(color) for color in args.colors]
else:
    if args.cmap != "tableau":
        cmap = matplotlib.cm.get_cmap(args.cmap)
        if args.same_plot:
            colors = [cmap(v) for v in np.arange(0,1+1./len(args.genes),1./(len(args.genes)-1))][::-1]
        else:
            colors = [cmap(v) for v in np.arange(0,1+1./len(args.mats),1./(len(args.mats)-1))][::-1]
    else:        
        if args.same_plot:
            colors = tableau10[:len(args.genes)+1] if len(args.genes) <= 10 else tableau20[:len(args.genes)+1]
        else:
            colors = tableau10[:len(args.mats)+1] if len(args.mats) <= 10 else tableau20[:len(args.mats)+1]

##################################################

#if gene names are given, use for plot, otherwise title with gene ids
if args.gene_names:
    gene_names = args.gene_names
else:
    gene_names = args.genes

# import expression matrix (matrices)
mats = [pd.read_csv(mat, sep='\t', index_col=0) for mat in args.mats]

# link up sample names to time point in the case of multiple replicates
timepoint_to_timepoint_reps_list = []
for mat in mats:
    timepoint_to_timepoint_reps = defaultdict(list) 
    for timepoint_rep in list(mat.columns):
        try:
            timepoint_to_timepoint_reps[timepoint_rep.split(".")[-2]].append(timepoint_rep)
        except IndexError:
            timepoint_to_timepoint_reps[timepoint_rep].append(timepoint_rep)
    
    timepoint_to_timepoint_reps_list.append(timepoint_to_timepoint_reps)

# which time points are being tested?
timepoints_list = []
for timepoint_to_timepoint_reps in timepoint_to_timepoint_reps_list:
    timepoints_list.append(sorted(timepoint_to_timepoint_reps.keys()))

# set up mean and standard deviation dataframes
mat_means = [pd.DataFrame(index=list(mat.index)) for mat in mats]
mat_stds = [pd.DataFrame(index=list(mat.index)) for mat in mats]

# for each expression matrix and for each time point
# compute mean and std. dev. across replicates
for i, mat in enumerate(mats):
    for timepoint in timepoints_list[i]:
        mat_means[i][timepoint] = mat[timepoint_to_timepoint_reps_list[i][timepoint]].mean(axis=1)
        mat_stds[i][timepoint] = mat[timepoint_to_timepoint_reps_list[i][timepoint]].std(axis=1)

# convert time point to numeric
timepoints_numeric_list = []
for timepoints in timepoints_list:
    timepoints_numeric_list.append([timepoint_conversion_to_numeric(timepoint.split(".")[-1]) for timepoint in timepoints])

timepoints_numeric_overall = sorted(set([t for x in timepoints_numeric_list for t in x]))
    
nrows = len(args.genes) if not args.same_plot else 1
figsize = (3,2*len(args.genes)) if not args.same_plot else (6,4)

xmin = min(timepoints_numeric_overall)
xmax = max(timepoints_numeric_overall)

fig, axes = plt.subplots(ncols=1, nrows=nrows, figsize=figsize)

if args.same_plot:
    total_min = np.inf; total_max = -np.inf

for i, (site, gene_name) in enumerate(zip(args.genes, gene_names)):
    
    if not args.same_plot:
        total_min = np.inf; total_max = -np.inf
        ax = axes if len(args.genes) == 1 else axes[i]
        for j, color in enumerate(colors):
            
            mean_sub = list(mat_means[j].ix[site])
            std_sub = list(mat_stds[j].ix[site])
            this_min = (np.array(mean_sub) - np.array(std_sub)).min()
            this_max = (np.array(mean_sub) + np.array(std_sub)).max()
            
            total_min = this_min if this_min < total_min else total_min
            total_max = this_max if this_max > total_max else total_max
            
            label = args.mat_names[j] if args.mat_names else ""
            (_, caps, _) = ax.errorbar(timepoints_numeric_list[j], 
                                       mean_sub, 
                                       yerr=std_sub, 
                                       color=color_converter(color), 
                                       linewidth=2, label=label)
            
            for cap in caps:
                cap.set_markeredgewidth(2)
            
            ax.set_title(gene_name, fontsize=8)
            
        # if log fold change plotted, then expression may be negative
        # in which case draw an hline at 0, if expression crosses the x-axis
        if args.logFC:
            if np.sign(total_min) != np.sign(total_max):
                ax.axhline(0, lw=1, color="black", 
                           xmin=xmin, xmax=xmax)
        else:
            ax.set_ylim((max(0, ax.get_ylim()[0]), ax.get_ylim()[1]))
    else:
        ax = axes
        mean_sub = list(mat_means[0].ix[site])
        std_sub = list(mat_stds[0].ix[site])
        this_min = (np.array(mean_sub) - np.array(std_sub)).min()
        this_max = (np.array(mean_sub) + np.array(std_sub)).max()
        
        total_min = this_min if this_min < total_min else total_min
        total_max = this_max if this_max > total_max else total_max
        
        xmin = min(timepoints_numeric_list[0]) 
        xmax = max(timepoints_numeric_list[0]) 
        
        ax.plot(timepoints_numeric_overall, mean_sub, color=colors[i], linewidth=1, label=gene_name)
        (_, caps, _) = ax.errorbar(timepoints_numeric_overall, mean_sub, yerr=std_sub, color=colors[i], linewidth=1)
        
        for cap in caps:
            cap.set_markeredgewidth(2)
        
    ax.set_xlabel(args.xlabel)
    ax.set_ylabel(args.unit, color='black')
    utils.adjust_spines(ax, ['left', 'bottom'])
    ax.spines['left'].set_linewidth(1)
    ax.spines['bottom'].set_linewidth(1)
    ax.set_xticks(timepoints_numeric_overall)
    xticklabels=[]
    for t in timepoints_numeric_overall:
        if int(t) == t: 
            xticklabels.append("%0.0f"%t)
        elif "%0.1f"%t == t: 
            xticklabels.append("%0.1f"%t)
        else:
            xticklabels.append("%0.2f"%t)
    
    ax.set_xticklabels(xticklabels)
    ax.set_xlim((xmin-(max(timepoints_numeric_overall) - min(timepoints_numeric_overall))/20., 
                 xmax+(max(timepoints_numeric_overall) - min(timepoints_numeric_overall))/20.))

# if log fold change plotted, then expression may be negative
# in which case draw an hline at 0, if expression crosses the x-axis
if args.logFC:
    if np.sign(total_min) != np.sign(total_max):
        ax.axhline(0, lw=1, color="black", 
                   xmin=xmin, xmax=xmax)
else:
    ax.set_ylim((max(0, ax.get_ylim()[0]), ax.get_ylim()[1]))

if args.ymin:
    ax.set_ylim((args.ymin, ax.get_ylim()[1]))
if args.ymax:
    ax.set_ylim((ax.get_ylim()[0], args.ymax))

if args.legend:
    # Put a legend to the right side of the axis
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], fontsize=8, loc='center left', bbox_to_anchor=(1, 0.5), frameon=False)
    plt.savefig(args.plot, bbox_inches='tight')
else:
    plt.tight_layout()
    plt.savefig(args.plot)