#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sys import argv
from palettable import tableau

results = argv[1].split(',')
names = argv[2].split(',')
outplot = argv[3]

# results = "mvlr_results_pseudocount_0.01_ctrl_subtracted/elastic_net.TFs_only.ctrl_subtracted.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.all_summed_exclude_proximal.uniform.mean_logFC_pseudocount_0.01.txt,mvlr_results_pseudocount_0.01_ctrl_subtracted/elastic_net.TFs_only.ctrl_subtracted.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.mean_logFC_pseudocount_0.01.txt,mvlr_results_pseudocount_0.01_ctrl_subtracted/elastic_net.TFs_only.ctrl_subtracted.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_shuffled_summed_exclude_proximal.uniform.mean_logFC_pseudocount_0.01.txt,mvlr_results_pseudocount_0.01_ctrl_subtracted/elastic_net.TFs_only.ctrl_subtracted.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.dhss_extended_no_EP300_no_CTCF_size_matched_to_enhancers_summed_exclude_proximal.uniform.mean_logFC_pseudocount_0.01.txt".split(",")
# names = "all_summed_exclude_proximal,enhancer_extended_summed_exclude_proximal,enhancer_extended_shuffled_summed_exclude_proximal,dhss_extended_no_EP300_no_CTCF_size_matched_to_enhancers_summed_exclude_proximal".split(",")
# outplot = "plots/elastic_net.TFs_only.ctrl_subtracted.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.uniform.mean_logFC_pseudocount_0.01.R2.png"

# results = "mvlr_results_pseudocount_0.1/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.all_summed.laplace.weights_1_100.mean_logFC_pseudocount_0.1.txt,mvlr_results_pseudocount_0.1/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.all_summed_exclude_proximal.laplace.weights_1_100.mean_logFC_pseudocount_0.1.txt,mvlr_results_pseudocount_0.1/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.laplace.weights_1_100.mean_logFC_pseudocount_0.1.txt,mvlr_results_pseudocount_0.1/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_inverse_summed_exclude_proximal.laplace.weights_1_100.mean_logFC_pseudocount_0.1.txt,mvlr_results_pseudocount_0.1/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_shuffled_summed_exclude_proximal.laplace.weights_1_100.mean_logFC_pseudocount_0.1.txt,mvlr_results_pseudocount_0.1/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.dhss_extended_no_EP300_no_CTCF_size_matched_to_enhancers_summed_exclude_proximal.laplace.weights_1_100.mean_logFC_pseudocount_0.1.txt".split(',')
# names = "all_summed,all_summed_exclude_proximal,enhancer_extended_summed_exclude_proximal,enhancer_extended_inverse_summed_exclude_proximal,enhancer_extended_shuffled_summed_exclude_proximal,dhss_extended_no_EP300_no_CTCF_size_matched_to_enhancers_summed_exclude_proximal".split(',')
# N = float(11896)
# outplot = "plots/simple_mvlr.TFs_only.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.laplace.weights_1_100.mean_logFC_pseudocount_0.1.R2.png"

R2_df = pd.DataFrame()
R2_se_df = pd.DataFrame()
for name, result in zip(names, results):
    R2_df[name] = pd.read_csv(result, sep='\t', index_col=0)['R2']
    R2_se_df[name] = pd.read_csv(result, sep='\t', index_col=0)['R2_se']

x = R2_df.index
colors = tableau.Tableau_10.mpl_colors

# standard error was not collected on all R2
R2_se_df = R2_se_df.dropna()

fig, ax = plt.subplots(figsize=(5,5))
for i, col in enumerate(R2_df.columns):
    ax.plot(x, R2_df[col], color=colors[i], lw=1, label=col)
    ax.errorbar(R2_se_df.index, np.array(R2_df[col].ix[R2_se_df.index]), yerr=np.array(R2_se_df[col]), fmt='.', color=colors[i])

ax.set_xlabel('Size of region considered around TSS (kb)')
ax.set_ylabel('R2')
ax.set_xlim((0,max(x)))
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_position(('outward', 10)) 
ax.spines['bottom'].set_position(('outward', 10)) 
ax.legend(frameon=False, fontsize=8, loc="best")

ax.set_ylim((0,ax.get_ylim()[1]))
plt.tight_layout()
plt.savefig(outplot)
plt.savefig(outplot.replace('.png','.pdf'))