#!/usr/bin/env python
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

import pandas as pd
from scipy.stats import fisher_exact
import numpy as np
from GGR import utils
import argparse
import operator

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script plots GO enrichment results. [Describe in more depth here.]

""")

##################################################
# required args:
parser.add_argument("-i", "--GO", nargs='+', type=str,
                    help="""required, file paths to goatools enrichment files.
                    Typically, input may be results from, e.g. up-reg, dn-reg, and all DEGs
                    or input may be different gene expression clusters. Results can be 
                    distinguished in plot by --colors.
                    """, action='store', required=True)

parser.add_argument("-n", '--names', nargs='+', type=str,
                    help='required, names corresponding (in order) to GO enrichment results.',
                    action='store', required=True)

parser.add_argument("-c", '--colors', nargs='+', type=str,
                    help='required, colors corresponding (in order) to GO enrichment results.',
                    action='store', required=True)

parser.add_argument("-o", '--outplot', type=str,
                    help='required, outplot', 
                    action='store', required=True)

##################################################
# optional args:

parser.add_argument("--FDR", type=float, default=.1,
                    help="""optional, FDR of enrichment results [default=0.1].
                    Used to draw a line across the plot at the level of accepted significance.
                    """, dest="FDR", action='store')

###################################################

args = parser.parse_args()

tableau20 = [(31/255., 119/255., 180/255.), 
             (174/255., 199/255., 232/255.),
             (255/255., 127/255., 14/255.),
             (255/255., 187/255., 120/255.),    
             (44/255., 160/255., 44/255.),
             (152/255., 223/255., 138/255.),
             (214/255., 39/255., 40/255.),
             (255/255., 152/255., 150/255.),    
             (148/255., 103/255., 189/255.),
             (197/255., 176/255., 213/255.),
             (140/255., 86/255., 75/255.),
             (196/255., 156/255., 148/255.),    
             (227/255., 119/255., 194/255.),
             (247/255., 182/255., 210/255.),
             (127/255., 127/255., 127/255.),
             (199/255., 199/255., 199/255.),    
             (188/255., 189/255., 34/255.),
             (219/255., 219/255., 141/255.),
             (23/255., 190/255., 207/255.),
             (158/255., 218/255., 229/255.)]    
tableau10 = tableau20[::2]

def process_GO_results(df):
    ID_to_attributes = {}
    for ID, enrichment, name, p_fdr in zip(df['GO_term_ID'], df['Enrichment'], \
                                           df['GO_term_name'], df['p_fdr_bh']):
        p_fdr = float(p_fdr)
        ID_to_attributes[ID] = (name, enrichment, p_fdr)
        
    return(ID_to_attributes)

GOs = [process_GO_results(pd.read_csv(GO, sep='\t')) for GO in args.GO]

# if requested colors is actually a colormap, then
# segment cmap into discrete colors
if len(args.colors) == 1 and len(GOs) > 1:
    if args.colors[0] == 'tableau':
        colors = tableau10[:len(GOs)+1] if len(GOs) <= 10 else tableau20[:len(GOs)+1]
    else:
        cmap = plt.get_cmap(args.colors[0])
        colors = [cmap(i) for i in np.linspace(0, 1, len(GOs))]
else:
    colors = args.colors
    

# find terms with significant enrichment/depletion and rank
sig_terms = []; term_to_min_FDR = {}; term_to_name = {}
all_terms = set([term for GO in GOs for term in GO.keys()])
for term in sorted(all_terms):
    for GO in GOs:
        FDR = GO[term][-1]
        if ( FDR < args.FDR ):
            sig_terms.append(term)
            if term in term_to_min_FDR:
                if FDR < term_to_min_FDR[term]:
                    term_to_min_FDR[term] = FDR
            else:
                term_to_min_FDR[term] = FDR
        
        term_to_name[term] = GO[term][0]

# rank terms by significance of enrichment/depletion
FDRs = [FDR for term,FDR in sorted(term_to_min_FDR.items(), key=operator.itemgetter(1))]
ranked_sig_terms = [term for term,FDR in sorted(term_to_min_FDR.items(), key=operator.itemgetter(1))]

ys = []
for GO in GOs:
    y = []
    for term in ranked_sig_terms: 
        if term not in GO:
            y.append(np.nan)
        else:
            if GO[term][1] == 'e':
                y.append(-np.log10(GO[term][-1]))
            else:
                y.append(np.log10(GO[term][-1]))
    
    ys.append(y)

term_names = [term_to_name[term] for term in ranked_sig_terms]

x = np.arange(len(ranked_sig_terms))

def common_ax_cmds(ax, x, term_names):
    ax.set_xlim((x.min()-0.5, x.max()+0.5))
    utils.adjust_spines(ax, ['left'])
    ax.yaxis.set_tick_params(width=2)
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(10) 
        
    ax.spines['left'].set_linewidth(2)
    ax.set_xticks(x+0.25)
    ax.set_xticklabels(term_names, rotation=45, ha='right')
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(8) 
    
    ax.xaxis.set_ticks_position('none') 
    ax.set_ylabel('log10 FDR\n(enrichment > 0, depletion < 0)', fontsize=10)
    ax.set_xlabel('Gene ontology slim term')
    
    plt.axhline(-np.log10(args.FDR), ls='--', color='black', alpha=0.5, label='FDR cut-off')
    plt.axhline( np.log10(args.FDR), ls='--', color='black', alpha=0.5)
    plt.axhline(0, color='black')
    
    # Shrink current axis by 30%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])
    # Put a legend to the right of the current axis
    lgd = ax.legend(loc='center left', frameon=False, fontsize=10, scatterpoints=1, bbox_to_anchor=(1, 0.5))
    
#     ax.legend(loc='lower right', frameon=False, fontsize=10, scatterpoints=1)
    
    ymin, ymax = ax.get_ylim()
    yspan = ymax - ymin
    colors = ['gray','white'] * ( len(x)+1/2 )
    for xstart, color in zip( x, colors ):
        rect = Rectangle((xstart - 0.5, ymin), 1, yspan, color=color, alpha=0.1, zorder=3)
        ax.add_patch(rect)
    
    return(lgd)

if len(x) < 35:
    fig,ax = plt.subplots(figsize=(6,6))
    
    for i,y in enumerate(ys):
        ax.scatter(x, y, color=colors[i], label=args.names[i])
    
    lgd = common_ax_cmds(ax, x, term_names)
    plt.savefig(args.outplot, bbox_extra_artists=(lgd,), bbox_inches='tight')
else:
    xs = [x[:len(x)/2], x[len(x)/2:]]
    ys_split = [(y[:len(x)/2], y[len(x)/2:]) for y in ys]
    term_names_split = [term_names[:len(x)/2], term_names[len(x)/2:]]
    
    for j, x in enumerate(xs): 
        fig, ax = plt.subplots(figsize=(6,6))
        
        for i,y in enumerate(ys):
            ax.scatter(x, ys_split[i][j], color=colors[i], label=args.names[i])
                
        lgd = common_ax_cmds(ax, x, term_names_split[j])
        
        this_outplot = '.'.join(args.outplot.split('.')[:-1]) + '.' + str(j+1) + '.' + args.outplot.split('.')[-1]
        plt.savefig(this_outplot, bbox_extra_artists=(lgd,), bbox_inches='tight')
        print 'Output saved to:', this_outplot