#!/usr/bin/env python
##################################################
#  elastic_net_expression_prediction.py
#
##################################################
import numpy as np
import pandas as pd
from sklearn import cross_validation
from sklearn.linear_model import ElasticNet
from sklearn.cross_validation import KFold
from sklearn.metrics import r2_score
from sklearn.preprocessing import scale
from scipy.stats import sem
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""
<describe expression_prediction script>

""")

##################################################
# required args:

parser.add_argument("--y", help='required, numpy array of response variable', required=True)
parser.add_argument("--Xs", nargs='+',
                    help='required, list of numpy arrays representing predictor variablees', 
                    required=True)
parser.add_argument("--names", nargs='+',
                    help='required, names of predictor variablees', 
                    required=True)
parser.add_argument("--out", help='required, file to which to save r2', required=True)

##################################################
# optional args:

parser.add_argument("--covariates", nargs='+',
                    help='optional, covariates')
parser.add_argument("--covariate_names", nargs='+',
                    help='required, covariate names')
parser.add_argument("-l", "--model_type", type=str, default="linear",
                    help='optional, choose <linear|logistic|multinomial> regression, default=linear')
parser.add_argument("--weights_to_fit", nargs='+', type=int,
                    help="""optional, which column of Xs (1-indexed to correspond with "weights")
                    to which to fit the model, saves time on costly ElasticNet grid search procedure""")
parser.add_argument("--bootstrap_weights", nargs='+', type=int,
                    help="""optional, which column of Xs (1-indexed to correspond with "weights") to bootstrap
                    in order to get standard error estimates of estimated betas""")
parser.add_argument("--nbootstrap", type=int, default=1000,
                    help="""if --bootstrap_weights given, nboostrap indicates the number of bootstrap replicates""")
parser.add_argument('--standardize', action='store_true')

args = parser.parse_args()

# ##################################################
# #####
# # test block
# #####

# args = type("args", (object,), {})()
# args.y = "/data/reddylab/projects/GGR/results/expression_prediction/iter0/expression/mean_logFC.protein_coding.filtered_by_expression.npy"
# args.Xs = ["logFC_pseudocount_1_ctrl_subtracted/BCL3.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/CEBPB.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/cJun.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/CTCF.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/EP300.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/FOSL2.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/GR.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/H3K27ac.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/H3K4me1.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/H3K4me2.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/H3K4me3.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/H3K9me3.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/HES2.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/JunB.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy",
#            "logFC_pseudocount_1_ctrl_subtracted/DNase.ctrl_subtracted.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.enhancer_extended_summed_exclude_proximal.uniform.weights_1_100.mean_logFC_pseudocount_1.npy"]
# args.names = ["BCL3","CEBPB","cJun","CTCF","EP300","FOSL2","GR","H3K27ac","H3K4me1","H3K4me2","H3K4me3","H3K9me3","HES2","JunB","DNase"]
# args.covariates = ["/data/reddylab/projects/GGR/results/expression_prediction/iter0/expression/protein_coding.filtered_by_expression.logTPM.t00.npy",\
#                    "/data/reddylab/projects/GGR/data/expression_prediction/iter0/additional_predictors/protein_coding.filtered_by_expression.GC_content.upstream_3000_bp_promoter.npy"]
# args.covariate_names = ["basal_expression","GC_content"]
# args.bootstrap_weights = range(10,101,10)
# args.weights_to_fit = range(5,101,5)
# args.model_type = 'linear'
# args.standardize = True

# ##################################################

#####
# set up data
#####

y = np.load(args.y)

X_list = []
for X in args.Xs:
    X_list.append(np.load(X))

if args.standardize:
    X_list_standarized = []
    for X in X_list:
        nan_idx = np.isnan(X)
        X_scaled = X.copy()
        X_scaled[nan_idx] = 0
        X_scaled = scale(X_scaled)
        X_scaled[nan_idx] = np.nan
        X_list_standarized.append(X_scaled)
        
    X_list = X_list_standarized[:]
    if args.model_type == "linear":
        y = scale(y)

if args.covariates:
    cov_list = []
    for cov in args.covariates:
        cov_list.append(np.load(cov))
    
    covariates = np.array(cov_list).T
    if args.standardize:
        covariates = scale(covariates)

#####
# run model
#####

cv = KFold(len(y), n_folds=5, shuffle=True, random_state=1234)

# the following penalty terms should be sufficient
l1_ratios = [0.5,0.75,.9, .95, .99, 1]
alphas = 10**np.arange(-8,1).astype('float')

res_dict = {}
# for i in range(X_list[0].shape[1]):
for weight in args.weights_to_fit:
    i = weight - 1
    X_for_regr = np.hstack([np.vstack(X[:,i]) for X in X_list])
    if args.covariates:
        X_for_regr = np.hstack([X_for_regr,covariates])
        
    print "Weight = %s, testing for optimal penalty terms"%(i+1)
    # skip model building for NaN columns (happens when
    # proximal windows are excluded)
    if not np.any(np.isnan(X_for_regr)):
        # first test for the optimal l1_ratio / alpha
        score_dict = {}
        for l1_ratio in l1_ratios:
            for alpha in alphas:
                # set up the classifier
                regr = ElasticNet(l1_ratio=l1_ratio, alpha=alpha, random_state=1234)
                score = cross_validation.cross_val_score(regr, X_for_regr, y, cv=cv, scoring="r2")
                # set up the classifier
                print "\ttesting l1_ratio = %s, alpha = %s, R2 = %s"%(l1_ratio, alpha, score.mean())
                score_dict[(l1_ratio, alpha)] = score
        
        # compute mean log-loss across the 5 folds for each model
        mean_r2 = np.array([np.mean(score) for params,score in score_dict.iteritems()])
        ordered_params = np.array([params for params,score in score_dict.iteritems()])
        
        # compute standard error of the log-loss for each model
        sem_r2 = np.array([sem(score) for params,score in score_dict.iteritems()])
        
        # find maximal r2 loss minus 1-se
        max_r2_1se = max(mean_r2) - sem_r2[mean_r2.argmax()]
        
        # find the parameters for the models with neg log loss greater than
        # (maximal neg. log loss minus 1-se)
        opt_params = ordered_params[mean_r2 >= max_r2_1se]
        
        # loop over the models with r2 greater than
        # (maximal r2 minus 1-se)
        # and keep the model with the lowest number of 
        # non-zero coefficients
        print "Weight = %s, finding optimal coefficients"%(i+1)
        smallest_num_coefs = np.inf
        for l1_ratio, alpha in opt_params:
            fit = ElasticNet(l1_ratio=l1_ratio, alpha=alpha, random_state=1234).fit(X_for_regr, y)
            num_non_zero_coefs = len(fit.coef_[fit.coef_ != 0])
            print "\tl1_ratio = %s, alpha = %s, num. non-zero coefs = %s"%(l1_ratio, alpha, num_non_zero_coefs)
            if num_non_zero_coefs <= smallest_num_coefs:
                smallest_num_coefs = num_non_zero_coefs
                best_fit = fit
                best_l1_ratio = l1_ratio
                best_alpha = alpha
        
        fit = ElasticNet(l1_ratio=best_l1_ratio, alpha=best_alpha, random_state=1234).fit(X_for_regr, y)
        y_predict = fit.predict(X_for_regr)
        r2 = r2_score(y, y_predict)
        # coefficients from the best model
        if args.bootstrap_weights:
            if i + 1 in args.bootstrap_weights:
                coefs, r2s = [], []
                for j in range(args.nbootstrap):
                    print "\tfitting bootstrap models for standard error estimates, perm %s"%(j)
                    perm = np.random.choice(len(y), len(y))
                    X_boot, y_boot = X_for_regr[perm,:], y[perm]
                    fit = ElasticNet(l1_ratio=best_l1_ratio, alpha=best_alpha, random_state=1234).fit(X_boot, y_boot)
                    y_boot_predict = fit.predict(X_boot)
                    coefs.append(fit.coef_)
                    r2s.append(r2_score(y_boot, y_boot_predict))
                    
                coefs_se = np.array(coefs).std(axis=0)
                r2_se = np.array(r2s).std()
                
                res_dict[i+1] = np.r_[r2, r2_se, best_fit.coef_, coefs_se ]
            else:
                res_dict[i+1] = np.r_[r2, np.nan, best_fit.coef_, [np.nan]*X_for_regr.shape[1]]
        else:
            res_dict[i+1] = np.r_[r2, best_fit.coef_]
    else:
        if args.bootstrap_weights:
            res_dict[i+1] = np.r_[np.nan, np.nan, [np.nan]*X_for_regr.shape[1]*2]
        else:
            res_dict[i+1] = np.r_[r2, best_fit.coef_]

res_df = pd.DataFrame(res_dict)
res_df = res_df.T
predictors = args.names + args.covariate_names if args.covariates else args.names
if args.bootstrap_weights:
    columns = ['R2','R2_se'] + [x + "_beta" for x in predictors] + [x + "_beta_se" for x in predictors]
else:
    columns = ['R2'] + [x + "_beta" for x in predictors]

res_df.columns = columns
res_df.to_csv(args.out, sep="\t", index=True, header=True)