import numpy as np
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import os
import pdb
import time
import keras.backend as K
from tqdm import tqdm
import keras_metrics
from keras_metrics import categorical_precision, categorical_recall
from keras.layers import Input
from keras.preprocessing import sequence
from keras.layers import LSTM, Dense, Masking, Concatenate, concatenate, Bidirectional, BatchNormalization, Dropout
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential, Model, load_model
from keras.utils import multi_gpu_model
from sklearn.metrics import precision_score, recall_score, accuracy_score, confusion_matrix


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= ""

dependencies={"categorical_precision" : categorical_precision(),"categorical_recall" : categorical_recall()}

xhmm_polisher = load_model("../../../DECoNT_batchnorm_bilstm128_batchnorm_dense100_dense3_bs256_padding-1_30epochs_traintestsplitted_weightedcrossent_codex2.h5", custom_objects=dependencies)

data_path = "./MGISEQ2000_NA12878_processed_data_labeled_codex2/"

data_list_readdepths = []
data_list_indexes = []
data_list_cnvnator_preds = []
data_list_xhmm_preds = []

files_list = os.listdir(data_path)

for filename in files_list:
    with open(data_path+filename) as f:
        data = f.readlines()
        data = [x.strip() for x in data]

    indexes = [(int(x.split(',')[1][1:]),int(x.split(',')[2][1:])) for x in data]
    xhmm_preds = [x.split(',')[3][1:] for x in data]
    cnvnator_preds = [x.split(',')[4][1:].replace(']','') for x in data]
    read_depth_seqs = [x.split(',')[5:] for x in data]
    read_depth_seqs = [[y.replace('[','',1).replace(']','').replace(' ','') for y in x] for x in read_depth_seqs]
    read_depth_seqs = [[0 if not y else int(str(y)) for y in x] for x in read_depth_seqs]

    data_list_indexes.extend(indexes)
    data_list_cnvnator_preds.extend(cnvnator_preds)
    data_list_xhmm_preds.extend(xhmm_preds)
    data_list_readdepths.extend(read_depth_seqs)

#convert data lists to numpy arrays
data_list_readdepths = np.asarray(data_list_readdepths)
data_list_indexes = np.asarray(data_list_indexes)
data_list_cnvnator_preds = np.asarray(data_list_cnvnator_preds)
data_list_xhmm_preds = np.asarray(data_list_xhmm_preds)


data_list_readdepths = sequence.pad_sequences(data_list_readdepths, value=-1, maxlen=3999)




''' 
CNVNATOR PREDS: nan -> 0
                <DUP> -> 1
                <DEL> -> 2
XHMM PREDS: 'DEL' -> 0
            'DUP' -> 1
'''
data_list_cnvnator_preds[data_list_cnvnator_preds == 'nan'] = 0
data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DUP>'"] = 1
data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DEL>'"] = 2


data_list_xhmm_preds[data_list_xhmm_preds == "'del'"] = 0
data_list_xhmm_preds[data_list_xhmm_preds == "'dup'"] = 1


data_list_xhmm_preds = to_categorical(data_list_xhmm_preds, num_classes =2)
data_list_cnvnator_preds = to_categorical(data_list_cnvnator_preds, num_classes =3)

data_list_readdepths = np.expand_dims(data_list_readdepths, axis=2)

#normalize a bit.
data_list_readdepths = data_list_readdepths #/ 45000



polished = xhmm_polisher.predict([data_list_xhmm_preds, data_list_readdepths], verbose=1)

notpolished = np.argmax(data_list_xhmm_preds, axis=1)
notpolished[notpolished == 0] = 2
pollabels = np.argmax(polished, axis=1)
truth = np.argmax(data_list_cnvnator_preds, axis=1)

print("Unpolished Confusion Matrix: ")
unpolished_cm = confusion_matrix(truth, notpolished)
print(unpolished_cm)

print("Polished Confusion Matrix: ")
polished_cm = confusion_matrix(truth, pollabels)
print(polished_cm)

unpolished_dupprecision = unpolished_cm[1][1] / (unpolished_cm[0][1] + unpolished_cm[2][1] + unpolished_cm[1][1])
unpolished_delprecision = unpolished_cm[2][2] / (unpolished_cm[0][2] + unpolished_cm[1][2] + unpolished_cm[2][2])
unpolished_overallprecision = (unpolished_cm[2][2] + unpolished_cm[1][1])/ (unpolished_cm[0][2] + unpolished_cm[2][2] + unpolished_cm[0][1] + unpolished_cm[2][1] + unpolished_cm[1][2] + unpolished_cm[1][1])



polished_dupprecision = polished_cm[1][1] / (polished_cm[0][1] + polished_cm[2][1] + polished_cm[1][1])
polished_delprecision = polished_cm[2][2] / (polished_cm[0][2] + polished_cm[1][2] + polished_cm[2][2])
polished_overallprecision = (polished_cm[2][2] + polished_cm[1][1])/ (polished_cm[0][2] + polished_cm[2][2] + polished_cm[0][1] + polished_cm[2][1] + polished_cm[1][2] + polished_cm[1][1])

print("Duplication call precision (Before polishing - After Polishing): ", unpolished_dupprecision, " - ", polished_dupprecision)
print("Deletion call precision (Before polishing - After Polishing): ", unpolished_delprecision, " - ", polished_delprecision)
print("Overall call precision (Before polishing - After Polishing): ", unpolished_overallprecision, " - ", polished_overallprecision)




#somecomment

pdb.set_trace()