# getHetFlankInteractions.py
# Jonathan M. Galazka, Andrew D. Klocko
# 2015
# Usage: python ./getHetFlankInteraction.py
# 
# Usage to output the strength of interactions between heterochromatin flanks


import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import sys


# Import array
datasets = ['NMF39_1', 'NMF39_2', 'N3944_1', 'N3944_2', 'N5580_1', 'N5580_2']
resolution = 10000
res_string = str(resolution)
chr_starts_path = '/Volumes/HD/HiC2/NMF39/NMF39_ic-' + res_string + '-diag-txt/ChromosomeStarts'
chr_starts_array = np.loadtxt(chr_starts_path, delimiter=' ')
chr_starts = np.transpose(chr_starts_array.astype(int)[0:7])
chr_ends = np.transpose(chr_starts_array.astype(int)[1:8])

arrays = []
for dataset in datasets:
	dataset_path = '/Volumes/HD/HiC2/' + dataset + '/' + dataset + '_ic-hm-' + res_string + '-diag-txt/heatmap'	
	array = np.loadtxt(dataset_path, delimiter=' ')
	array = array + 0.1
	array = array + np.mean(array)
	arrays.append(array)

het_file = '/Users/galazkaj/Dropbox/hicScripts/Figure4/heterochromatin_all.txt'
het_filehandle = open(het_file, 'r')


lgs = []
starts = []
ends = []
for line in het_filehandle:
	line = line.strip()
	elements = line.split('\t')
	lg, start, stop = int(elements[0]), int(elements[1]), int(elements[2]) + 1
	lgs.append(lg)
	starts.append(start)
	ends.append(stop)
	
for i in range(0, len(lgs)):

	lg = lgs[i]
	chr_start = chr_starts[lg]
	start = starts[i]
	end = ends[i]
	size = end - start
	actual_start = chr_start + start
	actual_end = chr_start + end
	
	wt1 = arrays[0][actual_start - 15:actual_end + 15 , actual_start - 15:actual_end + 15]
	wt2 = arrays[1][actual_start - 15:actual_end + 15 , actual_start - 15:actual_end + 15]

	dim51 = arrays[2][actual_start - 15:actual_end + 15 , actual_start - 15:actual_end + 15]
	dim52 = arrays[3][actual_start - 15:actual_end + 15 , actual_start - 15:actual_end + 15]

	hpo1 = arrays[4][actual_start - 15:actual_end + 15 , actual_start - 15:actual_end + 15]
	hpo2 = arrays[5][actual_start - 15:actual_end + 15 , actual_start - 15:actual_end + 15]

	
	list = [wt1, wt2, dim51, dim52, hpo1, hpo2]
	
	het_s = 15
	het_e = 15 + size
	
	a = 0
	meas_r_s = 15 + size + a
	meas_r_e = 15 + size + a + size
	meas_c_s = 15 - a - size
	meas_c_e = 15 - a
	
	left_s = 15 - a - size
	left_e = 15 - a
	
	right_s = 15 + a
	right_e = 15 + a + size
	

	sys.stdout.write(str(lg) + '_' + str(start) + '_' + str(end) + '\t')
	
	ratios = []
	for het in list:
	
		flanks = np.sum(het[meas_r_s:meas_r_e,meas_c_s:meas_c_e])
		left = np.sum(het[left_s:left_e,left_s:left_e]) / 2.0
		right = np.sum(het[right_s:right_e,right_s:right_e]) / 2.0
		ratio = flanks / ((left + right) / 2.0)
		ratios.append(str(ratio))
		
	print('\t'.join(ratios))
	
	
	
	f1, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3)
	f1.suptitle(str(lg) + ' ' + str(start) + ' ' + str(end) + ' ' + str(ratios[0]) + ' ' + str(ratios[1]) + ' ' + str(ratios[2]))
	
	heat1 = ax1.imshow(np.log2(wt1), interpolation='nearest', cmap=cm.coolwarm)
	heat1.set_clim(1,7)
	heat2 = ax2.imshow(np.log2(dim51), interpolation='nearest', cmap=cm.coolwarm)
	heat2.set_clim(1,7)
	heat3 = ax3.imshow(np.log2(hpo1), interpolation='nearest', cmap=cm.coolwarm)
	heat3.set_clim(1,7)
	
	plt.show()
	
