#!/usr/bin/env python
import numpy as np
from pybedtools import BedTool
from scipy.stats import fisher_exact
from sys import argv
import brewer2mpl
import os
import matplotlib
matplotlib.use('Agg')
font = {'size'   : 8}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt
from reddylab_utils.reddylab_pybedtools_utils import compute_overlap_between_two_beds
from reddylab_utils.reddylab_plotting_utils import color_blind_color_converter,simple_bar_with_err_plot,remove_spines
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

XXX 

""")

##################################################
# required args:

parser.add_argument("-q", "--query_bed_files", nargs='+',
                    help="""required, file paths to query bed file(s)
/path/to/query_bed1.bed /path/to/query_bed2.bed

""", required=True)
parser.add_argument("-t", "--target_bed_file",
                    help="""required, file paths to target bed file
/path/to/sites.bed

""", required=True)
parser.add_argument("--outplot", help="required, path to outplot", required=True)

##################################################
# optional args:

parser.add_argument("--target_name",  help="optional, name for --target_bed_file", default="target sites")
parser.add_argument("--query_names",  nargs='+', help="optional, name(s) for --query_bed_files")
parser.add_argument("--colors",  nargs='+', help="optional, colors for --query_bed_files")
parser.add_argument("--target_to_query",  help="""optional, instead of computing overlap of each query to target,
compute overlap of target to query (essentially, this just changes the denominator)""", action="store_true")

##################################################

args = parser.parse_args()

##################################################

# import data
target_bed_file = BedTool(args.target_bed_file)
query_bed_files = [BedTool(query_bed_file) if os.path.isfile(query_bed_file) else BedTool([]) for query_bed_file in args.query_bed_files]

query_names = args.query_names if args.query_names else range(len(query_bed_files))

# overlaps_y = {}
# overlaps_err = {}
# overlaps_num_dict = {}

# compute overlap of query with target
overlaps_y, overlaps_err, overlaps_nums = [],[],[]
for name, query_bed_file in zip(query_names, query_bed_files):
    if not args.target_to_query:
        y, err, overlap, N = compute_overlap_between_two_beds(query_bed_file, target_bed_file)
    else:
        y, err, overlap, N = compute_overlap_between_two_beds(target_bed_file, query_bed_file)
    
#     overlaps_y[name] = y
#     overlaps_err[name] = err
#     overlaps_num_dict[name] = (overlap, N)
    overlaps_y.append(y)
    overlaps_err.append(err)
    overlaps_nums.append((overlap, N))

#########################

fig,ax = plt.subplots(figsize=(max(4, len(query_names)*0.5), 6))

if args.colors:
    colors = [color_blind_color_converter(color) for color in args.colors]
else:
    colors = [plt.cm.rainbow(x) for x in np.linspace(0,1,len(query_names))]

# plot bar plot
# simple_bar_with_err_plot(ax, [overlaps_y[name] for name in query_names], [overlaps_err[name] for name in query_names], colors)
simple_bar_with_err_plot(ax, overlaps_y, overlaps_err, colors)
remove_spines(ax, ['top','right'])
ax.spines['left'].set_position(('outward', 10))

if args.query_names:
    ax.set_xticks(np.arange(len(query_names)) * 0.5 + 0.5 - 0.25)
    ax.set_xticklabels(query_names, rotation=45, ha="right")
    ax.xaxis.set_ticks_position('none') 

if not args.target_to_query:
    ylabel = "Percentage of sites overlapping %s"%(args.target_name)
else:
    ylabel = "Percentage of %s sites overlapping with x dataset"%(args.target_name)

ylabel = ylabel if len(ylabel) < 80 else " ".join(ylabel.split(" ")[:len(ylabel.split(" "))/2]) + "\n" + " ".join(ylabel.split(" ")[len(ylabel.split(" "))/2:])
ax.set_ylabel(ylabel, fontsize=8)
ax.set_ylim((0,ax.get_ylim()[1]))
ax.set_xlim((ax.get_xlim()[0],ax.get_xlim()[1]+0.02))
plt.tight_layout()
plt.savefig(args.outplot)

##########
# Print percentage overlap and Fisher's Exact test pairwise comparisons
##########

# for query_name in query_names:
#     overlap_y = overlaps_y[query_name]
#     if not args.target_to_query:
#         print "Percentage of %s overlapping %s = %0.2f%%"%(query_name, args.target_name, overlap_y)
#     else:
#         print "Percentage of %s overlapping %s = %0.2f%%"%(args.target_name, query_name, overlap_y)

# print
# for i, query_name1 in enumerate(query_names):
#     overlap_1, N_1 = overlaps_num_dict[query_name1]
#     print
#     for j in range(i):
#         query_name2 = query_names[j]
#         overlap_2, N_2 = overlaps_num_dict[query_name2]
#         _,p = fisher_exact([[overlap_1, N_1 - overlap_1],[overlap_2, N_2 - overlap_2]])
#         if not args.target_to_query:
#             print "Fisher's Exact Test, %s overlapping %s vs. %s, p = %0.3e"%(args.target_name, query_name1, query_name2, p)
#         else:
#             print "Fisher's Exact Test, %s vs. %s, overlapping %s, p = %0.3e"%(query_name1, query_name2, args.target_name, p)

for i, query_name in enumerate(query_names):
    overlap_y = overlaps_y[i]
    if not args.target_to_query:
        print "Percentage of %s overlapping %s = %0.2f%%"%(query_name, args.target_name, overlap_y)
    else:
        print "Percentage of %s overlapping %s = %0.2f%%"%(args.target_name, query_name, overlap_y)

print
for i, query_name1 in enumerate(query_names):
    overlap_1, N_1 = overlaps_nums[i]
    for j in range(i):
        query_name2 = query_names[j]
        overlap_2, N_2 = overlaps_nums[j]
        _,p = fisher_exact([[overlap_1, N_1 - overlap_1],[overlap_2, N_2 - overlap_2]])
        if not args.target_to_query:
            print "Fisher's Exact Test, %s overlapping %s vs. %s, p = %0.3e"%(args.target_name, query_name1, query_name2, p)
        else:
            print "Fisher's Exact Test, %s vs. %s, overlapping %s, p = %0.3e"%(query_name1, query_name2, args.target_name, p)

