import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix as coo

cells = ['H1','endoderm'] #this script only works if there are 2 timepoints!
T = len(cells)
chroms = ['chr'+str(i) for i in range(1,23)]
radius = 500000
binsize = 10000
r = int(radius/binsize)
n = int(r*2)

diffmats = {}
numSigDbs = {}
for direction in ['+','-']:
    diffmats[direction] = np.zeros((n,n))
    numSigDbs[direction] = 0

for chro in chroms:
    print(chro)
    # get offset coordinate
    bed = pd.read_csv('bed/H1_{}_10kb.bed'.format(chro), sep="\t",header=None,names=['chro','start','end','idx'],index_col=False)
    offset = bed.start.values[0]

    boundaries = pd.read_csv('output/{}/significant_boundaries_summit_only.txt'.format(chro),sep="\t",index_col=False)
    boundaries.fillna(0,inplace=True)

    mats = {}
    for cell in cells:
        #get expected count and O/E counts for those < 1MB apart
        mat = pd.read_csv('mat/{}_{}_10kb.txt'.format(cell, chro), sep="\t",header=None,names=['i','j','cnt'],index_col=False)
        mat['dist'] = np.absolute(mat.i-mat.j)
        mat = mat.loc[mat.dist < n]
        expected = mat.groupby(['dist'])['cnt'].mean().reset_index()
        mat = pd.merge(mat, expected, how='inner', on ='dist',suffixes=['','_expected'])
        mat.cnt = mat.cnt/mat.cnt_expected
        mats[cell] = mat


    for t1 in range(T):
        cell1 = cells[t1]
        mat1 = mats[cell1]
        for t2 in range(t1+1,T):
            cell2 = cells[t2]
            mat2 = mats[cell2]

            # get direction of change, i.e. a boundary presentin T1 and absent in T1 (=1) or
            # present in T2 but absent in T1 (=-1)
            tmpbed = bed.copy()
            tmpbed['direction'] =  boundaries[cell1].values - boundaries[cell2].values

            sigdb = pd.read_csv('output/{}/{}_vs_{}_significantly_differential_boundary_regions.txt'.format(chro,cell1,cell2), sep="\t",header=None,usecols=[0,1,2],names=['chro','start','end'],index_col=False)
            sigdb['start_idx'] = ((sigdb.start-offset)/binsize).astype(int)
            sigdb = pd.merge(sigdb,tmpbed,how='left',on=['chro','start','end'])

            for _,row in sigdb.iterrows():
                print(_)
                center = row.start_idx
                window_start = center-r
                window_end = center+r-1
                
                submat1 = mat1.loc[mat1.i.between(window_start,window_end) & mat1.j.between(window_start,window_end)].copy()
                submat2 = mat2.loc[mat2.i.between(window_start,window_end) & mat2.j.between(window_start,window_end)].copy()

                submat1.i -= window_start
                submat1.j -= window_start
                submat2.i -= window_start
                submat2.j -= window_start

                submat1 = coo((submat1.cnt,(submat1.i,submat1.j)),shape=(n,n)).toarray()
                submat2 = coo((submat2.cnt,(submat2.i,submat2.j)),shape=(n,n)).toarray()

                if row.direction > 0:
                    diffmats['+'] += (submat1-submat2)
                    numSigDbs['+'] +=1
                elif row.direction < 0:
                    diffmats['-'] += (submat1-submat2)
                    numSigDbs['-'] +=1

for direction in ['+','-']:
    diffmat = diffmats[direction]
    numSigDb = numSigDbs[direction]
    print(numSigDb)
    diffmat /= numSigDb
    diffmat += diffmat.T
    diffmat[np.diag_indices_from(diffmat)] /=2
    sns.heatmap(diffmat, cmap='RdBu_r',center=0)
    plt.savefig('diffmat'+direction)
    plt.close()
