# observedExpectedMethod.py
# Jonathan M. Galazka, Andrew D. Klocko
# 2015
# Usage: python ./observedExpectedMethod.py
# Outputs 5 plots that demonstrate how observed/expected heatmaps are calculated.

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from mpl_toolkits.axes_grid1 import make_axes_locatable
import scipy.stats as stats
import matplotlib.cm as cm

# given array (a) and a diagonal index (k) returns all array indices
# along that diagonal
def kth_diag_indices(a, k):
	"""
	This function was take from this stackoverflow post:
	http://stackoverflow.com/questions/10925671/numpy-k-th-diagonal-indices
	And is credited to Hans Then
	"""
	rows, cols = np.diag_indices_from(a)
	if k < 0:
		return rows[:k], cols[-k:]
	elif k > 0:
		return rows[k:], cols[:-k]
	else:
		return rows, cols

# given a colormap (cmap), shifts the midpoint		
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower ofset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax/(vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highets point in the colormap's range.
          Defaults to 1.0 (no upper ofset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap
 
# given an array, finds the background by fitting a curve y = a^(-bx) + c to
# a plot of medians vs. distance. background is c.  		    
def getBackground(array):

	arraywidth, arraylength = array.shape
	
	diagmedians = list()

	for i in range(0, arraywidth):
		diagonalmedian = np.median(np.diagonal(array, offset = i))
		diagmedians.append(diagonalmedian)
		
	x = np.asarray(range(0, len(diagmedians)))

	def fitFunc(x, a, b, c):
		return a*np.exp(-b*x) + c

	fitParams, fitCovariances = curve_fit(fitFunc, x, diagmedians)
	
	background = fitParams[2]
	
	return background, diagmedians

# given an array and background, returns a 2D heatmap of expected values along each diagonal
def calculateExpected(array, background, cutoff): 

	arraywidth, arraylength = array.shape
	diagmedians = list()

	for i in range(0, arraywidth):
		diagonalmedian = np.median(np.diagonal(array, offset = i))
		diagmedians.append(diagonalmedian)
		
	x = np.asarray(range(1, len(diagmedians)))
		
	def powerLaw(x, a, k, alpha):
		return a + k * (x ** alpha)
	
	fitParams, fitCovariances = curve_fit(powerLaw, x, diagmedians[1:len(diagmedians)], p0 = [-1, 100, -.9])
	
	a = fitParams[0]
	k = fitParams[1]
	alpha = fitParams[2]
	
	expectedarray = np.zeros((arraywidth, arraylength), dtype=float)

	for k in range(-arraywidth + 1, arraywidth):
	
		rows, cols = kth_diag_indices(expectedarray, k)
		
		zipped_rows_cols = zip(rows, cols)
	
		
		for row, col in zipped_rows_cols:
		
			distance = abs(row - col)
			
			if(distance < cutoff):
				expectedarray[row, col] = diagmedians[distance]
			else:
				expectedarray[row, col] = background
	
	return expectedarray

# determines the slope in windows of size windowsize of a plot of medians	
def rollingLinearRegression(medians, windowsize):

	slopes = list()
	
	for i in range(0, len(medians) - windowsize):
	
		window_values = medians[i : i + windowsize]
		window_index = range(0, windowsize)
		slope, intercept, r_value, p_value, std_err = stats.linregress(window_index, window_values)
		slopes.append(slope)
		
	return slopes
	
def detectReachBackground(slopes):
	
	slope_num = 0
	
	for slope in slopes:
		
		if(slope > -0.001):
			mb = slope_num * 0.04
			cutoff = slope_num + 10
			return mb, cutoff
			
		else:
			pass
		
		slope_num = slope_num + 1
				
# datasets to cycle through
datasets = ['NMF39_1','NMF39_2','N3944_1','N3944_2','N5580_1','N5580_2']

# how you want these datasets labeled in figure legends
labels_list = ['WT 1','WT 2','dim-5 1','dim-5 2','hpo 1','hpo 2']

# just some empty lists
arrays = list()
norm_arrays = list()
expected_arrays = list()
percent_oes = list()
observed_expecteds = list()
cutoffs = list()

### 
# set-up 5 figures
figure1 = plt.figure()
ax1 = figure1.add_axes([0.1,0.1,0.35,0.8])
ax1.set_title('Medians')
ax1.set_ylim([0,75])
ax1.set_xlim([-0.1,4])
ax1.set_xlabel('Distance (mb)')
ax1.set_ylabel('Median number of contacts')
ax2 = figure1.add_axes([0.5,0.1,0.35,0.8])
ax2.set_title('Normalized medians')
ax2.set_ylim([0,75])
ax2.set_xlim([-0.1,4])
ax2.set_xlabel('Distance (mb)')
ax2.yaxis.set_ticklabels([])

figure3 = plt.figure()
main_ax3 = figure3.add_axes([0.1,0.1,0.8,0.8])
inset_ax3 = figure3.add_axes([0.4, 0.25, 0.4, 0.4])
main_ax3.set_title('Slopes')
main_ax3.set_xlim([-0.1,4])
main_ax3.set_ylim([-10,1])
main_ax3.set_xlabel('Distance (bins)')
main_ax3.set_ylabel('Slope (0.4 mb sliding window)')
inset_ax3.set_xlim([-0.1,4])
inset_ax3.set_ylim([-0.5,0.2])
inset_ax3.set_xlabel('Distance (mb)')
inset_ax3.set_ylabel('Slope (0.4 mb sliding window)')

figure4 = plt.figure()
ax4 = figure4.add_subplot(111)
ax4.set_title('Wt normalized')

figure5 = plt.figure()
ax5 = figure5.add_subplot(111)
ax5.set_title('Expected')

figure6 = plt.figure()
ax6 = figure6.add_subplot(111)
ax6.set_title('Percent observed over Expected')

###
# cycle through datasets and calculate distance cutoff at which none of the datasets
# have a contact frequency trend any longer

counter = 0 # just a counter

for dataset in datasets:
	
	resolution = '40000' # resolution of datasets
	
	lg_to_plot = 1 # lg to plot. zero based.
	
	# first and last bins of chromosomes
	if resolution == '40000':
		chr_starts = [0, 245, 357, 489, 640, 801, 907]
		chr_ends = [245, 357, 489, 640, 801, 907, 1014]
	elif resolution == '20000':
		chr_starts = [0, 490, 714, 978, 1279, 1601, 1812]
		chr_ends = [490, 714, 978, 1279, 1601, 1812, 2025]
			
	# dataset path. these must have been exported using the Mirny Lab h5DictToTxt.py script
	datasetpath = '/Volumes/HD/HiC2/' + dataset + '/' + dataset + '_ic-' + resolution + '-diag-txt/heatmap'
	
	array = np.loadtxt(datasetpath, delimiter=' ') # load array
	
	array = array + 1.0 # add 1.0 to get rid of pesky zeros
	
	norm_array = array / np.mean(array) # normalize by dividing by the mean of the array
	
	arrays.append(array) # add array to array list
	
	norm_arrays.append(norm_array) # add normalize array to normalize array list
	
	background, diagmedians = getBackground(array) # get the background level and medians of the diagonals
	
	norm_background, norm_diagmedians = getBackground(norm_array) # get background and medians from normalized
	
	slopes = rollingLinearRegression(diagmedians, 10) # calculate slopes of 100 bin rolling windows
	
	cutoff = detectReachBackground(slopes) # detect the distance in mb at which slope approaches zero
	
	cutoffs.append(cutoff) # add to list
	
	median_bins_to_kb = range(0, len(diagmedians)*40, 40) # convert bins to kb. change this if changing resolution.
	median_mb = list()
	for kb in median_bins_to_kb:
		mb = kb / 1000.0
		median_mb.append(mb)
	
	slopes_to_kb = range(0, len(slopes)*40, 40) # convert bins to kb. change this if changing resolution.
	slope_mb = list()
	for kb in slopes_to_kb:
		mb = kb / 1000.0
		slope_mb.append(mb) 
	
	ax1.plot(median_mb, diagmedians, label = labels_list[counter]) # plot diagonal means
	
	ax2.plot(median_mb, norm_diagmedians, label = labels_list[counter]) # plot normalized diagonal means
	
	main_ax3.plot(slope_mb, slopes, label = labels_list[counter]) # plot slopes
	
	inset_ax3.plot(slope_mb, slopes, label = labels_list[counter]) # plot slopes
	
	counter = counter + 1 # counter up by one
	
	
###
# determine maximum cutoff of all datasets to use for all datasets	
cutoff_array = np.asarray(cutoffs)
cutoff_max = np.amax(cutoff_array)
cutoff_plus10 = cutoff_max + 0.4
cutoff_to_mb = cutoff_plus10 * 0.04
main_ax3.plot([cutoff_to_mb,cutoff_to_mb],[-0.5,0.5],color='black')
inset_ax3.plot([cutoff_to_mb,cutoff_to_mb],[-0.1,0.1],color='black')


###
# cycle back through normalized arrays and use cutoff determined above to make
# expected heatmaps
for norm_array in norm_arrays:
	
	resolution = '40000' # resolution of datasets
	
	lg_to_plot = 1 # lg to plot. zero based.
	
	# first and last bins of chromosomes
	if resolution == '40000':
		chr_starts = [0, 245, 357, 489, 640, 801, 907]
		chr_ends = [245, 357, 489, 640, 801, 907, 1014]
	elif resolution == '20000':
		chr_starts = [0, 490, 714, 978, 1279, 1601, 1812]
		chr_ends = [490, 714, 978, 1279, 1601, 1812, 2025]
	
	
	norm_background, norm_diagmedians = getBackground(norm_array) # get background and medians from normalized
	
	expectedarray = calculateExpected(norm_array, norm_background, cutoff_plus10) # calculate expected array
	
	expected_arrays.append(expectedarray) # add expected to list of expecteds
	
	observed_expected = norm_array / expectedarray # calculate observed over expected ration
	
	observed_expecteds.append(observed_expected) # add to list
	
	percent_oe = (100 * observed_expected) - 100 # calculate percent observed over expected
	
	percent_oes.append(percent_oe) # add to list
		

# plot log2 normalized heatmap 
im2 = ax4.imshow(np.log2(norm_arrays[0][chr_starts[lg_to_plot]:chr_ends[lg_to_plot],chr_starts[lg_to_plot]:chr_ends[lg_to_plot]]))
divider2 = make_axes_locatable(ax4)
cax2 = divider2.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im2, cax=cax2)

# plot log2 expected heatmap
im3 = ax5.imshow(np.log2(expected_arrays[0][chr_starts[lg_to_plot]:chr_ends[lg_to_plot],chr_starts[lg_to_plot]:chr_ends[lg_to_plot]]))
divider3 = make_axes_locatable(ax5)
cax3 = divider3.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im3, cax=cax3)

# plot percent observed over expected
lowerlimit = -100.0
upperlimit = 200.0
midpointvalue = 1 - (upperlimit / (upperlimit + abs(lowerlimit)))
orig_cmap = cm.bwr
shifted_cmap = shiftedColorMap(orig_cmap, start = 0.0, midpoint = midpointvalue, stop = 1.0, name = 'shiftedcmap')
im4 = ax6.imshow(percent_oes[0][chr_starts[lg_to_plot]:chr_ends[lg_to_plot],chr_starts[lg_to_plot]:chr_ends[lg_to_plot]], cmap=shifted_cmap)
im4.set_clim(lowerlimit, upperlimit)
divider4 = make_axes_locatable(ax6)
cax4 = divider4.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im4, cax=cax4)

figure1.savefig('panelB.eps')
figure3.savefig('panelC.eps')
figure4.savefig('panelA.eps')
figure5.savefig('panelD.eps')
figure6.savefig('panelE.eps')

### end program