#!/usr/bin/env python
import pandas as pd
import numpy as np
import argparse
import matplotlib
matplotlib.use('Agg')
font = {'size'   : 8}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt
import seaborn.apionly as sns
from reddylab_utils.reddylab_plotting_utils import discrete_cmap,adjust_spines

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

box_plot_by_lists.py

Given lists of values, plot a violin distribution for each lists.

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args

required.add_argument("-i", "--lists", nargs="+", help="required, lists of values",  required=True)
required.add_argument("--names", nargs="+", help="required, names of lists",  required=True)
required.add_argument("-o", "--outplot", help="required, output plot", required=True)


##################################################
# optional args
# parser.add_argument("--cmap", help="optional, cmap for violin(s)", default="rainbow")

optional.add_argument("--palette", help="optional, cmap for box(es)", default="muted")
optional.add_argument("--ymin", help="optional, ymin", type=float)
optional.add_argument("--ymax", help="optional, ymax", type=float)
optional.add_argument("--ylabel", help="optional, ylabel", type=str, default="")
optional.add_argument("--xlabel", help="optional, xlabel", type=str, default="")
optional.add_argument("--orient", help="optional, orientation", type=str, default="v")
optional.add_argument("--add_regression_line", help="optional,", action="store_true")
optional.add_argument("--simple_aesthetics", help="optional,", action="store_true")

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

# lists = ["auxiliary/TFCP2L1.FFL_DEG_type.independent_ambiguous_dn.continuous_motif_score.rsat_clustered.GR.nearest_non_promoter_EP300.within_100kb.cut.txt",
# "auxiliary/TFCP2L1.FFL_DEG_type.independent_ambiguous_up.continuous_motif_score.rsat_clustered.GR.nearest_non_promoter_EP300.within_100kb.cut.txt",
# "auxiliary/TFCP2L1.FFL_DEG_type.independent_dn.continuous_motif_score.rsat_clustered.GR.nearest_non_promoter_EP300.within_100kb.cut.txt",
# "auxiliary/TFCP2L1.FFL_DEG_type.independent_up.continuous_motif_score.rsat_clustered.GR.nearest_non_promoter_EP300.within_100kb.cut.txt"]
# names = ["independent_ambiguous_dn","independent_ambiguous_up", "independent_dn","independent_up"]
# ymin = 0
# ymax = 10
# outplot= "plots/test.png"

list_of_lists = []
for l in args.lists:
    with open(l, "r") as f:
        list_of_lists.append([float(line.strip()) for line in f])

# colors = discrete_cmap(len(list_of_lists), args.cmap)

orig_names = args.names

names = [[name]*len(l) for name,l in zip(args.names, list_of_lists)]
names = [name for l in names for name in l]
elements = [element for l in list_of_lists for element in l]

data = pd.DataFrame({'x':names,'y':elements})

if args.orient == 'v':
    figsize = (len(orig_names), 4)
else:
#     figsize = (len(orig_names), len(orig_names))
    pass

fig,ax = plt.subplots(figsize=figsize)

# ax = sns.violinplot(x="x", y="y", ax=ax, cut=0,
#                     data=data, palette=args.palette,
#                     orient=args.orient)
if not args.simple_aesthetics:
    ax = sns.boxplot(x="x", y="y", ax=ax, width=0.8, 
                    flierprops={"marker": "."},
                    data=data, palette=args.palette)
else:
    ax = sns.boxplot(x="x", y="y", ax=ax, width=0.8, 
                     flierprops={"marker": ".", "markeredgecolor":None},
                    data=data, color="white")
    for i,artist in enumerate(ax.artists):
        artist.set_edgecolor('black')
        artist.set_facecolor('None')
        
        # Each box has 6 associated Line2D objects (to make the whiskers, fliers, etc.)
        # Loop over them here, and use the same colour as above
        for j in range(i*6,i*6+6):
            line = ax.lines[j]
            line.set_color('black')
            line.set_mfc('black')
            line.set_mec('black')
    
if args.add_regression_line:
    from scipy.stats import linregress
    names_to_num = {name:i for i,name in enumerate(orig_names)}
    slope, intercept, r_value, p_value, std_err = linregress([names_to_num[x] for x in data["x"]], data["y"])
    print slope, intercept, r_value, p_value, std_err
    xmin = min(names_to_num.values())
    xmax = max(names_to_num.values())
    ymin = slope * xmin + intercept
    ymax = slope * xmax + intercept
    print [xmin, xmax], [ymin, ymax]
    ax.plot([xmin, xmax], [ymin, ymax], color='red', lw=1)

adjust_spines(ax, ['left','bottom'])

ymin = args.ymin if isinstance(args.ymin, float) else ax.get_ylim()[0]
ymax = args.ymax if isinstance(args.ymax, float) else ax.get_ylim()[1]

ax.set_ylim((ymin,ymax))
ax.set_xticklabels(orig_names, rotation=45, ha="right")
ax.set_ylabel(args.ylabel)
ax.set_xlabel(args.xlabel)
plt.savefig(args.outplot, bbox_inches='tight')