#!/usr/bin/env python
from pybedtools import BedTool
import pandas as pd
import numpy as np
from collections import defaultdict
from scipy.stats import fisher_exact,mannwhitneyu
from sys import argv
from GGR import pybedtools_utils
import rpy2.robjects as robjects

loops = argv[1]
enhancers = argv[2].split(",")
names = argv[3].split(",")
ORs_out = argv[4]
ORs_low_out = argv[5]
ORs_hi_out = argv[6]
pvals_out = argv[7]

################################################

enhancer_bed_dict = {}
for name, enhancer in zip(names, enhancers):
    enhancer_bed_dict[name] = BedTool(enhancer)

# read loop bedpe (w/ or w/o logFC information)
loops = pd.read_csv(loops, sep="\t", header=None, names=['chrom1','start1','end1','chrom2','start2','end2','logFC'])
loops['num'] = np.arange(len(loops))

################################################
# create a bed file of upstream and downstream anchors

def pandas_dataframe_to_Bedtool(df):
    """
    Convert a pandas dataframe to a bedtool object.
    df: pandas dataframe where chrom, chromStart, and 
               chromEnd are the first, second and third columns
    returns:
        bed: pybedtools BedTool object
    """
    chroms, starts, ends, names = list(df.ix[:,0].copy()), list(df.ix[:,1].copy()), list(df.ix[:,2].copy()), list(df.ix[:,3].copy())
    bed = '\n'.join(['\t'.join([chrom,str(start),str(end),str(name)]) for chrom,start,end,name in zip(chroms,starts,ends,names)])
    bed = BedTool(bed, from_string=True)
    return(bed)

loop_upstream = pandas_dataframe_to_Bedtool(loops[['chrom1','start1','end1','num']])
loop_dnstream = pandas_dataframe_to_Bedtool(loops[['chrom2','start2','end2','num']])

################################################
# find the loop anchors for groups of enhancers

for name in enhancer_bed_dict.keys():
    enhancer_bed = enhancer_bed_dict[name]
    enhancer_bed_dict[name + "_upstream"] = enhancer_bed.intersect(loop_upstream, wao=True)
    enhancer_bed_dict[name + "_downstream"] = enhancer_bed.intersect(loop_dnstream, wao=True)

################################################
# create dictionary linking each enhancer peak to its associated loop (up- or downstream)

enhancer_peak_loop_dicts = {}
for name, enhancer_bed in enhancer_bed_dict.iteritems():
    if "_upstream" not in name and "_downstream" not in name:
        continue
    enhancer_df = enhancer_bed.to_dataframe()
    enhancer_df.columns = ['_','_','_','peak_name','_','_','_','loop_num','_']
    enhancer_df = enhancer_df[enhancer_df['loop_num'] != "."]
    enhancer_peak_loop_dicts[name] = {peak_name:loop_num for peak_name,loop_num in zip(enhancer_df['peak_name'],enhancer_df['loop_num']) if loop_num != "."}

################################################
# create dictionary linking loop to associated enhancers (inverse dict of the above)

enhancer_loop_peak_dicts = {}
for name, peak_loop_d in enhancer_peak_loop_dicts.iteritems():
    loop_peak_d = defaultdict(list)
    for peak_name,loop_num in peak_loop_d.iteritems():
        loop_peak_d[loop_num].append(peak_name)
        
    enhancer_loop_peak_dicts[name] = loop_peak_d

################################################


all_loops = set(np.arange(len(loops)).astype('str'))

################################################

def fisher_exact_wrapper(set1, set2, all_set):
    A = len(set1 & set2)
    B = len(set1 - set2)
    C = len(set2 - set1)
    D = len(all_set - ( set1 | set2 ))
    print A,B,C,D
    return fisher_exact([[A,B],[C,D]])

################################################
# Find enrichment by FET for co-occurrence of across types of enhancers
# on both sides of loop

ORs = np.zeros((len(names),len(names)))
ORs_low = np.zeros((len(names),len(names)))
ORs_hi = np.zeros((len(names),len(names)))
pvals = np.zeros((len(names),len(names)))

for i,name1 in enumerate(names):
    for j in range(i):
#         if j == i:
#             o,p = fisher_exact_wrapper(set(enhancer_loop_peak_dicts[name1 + "_upstream"].keys()), 
#                                    set(enhancer_loop_peak_dicts[name1 + "_downstream"].keys()), all_loops)
#             print "Enrichment of %s enhancer in anchor if other anchor has %s enhancer, log2 OR = %0.2f, p = %0.2e"%(name1, name1, np.log2(o), p)
#         else:
        name2 = names[j]

        positives = \
        set(enhancer_loop_peak_dicts[name1 + "_upstream"].keys()) \
        & set(enhancer_loop_peak_dicts[name2 + "_downstream"].keys()) | \
        set(enhancer_loop_peak_dicts[name1 + "_downstream"].keys()) \
        & set(enhancer_loop_peak_dicts[name2 + "_upstream"].keys())

        set1_not_set2 = (set(enhancer_loop_peak_dicts[name1 + "_upstream"].keys()) | \
                set(enhancer_loop_peak_dicts[name1 + "_downstream"].keys())) - positives
        
        set2_not_set1 = (set(enhancer_loop_peak_dicts[name2 + "_upstream"].keys()) | \
             set(enhancer_loop_peak_dicts[name2 + "_downstream"].keys())) - positives
        
        negatives = all_loops - ( positives | set1_not_set2 | set2_not_set1 )
        
        A = len( positives | ( set1_not_set2 & set2_not_set1 ) )
        B = len( set1_not_set2 - set2_not_set1 )
        C = len( set2_not_set1 - set1_not_set2 )
        D = len( negatives )
            
        v = robjects.IntVector([A, B, C, D])
        m = robjects.r['matrix'](v, nrow = 2)
        res = robjects.r['fisher.test'](m)
            
        OR = res[2][0]
        OR_low = res[1][0]
        OR_hi = res[1][1]
        
        p = res[0][0]
        
        OR_low, OR, OR_hi = np.log2([OR_low, OR, OR_hi])
        
        print "Enrichment of %s enhancer in anchor if other anchor has %s enhancer, log2 OR = %0.2f, p = %0.2e"%(name1, name2, OR, p)
        
        ORs_low[i,j] = OR_low
        ORs_hi[i,j] = OR_hi
        ORs[i,j] = OR
        pvals[i,j] = p

ORs = ORs + ORs.T
ORs_low = ORs_low + ORs_low.T
ORs_hi = ORs_hi + ORs_hi.T
pvals = pvals + pvals.T + np.diag(np.ones(len(pvals)))

ORs = pd.DataFrame(ORs, index=names, columns=names)
ORs_low = pd.DataFrame(ORs_low, index=names, columns=names)
ORs_hi = pd.DataFrame(ORs_hi, index=names, columns=names)
pvals = pd.DataFrame(pvals, index=names, columns=names)

ORs.to_csv(ORs_out, sep="\t")
ORs_low.to_csv(ORs_low_out, sep="\t")
ORs_hi.to_csv(ORs_hi_out, sep="\t")
pvals.to_csv(pvals_out, sep="\t")