#!/usr/bin/env python
import matplotlib.pyplot as plt
from matplotlib_venn import venn2,venn3
from pybedtools import BedTool
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script plots the overlap of two to three lists (or bed files if --bed indicated)
as a venn diagram.

""")

##################################################
# required args:

parser.add_argument("-i", "--input", nargs='+', type=str,
                    help="""required, file paths to input files
/path/to/site_type1.txt /path/to/site_type2.txt
OR
/path/to/site_type1.bed /path/to/site_type2.bed

""", required=True)
parser.add_argument("--names", nargs='+', type=str,
                    help="""names of each input file for plotting

""")
parser.add_argument("-o", "--outplot", type=str,
                    help="""required, file paths to outplot, e.g.,
/path/to/plot.png
""", required=True)

##################################################
# optional args:

parser.add_argument("--bed",
                    help="""
                    
If --bed indicated, input files are bed format and bedtools intersect is used.

""", action='store_true')

parser.add_argument("--frac", type=float, default=1e-9,
                    help="""
                    
If --bed indicated, input files are bed format and bedtools intersect is used.
(default: %(default)s)

""")

##################################################

args = parser.parse_args()

##################################################

def venn2_wrapper_bed(bed1, bed2, name1, name2, overlap, ax):
        
    venn2(subsets={'10': len( bed1.intersect(bed2, f=overlap, r=True, v=True) ), \
                   '01': len( bed2.intersect(bed1, f=overlap, r=True, v=True) ), \
                   '11': len( bed1.intersect(bed2, f=overlap, r=True, u=True) )}, \
          set_labels = (name1, name2), ax=ax)

def venn2_wrapper(set1, set2, name1, name2, ax):
    
    venn2(subsets={'10': len( set1 - set2 ), \
                   '01': len( set2 - set1 ), \
                   '11': len( set1 & set2 )}, \
          set_labels = (name1, name2), ax=ax)

def venn3_wrapper_bed(bed1, bed2, bed3, name1, name2, name3, overlap, ax):
    
    venn3(subsets={'100': len( bed1.intersect(bed2, f=overlap, r=True, v=True).intersect(bed3, v=True) ), \
                   '010': len( bed2.intersect(bed1, f=overlap, r=True, v=True).intersect(bed3, v=True) ), \
                   '001': len( bed3.intersect(bed1, f=overlap, r=True, v=True).intersect(bed2, v=True) ), \
                   '110': len( bed1.intersect(bed2, f=overlap, r=True, u=True).intersect(bed3, v=True) ), \
                   '101': len( bed1.intersect(bed3, f=overlap, r=True, u=True).intersect(bed2, v=True) ), \
                   '011': len( bed2.intersect(bed3, f=overlap, r=True, u=True).intersect(bed1, v=True) ), \
                   '111': len( bed1.intersect(bed2, f=overlap, r=True, u=True).intersect(bed3, u=True) )},\
          set_labels = (name1, name2, name3), ax=ax)

def venn3_wrapper(set1, set2, set3, name1, name2, name3, ax):
    
    venn3(subsets={'100': len( set1 - (set2 | set3) ), \
                   '010': len( set2 - (set1 | set3) ), \
                   '001': len( set3 - (set1 | set2) ), \
                   '110': len( ( set1 & set2 ) - set3 ), \
                   '101': len( ( set1 & set3 ) - set2 ), \
                   '011': len( ( set2 & set3 ) - set1 ), \
                   '111': len( ( set1 & set2 ) & set3 )},\
          set_labels = (name1, name2, name3), ax=ax)


fig, ax = plt.subplots()

if args.bed:
    infiles = [BedTool(inf) for inf in args.input]
    if len(infiles) == 2:
        venn2_wrapper_bed(infiles[0], infiles[1], args.names[0], args.names[1], args.frac, ax)
    else:
        venn3_wrapper_bed(infiles[0], infiles[1], infiles[2], args.names[0], args.names[1], args.names[2], args.frac, ax)
else:
    infiles = []
    for inf in args.input:
        with open(inf, 'r') as f:
            infiles.append(set([line.strip() for line in f]))
    
    if len(infiles) == 2:
        venn2_wrapper(infiles[0], infiles[1], args.names[0], args.names[1], ax)
    else:
        venn3_wrapper(infiles[0], infiles[1], infiles[2], 
                      args.names[0], args.names[1], args.names[2], ax)

plt.tight_layout()
plt.savefig(args.outplot)
