#!/usr/bin/env python
import numpy as np
from pybedtools import BedTool

def pandas_dataframe_to_Bedtool(df):
    """
    Convert a pandas dataframe to a bedtool object.

    df: pandas dataframe where chrom, chromStart, and 
               chromEnd are the first, second and third columns
    returns:
        bed: pybedtools BedTool object
    """
#     chroms, starts, ends = list(df.ix[:,0].copy()), list(df.ix[:,1].copy()), list(df.ix[:,2].copy())
#     bed = '\n'.join(['\t'.join([chrom,str(start),str(end)]) for chrom,start,end in zip(chroms,starts,ends)])
#     bed = BedTool(bed, from_string=True)
    bed = BedTool(df.to_string(header=False, index=False), from_string=True)
    return(bed)

# def pandas_dataframe_to_Bedtool(df):
#     """
#     Convert a pandas dataframe to a bedtool object.

#     df: pandas dataframe where chrom, chromStart, and 
#                chromEnd are the first, second and third columns
#     returns:
#         bed: pybedtools BedTool object
#     """
#     cols = []
#     for col in df.columns:
#         cols.append(list(df.ix[:,col].copy()))
    
#     for cols in zip(cols):
#         cols = appen
    
# #     chroms, starts, ends = list(df.ix[:,0].copy()), list(df.ix[:,1].copy()), list(df.ix[:,2].copy())
#     bed = '\n'.join(['\t'.join([chrom,str(start),str(end)]) for chrom,start,end in zip(chroms,starts,ends)])
#     bed = BedTool(bed, from_string=True)
#     return(bed)


def compute_dist_between_beds(bed1, bed2, io=True):
    """
    Compute the distance to the nearest record in bed2 to each
    record in bed1.

    bed1: pybedtools BedTool object
    bed2: pybedtools BedTool object
    io: if True ignore overlapping bed2 records in bed1,
        otherwise do not ignore overlapping
    
    returns:
        dists: numpy array of distances
    """
    # compute the closest bed2 to each bed1 record
    # "first" makes sure that no records are repeated
    bed = bed1.sort().closest(bed2.sort(), t="first", d=True, io=io)
    # distances are reported in final field of bed
    dists = np.array([int(dist.split()[-1]) for dist in str(bed).split('\n') if dist != ''])
    return(dists)

def compute_nearest_bed2_to_every_bed1(bed1, bed2, io=True):
    """
    Find the nearest record in bed2 to each
    record in bed1.

    bed1: pybedtools BedTool object with at least chrom, 
        chromStart, chromEnd, and name columns
    bed2: pybedtools BedTool object with at least chrom, 
        chromStart, chromEnd, and name columns
    io: if True ignore overlapping bed2 records in bed1,
        otherwise do not ignore overlapping
    
    returns:
        names1: list of names in bed1
        names2: list of names in bed2 nearest and in 
                corresponding order to bed1
    """
    bed1 = bed1.cut([0,1,2,3])
    bed2 = bed2.cut([0,1,2,3])
    # compute the closest bed2 to each bed1 record
    # "first" makes sure that no records are repeated
    bed = bed1.sort().closest(bed2.sort(), t="first", d=True, io=io)
    # extract names
    name1 = [record.split()[3] for record in str(bed).split('\n') if record != '']
    name2 = [record.split()[7] for record in str(bed).split('\n') if record != '']
    return(names1, names2)

def remove_intersecting_datasets(bed, bed_list):
    """
    Remove all records from bed that intersect with
    any records in the bed in bed_list
    
    bed: pybedtools BedTool object
    bed_list: list of pybedtools BedTool objects 
    
    returns:
        non: pybedtools BedTool object like bed but
             with any records that overlap with beds
             in bed_list removed
    """
    for i, rm_bed in enumerate(bed_list):
        if i == 0:
            non = bed.intersect(rm_bed, v=True)
        else:
            non = non.intersect(rm_bed, v=True)
    
    return non

def extract_field(bed, field):
    """
    Extract a numbered field (0-indexed) from pybedtools object.
    """
    return np.array([line.split('\t')[field] for line in str(bed).split('\n') if line != ''])

def subset_bed_by_names_in_field(bed, names, field):
    """
    Subset bed file by the condition that a given field matches a
    list or set of given names.
    """
    names = set(names)
    records_subset = [record for record in str(bed).split("\n") if record != "" and record.split("\t")[field] in names]
    bed = BedTool("\n".join(records_subset), from_string=True)
    
    return bed







# def return_overlaps(bed1, bed2, name_col_bed1, name_col_bed2):
    
#     num_cols_bed1 = len(str(bed1).split('\n')[0].split('\t'))
#     num_cols_bed2 = len(str(bed2).split('\n')[0].split('\t'))
    
#     bed1_intersect_bed2 = (bed1.sort().intersect(bed2.sort(), wb=True)).to_dataframe()
#     names_bed1 = list(bed1_intersect_bed2[list(bed1_intersect_bed2.columns)[name_col_bed1]])
#     names_bed2 = list(bed1_intersect_bed2[list(bed1_intersect_bed2.columns)[num_cols_bed1 + name_col_bed2]])
#     bed1_to_bed2_overlap_dict = {name_bed1:name_bed2 for name_bed1, name_bed2 in zip(names_bed1, names_bed2)} 
#     return(bed1_to_bed2_overlap_dict)

# def count_intersections_in_bedtool_wao_format(intersect_wao):
#     """
#     Count the number of intersections per element from
#     bedtools closest command

#     bed1: pybedtools BedTool object with at least chrom, 
#         chromStart, chromEnd, and name columns
#     bed2: pybedtools BedTool object with at least chrom, 
#         chromStart, chromEnd, and name columns
#     io: if True ignore overlapping bed2 records in bed1,
#         otherwise do not ignore overlapping
    
#     returns:
#         names1: list of names in bed1
#         names2: list of names in bed2 nearest and in 
#                 corresponding order to bed1
#     """
#     from collections import defaultdict
#     cnt = defaultdict(int)
#     for line in str(intersect_wao).split('\n'):
#         if (line == ''):
#             continue
#         line = line.split('\t')
#         name = line[3]
#         if line[4] == '.':
#             cnt[name] = 0
#         else:
#             cnt[name] += 1
    
#     counts = np.array(cnt.values())
#     return counts

# def assign_site_to_nearest_gene(sites, genes, site_name_col, gene_name_col):
    
#     sites_to_genes = sites.sort().closest(genes.sort(),  wb=True, d=True)
#     sites_to_genes_df = sites_to_genes.to_dataframe()
    
#     site_name_col = [col for col in sites_to_genes_df.columns][site_name_col]
#     gene_name_col = [col for col in sites_to_genes_df.columns][gene_name_col]
    
#     site_to_genes_dict = defaultdict(list)
#     gene_to_site_dict = defaultdict(list)
#     for site, gene in zip(list(sites_to_genes_df[site_name_col]), list(sites_to_genes_df[gene_name_col])):
#         site_to_genes_dict[site].append(gene)
#         gene_to_site_dict[gene].append(site)
    
#     for site, genes in site_to_genes_dict.iteritems():
#         site_to_genes_dict[site] = np.random.choice(genes)
    
#     return([site_to_genes_dict, gene_to_site_dict])

# def return_overlaps(bed1, bed2, name_col_bed1, name_col_bed2):
#     num_cols_bed1 = len(str(bed1).split('\n')[0].split('\t'))
#     num_cols_bed2 = len(str(bed2).split('\n')[0].split('\t'))
#     bed1_intersect_bed2 = (bed1.sort().intersect(bed2.sort(), wb=True)).to_dataframe()
#     names_bed1 = list(bed1_intersect_bed2[list(bed1_intersect_bed2.columns)[name_col_bed1]])
#     names_bed2 = list(bed1_intersect_bed2[list(bed1_intersect_bed2.columns)[num_cols_bed1 + name_col_bed2]])
#     bed1_to_bed2_overlap_dict = {name_bed1:name_bed2 for name_bed1, name_bed2 in zip(names_bed1, names_bed2)} 
#     return(bed1_to_bed2_overlap_dict)
