#!/usr/bin/env python
import pandas as pd
from scipy import interpolate
import numpy as np
import matplotlib.pyplot as plt
from sys import argv

prop = argv[1]
outplot = argv[2]

# prop = "num_differential.in.enhancers.with_factor_peaks.up_novel.cumulative.scaled_to_1.txt"
# outplot = "num_differential.in.enhancers.with_factor_peaks.up_novel.cumulative.scaled_to_1.spline.png"

# prop = "heatmap_proportion_of_discoveries_over_time.up.txt"
# outplot = "/data/reddylab/projects/GGR/results/integrative/data_resource_manuscript/spline_test.png"
prop = pd.read_csv(prop, sep="\t", index_col=0)
prop.columns = [0.5,1,2,3,4,5,6,7,8,10,12]
# add zero row
prop[0] = 0
prop = prop[sorted(prop.columns)]

x = list(prop.columns)
xnew = np.arange(0, 12., 12/10000.)

tableau20 = [(31/255., 119/255., 180/255.), 
             (174/255., 199/255., 232/255.),
             (255/255., 127/255., 14/255.),
             (255/255., 187/255., 120/255.),    
             (44/255., 160/255., 44/255.),
             (152/255., 223/255., 138/255.),
             (214/255., 39/255., 40/255.),
             (255/255., 152/255., 150/255.),    
             (148/255., 103/255., 189/255.),
             (197/255., 176/255., 213/255.),
             (140/255., 86/255., 75/255.),
             (196/255., 156/255., 148/255.),    
             (227/255., 119/255., 194/255.),
             (247/255., 182/255., 210/255.),
             (127/255., 127/255., 127/255.),
             (199/255., 199/255., 199/255.),    
             (188/255., 189/255., 34/255.),
             (219/255., 219/255., 141/255.),
             (23/255., 190/255., 207/255.),
             (158/255., 218/255., 229/255.)]    


medians = {}
fig,ax = plt.subplots(figsize=(6,4))
for i,idx in enumerate(prop.index):
    y = list(prop.ix[idx])    
#     tck = interpolate.splrep(x, y, k=5)
    tck = interpolate.PchipInterpolator(x, y)
#     ynew = interpolate.splev(xnew, tck, der=0)
    ynew = tck(xnew)
    medians[idx] = xnew[np.abs(ynew - 0.5).argmin()]
    
    ax.plot(xnew, ynew, lw=2, color=tableau20[i], label=idx)
#     ax.scatter(x, y, lw=2, color="red", label="true")

ax.set_xlim((-0.5,12.5))
ax.set_ylim((0,1.05))
ax.legend(loc="lower right", fontsize=8)
plt.savefig(outplot)
plt.savefig(outplot.replace(".png",".pdf"))

import operator
sorted_medians = sorted(medians.items(), key=operator.itemgetter(1))

print "Dataset\tmedian response time"
for TF,median in sorted_medians:
    print "%s\t%0.3f"%(TF, median)