# telomereSignal.py
# Jonathan M. Galazka, Andrew D. Klocko
# 2015
# Usage: python ./telomereSignal.py
#
# Will plot the average intra-telomeric signal from all chromosomes. LGV is not included
# because the left telomere is not mapped. Also outputs the "strength" of these interactions.
# Also plots whole genome with telomeres indicated with boxes.
# Also plots all 7 telomeres from each strain.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.cm as cm

datasets = ['NMF39','N3944','N5580'] # datasets to open

fig, axes = plt.subplots(nrows=3, ncols=1) # set-up plot comparing all three datasets

counter = 0 # simple counter

# loop through datasets
for i in range(0, len(datasets)):

	dataset = datasets[i]

	resolution = 40

	resolution_string = str(resolution * 1000)

	datasetpath = '/Volumes/HD/HiC2/' + dataset + '/' + dataset + '_ic-' + resolution_string + '-diag-txt/obs-exp.txt'	

	# get chromosome starts and ends
	chr_starts_path = '/Volumes/HD/HiC2/' + dataset + '/' + dataset + '_ic-' + resolution_string + '-diag-txt/chromosomeStarts'
	chr_starts_array = np.loadtxt(chr_starts_path, delimiter=' ')
	chr_starts = np.transpose(chr_starts_array.astype(int)[0:7])
	chr_ends = np.transpose(chr_starts_array.astype(int)[1:8])
	
	# get telomere start and ends
	left_telomere_starts = chr_starts
	left_telomere_ends = chr_starts + 20
	right_telomere_starts = chr_ends - 20
	right_telomere_ends = chr_ends

	array = np.loadtxt(datasetpath, delimiter=' ')
	array[array == 0.0000000000] = np.nan # convert zeros to nans. zeros aren't real just bins with poor quality
	
	# plot whole genome hic with boxes around telomeres
	fig2, ax2 = plt.subplots(1)
	heatmap = ax2.imshow(np.log2(array), interpolation='nearest', cmap=cm.coolwarm)
	heatmap.set_clim(-3,3)	
	currentAxis = plt.gca()
	for i in range(0, 7):
		currentAxis.add_patch(Rectangle((left_telomere_starts[i] - 0.5, right_telomere_ends[i] - 0.5), 4, -4, fill=False, alpha=1, lw=2.0))
	
	# abstract telomere-telomere region from each chromosome and store in list
	telomeres = list()
	for j in range(0,7):
		l_tel_start = left_telomere_starts[j]
		l_tel_ends = left_telomere_ends[j]
		r_tel_starts = right_telomere_starts[j]
		r_tel_ends = right_telomere_ends[j]	
		telomere = array[r_tel_starts:r_tel_ends, l_tel_start:l_tel_ends]
		telomeres.append(telomere)
	
	# make average telomere-telomere ignoring LGV
	mean_array = np.zeros((telomeres[0].shape[0], telomeres[0].shape[0]))
	for j in range(0, telomeres[0].shape[0]):
		for k in range(0, telomeres[0].shape[0]):
			values_at_position = list()
			for i in range(0, len(telomeres)):
				if(i == 4):
					pass
				else:
					values_at_position.append(telomeres[i][j,k])
			mean_at_position = np.nanmean(values_at_position)
			mean_array[j,k] = mean_at_position
	
	# plot average telomere-telomere				
	heatmap2 = axes[counter].imshow(np.log2(mean_array), interpolation='nearest', cmap=cm.coolwarm)
	axes[counter].tick_params(axis='both',which='both',bottom='off',top='off',left='off',right='off',labelbottom='off',labeltop='off',labelleft='off',labelright='off') 
	axes[counter].set_title(dataset)
	heatmap2.set_clim(-3,3)
	
	# add boxes
	bottom = mean_array.shape[0]
	left = 0
	top = 0 + 4
	right = 20 - 4
	width = 4
	height = -4
	axes[counter].add_patch(Rectangle((left - 0.5, bottom - 0.5), width, height, fill=False, alpha=1, lw=1.0))
	axes[counter].add_patch(Rectangle((right - 0.5, top - 0.5), width, height, fill=False, alpha=1, lw=1.0))

	# get telomere-telomere signal (bottom_left) and background (top_right)
	bottom_left = mean_array[16:19, 0:3]
	top_right = mean_array[0:3, 16:19]
	
	# output telomere-telomere signal
	print('Strain: ' + dataset)
	print('Telomere signal: ' + str(np.sum(bottom_left)))
	print('Non telomere signal: ' + str(np.sum(top_right)))
	print('Ratio: ' + str(np.sum(bottom_left) / np.sum(top_right)) + '\n')

	# plot all 7 telomeres from this strain
	fig3, axes2 = plt.subplots(nrows=1, ncols=7)
	fig3.suptitle(dataset)
	for i in range(0, 7):
		heatmap = axes2[i].imshow(np.log2(telomeres[i]), interpolation='nearest', cmap=cm.coolwarm)
		axes2[i].tick_params(axis='both',which='both',bottom='off',top='off',left='off',right='off',labelbottom='off',labeltop='off',labelleft='off',labelright='off') 
		axes2[i].set_title(str(i + 1))
		heatmap.set_clim(-3,3)
	
	counter = counter + 1

plt.show()
#fig.savefig('/Users/galazkaj/Desktop/telomere_heatmap_wt_dim5_hpo.pdf')

	