import csv
import pdb
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

tips = sns.load_dataset("tips")


xhmmcalls = np.load('data_list_xhmm_preds_test_weightedcrossent_embeddingout.npy')
cnvnatorcalls = np.load('data_list_cnvnator_preds_test_weightedcrossent_embeddingout.npy')
chrs = np.load('data_list_chrs_test_weightedcrossent_embeddingout.npy')
inds = np.load('data_list_indexes_test_weightedcrossent_embeddingout.npy')
polished = np.load('polished_rev12.npy')
lens = np.asarray([x[1]-x[0] for x in inds])
repeats = np.loadtxt('duprepcontent.csv', skiprows=1, delimiter=',', dtype=str)
corrected = []
repeats = repeats[:-3]
rp_chrs = repeats[:,0]
rp_percs = repeats[:,-1]

rp_chrs = [int(x[3:]) for x in rp_chrs]
rp_percs = [np.float(x) for x in rp_percs]

for i in range(len(xhmmcalls)):
	if np.argmax(xhmmcalls[i]) == np.argmax(cnvnatorcalls[i]):
		corrected.append('DECoNT agrees')
	elif np.argmax(polished[i]) == np.argmax(cnvnatorcalls[i]):
		corrected.append('DECoNT corrects')
	else: corrected.append('DECoNT agrees')


corrected = np.asarray(corrected)
indexes = np.arange(corrected.shape[0])

zipped = zip(chrs, indexes)
sorted_pairs = sorted(zipped)

tuples = zip(*sorted_pairs)
chrs, indexes = [ list(tuple) for tuple in  tuples]

corrected = corrected[indexes]
lens = lens[indexes]


xs = [int(c[3:]) + np.random.normal(0, 0.04) for c in chrs]
chrs = [int(x[3:]) for x in chrs]

decont_percs = []
for item in np.unique(chrs):
	chr_indexes = np.where(np.asarray(chrs)==item)
	tempperc = np.where(corrected[chr_indexes] == 'DECoNT corrects')[0].shape[0] / corrected[chr_indexes].shape[0]
	decont_percs.append(tempperc)

decontpercszipped = zip(decont_percs, np.unique(chrs))
srtd = sorted(decontpercszipped)
tpls = zip(*srtd)
decont_percs, uniquechrs = [ list(tuple) for tuple in  tpls]


repeatpercs = [np.asarray(rp_percs)[np.where(np.asarray(rp_chrs)==x)][0] for x in uniquechrs]

fig, ax = plt.subplots()
ax.scatter(decont_percs, repeatpercs)

for i, txt in enumerate(uniquechrs):
    ax.annotate('chr'+str(txt), (decont_percs[i], repeatpercs[i]))
plt.xlabel("Ratio of DECoNT corrected CNVs", fontweight='normal', fontsize=14)
plt.ylabel("Repeat Percentage", fontweight='normal', fontsize=14)
ax.plot([0,1],[0,1], transform=ax.transAxes, color='black', linestyle='--', linewidth=0.4)
plt.show()

pdb.set_trace()
