#!/usr/bin/env python
from sys import argv
import pandas as pd
from GGR import utils
from scipy.stats import pearsonr

import matplotlib
matplotlib.use('Agg')
font = {'size'   : 8}
matplotlib.rc('font', **font)
import matplotlib.pyplot as plt
from beeswarm import *

mat = argv[1]
surrogate_variables = argv[2]
covariates = argv[3]
outplot = argv[4]
outerr = argv[5]

# mat = "/data/reddylab/projects/GGR/data/rna_seq/quantified_read_counts/iter0_union/featurecounts.genes.counts.selected_reps.protein_coding.txt"
# surrogate_variables = "/data/reddylab/projects/GGR/data/rna_seq/quantified_read_counts/iter0_union/featurecounts.genes.counts.selected_reps.surrogate_variables.protein_coding.txt"
# covariates = "/data/reddylab/projects/GGR/analyses/reports/rna_seq_metadata_iter0.tsv"
# outplot = "/data/reddylab/projects/GGR/results/rna_seq/differential_expression/iter0/plots/surrogate_variables_versus_covariates.protein_coding.png"
# outplot = "/data/reddylab/projects/GGR/results/rna_seq/differential_expression/iter0/plots/surrogate_variables_versus_covariates.protein_coding.png"

samples = list(pd.read_csv(mat, sep='\t', index_col=0).columns)
surrogate_variables = pd.read_csv(surrogate_variables, sep='\t', header=None)
surrogate_variables.index = samples
covariates = pd.read_csv(covariates, sep='\t')
covariates['sample'] = [ '_'.join(sample.split('_')[1:]) for sample in covariates['sample'] ]

scatter_mat = pd.DataFrame()
scatter_mat['sample'] = surrogate_variables.index
for i in range(surrogate_variables.shape[1]):
    scatter_mat['SV' + str(i+1)] = list(surrogate_variables[i])

def convert_sample_name_to_timepoint(sample):
    t = sample.split('_')[0]
    if '.' in t:
        t = t.split('.')[1]
    return t

if 'timepoint' not in set(scatter_mat.columns):
    scatter_mat['timepoint'] = [utils.timepoint_conversion_to_numeric(convert_sample_name_to_timepoint(sample)) for sample in scatter_mat['sample'] ]

for cov in ['lane','batch']:
    if cov in set(covariates.columns):
        cov_dict = {}
        for sample, cov_val in zip(covariates['sample'], covariates[cov]):
            cov_dict[sample] = cov_val
        
        scatter_mat[cov] = [cov_dict[sample.replace('plus','')] for sample in scatter_mat['sample']]

with open(outerr, 'w') as f:
    f.write('SV\tPearson correlation coefficient\tp-value\n')

# test for correlation between SVs and time variable
SVs_to_drop = []
for i in range(surrogate_variables.shape[1]):
    r, p = pearsonr(scatter_mat['SV' + str(i+1)], scatter_mat['timepoint'])
    with open(outerr, 'a') as f:
        f.write('%s\t%s\t%s'%('SV' + str(i+1), r, p) + '\n')
    
    if p < 0.01:
        SVs_to_drop.append(str(i+1))
        
print ','.join(SVs_to_drop)

#####
# Plot
#####

covs = sorted(set(scatter_mat.columns) & set(['timepoint','lane','batch']))
svs = ['SV' + str(i+1) for i in range(surrogate_variables.shape[1])]

ymin = surrogate_variables.values.min() - 0.1
ymax = surrogate_variables.values.max() + 0.1

fig, axes = plt.subplots(ncols=len(covs), nrows=len(svs), sharey=True)
for i, sv in enumerate(svs):
    for j, cov in enumerate(covs):
        xs = sorted(set(scatter_mat[cov]))
        d = [list(scatter_mat[scatter_mat[cov] == x][sv]) for x in xs]
        if len(covs) == 1:
            try:
                bs, ax = beeswarm(d, ax=axes, ylim=(ymin, ymax), method="center", labels=xs, col=["black"])
            except AttributeError:
                bs, ax = beeswarm(d, ax=axes[i], ylim=(ymin, ymax), method="center", labels=xs, col=["black"])
        else:
            bs, ax = beeswarm(d, ax=axes[i][j], ylim=(ymin, ymax), method="center", labels=xs, col=["black"])
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        if j == 0:
            ax.set_ylabel(sv)
        if i == len(svs) - 1:
            ax.set_xlabel(cov)
            ax.xaxis.set_ticks_position('bottom')
            ax.set_xticklabels(labels=xs, rotation=45, ha='right', fontsize=8)
        else:
            ax.set_xticklabels([])
            ax.spines['bottom'].set_visible(False)
            ax.xaxis.set_ticks_position('none')

plt.subplots_adjust(bottom=0.2)
plt.savefig(outplot)
