""" This script takes in a config file containing 
parameter and input file list and draws a figure with
a HiC matrix heatmap, TADs, and a set of signal tracks
usage: python pyrmaid.py [config].ini
dependencies: heatmap.py, bar.py, readdata.py
"""
import matplotlib.patches as patches
import sys
import csv
import numpy as np
from configparser import ConfigParser
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import gridspec
from matplotlib.colors import Normalize
import heatmap
import bar
import readdata as rd
import seaborn as sns

def main(config_file):
	# Read in parameters from config file:
	config = ConfigParser()
	config.read(config_file)
	chro = config.get('coordinates','chro')
	start = config.getint('coordinates','start')
	end = config.getint('coordinates', 'end')
	binsize = config.getint('coordinates','binsize')
	contexts = config.get('content','contexts').split(',')
	signals = config.get('content','signals').split(',')
	nCols = len(contexts)
	nSignals = len(signals)
	# hic file names
	hicfiles = []
	for i in range(nCols):
		fname = config.get('hicfiles','file'+str(i))
		hicfiles.append(fname)
	#pval = np.loadtxt(config.get('files','pval'))
	score = np.loadtxt(config.get('files','score'))
	sig = np.loadtxt(config.get('files','sig'))
	# output file:
	output=config.get('files','output')
	
	#Drawing a grid to fit all plots
	fig = plt.figure(figsize = (nCols*4,4.8))
	height_ratios = [3.5, 0.5, 0.1]
	gs = gridspec.GridSpec(figure=fig,ncols=nCols,nrows=nSignals+1,
		height_ratios=height_ratios)

	#Draw subplots to fill in
	heatmaps = []
	tracks = []
	for i in range(nCols):
		heatax = fig.add_subplot(gs[0,i],frame_on=False)
		heatmaps.append(heatax)
	for i in range(nSignals):
		row = []
		for j in range(nCols):
			trackax = fig.add_subplot(gs[i+1,j],frame_on=False)
			row.append(trackax)
		tracks.append(row)

	# draw HiC matrices
	for i in range(nCols):
		C = rd.readSymmetricSparseMatrix(hicfiles[i],0,200)
		np.fill_diagonal(C, 0)
		C = np.sqrt(C)
		#C = np.log2(C+1)
		vmax = np.percentile(C, 95)
		#vmax = C.max()
		ax = heatmaps[i]
		print('Drawing heatmap {:d}...'.format(i+1))
		#heatmap.drawRotatedHalfHeatmapUp(fig,ax,C,vmax)
		#ax.axvline(x=101,ymin=-0.3,c='grey',linestyle='--',alpha=0.5,linewidth=3,clip_on=False)
		sns.heatmap(C, ax=ax, cmap='Reds', vmax=vmax, cbar=False)
		ax.tick_params(right=False,left=False,top=False,labelbottom=False,labelleft=False,bottom=False)		

	# draw rest of tracks
	#colors=sns.husl_palette(nSignals+1,l=0.55)
	max_score = np.percentile(score, 99.7) 
	colors = ['#AA4499','#88ccee','#44aa99','#117733']
	markers = ['*','^','d','p']
	for j in range(nCols):
		bar.drawTrack(tracks[0][j],score[:,j],'#0077bb',max_score)
		#bar.drawTrack(tracks[1][j],pval[:,j],'#009988',2)
		bar.drawMarks(tracks[1][j],200,np.nonzero(sig[:,j])[0],'#ee3377','*')
		for x in range(2):
			ax = tracks[x][j]
		#ax.set_xticklabels([int(x) for x in np.arange(start,end,(end-start)/4)],ha='left')
			ax.tick_params(which='both',
				right=False,left=False,top=False,bottom=False,
				labelbottom=False,labelleft=False)

	gs.tight_layout(fig)
	plt.savefig(output +'.png')

if __name__ == '__main__':
	main(sys.argv[1])
