# 04_observedExpected.py
# Jonathan M Galazka
#
# First builds an 'expected' heatmap composed of the median frequency of interactions
# at each genomic distance until background has been reached, at which point
# the background value is used instead.
# Second calculate the ratio of the observed values to these expected values.
#
# Usage: python ./04_observedExpected.py
# Use after 03_correctBatch.py to filter datasets and produce uncorrected heatmaps.
# Be sure to use a matching directory structure as in 01_mapBatch.py


import numpy as np
from scipy.optimize import curve_fit
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.cm as cm


# given array (a) and a diagonal index (k) returns all array indices
# along that diagonal
def kth_diag_indices(a, k):
	"""
	This function was take from this stackoverflow post:
	http://stackoverflow.com/questions/10925671/numpy-k-th-diagonal-indices
	And is credited to Hans Then
	"""
	rows, cols = np.diag_indices_from(a)
	if k < 0:
		return rows[:k], cols[-k:]
	elif k > 0:
		return rows[k:], cols[:-k]
	else:
		return rows, cols
 
# given an array, finds the background by fitting a curve y = a^(-bx) + c to
# a plot of medians vs. distance. background is c.  		    
def getBackground(array):

	arraywidth, arraylength = array.shape
	
	diagmedians = list()

	for i in range(0, arraywidth):
		diagonalmedian = np.median(np.diagonal(array, offset = i))
		diagmedians.append(diagonalmedian)
		
	x = np.asarray(range(0, len(diagmedians)))

	def fitFunc(x, a, b, c):
		return a*np.exp(-b*x) + c

	fitParams, fitCovariances = curve_fit(fitFunc, x, diagmedians)
	
	background = fitParams[2]
	
	return background, diagmedians

# given an array and background, returns a 2D heatmap of expected values along each diagonal
def calculateExpected(array, background, cutoff): 

	arraywidth, arraylength = array.shape
	diagmedians = list()

	for i in range(0, arraywidth):
		diagonalmedian = np.median(np.diagonal(array, offset = i))
		diagmedians.append(diagonalmedian)
		
	x = np.asarray(range(1, len(diagmedians)))
		
	def powerLaw(x, a, k, alpha):
		return a + k * (x ** alpha)
	
	fitParams, fitCovariances = curve_fit(powerLaw, x, diagmedians[1:len(diagmedians)], p0 = [-1, 100, -.9])
	
	a = fitParams[0]
	k = fitParams[1]
	alpha = fitParams[2]
	
	expectedarray = np.zeros((arraywidth, arraylength), dtype=float)

	for k in range(-arraywidth + 1, arraywidth):
	
		rows, cols = kth_diag_indices(expectedarray, k)
		
		zipped_rows_cols = zip(rows, cols)
	
		for row, col in zipped_rows_cols:
		
			distance = abs(row - col)
			
			if(distance < cutoff):
				expectedarray[row, col] = diagmedians[distance]
			else:
				expectedarray[row, col] = background
	
	return expectedarray

# determines the slope in windows of size windowsize of a plot of medians	
def rollingLinearRegression(medians, windowsize):

	slopes = list()
	
	for i in range(0, len(medians) - windowsize):
	
		window_values = medians[i : i + windowsize]
		window_index = range(0, windowsize)
		slope, intercept, r_value, p_value, std_err = stats.linregress(window_index, window_values)
		slopes.append(slope)
		
	return slopes
	
# detect when background has been reached.
def detectReachBackground(slopes):
	
	slope_num = 0
	
	for slope in slopes:
		
		if(slope > -0.001):
			cutoff = slope_num + 10
			return cutoff
			
		else:
			pass
		
		slope_num = slope_num + 1
		
		
datasets = ['dataset1','dataset2']

resolutions = [50, 40, 20, 10] # resolution in kb
			
for dataset in datasets:

	for resolution in resolutions:

		res_string = str(resolution * 1000)

		datasetpath = '/Volumes/HD/HiC/' + dataset + '/' + dataset + '-ic-hm-' + res_string + '-txt/heatmap'
	
		chr_starts_path = '/Volumes/HD/HiC/' + dataset + '/' + dataset + '-ic-hm-' + res_string + '-txt/chromosomeStarts'
		chr_starts_array = np.loadtxt(chr_starts_path, delimiter=' ')
		chr_starts = np.transpose(chr_starts_array.astype(int)[0:7])
		chr_ends = np.transpose(chr_starts_array.astype(int)[1:8])
	
		array = np.loadtxt(datasetpath, delimiter=' ')

		width, height = array.shape
	
		zero_indices = np.where(array == 0.000000000000000000e+00) # Get indices of any zeros
	
		array = array + 1.0
	
		background, diagmedians = getBackground(array) # Detect the background of the array
	
		slopes = rollingLinearRegression(diagmedians, 10) # Calculate rolling average of slope with window size 10
	
		cutoff = detectReachBackground(slopes) # Determine when background has been reached. Slope no longer changing.
	
		expectedarray = calculateExpected(array, background, cutoff) # Calculate expected array
	
		background_array = np.empty([width, height]) # Make a background array
		background_array[:] = background # Fill all with the background value determined above

		# Overwrite the background array with expected values in expectedArray but only in intra-chromosomal region. Leave inter-chromosomal regions at background.
		for i in range(0,7):

			background_array[chr_starts[i]:chr_ends[i],chr_starts[i]:chr_ends[i]] = expectedarray[chr_starts[i]:chr_ends[i],chr_starts[i]:chr_ends[i]]
	
		observed_expected = array / background_array # Calculate observed expected ratio
	
		observed_expected[zero_indices] = 0 # Rest zero values to zero.
	
		savepath = '/Volumes/HD/HiC/' + dataset + '/' + dataset + '-ic-hm-' + res_string + '-txt/obs-exp.txt'
	
		np.savetxt(savepath, observed_expected, fmt='%.10f', delimiter=' ')

		# Uncomment to plot
		#heatmap = plt.imshow(np.log2(observed_expected), cmap=cm.coolwarm)
		#heatmap.set_clim(-3,3)
		#plt.colorbar(heatmap)
		#plt.show()