#!/usr/bin/env python
import pandas as pd
import numpy as np
from reddylab_utils.reddylab_plotting_utils import discrete_cmap,adjust_spines,simple_barplot,simple_bar_with_err_plot
import matplotlib
matplotlib.use('Agg')
font = {'size'   : 8}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

Given a dataframe of proportions or counts, plot a barplot where
each column in the dataframe is represented by a different bar and each row 
is represented by a different subpanel row in the larger plot.

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args:

required.add_argument("-i", "--df", help="required, file path to dataframe", required=True)
required.add_argument("-o", "--outplot", help="required, path to outplot", required=True)

##################################################
# optional args:

optional.add_argument("--proportions", 
                      help="optional, if indicated, then dataframe contains normalized proportions",
                      action="store_true")
optional.add_argument("--title_rows", 
                      help="optional, if indicated, then row names used to title each subplot", 
                      action="store_true")
optional.add_argument("--ymax", type=float,
                      help="optional, ymax")
optional.add_argument("--ymin", type=float, default=0, 
                      help="optional, ymin (default: %(default)s)")
optional.add_argument("--ylabel", help="optional, ylabel")
optional.add_argument("--colors", nargs="+", default=["gray"],
                      help="Specify colors of barplots (default: %(default)s)") 
optional.add_argument("--cmap", 
                      help="""optional, name of matplotlib colormap for each row, see:
http://matplotlib.org/examples/color/colormaps_reference.html
if not given, then --color is used for all rows""")
optional.add_argument("--se", 
                      help="if --se indicated, second column of values represents standard error",
                      action="store_true") 

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

df = pd.read_csv(args.df, sep="\t", index_col=0)

x = np.arange(df.shape[1]) * 0.5 + 0.05

height = df.shape[0] if not args.se else df.shape[0]/2.
width = df.shape[1]

nrows=int(height)
ncols=1

fig, axes = plt.subplots(figsize=(1 + 0.5*width, 1 + 1*height), 
                         ncols=1, nrows=nrows, sharex=True)

if nrows == 1:
    axes = [axes]
    colors = [args.colors]
else:
    if args.cmap:
        colors = discrete_cmap(nrows, args.cmap)[::-1]
    else:
        colors = [args.colors] * nrows if len(args.colors) == 1 else args.colors        

# plot values in dataframe as barplots, one ax per row
for i,ax in enumerate(axes):
    if not args.se:
        simple_barplot(ax, np.array(df)[i,:], colors[i])
    else:
        simple_bar_with_err_plot(ax, np.array(df)[i*2,:], np.array(df)[i*2+1,:], colors[i])
    
    ymax = args.ymax if args.ymax else ax.get_ylim()[1]
    ymin = args.ymin if args.ymin else ax.get_ylim()[0]
    
    # refine plot attributes
    ax.set_ylim((ymin, ymax))
    ax.get_xlim()
    ax.yaxis.set_ticks_position('left')
    ax.spines['left'].set_position(('outward', 10)) 
    ax.xaxis.set_ticks_position('bottom')
    ax.spines['top'].set_color('none')
    ax.spines['right'].set_color('none')
    if i + 1 == len(axes):
        ax.get_xaxis().set_tick_params(direction='out', width=1)
    else:
        ax.xaxis.set_ticks_position('none')
    ax.set_xticks([])
    ax.set_xticklabels([])
    if args.title_rows:
        ax.set_title(df.index[i], fontsize=8)

if args.ylabel:
    if df.shape[0] == 1:
        ax.set_ylabel(args.ylabel, fontsize=8)
    else:
        fig.suptitle(args.ylabel, fontsize=8)

# label columns/bars
ax.set_xticks(x)
ax.set_xticklabels(df.columns, rotation=30, ha='right')

plt.savefig(args.outplot, bbox_inches='tight')
