import numpy as np
import pandas as pd
import os
import pdb
import time
import keras.backend as K
from tqdm import tqdm
from keras.layers import Input
from keras.preprocessing import sequence
from keras.layers import LSTM, Dense, Masking, Concatenate, concatenate, BatchNormalization, Bidirectional
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential, Model
from keras import metrics
from tensorflow.keras.models import Sequential, Model, load_model

from sklearn.model_selection import train_test_split

#Load the training data.

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= ""

start = time.time()
data_path = "./MGISEQ2000_NA12878_processed_data_labeled_freec/"

data_list_readdepths = []
data_list_indexes = []
data_list_canavar_preds = []
data_list_freec_preds = []

files_list = os.listdir(data_path)

for filename in tqdm(files_list):
    with open(data_path+filename) as f:
        data = f.readlines()
        data = [x.strip() for x in data]

    indexes = [(int(x.split(',')[1][1:]),int(x.split(',')[2][1:])) for x in data]
    freec_preds = [x.split(',')[3][1:] for x in data]
    canavar_preds = [x.split(',')[4][1:] for x in data]
    read_depth_seqs = [x.split(',')[5:] for x in data]
    read_depth_seqs = [[y.replace('[','',1).replace(']','').replace(' ','') for y in x] for x in read_depth_seqs]
    read_depth_seqs = [[0 if not y else int(str(y)) for y in x] for x in read_depth_seqs]
    data_list_indexes.extend(indexes)
    data_list_canavar_preds.extend(canavar_preds)
    data_list_freec_preds.extend(freec_preds)
    data_list_readdepths.extend(read_depth_seqs)

end = time.time()
print("Loading of the data took ", end-start," seconds.")

data_list_canavar_preds = [x.replace(']','') for x in data_list_canavar_preds]


#convert data lists to numpy arrays
data_list_readdepths = np.asarray(data_list_readdepths)
data_list_indexes = np.asarray(data_list_indexes)
data_list_canavar_preds = np.asarray(data_list_canavar_preds)
data_list_freec_preds = np.asarray(data_list_freec_preds)


data_list_readdepths = sequence.pad_sequences(data_list_readdepths, maxlen= 192000, value = -1)
data_list_readdepths = [np.mean(x.reshape(-1, 100), axis=1) for x in data_list_readdepths]
data_list_readdepths = np.asarray(data_list_readdepths)



data_list_readdepths = np.expand_dims(data_list_readdepths, axis=2)

#model
max_length =  192000 # maximum length of read depth signals
inpsize = max_length / 100


freec_polisher = load_model("../../DECoNTfreec_batchnorm_bilstm128_batchnorm_dense100_dense1_bs256_padding-1_60epochs_traintestsplitted_mae.h5")

polished_preds = freec_polisher.predict([data_list_freec_preds, data_list_readdepths], verbose=1)
#somecomment

print("Number of CNV events: ", polished_preds.shape[0])
polished_preds = np.ceil(polished_preds.reshape(-1,))
ae_freec = np.abs(data_list_canavar_preds.astype(np.float64) - data_list_freec_preds.astype(np.float64))
ae_decont = np.abs(data_list_canavar_preds.astype(np.float64) - polished_preds.astype(np.float64))
dif = ae_freec - ae_decont
sdif = -np.sort(-dif)
print("Percentage of improved events: ", np.sum(sdif>0)/polished_preds.shape[0])
print("Percentage of deteriorated events: ", np.sum(sdif<0)/polished_preds.shape[0])
print("Mean Absolute Error (MAE) difference decreased by: ", np.mean(ae_freec) - np.mean(ae_decont))