#!/usr/bin/env python
import argparse
import pandas as pd
from scipy.stats import pearsonr,ttest_ind
import matplotlib.pyplot as plt
from beeswarm import *

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script is specifically for time series data and tests whether there is a
significant association between time and surrogate variables (SVs). While SVs
are orthogonal to main effects between individual time points and the baseline
00h time point, they may still be associated strongly with time or may strongly
distinguish 00h from all other time points (in the case of GR binding). This
script plots the association and tests for significance across the time course
with a t-test of the Pearson correlation and a t-test comparing the mean at 00h
to all other samples. If one of the above p-values are 0.01, SV should be discarded.

""")

required = parser.add_argument_group('required arguments')

##################################################
# required args:
required.add_argument("-i", "--mat", help="required, input count matrix.", required=True)

required.add_argument("--surrogate_variables",
    help="""Surrogate variables in a tab-separated, header-less, index-less dataframe\\n\\
of dimension number_samples x number_surrogate_variables""", required=True)

required.add_argument("--outerr", type=str,
    help="required, error output includes Pearson correlation coefficient and p-value of both t-tests", required=True)
required.add_argument("--outplot", type=str, help="required, outplot", required=True)

##################################################
args = parser.parse_args()

samples = list(pd.read_csv(args.mat, sep='\t', index_col=0).columns)
surrogate_variables = pd.read_csv(args.surrogate_variables, sep='\t', header=None)
surrogate_variables.index = samples

scatter_mat = pd.DataFrame()
scatter_mat['sample'] = surrogate_variables.index
for i in range(surrogate_variables.shape[1]):
    scatter_mat['SV' + str(i+1)] = list(surrogate_variables[i])

def convert_sample_name_to_timepoint(sample):
    t = sample.split('.')[-2]
    return t

def timepoint_conversion_to_numeric(t):
    hrs = 0
    convert_dict = {"h":1,"m":1/60.,"s":1/3600.}
    for i in range(0,len(t), 3):
        hrs += float(t[i:i+2]) * convert_dict[t[i+2]]
    
    if hrs == int(hrs):
        hrs = int(hrs)
    
    return hrs

if 'timepoint' not in set(scatter_mat.columns):
    scatter_mat['timepoint'] = [timepoint_conversion_to_numeric(convert_sample_name_to_timepoint(sample)) for sample in scatter_mat['sample'] ]

with open(args.outerr, 'w') as f:
    f.write('SV\tPearson correlation coefficient\tPearson p-value\tt00 vs. not t00, T-test p-value\n')

# test for correlation between SVs and time variable
SVs_to_drop = []
for i in range(surrogate_variables.shape[1]):
    r, r_p = pearsonr(scatter_mat['SV' + str(i+1)], scatter_mat['timepoint'])
    not_t00 = scatter_mat[scatter_mat['timepoint'] != 0.0]['SV' + str(i+1)].values
    t00 = scatter_mat[scatter_mat['timepoint'] ==  0.0]['SV' + str(i+1)].values
    t, t_p = ttest_ind(not_t00, t00)
    with open(outerr, 'a') as f:
        f.write('%s\t%s\t%s\t%s'%('SV' + str(i+1), r, r_p, t_p) + '\n')
    
    if min(r_p,t_p) < 0.01 :
        SVs_to_drop.append(str(i+1))    

print ','.join(SVs_to_drop)

ymin = surrogate_variables.values.min() - 0.1
ymax = surrogate_variables.values.max() + 0.1

fig, axes = plt.subplots(ncols=1, nrows=surrogate_variables.shape[1], 
                         figsize=(4, 2*surrogate_variables.shape[1]), 
                         sharey=True)
for i in range(surrogate_variables.shape[1]):
    xs = sorted(set(scatter_mat['timepoint']))
    d = [list(scatter_mat[scatter_mat['timepoint'] == x]['SV' + str(i+1)]) for x in xs]
    try:
        bs, ax = beeswarm(d, ax=axes, ylim=(ymin, ymax), method="center", labels=xs, col=["black"])
    except AttributeError:
        bs, ax = beeswarm(d, ax=axes[i], ylim=(ymin, ymax), method="center", labels=xs, col=["black"])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.set_ylabel('SV' + str(i+1))
    if i == surrogate_variables.shape[1] - 1:
        ax.set_xlabel('Time (hrs)')
        ax.xaxis.set_ticks_position('bottom')
        ax.set_xticklabels(labels=xs, rotation=45, ha='right', fontsize=8)
    else:
        ax.set_xticklabels([])
        ax.spines['bottom'].set_visible(False)
        ax.xaxis.set_ticks_position('none')

plt.subplots_adjust(bottom=0.2)
plt.savefig(args.outplot)
