#!/usr/bin/env python
import numpy as np
# import matplotlib
# font = {'size'   : 10}
# matplotlib.rc('font', **font)
import matplotlib.pyplot as plt
from GGR import utils
from sys import argv
import pandas as pd
from scipy.stats import pearsonr,spearmanr
import seaborn.apionly as sns
import statsmodels.api as sm
from scipy.stats import t

GR_logFC = argv[1]
EP300_logFC = argv[2]
GR_timepoint = argv[3]
EP300_timepoint = argv[4]
EP300_FDR = argv[5]
outlmplot = argv[6]
outregplot = argv[7]
outcoefs = argv[8]


# GR_logFC = "GR.in.EP300.logFC.txt"
# EP300_logFC = "EP300.in.EP300.logFC.txt"
# GR_timepoint = "GR.t3"
# EP300_timepoint ="EP300.t3"
# EP300_FDR = "EP300.in.EP300.FDR.txt"
# outlmplot = "GR_logFC_vs_EP300_logFC.t3.lmplot.png"
# outregplot ="GR_logFC_vs_EP300_logFC.t3.regplot.png"
# outcoefs = "GR_logFC_vs_EP300_logFC.t3.coefs.txt"

# GR_logFC = "GR.in.EP300.logFC.txt"
# EP300_logFC = "EP300.in.EP300.logFC.txt"
# EP300_FDR = "EP300.in.EP300.FDR.txt"
# GR_timepoint = "GR.t1"
# EP300_timepoint = "EP300.t1"
# outlmplot = "GR_logFC_vs_EP300_logFC.t1.lmplot.png"
# outregplot = "GR_logFC_vs_EP300_logFC.t1.regplot.png"
# outcoefs = "GR_logFC_vs_EP300_logFC.t1.coefs.txt"

GR_logFC = pd.read_csv(GR_logFC, index_col=0, sep="\t")
EP300_logFC = pd.read_csv(EP300_logFC, index_col=0, sep="\t")
EP300_FDR = pd.read_csv(EP300_FDR, index_col=0, sep="\t")

# subset data to t05
GR_logFC = np.array(GR_logFC[GR_timepoint])
EP300_logFC = np.array(EP300_logFC[EP300_timepoint])
EP300_FDR = np.array(EP300_FDR[EP300_timepoint])

##########################
# create dataframe that contains all data
# convenient for seaborn

data = pd.DataFrame({"regulation":['up' if (fc > 0) & (fdr < 0.1) else 'down' if (fc < 0) & (fdr < 0.1) else 'non' for fc,fdr in zip(EP300_logFC, EP300_FDR)],
                     "EP300_logFC": EP300_logFC, "GR_logFC": GR_logFC})

##########################
# create lmplot fitting regression to _each_ group of EP300 sites by regulation (up, down, static) 

fig,ax = plt.subplots()
ax.set_aspect("equal")
ax.set_ylim((-3,4.5))
ax.set_xlim((-1,6))
g = sns.lmplot(x="GR_logFC", 
               y="EP300_logFC", 
               hue="regulation",
               fit_reg=False,
               scatter_kws={'alpha':0.2,
                            'rasterized':True}, 
               data=data,
               palette=dict(up="#e6a025", 
                            down="#4491c5", 
                            non="gray"))
g.set(ylim=(-3,4.5), 
      xlim=(-1,6))
sns.regplot(x="GR_logFC", 
                y="EP300_logFC",
                ax=g.axes[0,0],
                scatter=False,
                line_kws={'color':"black", 
                          "lw":1}, 
                data=data)
sns.regplot(x="GR_logFC", 
                y="EP300_logFC",
                ax=g.axes[0,0],
                scatter=False,
                line_kws={'color':"#e6a025", 
                          "lw":1}, 
                data=data[data['regulation']=='up'])
sns.regplot(x="GR_logFC", 
                y="EP300_logFC",
                ax=g.axes[0,0],
                scatter=False,
                line_kws={'color':"#4491c5", 
                          "lw":1}, 
                data=data[data['regulation']=='down'])
sns.regplot(x="GR_logFC", 
                y="EP300_logFC",
                ax=g.axes[0,0],
                scatter=False,
                line_kws={'color':"gray", 
                          "lw":1}, 
                data=data[data['regulation']=='non'])
# sns.despine()
# g.axes[0,0].set_ylim((-3,4.5))
# g.axes[0,0].set_xlim((-1,6))

plt.savefig(outlmplot, rasterized=True)
plt.savefig(outlmplot.replace(".png",".pdf"), rasterized=True)


##########################
# create regplot fitting overall regression to all EP300 sites regardless of regulation

colors = ['#e6a025' if (fc > 0) & (fdr < 0.1) else '#4491c5' if (fc < 0) & (fdr < 0.1) else 'gray' for fc,fdr in zip(EP300_logFC, EP300_FDR)]

# fig,ax = plt.subplots(figsize=(3,3))
# ax.axhline(y=0,c='black',lw=1, zorder=-1)
# ax.set_aspect("equal")
# g = sns.regplot(x="GR_logFC", 
#                 y="EP300_logFC", 
#                 scatter_kws={'alpha':0.4, 
#                              'color':colors, 
#                              's':1,
#                              'rasterized':True},
#                 line_kws={'color':"red", 
#                           "lw":1}, 
#                 data=data)
# sns.despine()
# plt.savefig(outregplot)
# plt.savefig(outregplot.replace(".png",".pdf"))


##########################
# fit regression

# set-up:

X = sm.add_constant(GR_logFC)
y = EP300_logFC

X_down = X[(EP300_logFC < 0) & (EP300_FDR < 0.1)]
X_up = X[(EP300_logFC > 0) & (EP300_FDR < 0.1)]
X_non = X[(EP300_FDR > 0.2)]

y_down = y[(EP300_logFC < 0) & (EP300_FDR < 0.1)]
y_up = y[(EP300_logFC > 0) & (EP300_FDR < 0.1)]
y_non = y[(EP300_FDR > 0.2)]


##############################

betas = []
beta_ses = []

print "= = = = = = = = = = = = = = = = = = = = = = = = = "
print 'Regression, ALL EP300 logFC on ALL GR logFC'
print "= = = = = = = = = = = = = = = = = = = = = = = = = "


model = sm.OLS(y, X)
results = model.fit()
print results.summary()

betas.append(results.params[-1])
beta_ses.append(results.bse[-1])

p = t.sf(np.abs(results.tvalues[-1]), len(EP300_logFC)-1) # two-sided pvalue = Prob(abs(t)>tt)
print 'T-test, GR logFC as coefficient, p = %0.2e' % (p)
print

print "= = = = = = = = = = = = = = = = = = = = = = = = = "
print 'Regression, EP300 logFC decreased binding on GR logFC'
print "= = = = = = = = = = = = = = = = = = = = = = = = = "

model = sm.OLS(y_down, X_down)
results = model.fit()
print results.summary()

betas.append(results.params[-1])
beta_ses.append(results.bse[-1])

p = t.sf(np.abs(results.tvalues[-1]), len(EP300_logFC)-1) # two-sided pvalue = Prob(abs(t)>tt)
print 'T-test, GR logFC as coefficient, p = %0.2e' % (p)
print

print "= = = = = = = = = = = = = = = = = = = = = = = = = "
print 'Regression, EP300 logFC non-dynamic binding on GR logFC'
print "= = = = = = = = = = = = = = = = = = = = = = = = = "
model = sm.OLS(y_non, X_non)
results = model.fit()
print results.summary()
p = t.sf(np.abs(results.tvalues[-1]), len(EP300_logFC)-1) # two-sided pvalue = Prob(abs(t)>tt)
print 'T-test, GR logFC as coefficient, p = %0.2e' % (p)
print

betas.append(results.params[-1])
beta_ses.append(results.bse[-1])

print "= = = = = = = = = = = = = = = = = = = = = = = = = "
print 'Regression, EP300 logFC increased binding on GR logFC'
print "= = = = = = = = = = = = = = = = = = = = = = = = = "
model = sm.OLS(y_up, X_up)
results = model.fit()
print results.summary()
p = t.sf(np.abs(results.tvalues[-1]), len(EP300_logFC)-1) # two-sided pvalue = Prob(abs(t)>tt)
print 'T-test, GR logFC as coefficient, p = %0.2e' % (p)
print

betas.append(results.params[-1])
beta_ses.append(results.bse[-1])

coefs = pd.DataFrame({'beta':betas,'beta_se':beta_ses})
coefs.index = ["all","decreased","non-dynamic","increased"]
coefs.to_csv(outcoefs, sep="\t", index=True)

from statsmodels.formula.api import ols

print "= = = = = = = = = = = = = = = = = = = = = = = = = "
print 'ANOVA, EP300 logFC on GR logFC + "regulation" of EP300 + GR logFC x "regulation" of EP300'
print "= = = = = = = = = = = = = = = = = = = = = = = = = "

lm = ols('EP300_logFC ~ GR_logFC*C(regulation)', data=data).fit()
# which is the same as below:
# lm = ols('EP300_logFC ~ GR_logFC + C(regulation) + GR_logFC*C(regulation)', data=data).fit()
table = sm.stats.anova_lm(lm, typ=2)
print table
