#!/usr/bin/env python
##################################################
#  multivariate_regression_expression_prediction.py
#
#  @ /data/reddylab/projects/GGR/analyses/group_general/top_level_scripts/multivariate_regression_expression_prediction.py
#
##################################################
import numpy as np
import statsmodels.api as sm
import pandas as pd
from sklearn.preprocessing import scale
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""
<describe expression_prediction script>

""")

##################################################
# required args:

parser.add_argument("--y", help='required, numpy array of response variable', required=True)
parser.add_argument("--Xs", nargs='+',
                    help='required, list of numpy arrays representing predictor variablees', 
                    required=True)
parser.add_argument("--names", nargs='+',
                    help='required, names of predictor variablees', 
                    required=True)
parser.add_argument("--out", help='required, file to which to save r2', required=True)

##################################################
# optional args:

parser.add_argument("--covariates", nargs='+',
                    help='optional, covariates')
parser.add_argument("--covariate_names", nargs='+',
                    help='required, covariate names')
parser.add_argument("-l", "--model_type", type=str, default="linear",
                    help='optional, choose <linear|logistic|multinomial> regression, default=linear')
# parser.add_argument("--predictor_names", type=str, nargs='+', default=None,
#                     help="""required, if multiple predictors given to script, backwards feature selection
# will automatically eliminate non-significant predictors.  It can be useful to know which predictors are eliminated,
# in which case predictor names should be given to script """, 
#                     dest="predictor_names")
parser.add_argument('--model_selection', dest='model_selection', action='store_true')
parser.add_argument('--standardize', dest='standardize', action='store_true')
parser.add_argument('--regularized', dest='regularized', action='store_true')

args = parser.parse_args()

# ##################################################
# #####
# # test block
# #####

# args = type("args", (object,), {})()
# args.y = "/data/reddylab/projects/GGR/results/expression_prediction/iter0/expression/mean_logFC.protein_coding.filtered_by_expression.npy"
# args.Xs = ["logFC_pseudocount_1/BCL3.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/CEBPB.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/cJun.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/CTCF.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/EP300.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/FOSL2.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/GR.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/H3K27ac.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/H3K4me1.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/H3K4me2.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/H3K4me3.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/H3K9me3.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/HES2.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/JunB.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1/DNase.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy"]
# args.names = ["BCL3","CEBPB","cJun","CTCF","EP300","FOSL2","GR","H3K27ac","H3K4me1","H3K4me2","H3K4me3","H3K9me3","HES2","JunB","DNase"]
# args.covariates = ["/data/reddylab/projects/GGR/results/expression_prediction/iter0/expression/protein_coding.filtered_by_expression.logTPM.t00.npy",\
#                    "/data/reddylab/projects/GGR/data/expression_prediction/iter0/additional_predictors/protein_coding.filtered_by_expression.GC_content.upstream_3000_bp_promoter.npy"]
# args.covariate_names = ["basal_expression","GC_content"]
# args.model_type = 'linear'
# args.model_selection = False
# args.regularized = False
# args.standardize = True

# ##################################################

#####
# set up data
#####

y = np.load(args.y)

X_list = []
for X in args.Xs:
    X_list.append(np.load(X))

if args.standardize:
    X_list_standarized = []
    for X in X_list:
        nan_idx = np.isnan(X)
        X_scaled = X.copy()
        X_scaled[nan_idx] = 0
        X_scaled = scale(X_scaled)
        X_scaled[nan_idx] = np.nan
        X_list_standarized.append(X_scaled)
        
    X_list = X_list_standarized[:]
    if args.model_type == "linear":
        y = scale(y)

if args.covariates:
    cov_list = []
    for cov in args.covariates:
        cov_list.append(np.load(cov))
    
    covariates = np.array(cov_list).T
    if args.standardize:
        covariates = scale(covariates)

# add intercept term
intercept = np.vstack(np.array([1] * len(y)))

#####
# set up model
#####

if args.model_type == "linear":
    model = sm.OLS
elif args.model_type == "logistic":
    model = sm.Logit
elif args.model_type == "multinomial":
    model = sm.MNLogit
    # only keep predictor and response for clusters with
    # greater than 200 members
    ks = [k for k in np.unique(y) if len(y[y==k]) > 200]
    y = y[np.in1d(y, ks)]
    X = X[np.in1d(y, ks)]

#####
# run full model
#####

res_dict = {}

if args.regularized:
    last_est = model(y, X).fit_regularized(method='l1_cvxopt_cp')
else:
    for i in range(X_list[0].shape[1]):
        X_for_regr = np.hstack([np.vstack(X[:,i]) for X in X_list])
        X_for_regr = np.hstack([intercept, X_for_regr])
        if args.covariates:
            X_for_regr = np.hstack([X_for_regr,covariates])
            
        print i
        if not np.any(np.isnan(X_for_regr)):
            res = model(y, X_for_regr).fit()
            res.summary()
            res_dict[i+1] = np.r_[res.rsquared,res.rsquared_adj,res.pvalues,res.params,res.bse]
        else:
            res_dict[i+1] = np.r_[np.nan,np.nan,[np.nan]*X_for_regr.shape[1]*3]

res_df = pd.DataFrame(res_dict)
res_df = res_df.T
predictors = ["intercept"] + args.names
predictors = predictors + args.covariate_names if args.covariates else predictors
res_df.columns = ['R2','R2_adjusted'] + \
                 [x + "_p" for x in predictors] + \
                 [x + "_beta" for x in predictors] + \
                 [x + "_beta_se" for x in predictors]
res_df.to_csv(args.out, sep="\t", index=True, header=True)