#!/usr/bin/env python
import pandas as pd
import matplotlib.pyplot as plt
from reddylab_utils.reddylab_plotting_utils import simple_bar_with_err_plot,discrete_cmap
import numpy as np
from matplotlib import gridspec
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script takes a file of this format:

GO_term_name	OR_low_CI	OR	OR_high_CI	p_fdr_bh
chromosome organization	-6.394	-3.337	-1.467	9.79e-05
mRNA processing	-6.333	-3.278	-1.407	0.000105
lipid metabolic process	0.321	0.694	1.051	0.00397
small molecule metabolic process	0.257	0.569	0.870	0.00449
DNA metabolic process	-1.990	-1.166	-0.446	0.00648
...

and plots the ORs as bars with CIs where each bar is
colored by the significance (FDR).

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args:

required.add_argument("-i", "--df", help="required, path to input", required=True)
required.add_argument("-o", "--outplot",  help="required, path to output", required=True)

##################################################
# optional args:

optional.add_argument("--cmap", help="optional, sequential colormap [default=BuGn]", default='BuGn')

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

cmap = plt.get_cmap(args.cmap)

df = pd.read_csv(args.df, sep="\t", index_col=0)
df['neglog10_p_fdr_bh'] = -np.log10(df['p_fdr_bh'].values)
df = df.sort_values('OR', ascending=False)

# set minimal FDR at 0.05 and max at 10^-8
vmin = -np.log10(0.2)
vmax = -np.log10(10**-8)

colors = discrete_cmap(1000, cmap)[::-1]
span = np.linspace(vmin,vmax,1000)
    
colors = [colors[np.where(span == span[span<x][-1])[0][0]] 
          for x in df['neglog10_p_fdr_bh'].values]

fig = plt.figure(figsize=(len(df)/4.,6))

gs = gridspec.GridSpec(1, 2, width_ratios=[10, 1]) 
ax = plt.subplot(gs[0])

simple_bar_with_err_plot(ax, 
                         df['OR'], 
                         [df['OR'] - df['OR_low_CI'], df['OR_high_CI'] - df['OR']],
                         colors)

ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('outward', 10)) 
ax.xaxis.set_ticks_position('bottom')
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.set_xticks(np.arange(len(df))/2.+0.25)
ax.set_xticklabels(list(df.index), ha="right",rotation=45, fontsize=4)

cbar_ax = plt.subplot(gs[1])
cbar_ax.imshow(np.vstack(span), aspect='auto', cmap=cmap, origin="lower")
cbar_ax.set_xticks([])
cbar_ax.yaxis.set_ticks_position('right')
cbar_ax.yaxis.tick_right()

yticklabels = np.arange(2,9)
yticks = [np.where(span == span[span<=x][-1])[0][0] for x in np.arange(2,9).astype('float')]

cbar_ax.set_yticks(yticks)
cbar_ax.set_yticklabels(yticklabels)
plt.subplots_adjust(left=0.2,right=0.85)
plt.savefig(args.outplot)