from __future__ import division
import pyBigWig
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#use this one if your INPUT is a bed file of just the genes TSSs OR if you want to go to the middle of a consensus sequence or enhancer
def process_center(center, direction=0):
    #input direction=1 if directionality for genomic intervals is known
    #if there is no directionality to coordinates, sfunction will automatically assume they are the same for intervals
    if direction:
        centered=pd.read_csv(center, sep='\t', header=None, usecols=[0,1,2,3,4,5], names = ['Chr', 'start','end', 'seq', 'num','direction'], index_col = False)
        centered['middle']=np.round(centered[['start', 'end']].mean(axis=1))
        return centered
    else:
        centered=pd.read_csv(center, sep='\t', header=None, usecols=[0,1,2], names = ['Chr', 'start','end'], index_col = False)
        centered['middle']=np.round(centered[['start', 'end']].mean(axis=1))
        centered['direction']='+'
        return centered

#use this one if your INPUT is a bed file of the entire gene and you want to plot over its TSS
def process_center_fromfull(center, direction=0):
    #input direction=1 if directionality for genomic intervals is known
    #if there is no directionality to coordinates, function will automatically assume they are the same for intervals
    if direction:
        centered=pd.read_csv(center, sep='\t',  header=None, usecols=[0,1,2,3,4,5], names = ['Chr', 'start','end', 'seq', 'num','direction'], index_col = False)
        centered['middle']=centered.apply(lambda row: row['start'] if row['direction']=='+' else row['end'],axis=1 )
        return centered
    else:
        centered=pd.read_csv(center, sep='\t', header=None, usecols=[0,1,2], names = ['Chr', 'start','end'], index_col = False)
        centered['middle']=centered['start']
        centered['direction']='+'
        return centered

#use this one if your INPUT is a bed file of the entire gene and its promotor and you want to pileup over its TSS
def process_center_frompromgene(center):
    #input direction=1 if directionality for genomic intervals is known
    #if there is no directionality to coordinates, function will automatically assume they are the same for intervals
    centered=pd.read_csv(center, sep='\t', header=None, usecols=[0,1,2,3,4,5], names = ['Chr', 'start','end', 'seq', 'num','direction'], index_col = False)
    centered['middle']=centered.apply(lambda row: row['start']+2000 if row['direction']=='+' else row['end']-2000,axis=1 )
    return centered

def get_pileup(pileup_type, window, bins, center,averaged=1):
    #pileup_type = bigwigs
    #window = window size, will do that window size upstream and downstream
    #bins = number of bins total
    #center = what to center the pileups over, should be output of process_center function
    #averaged =1 means return average, average=0 means return sums
    if averaged:
        a= np.array([pileup_type.stats(center['Chr'][row],max(0,int(center['middle'][row])-window), min(int(center['middle'][row])+window, pileup_type.chroms(center['Chr'][row])), type="mean", nBins=bins, exact=True)
                if center['direction'][row] =='+'else pileup_type.stats(center['Chr'][row],max(0,int(center['middle'][row])-window), min(int(center['middle'][row])+window, pileup_type.chroms(center['Chr'][row])), type="mean", nBins=bins, exact=True)[::-1]
                for row in range(len(center))])
        a=a.astype(dtype=np.float64)
        a[np.isnan(a)]=0
        return a.sum(axis=0)/len(center)

    else:
        a= np.array([pileup_type.stats(center['Chr'][row],max(0,int(center['middle'][row])-window), min(int(center['middle'][row])+window, pileup_type.chroms(center['Chr'][row])), type="mean", nBins=bins, exact=True)
                if center['direction'][row] =='+'else pileup_type.stats(center['Chr'][row],max(0,int(center['middle'][row])-window), min(int(center['middle'][row])+window, pileup_type.chroms(center['Chr'][row])), type="mean", nBins=bins, exact=True)[::-1]
                for row in range(len(center))])
        a=a.astype(dtype=np.float64)
        a[np.isnan(a)]=0
        return a.sum(axis=0)



def generate_axis_label(window,bins):
    #window, bins should be same as entered into get_pileup
    axis_label = range(-window, 0, int(2*window/bins)) + range(int(2*window/bins), window + int(2*window/bins), int(2*window/bins))
    return axis_label



def pileupplotter(outputname, center, pileup1, pileup2, pileup3, window=2000, bins=200, ylim=5, averaged=1):
#pileupplotter(outputname, center, pileup1, pileup2, pileup3, window=2000, bins=200, ylim=5, averaged=1):
#averaged =1 means return average, average=0 means return sums
#edit figure characteristics as desired

    pileup_1=get_pileup(pileup1, window, bins, center, averaged)
    pileup_2=get_pileup(pileup2, window, bins, center, averaged)
    pileup_3=get_pileup(pileup3, window, bins, center, averaged)


    axis_label=generate_axis_label(window,bins)

    fig = plt.figure(figsize=(10, 5))

    #ax1 = plt.subplot(321)
    #ax2 = plt.subplot(323)
    ax3 = plt.subplot(122)
    #ax4 = plt.subplot(325)

    #ax1.plot(axis_label,pileup_1, zorder=2, color='green')
    #ax1.set_xlabel('Distance from Center of CTCF Consensus Motif (BP) ')
    #ax1.set_title('ES YY1 Pileups')

    #ax2.plot(axis_label,pileup_2, zorder=2, color='blue')
    #ax2.set_xlabel('Distance (BP) ')
    #ax2.set_title('NPC YY1 Pileups')

    #ax4.plot(axis_label,pileup_3, zorder=2, color='red')
    #ax4.set_xlabel('Distance (BP) ')
    #ax4.set_title('ProB YY1 Pileups')
    
    plt.setp(ax3.get_xticklabels(), fontsize=20)
    plt.setp(ax3.get_yticklabels(), fontsize=20)

    ax3.plot(axis_label,pileup_1, zorder=2, color='green')
    ax3.plot(axis_label,pileup_2, zorder=2, color='blue')
    ax3.plot(axis_label,pileup_3, zorder=2, color='red')
    ax3.set_ylim([0, ylim])
    ax3.set_xlim([-window, window])
    ax3.set_xlabel('Distance (BP)')
    ax3.legend(['NPC YY1', 'ES YY1', 'ProB YY1'])
    #ax3.legend(['NPC YY1', 'ES YY1'])
    #ax3.set_title('Genes_H3K27ac_Pileups')

    plt.tight_layout()

    plt.savefig(outputname)
    return
#examples

#first, load bigwigs

#NPC_CTCF = pyBigWig.open("Project004-IPSHOCH-Exp5-ChIP-NPC-CTCF-IP-Blobel-Rep2-S1_S8_peakcall_MACS2_rendition_1E-8_downsampled_try3_treat_pileup_sorted_clipped.bw")
#ES_CTCF = pyBigWig.open("Project004-IPSHOCH-Exp5-ChIP-ES-CTCF-IP-Blobel-Rep2-S1_S2_MACS2_rendition_1E-8_downsampled_try3_treat_pileup_sorted_clipped.bw")
#ES_YY1=pyBigWig.open("mES_Untreated_YY1_batch3_onlymapped_sorted_noduplicates_downsampled_treat_pileup_sorted_clipped.bw")
#NPC_YY1=pyBigWig.open("YY1_NPC_onlymapped_sorted_noduplicates_downsampled_treat_pileup_sorted_clipped.bw")
#NPC_Cohesin=pyBigWig.open("Phillips_NPC_Smc1_pe1_onlymapped_sorted_noduplicates_downsampled_treat_pileup_sorted_clipped.bw")
#ProB_YY1=pyBigWig.open("Pro-B_cell_Yy1_onlymapped_sorted_noduplicates_downsampled_treat_pileup_sorted_clipped.bw")
#NPC_H3K27ac=pyBigWig.open("Group1_Creghyton_NPC_H3K27ac_Phillips_WCE_MACS2_1E-8_downsampled_treat_pileup_sorted_clipped.bw")
#ES_H3K27ac=pyBigWig.open("Group1_Creghyton_ES_H3K27ac_Creghyton_WCE_MACS2_1E-8_downsampled_treat_pileup_sorted_clipped.bw")



#then, load the things of interest
#constit_genes_2 = process_center_fromfull('constit_genes_2.bed', direction=1)
#constit_genes_4 = process_center_fromfull('constit_genes_4.bed', direction=1)
#ES_genes_2 = process_center_fromfull('Es_genes_2.bed', direction=1)
#ES_genes_4 = process_center_fromfull('Es_genes_4.bed', direction=1)
#NPC_genes_2 = process_center_fromfull('Npc_genes_2.bed', direction=1)
#NPC_genes_4 = process_center_fromfull('Npc_genes_4.bed', direction=1)
#ES_enhancers=process_center('ES_enhancers.bed', direction=0)
#NPC_enhancers=process_center('NPC_enhancers.bed', direction=0)


#constit_genes_2_loop = process_center_fromfull('constit_genes_2_inloops_full.bed', direction=1)
#constit_genes_4_loop = process_center_fromfull('constit_genes_4_inloops_full.bed', direction=1)
#ES_genes_2_loop = process_center_fromfull('es_genes_2_inloops_full.bed', direction=1)
#ES_genes_4_loop = process_center_fromfull('es_genes_4_inloops_full.bed', direction=1)
#NPC_genes_2_loop = process_center_fromfull('npc_genes_2_inloops_full.bed', direction=1)
#NPC_genes_4_loop = process_center_fromfull('npc_genes_4_inloops_full.bed', direction=1)
#ES_enhancers_loop=process_center('ES_enhancers_inloops_full.bed', direction=0)
#NPC_enhancers_loop=process_center('NPC_enhancers_inloops_full.bed', direction=0)


#constit_genes_2_bg = process_center_fromfull('constit_genes_2_notinloops_full.bed', direction=1)
#constit_genes_4_bg = process_center_fromfull('constit_genes_4_notinloops_full.bed', direction=1)
#ES_genes_2_bg = process_center_fromfull('ES_genes_2_notinloops_full.bed', direction=1)
#ES_genes_4_bg = process_center_fromfull('ES_genes_4_notinloops_full.bed', direction=1)
#NPC_genes_2_bg = process_center_fromfull('NPC_genes_2_notinloops_full.bed', direction=1)
#NPC_genes_4_bg = process_center_fromfull('NPC_genes_4_notinloops_full.bed', direction=1)
#ES_enhancers_bg=process_center('ES_enhancers_notinloops_full.bed', direction=0)
#NPC_enhancers_bg=process_center('NPC_enhancers_notinloops_full.bed', direction=0)

#then generate plots
#pileuplotter('ES_genes_2UP_H3K27ac', ES_genes_2, NPC_H3K27ac, ES_H3K27ac)
#pileuplotter('ES_genes_4UP_H3K27ac', ES_genes_4, NPC_H3K27ac, ES_H3K27ac)
#pileuplotter('NPC_genes_2UP_H3K27ac', NPC_genes_2, NPC_H3K27ac, ES_H3K27ac)
#pileuplotter('NPC_genes_4UP_H3K27ac', NPC_genes_4, NPC_H3K27ac, ES_H3K27ac)
#pileuplotter('Constit_genes_2UP_H3K27ac', constit_genes_2, NPC_H3K27ac, ES_H3K27ac)
#pileuplotter('Constit_genes_4UP_H3K27ac', constit_genes_4, NPC_H3K27ac, ES_H3K27ac)
