import csv
import pdb
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

tips = sns.load_dataset("tips")


xhmmcalls = np.load('data_list_xhmm_preds_test_weightedcrossent_embeddingout2.npy')
cnvnatorcalls = np.load('data_list_cnvnator_preds_test_weightedcrossent_embeddingout2.npy')
chrs = np.load('data_list_chrs_test_weightedcrossent_embeddingout2.npy')
inds = np.load('data_list_indexes_test_weightedcrossent_embeddingout2.npy')
polished = np.load('polishedembedding2.npy')
lens = np.asarray([x[1]-x[0] for x in inds])
corrected = []

'''
for i in range(len(xhmmcalls)):
	if np.argmax(xhmmcalls[i]) == np.argmax(cnvnatorcalls[i]):
		corrected.append('DECoNT agrees')
	elif np.argmax(polished[i]) == np.argmax(cnvnatorcalls[i]):
		corrected.append('DECoNT corrects')
	else: corrected.append('DECoNT agrees')
'''
#pdb.set_trace()

for i in range(len(xhmmcalls)):
    if np.argmax(xhmmcalls[i]) + 1 == polished[i] == np.argmax(cnvnatorcalls[i]):
        corrected.append('XHMM call is correct and DECoNT agrees')  # agrees
    elif np.argmax(xhmmcalls[i]) + 1 == np.argmax(cnvnatorcalls[i]):
        corrected.append('XHMM call is correct, DECoNT call is incorrect')  # correct
    elif polished[i] == np.argmax(cnvnatorcalls[i]) != np.argmax(xhmmcalls[i]) + 1:
        corrected.append('DECoNT call is correct, XHMM call is incorrect')  # agrees
    else:
        corrected.append('Both calls are incorrect')

corrected = np.asarray(corrected)
indexes = np.arange(corrected.shape[0])

zipped = zip(chrs, indexes)
sorted_pairs = sorted(zipped)

tuples = zip(*sorted_pairs)
chrs, indexes = [ list(tuple) for tuple in  tuples]

corrected = corrected[indexes]
lens = lens[indexes]


xs = [int(c[3:]) + np.random.normal(0, 0.12) if str(c[3:]) not in ['X','Y'] else 23 + np.random.normal(0, 0.12) if str(c[3:]) == 'X' else 24 + np.random.normal(0, 0.04) for c in chrs]
chrs = [int(c[3:]) if str(c[3:]) not in ['X','Y'] else 23 if str(c[3:]) == 'X' else 24 for c in chrs]


#colors = [0 if x[0] == 'C' else 1 for x in corrected]
lens = lens / 1000

chr_ratios = []
ccs = []
percs = []
for cc in np.unique(chrs):
	indexes = np.asarray(chrs) == cc
	tempcorrecteds = corrected[indexes]
	percs.append((np.sum(tempcorrecteds == 'DECoNT call is correct, XHMM call is incorrect') / len(tempcorrecteds)))
	chr_ratios.append(400 + (np.sum(tempcorrecteds == 'DECoNT call is correct, XHMM call is incorrect') / len(tempcorrecteds))*200)
	ccs.append(cc-1)






palette = ['green', 'gray','black', 'purple']
sns.despine(bottom=True) # removes right and top axis lines
#sns.violinplot(x=chrs, y=lens, inner= None, color='.8')
sns.stripplot(x=chrs, y=lens, alpha=0.8, hue=corrected, palette=palette, size=4, jitter=0.4)
sns.lineplot(x=ccs, y=chr_ratios, size=7, legend=None, label='Percentage of correctly polished calls')
plt.xlabel("Chromosome", fontweight='normal', fontsize=14)
plt.ylabel("Length of CNV (kbp)", fontweight='normal', fontsize=14)
plt.axhline(y=np.median(lens), color='blue', linestyle='--', linewidth=1, label='Median CNV length')
#plt.axhline(y=500, color='black', linestyle='dashdot', linewidth=0.5, label='0 percent correction')
#plt.axhline(y=650, color='red', linestyle='dashdot', linewidth=0.5, label='100 percent correction')
plt.xticks(ticks=range(0,24), labels=[str(i) for i in range(1,23)] + ['X', 'Y'], fontsize=14)
locations = [chr_ratios[i] + 15 if i not in [5,6,7,13,17,20,22] else chr_ratios[i] -15 for i in range(len(chr_ratios))] 
locations[7] = locations[7]+10
for i, txt in enumerate(chr_ratios):
    plt.annotate(str(percs[i])[:4], (ccs[i], locations[i]), fontsize=14)

plt.axvline(x=-0.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=0.5, color='black', linestyle='--', linewidth=0.2)

plt.axvline(x=1.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=2.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=3.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=4.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=5.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=6.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=7.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=8.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=9.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=10.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=11.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=12.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=13.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=14.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=15.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=16.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=17.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=18.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=19.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=20.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=21.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=22.5, color='black', linestyle='--', linewidth=0.2)
plt.axvline(x=23.5, color='black', linestyle='--', linewidth=0.2)



plt.tight_layout()
plt.ylim(ymax = 750, ymin = 0)

plt.legend( loc='upper left', borderaxespad=0., framealpha=0.5, facecolor ='white', frameon=True, fontsize=14)
#plt.savefig('/Users/furkanozden/Desktop/chromosomestratification.jpeg', bbox_to_inches='tight', dpi=500)
plt.show()
pdb.set_trace()









#ax = sns.boxplot(y=chrs, x=lens, hue=corrected, palette="muted")
#pdb.set_trace()
#ax.set(ylabel='length of CNV in base-pair')
#plt.show()
