#!/usr/bin/env python
from pybedtools import BedTool
import numpy as np
import pandas as pd
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

For more than three lists or for more than three bed files,
then a venn diagram becomes inappropriate. This script computes a dataframe
of percentages of overlap or the jaccard index of overlap between all files. 
Unless you want a Chow-Ruskey diagram, which is ugly as sin, then it 
may be best to feed this dataframe into a heatmap plotting script (e.g. plot_heatmap.py).

For percentage dataframe output, cell ij = percentage of sites in list/bed file i that
overlap with sites in list/bed file j. Jaccard index is symmetric, i.e.
cell ij = cell ji.

""")

##################################################
# required args:

parser.add_argument("-i", "--input", nargs='+',
                    help="""required, file paths to lists
or to bed files (if --bed indicated).

""", required=True)

parser.add_argument("--names", nargs='+',
                    help="""required, list of names 
corresponding to input files.

""", required=True)

parser.add_argument("-o","--out", type=str,
                    help="""required, file path to output
""", required=True)

##################################################
# optional args:

parser.add_argument("--do_not_take_union",
                    help="""applies only to bed files.
by defulat, script computes a union across files for bed files (bedtools merge
with overlapping or book-ended features merged), and from that compares
each bed file to the union for percentage and jaccard computations. This can
greatly speed up computations when many files are used as input (basically,
quadatric -> linear in number of files)

""", action='store_true')

parser.add_argument("--bed",
                    help="""input is bed files and not lists

""", action='store_true')
parser.add_argument("--percentage",
                    help="""indicate to compute percentage overlap instead of jaccard

""", action='store_true')

###################################################

args = parser.parse_args()

###################################################
# read in data

if args.bed:
    sites_to_compare = [BedTool(bed).sort() for bed in args.input]
else:
    sites_to_compare = []
    # import each file and convert to a set of elements
    for inf in args.input:
        with open(inf, "r") as f:
            sites_to_compare.append(set([line.strip() for line in f]))

##########
# create union bed, if applicable

if args.bed and not args.do_not_take_union:
    from reddylab_utils import pybedtools_utils
    for i, bed in enumerate(sites_to_compare):
        if i == 0:
            union = bed
        else:
            union = union.cat(bed)
    
    # add a name field for the union bed
    union_df = union.to_dataframe()
    union_df['name'] = range(len(union_df)) 
    union = pybedtools_utils.pandas_dataframe_to_Bedtool(union_df)
    
    # find which union peaks overlap each bed file
    union_dict = {}
    for i, bed in enumerate(sites_to_compare):
        union_dict[i] = set(pybedtools_utils.extract_field(union.intersect(bed, u=True), -1))
else:
    union_dict = {i:sites for i, sites in enumerate(sites_to_compare)}

##########
# compute jaccard or percentage array

array = np.zeros((len(sites_to_compare),len(sites_to_compare)))

if not args.do_not_take_union:
    if not args.percentage:
        # populate jaccard array
        for i in range(len(sites_to_compare)):
            sites1 = union_dict[i]
            for j in range(i):
                sites2 = union_dict[j]
                array[i,j] = len( sites1 & sites2 ) / float(len( sites1 | sites2 ))
    else:
        # populate percentage array
        for i in range(len(sites_to_compare)):
            sites1 = union_dict[i]
            for j in range(len(sites_to_compare)):
                if i == j: # identity = full overlap 
                    array[i,j] = 100.
                else:    
                    sites2 = union_dict[j]
                    array[i,j] = 100 * len( sites1 & sites2 ) / float(len( sites1 ))
else:
    if not args.percentage:
        # populate jaccard array
        for i in range(len(sites_to_compare)):
            print i
            sites1 = sites_to_compare[i]
            for j in range(i):
                sites2 = sites_to_compare[j]
                if len(sites1) == 0:
                    # avoid divide by zero error
                    array[i,j] = 0.0
                else:
                    if args.bed:
                        A = len(sites1.intersect(sites2, u=True)) 
                        B = len(sites1.intersect(sites2, v=True)) 
                        C = len(sites2.intersect(sites1, v=True)) 
                        array[i,j] = A / float(A+B+C)
                    else:
                        array[i,j] = len(sites1 & sites2) / float(len(sites1 | sites2))
    else:
        # populate percentage array
        for i in range(len(sites_to_compare)):
            print i
            sites1 = sites_to_compare[i]
            for j in range(len(sites_to_compare)):
                sites2 = sites_to_compare[j]
                if i == j: # identity = full overlap
                    array[i,j] = 100.
                else:    
                    if len(sites1) == 0:
                        # avoid divide by zero error
                        array[i,j] = 0.0
                    else:
                        if args.bed:
                            array[i,j] = 100.0 * len(sites1.intersect(sites2, u=True)) / float(len(sites1))
                        else:
                            array[i,j] = 100.0 * len(sites1 & sites2) / float(len(sites1))

if not args.percentage:
    # make lower triangular matrix => symmetric
    array = array + np.diag(np.ones(len(array))) + array.T

df = pd.DataFrame(array, index=args.names, columns=args.names)
df.to_csv(args.out, sep="\t", index=True, header=True)