#!/usr/bin/env python
import numpy as np
from pybedtools import BedTool
from scipy.stats import mannwhitneyu,fisher_exact
from reddylab_utils.reddylab_pybedtools_utils import pandas_dataframe_to_Bedtool,extract_field
import json
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script considers proximity of sites of interest and target
sites by "overlap", which is understood as presence/absence within some window 
around the target using Fisher's Exact Test (by default, 10 Kb). A heatmap will 
be plotted in which each cell will reflect log odds where increased odds means 
increased proximity/overlap. If the log odds in a given cell is not significant 
(p < 0.05), then it will be labeled as such ("N.S."). An alternative method is 
available as well, see --mann_whitney option.

""")

##################################################
# required args:

parser.add_argument("--query_sites_of_interest_list", nargs='+', type=str,
                    help="""required, file paths to beds of query sites of interest, e.g.:
/path/to/site_type1.bed /path/to/site_type2.bed
""", dest="query_sites_of_interest_list", required=True)

parser.add_argument("--target_sites_of_interest_list", nargs='+', type=str,
                    help="""required, file paths to beds of target sites of interest, e.g.:
/path/to/site_type1.bed /path/to/site_type2.bed
""", dest="target_sites_of_interest_list", required=True)

parser.add_argument("--null_sites", type=str,
                    help="""optional, file path to bed of null (bkg.) sites of the format:
                    
chrX	1234	2345	siteX

if not provided, background is the union of all sets of query sites (vs.
particular set of query sites under consideration).

""", dest="null_sites")

parser.add_argument('--query_set_names', dest='query_set_names', nargs='+', type=str,
                    help='required, names corresponding (in order) to list of sites of interest',
                    required=True)
parser.add_argument('--target_set_names', dest='target_set_names', nargs='+', type=str,
                    help='required, names corresponding (in order) to list of sites of interest',
                    required=True)
parser.add_argument('--out_ratios', dest='out_ratios', type=str,
                    help='optional, file to which to save log2 odds ratio dataframe', required=True)
parser.add_argument('--out_pvalues', dest='out_pvalues', type=str,
                    help='optional, file to which to save p-values dataframe', required=True)

###################################################
# optional args:

parser.add_argument("-d", '--dist', dest='dist', type=int, default=10000,
                    help="""optional, "dist" is distance 
to be considered for elements to be "overlapping" (default: %(default)s).
No effect if --mann_whitney indicated.
""")

parser.add_argument("--mann_whitney", 
                    help="""
                    
Consider proximity of sites of interest and target sites by 
comparing distributions of distance of sites of interest to 
nearest target site versus distance of sites of interest to
nearest null site using a Mann-Whitney Test. Plotted heatmap
cells in this case will reflect 
log(median dist to null / median dist to target).

""", action='store_true')

parser.add_argument("--ordinal", 
                    help="""
                    
if --ordinal indicated....

""", action='store_true')

parser.add_argument("--group_size_threshold",  dest='group_size_threshold', \
                    type=int, default=0,
                    help="""
Only include comparisons for groups with greater than --group_size_threshold members (default: %(default)s)
""")

###################################################

args = parser.parse_args()

###################################################

def extract_dist_from_BedTool(bed):
    dists = extract_field(bed, -1)
    dists = dists.astype('int')
    # ignore sites that have no matching sites on the same chromosome ("-1")
    # these are very few and there exclusion simplifies analysis without loss of generality
    dists = dists[dists > -1]
    return(dists)

# read in data
query_sites_of_interest_list = [BedTool(sites_of_interest) for sites_of_interest in args.query_sites_of_interest_list]
target_sites_of_interest_list = [BedTool(sites_of_interest) for sites_of_interest in args.target_sites_of_interest_list]


# if null_sites not given, then take a union across all query sites
if args.null_sites:
    null_sites = BedTool(args.null_sites)
else:
    for i,query_sites_of_interest in enumerate(query_sites_of_interest_list):
        if i == 0:
            query_sites_all = query_sites_of_interest
        else: 
            query_sites_all = query_sites_all.cat(query_sites_of_interest)

# check that query and target sites are larger in size 
# (number of group members) than group_size_threshold
new_query_sites_of_interest_list, query_set_names = [], []
new_target_sites_of_interest_list, target_set_names = [], []
for sites_of_interest, set_name in zip(query_sites_of_interest_list, args.query_set_names):
    if len(sites_of_interest) > args.group_size_threshold:
        query_set_names.append(set_name)
        new_query_sites_of_interest_list.append(sites_of_interest)

for sites_of_interest, set_name in zip(target_sites_of_interest_list, args.target_set_names):
    if len(sites_of_interest) > args.group_size_threshold:
        target_set_names.append(set_name)
        new_target_sites_of_interest_list.append(sites_of_interest)

query_sites_of_interest_list = new_query_sites_of_interest_list
target_sites_of_interest_list = new_target_sites_of_interest_list

# Run analysis

ratio_array = np.zeros((len(query_sites_of_interest_list), len(target_sites_of_interest_list)))
p_array = np.zeros((len(query_sites_of_interest_list), len(target_sites_of_interest_list)))

# for each target site
for j, target_sites_of_interest in enumerate(target_sites_of_interest_list):
    
    # compute distance of background to nearest target site
    if args.null_sites:
        dists_to_null = extract_dist_from_BedTool(null_sites.cut([0,1,2]).sort().closest(target_sites_of_interest.cut([0,1,2]).sort(), d=True, io=False))
    else:
        null_sites = query_sites_all.intersect(target_sites_of_interest, v=True)
        dists_to_null = extract_dist_from_BedTool(null_sites.cut([0,1,2]).sort().closest(target_sites_of_interest.cut([0,1,2]).sort(), d=True, io=False))
    
    for i, query_sites_of_interest in enumerate(query_sites_of_interest_list):
        
        # compute distance of query to nearest target site
        dists_to_sig = extract_dist_from_BedTool(query_sites_of_interest.cut([0,1,2]).sort().closest(target_sites_of_interest.cut([0,1,2]).sort(), d=True, io=False))
        
        if not args.mann_whitney: # then run fisher's exact
            
            a = len(dists_to_sig[dists_to_sig < args.dist])
            b = len(dists_to_sig[dists_to_sig > args.dist])
            c = len(dists_to_null[dists_to_null < args.dist])
            d = len(dists_to_null[dists_to_null > args.dist])
            
            # don't want to divide by zero for OR
            b_pseudo = 1 if b == 0 else b
            c_pseudo = 1 if c == 0 else c
            
            OR = float(a * d) / float(b_pseudo * c_pseudo)
                        
            _,p = fisher_exact([[a,b],[c,d]])
            
        else:
            
            OR = np.median(dists_to_null + 1) / np.median(dists_to_sig + 1)
            
            U,p = mannwhitneyu(dists_to_sig, dists_to_null)
            
        ratio_array[i,j] = OR
        p_array[i,j] = p

# take log2 of odds ratios
ratio_array[ratio_array == 0] = 1e-10
ratio_array = np.log2(ratio_array)

# write results to file
if args.out_ratios:
    import pandas as pd
    ratio_df = pd.DataFrame(ratio_array)
    ratio_df.index = query_set_names
    ratio_df.columns = target_set_names
    ratio_df.to_csv(args.out_ratios, sep="\t", index=True, header=True)

if args.out_pvalues:
    import pandas as pd
    p_df = pd.DataFrame(p_array)
    p_df.index = query_set_names
    p_df.columns = target_set_names
    p_df.to_csv(args.out_pvalues, sep="\t", index=True, header=True)
