import numpy as np
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import os
import pdb
import time
import keras.backend as K
from tqdm import tqdm
import keras_metrics
from keras_metrics import categorical_precision, categorical_recall
from keras.layers import Input
from keras.preprocessing import sequence
from keras.layers import LSTM, Dense, Masking, Concatenate, concatenate, Bidirectional, BatchNormalization, Dropout
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential, Model, load_model
from sklearn.metrics import precision_score, recall_score, accuracy_score, confusion_matrix
from keras.utils import multi_gpu_model

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

dependencies={ "categorical_precision" : categorical_precision(),"categorical_recall" : categorical_recall()}

polishing_model="conifer"
if polishing_model =='xhmm':
    xhmm_polisher = load_model("../DECoNT_batchnorm_bilstm128_batchnorm_dense100_dense3_bs256_padding-1_30epochs_traintestsplitted_weightedcrossent_xhmm.h5", custom_objects=dependencies)
    maxlength=8501

if polishing_model =='conifer':
    xhmm_polisher = load_model("../DECoNT_batchnorm_bilstm128_batchnorm_dense100_dense3_bs256_padding-1_30epochs_traintestsplitted_weightedcrossent_conifer.h5", custom_objects=dependencies)
    maxlength=29456

if polishing_model =='codex2':
    xhmm_polisher = load_model("../DECoNT_batchnorm_bilstm128_batchnorm_dense100_dense3_bs256_padding-1_30epochs_traintestsplitted_weightedcrossent_codex2.h5", custom_objects=dependencies)
    maxlength=3999


data_list_readdepths = np.load('./xhmm_test_fold/data_list_readdepths_test_weightedcrossent.npy')
data_list_indexes = []
data_list_cnvnator_preds = np.load('./xhmm_test_fold/data_list_cnvnator_preds_test_weightedcrossent.npy')
data_list_xhmm_preds = np.load('./xhmm_test_fold/data_list_xhmm_preds_test_weightedcrossent.npy')


data_list_readdepths = sequence.pad_sequences(data_list_readdepths, value=-1, maxlen=maxlength)
#print(data_list_readdepths)



print("Read depths data matrix shape: ", data_list_readdepths.shape)
print("Xhmm predictions data matrix shape: ", data_list_xhmm_preds.shape)
print("Cnvnator predictions (labels) data matrix shape: ", data_list_cnvnator_preds.shape)

polished = xhmm_polisher.predict([data_list_xhmm_preds, data_list_readdepths], verbose=1)
notpolished = np.argmax(data_list_xhmm_preds, axis=1)
notpolished[notpolished == 0] = 2
pollabels = np.argmax(polished, axis=1)
truth = np.argmax(data_list_cnvnator_preds, axis=1)

print("Unpolished Confusion Matrix: ")
unpolished_cm = confusion_matrix(truth, notpolished)
print(unpolished_cm)

print("Polished Confusion Matrix: ")
polished_cm = confusion_matrix(truth, pollabels)
print(polished_cm)

unpolished_dupprecision = unpolished_cm[1][1] / (unpolished_cm[0][1] + unpolished_cm[2][1] + unpolished_cm[1][1])
unpolished_delprecision = unpolished_cm[2][2] / (unpolished_cm[0][2] + unpolished_cm[1][2] + unpolished_cm[2][2])
unpolished_overallprecision = (unpolished_cm[2][2] + unpolished_cm[1][1])/ (unpolished_cm[0][2] + unpolished_cm[2][2] + unpolished_cm[0][1] + unpolished_cm[2][1] + unpolished_cm[1][2] + unpolished_cm[1][1])



polished_dupprecision = polished_cm[1][1] / (polished_cm[0][1] + polished_cm[2][1] + polished_cm[1][1])
polished_delprecision = polished_cm[2][2] / (polished_cm[0][2] + polished_cm[1][2] + polished_cm[2][2])
polished_overallprecision = (polished_cm[2][2] + polished_cm[1][1])/ (polished_cm[0][2] + polished_cm[2][2] + polished_cm[0][1] + polished_cm[2][1] + polished_cm[1][2] + polished_cm[1][1])

print("Duplication call precision (Before polishing - After Polishing): ", unpolished_dupprecision, " - ", polished_dupprecision)
print("Deletion call precision (Before polishing - After Polishing): ", unpolished_delprecision, " - ", polished_delprecision)
print("Overall call precision (Before polishing - After Polishing): ", unpolished_overallprecision, " - ", polished_overallprecision)