#!/usr/bin/env python
import pandas as pd
import numpy as np
from sklearn import cross_validation
from sklearn.linear_model import SGDClassifier
from sklearn.cross_validation import KFold
from sklearn.metrics import log_loss,roc_auc_score,f1_score,accuracy_score
from sklearn.preprocessing import scale
from scipy.stats import sem
from sys import argv

df = argv[1]
y = argv[2]
out_coef = argv[3]
out_score = argv[4]
out_fit_params = argv[5]

# df = "initial_occupancy_and_GR_motif_score.non_promoter.EP300_and_DNase.txt"
# y = "binary_GR.in.EP300_and_DNase.txt"
# out_coef = "predict_binary_GR.in.EP300_and_DNase.by_initial_occupancy.coefs.txt"
# out_score = "predict_binary_GR.in.EP300_and_DNase.by_initial_occupancy.scores.txt"
# out_fit_params = "predict_binary_GR.in.EP300_and_DNase.by_initial_occupancy.fit_params.txt"

df = pd.read_csv(df, sep="\t", index_col=0)
y = pd.read_csv(y, sep="\t", index_col=0)

idx = sorted(set(df.index) | set(y.index)) 

df = df.ix[idx]
y = np.array(y.ix[idx])

X = scale(df)

#####
# run model
#####

cv = KFold(len(y), n_folds=5, shuffle=True, random_state=1234)

# the following penalty terms should be sufficient
l1_ratios = [0.75, .9, .95, .99, 1]
alphas = 10**np.arange(-6,1).astype('float')

# for i in range(X_list[0].shape[1]):
# first test for the optimal l1_ratio / alpha
score_dict = {}
for l1_ratio in l1_ratios:
    for alpha in alphas:
        # set up the classifier
        regr = SGDClassifier(l1_ratio=l1_ratio, alpha=alpha,
                             loss="log", penalty="elasticnet",
                             random_state=1234)
        score = cross_validation.cross_val_score(regr, X, y.flatten(), 
                                                 cv=cv, 
                                                 scoring="neg_log_loss")
        # set up the classifier
        print "\ttesting l1_ratio = %s, alpha = %s, neg. los-loss = %s"%(l1_ratio, alpha, score.mean())
        score_dict[(l1_ratio, alpha)] = score

# compute mean log-loss across the 5 folds for each model
mean_score = np.array([np.mean(score) for params,score in score_dict.iteritems()])
ordered_params = np.array([params for params,score in score_dict.iteritems()])

# compute standard error of the log-loss for each model
sem_score = np.array([sem(score) for params,score in score_dict.iteritems()])

# find maximal r2 loss minus 1-se
max_score_1se = max(mean_score) - sem_score[mean_score.argmax()]

# find the parameters for the models with neg log loss greater than
# (maximal neg. log loss minus 1-se)
opt_params = ordered_params[mean_score >= max_score_1se]

# loop over the models with r2 greater than
# (maximal r2 minus 1-se)
# and keep the model with the lowest number of 
# non-zero coefficients
smallest_num_coefs = np.inf
for l1_ratio, alpha in opt_params:
    fit = SGDClassifier(l1_ratio=l1_ratio, alpha=alpha,
                         loss="log", penalty="elasticnet",
                         random_state=1234).fit(X, y)
    num_non_zero_coefs = len(fit.coef_[fit.coef_ != 0])
    print "\tl1_ratio = %s, alpha = %s, num. non-zero coefs = %s"%(l1_ratio, alpha, num_non_zero_coefs)
    if num_non_zero_coefs <= smallest_num_coefs:
        smallest_num_coefs = num_non_zero_coefs
        best_fit = fit
        best_l1_ratio = l1_ratio
        best_alpha = alpha

fit = SGDClassifier(l1_ratio=best_l1_ratio, alpha=best_alpha,
                         loss="log", penalty="elasticnet",
                         random_state=1234).fit(X, y)

pred_y = fit.predict(X)

accuracy = accuracy_score(y, pred_y)
f1 = f1_score(y, pred_y)
pred_probs = fit.predict_proba(X)
AUC = roc_auc_score(y, pred_probs[:,1])

coef_df = pd.DataFrame(data={'coef':fit.coef_[0]}, 
                       index=list(df.columns))
coef_df.to_csv(out_coef, sep = "\t", index=True)

score_df = pd.DataFrame(data={'score':[accuracy, f1, AUC]}, 
                       index=['accuracy','F1 score','AUC'])
score_df.to_csv(out_score, sep = "\t", index=True)

with open(out_fit_params, "w") as f:
    f.write("best_l1_ratio\tbest_alpha\n")
    f.write("%s\t%s\n"%(best_l1_ratio,best_alpha))
