#!/bin/env python
import pandas as pd
import numpy as np
from pybedtools import BedTool
from GGR import pybedtools_utils
from sys import argv

fimo = argv[1]
NR_bed = argv[2]
other_beds = argv[3].split(",")
out = argv[4]

# fimo = "AtT20/AtT20.GR.dex.20m.fimo_best.txt"
# NR_bed = "AtT20/AtT20.GR.dex.20m.bed"
# other_beds = "AtT20/AtT20.EP300.t00.bed".split(",")
# # genome = "/data/reddylab/Reference_Data/Genomes/mm10/GRCm38.sizes"
# # genome = "/data/reddylab/projects/GGR/auxiliary/hg38.sizes"
# out = "AtT20/AtT20.GR.dex.20m.motif_quintiles.by_TF_intersections.txt"

names = ['name','pval','start','end','strand','match']
fimo = pd.read_csv(fimo, sep="\t", header=None, names=names)

NR_bed = BedTool(NR_bed)
other_beds = [BedTool(bed) for bed in other_beds]

sites_dict = {}
if len(other_beds) == 2:
    sites_dict['-,-'] = pybedtools_utils.extract_field(NR_bed.intersect(other_beds[0], v=True).intersect(other_beds[1], v=True), -1)
    sites_dict['+,-'] = pybedtools_utils.extract_field(NR_bed.intersect(other_beds[0], u=True).intersect(other_beds[1], v=True), -1)
    sites_dict['-,+'] = pybedtools_utils.extract_field(NR_bed.intersect(other_beds[0], v=True).intersect(other_beds[1], u=True), -1)
    sites_dict['+,+'] = pybedtools_utils.extract_field(NR_bed.intersect(other_beds[0], u=True).intersect(other_beds[1], u=True), -1)
elif len(other_beds) == 1:
    sites_dict['-'] = pybedtools_utils.extract_field(NR_bed.intersect(other_beds[0], v=True), -1)
    sites_dict['+'] = pybedtools_utils.extract_field(NR_bed.intersect(other_beds[0], u=True), -1)

pvals = np.array(sorted(fimo['pval']))

cutoffs = [1] + [x[-1] for x in np.array_split(pvals[::-1], 5)]

def bin_by_quintiles(array, cutoffs):
    nums = []
    for i,high in enumerate(cutoffs[:-1]):
        low = cutoffs[i+1]
        nums.append(len(array[np.where((array > low) & (array < high))]))
    
    return nums

quintile_counts = {}; pvals_by_site_type = {}
for k in sorted(sites_dict):
    sub_pvals = np.array(fimo[fimo['name'].isin(set(sites_dict[k]))]['pval'])
    pvals_by_site_type[k] = sub_pvals
    quintile_counts[k] = bin_by_quintiles(sub_pvals, cutoffs)
    
quintile_counts_df = pd.DataFrame(quintile_counts)
quintile_counts_df = quintile_counts_df / quintile_counts_df.sum(axis=0) 
quintile_counts_df.index = range(1,6)
quintile_counts_df.to_csv(out, sep="\t")

from scipy.stats import mannwhitneyu
for i,k1 in enumerate(sorted(pvals_by_site_type)):
    for j in range(i):
        k2 = sorted(pvals_by_site_type)[j]
        u,p = mannwhitneyu(-np.log10(pvals_by_site_type[k1]), 
                           -np.log10(pvals_by_site_type[k2]))
        print "MWU, %s sites vs. %s sites, p = %0.2e"%(k1, k2, p)