#!/usr/bin/env python
import pandas as pd
import numpy as np
import argparse
import matplotlib.pyplot as plt
import seaborn.apionly as sns
# from reddylab_utils.reddylab_plotting_utils import discrete_cmap,adjust_spines

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

kde_plot_by_lists.py

Given lists of values, plot a kde for each list.

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args

required.add_argument("--lists", nargs="+", help="required, lists of values",  required=True)
required.add_argument("--names", nargs="+", help="required, names of lists",  required=True)
required.add_argument("--outplot", help="required, output plot", required=True)

##################################################
# optional args

optional.add_argument("--colors", nargs="+", help="optional, cmap for box(es)", default="muted")
optional.add_argument("--xmin", help="optional, xmin", type=float)
optional.add_argument("--xmax", help="optional, xmax", type=float)
optional.add_argument("--ylabel", help="optional, ylabel", type=str, default="")
optional.add_argument("--xlabel", help="optional, xlabel", type=str, default="")

optional.add_argument("--shade", help="optional,", action="store_true")
optional.add_argument("--add_medians", help="optional,", action="store_true")

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

####################
# read lists
list_of_lists = []
for l in args.lists:
    with open(l, "r") as f:
        list_of_lists.append([float(line.strip()) for line in f])

####################
# define plot attributes
tableau20 = [(31/255., 119/255., 180/255.), 
             (174/255., 199/255., 232/255.),
             (255/255., 127/255., 14/255.),
             (255/255., 187/255., 120/255.),    
             (44/255., 160/255., 44/255.),
             (152/255., 223/255., 138/255.),
             (214/255., 39/255., 40/255.),
             (255/255., 152/255., 150/255.),    
             (148/255., 103/255., 189/255.),
             (197/255., 176/255., 213/255.),
             (140/255., 86/255., 75/255.),
             (196/255., 156/255., 148/255.),    
             (227/255., 119/255., 194/255.),
             (247/255., 182/255., 210/255.),
             (127/255., 127/255., 127/255.),
             (199/255., 199/255., 199/255.),    
             (188/255., 189/255., 34/255.),
             (219/255., 219/255., 141/255.),
             (23/255., 190/255., 207/255.),
             (158/255., 218/255., 229/255.)]    
tableau10 = tableau20[::2]

if args.colors:
    colors = args.colors
elif len(list_of_lists) <= 10:
    colors = tableau10[:len(list_of_lists)]
else:
    colors = tableau20[:len(list_of_lists)]

# rc={'axes.labelsize': 8., 
#     'font.size': 8., 
#     'xtick.labelsize':8., 
#     'ytick.labelsize':8., 
#     'legend.fontsize': 8., 
#     'axes.titlesize': 8.}
# sns.set(rc=rc)
# sns.set(xtick.labelsize=8)
sns.set(font_scale = 0.5,
       style="ticks")

####################
# plot

fig,ax = plt.subplots(figsize=(3,2))
for i,l in enumerate(list_of_lists):
    sns.kdeplot(l, ax=ax, color=colors[i], label=args.names[i], 
                ls="-", lw=1, shade=args.shade)
    if args.add_medians:
        ax.axvline(x=np.median(l), ls="-", lw=1, color=colors[i])

if args.xmin:
    ax.set_xlim((args.xmin,ax.get_xlim()[1]))

if args.xmax:
    ax.set_xlim((ax.get_xlim()[0],args.xmax))

sns.despine()
# plt.tight_layout()
plt.savefig(args.outplot)