import os
import numpy as np
import random
import time
import glob
from Bio import SeqIO
from pybedtools import BedTool
from keras.models import load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adadelta
from sklearn import metrics
import h5py

INPUT_LENGTH = 1000
EPOCH = 200
BATCH_SIZE = 200
WORK_DIR = "./"

def run_model(data, model, save_dir):

    weights_file = os.path.join(save_dir, "model_weights.hdf5")
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    print(model.summary())

    X_train = data["train_data"]
    Y_train = data["train_labels"]
    X_validation = data["val_data"]
    Y_validation = data["val_labels"]
    X_test = data["test_data"]
    Y_test = data["test_labels"]

    _callbacks = []
    checkpointer = ModelCheckpoint(filepath=weights_file, verbose=1, save_best_only=True)
    _callbacks.append(checkpointer)
    earlystopper = EarlyStopping(monitor="val_loss", patience=10, verbose=1)
    _callbacks.append(earlystopper)

    model.fit(X_train,Y_train,
              batch_size=500,
              epochs=EPOCH,
              validation_data=(X_validation, Y_validation),
              shuffle=True,
              callbacks=_callbacks, verbose=1)
   
    model.compile(loss='categorical_crossentropy', optimizer='adagrad', metrics=['accuracy'])
    model.load_weights(weights_file)
    model.fit(X_train,Y_train,
              batch_size=500,
              epochs=EPOCH,
              validation_data=(X_validation, Y_validation),
              shuffle=True,callbacks=_callbacks, verbose=1)
   	       
    model.load_weights(weights_file)
    Y_pred = model_predict(X_test,model)
    auc1 = metrics.roc_auc_score(Y_test[:,0], Y_pred[:,0])
    auc2 = metrics.roc_auc_score(Y_test[:,1], Y_pred[:,1])

    with open(os.path.join(save_dir, "auc.txt"), "w") as of:
        of.write("enhancer AUC: %f\n" % auc1)
        of.write("silencer AUC: %f\n" % auc2)

    [fprs, tprs, thrs] = metrics.roc_curve(Y_test[:,1], Y_pred[:, 1])
    sort_ix = np.argsort(np.abs(fprs - 0.1))
    fpr10_thr = thrs[sort_ix[0]]
    sort_ix = np.argsort(np.abs(fprs - 0.05))
    fpr5_thr = thrs[sort_ix[0]]
    sort_ix = np.argsort(np.abs(fprs - 0.03))
    fpr3_thr = thrs[sort_ix[0]]
    sort_ix = np.argsort(np.abs(fprs - 0.01))
    fpr1_thr = thrs[sort_ix[0]]

    [fprs, tprs, thrs] = metrics.roc_curve(Y_test[:,0], Y_pred[:, 0])
    sort_ix = np.argsort(np.abs(fprs - 0.1))
    fpr10_thre = thrs[sort_ix[0]]
    sort_ix = np.argsort(np.abs(fprs - 0.05))
    fpr5_thre = thrs[sort_ix[0]]
    sort_ix = np.argsort(np.abs(fprs - 0.03))
    fpr3_thre = thrs[sort_ix[0]]
    sort_ix = np.argsort(np.abs(fprs - 0.01))
    fpr1_thre = thrs[sort_ix[0]]

    with open(os.path.join(save_dir, "fpr_threshold_scores.txt"), "w") as of:
        of.write("silencer 10 \t %f\n" % fpr10_thr)
        of.write("5 \t %f\n" % fpr5_thr)
        of.write("3 \t %f\n" % fpr3_thr)
        of.write("1 \t %f\n\n" % fpr1_thr)
        of.write("enhancer 10 \t %f\n" % fpr10_thre)
        of.write("5 \t %f\n" % fpr5_thre)
        of.write("3 \t %f\n" % fpr3_thre)
        of.write("1 \t %f\n" % fpr1_thre)

def model_predict(data,model):
	print("prediction on test samples ...")
	y = model.predict(data, batch_size=1000, verbose=1)
	return y

def load_dataset(datafile):

    print("reading samples...")
    data = {}
    with h5py.File(datafile, "r") as inf:
        for _key in inf:
            data[_key] = inf[_key][()]

    return data


def train_model(data, results_dir):

    model_file = WORK_DIR + "/src/model.hdf5"
    model = load_model(model_file)

    if not os.path.exists(data):
        print("no data file"+data)
        exit()
    if not os.path.exists(results_dir):
        os.mkdir(results_dir)
    
    data = load_dataset(data)
    run_model(data, model, results_dir)


if __name__ == "__main__":

    import sys
    data_file = sys.argv[1]
    results_dir = sys.argv[2]
    train_model(data_file, results_dir)
