#!/usr/bin/env python
import pandas as pd
import numpy as np
import seaborn.apionly as sns
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

plot_jointplot.py

Given two simple dataframes of the format:

site1	x1
site2	x2

and

site1	y1
site2	y2

plot a joint distribution of x vs. y.

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args:

required.add_argument("-x", type=str, help="required, dataframe", dest="x_df", required=True)
required.add_argument("-y", type=str, help="required, dataframe", dest="y_df", required=True)
required.add_argument("--outplot", type=str, help="required, output plot", required=True)

##################################################
# optional args:

optional.add_argument("--x_name", type=str, default="x", 
                      help="optional, x-value name" )
optional.add_argument("--y_name", type=str, default="y", 
                      help="optional, y-value name" )
optional.add_argument("--plot_regression_line", action="store_true",
                    help="optional, add regression line to jointplot" )
optional.add_argument("--plot_pc1", 
                    help="plot the direction of the first principal component", \
                    action='store_true')


optional.add_argument("--xmin", type=float, 
                      help="optional, x-axis minimum" )
optional.add_argument("--xmax", type=float, 
                      help="optional, x-axis maximum" )
optional.add_argument("--ymin", type=float, 
                      help="optional, y-axis minimum" )
optional.add_argument("--ymax", type=float, 
                      help="optional, y-axis maximum" )

optional.add_argument("--no_header", 
                    help="if file does not contain a header, then indicate with flag", \
                    action='store_true')
optional.add_argument("--colorbar", 
                    help="if file does not contain a header, then indicate with flag", \
                    action='store_true')

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

header = 'infer' if not args.no_header else None

x_df = pd.read_csv(args.x_df, sep="\t", index_col=0, header=header)
y_df = pd.read_csv(args.y_df, sep="\t", index_col=0, header=header)

common_indices = sorted(set(x_df.index) & set(y_df.index))

x = x_df.ix[common_indices].values.flatten()
y = y_df.ix[common_indices].values.flatten()

data = pd.DataFrame({args.x_name:x, args.y_name:y})

xmin = args.xmin if args.xmin else None
xmax = args.xmax if args.xmax else None
ymin = args.ymin if args.ymin else None
ymax = args.ymax if args.ymax else None

xlim = (xmin,xmax) if xmin != None or xmax != None else None
ylim = (ymin,ymax) if ymin != None or ymax != None else None


fig,ax = plt.subplots()

g = sns.jointplot(x=args.x_name, y=args.y_name, data=data, \
                  xlim=xlim, ylim=ylim, stat_func=None,
                  kind="kde", space=0, color="black")

cmap = sns.cubehelix_palette(start=0.3, rot=-.5, dark=0, light=1, reverse=True, as_cmap=True)
if args.colorbar:
    cax = g.fig.add_axes([.85, .25, .05, .4])  # x, y, width, height
    sns.kdeplot(data[args.x_name], data[args.y_name], cmap=cmap, 
                n_levels=60, shade=True, ax=g.ax_joint, cbar=True, cbar_ax=cax)
    cax_yticklabel_max = float(cax.get_yticklabels()[-1].get_text())/cax.get_yticks()[-1]
    cax.yaxis.set_ticks(np.linspace(0,1,4))
    cax.set_yticklabels(np.linspace(0,cax_yticklabel_max,4))
else:
    sns.kdeplot(data[args.x_name], data[args.y_name], cmap=cmap, 
                n_levels=60, shade=True, ax=g.ax_joint)

if args.plot_regression_line:
    sns.regplot(x=args.x_name, y=args.y_name, data=data, \
                ax=g.ax_joint, scatter=False)

ax.set_xlabel(args.x_name)
ax.set_ylabel(args.y_name)
g.ax_marg_x.xaxis.set_ticks_position('none')
g.ax_marg_y.yaxis.set_ticks_position('none')

plt.savefig(args.outplot)