import gzip
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from collections import Counter
import shelve

counter_shelf = shelve.open('shelve.db')

pcb1 = "/data/long_read/lr_consoritum/pcb/ENCFF563QZR.fq"
pcb2 = "/data/long_read/lr_consoritum/pcb/ENCFF370NFS.fq"
pcb3 = "/data/long_read/lr_consoritum/pcb/ENCFF245IPA.fq"

ont1 = "/data/long_read/lr_consoritum/ont/ENCFF263YFG.fq"
ont2 = "/data/long_read/lr_consoritum/ont/ENCFF023EXJ.fq"
ont3 = "/data/long_read/lr_consoritum/ont/ENCFF961HLO.fq"

sr1 = '/data/long_read/lr_consoritum/illumina/ENCFF766OAK.fastq'
sr2 = '/data/long_read/lr_consoritum/illumina/ENCFF198RQU.fastq'
sr3 = '/data/long_read/lr_consoritum/illumina/ENCFF247XJT.fastq'

input_file_list = [pcb1, pcb2, pcb3, ont1, ont2, ont3, sr1, sr2, sr3]

# Change files to plot here
FILE1 = ont1 
FILE2 = pcb1

counters = {}
for i in input_file_list:
    counters[i] = Counter()
counters

def create_counter_gzipped(filename):
    with open(filename) as f:
        count_next = False
        for line in f:
            if count_next:
                counters[filename][len(line)] += 1
            if line.startswith('@'):
                count_next = True
            else:
                count_next = False

# This is for reqular fastq files only

def create_counter(filename):
    with open(filename) as f:
        count_next = False
        for line in f:
            if count_next:
                counters[filename][len(line)] += 1
            if line.startswith('@'):
                count_next = True
            else:
                count_next = False


for current_file in input_file_list:
    if current_file not in counter_shelf:
        create_counter_gzipped(current_file)
        counter_shelf[current_file] = counters[current_file]
    else:
        counters[current_file] = counter_shelf[current_file]
counter_shelf.close()

elements1 = list(counters[FILE1].elements())
elements2 = list(counters[FILE2].elements())

print(f"{FILE1} mean = {np.mean(elements1):.2f} median = {np.median(elements1):.2f} min = {np.min(elements1)} max = {np.max(elements1)}")
print(f"{FILE2} mean = {np.mean(elements2):.2f} median = {np.median(elements2):.2f} min = {np.min(elements2)} max = {np.max(elements2)}")

READ_LENGTH_MAX = 7500
CALC_PDF_EVERY_X_POINTS = 50

kernel1 = stats.gaussian_kde(elements1)
kernel1_xs = np.arange(0, READ_LENGTH_MAX, CALC_PDF_EVERY_X_POINTS)
kernel1_ys = kernel1.pdf(x=kernel1_xs)

kernel2 = stats.gaussian_kde(elements2)
kernel2_xs = np.arange(0, READ_LENGTH_MAX, CALC_PDF_EVERY_X_POINTS)
kernel2_ys = kernel2.pdf(x=kernel2_xs)

print("KDE calculation completed...")

# Create Final Plot

MAX_X_HIST = max([max(elements1), max(elements2)])
NUM_BINS = round(MAX_X_HIST / 40)
bin_width = MAX_X_HIST / NUM_BINS

hist1, bins1 = np.histogram(elements1, range=(0, MAX_X_HIST), bins=NUM_BINS)
# Convert to density
hist1 = hist1 / (sum(hist1)*bin_width)

# Plot the density histogram using plt.bar
plt.bar(bins1[:-1], hist1, width=np.diff(bins1),
        align='edge', alpha=0.75, color='steelblue', label=FILE1)


hist2, bins2 = np.histogram(elements2, range=(0, MAX_X_HIST), bins=NUM_BINS)
# Convert to density
hist2 = hist2 / (sum(hist2)*bin_width)

# Plot the density histogram using plt.bar
plt.bar(bins2[:-1], hist2, width=np.diff(bins2),
        align='edge', alpha=0.50, color='deeppink', label=FILE2)


# Draw KDE lines
plt.plot(kernel1_xs, kernel1_ys, color='steelblue')
plt.plot(kernel2_xs, kernel2_ys, color='deeppink')


if True:
    # Draw a straight line in 610 and draw a straight line in 2209 for pink
    MAX_Y_FOR_BLUE_VERTICAL_LINES = 0.00085
    MAX_Y_FOR_PINK_VERTICAL_LINES = 0.0003
    plt.plot([610, 610], [0, MAX_Y_FOR_BLUE_VERTICAL_LINES],
             color='steelblue', linewidth=2, alpha=1)

    plt.plot([2209, 2209], [0, MAX_Y_FOR_PINK_VERTICAL_LINES],
             color='deeppink', linewidth=2, alpha=1)

plt.xlabel('Read length', fontsize=16)
plt.ylabel('Read density', fontsize=16)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.xlim(0, READ_LENGTH_MAX)
plt.ylim(0, 0.0016)
plt.title('Non weighted Histogram of Read Length', fontsize=18)
plt.legend(fontsize=14)

if "/" in FILE1:
        FILE1 = FILE1.split('/')[1]
if "/" in FILE2:
        FILE2 = FILE2.split('/')[1]
output_name = f"read_length_distribution_{FILE1.split('.')[0]}_{FILE2.split('.')[0]}.png"


READ_LENGTH_MAX = 7500
CALC_PDF_EVERY_X_POINTS = 50
xs = np.arange(0, READ_LENGTH_MAX, CALC_PDF_EVERY_X_POINTS)

pcb_avg_ys = np.zeros(shape=xs.shape)
for file in [pcb1, pcb2, pcb3]:
    tmp_kernel = stats.gaussian_kde(list(counters[file].elements()))
    kernel_ys = tmp_kernel.pdf(x=xs)
    pcb_avg_ys += kernel_ys
pcb_avg_ys = pcb_avg_ys / 3

ont_avg_ys = np.zeros(shape=xs.shape)
for file in [ont1, ont2, ont3]:
    tmp_kernel = stats.gaussian_kde(list(counters[file].elements()))
    kernel_ys = tmp_kernel.pdf(x=xs)
    ont_avg_ys += kernel_ys
ont_avg_ys = ont_avg_ys / 3


MAX_X_HIST = 154773                       
NUM_BINS = round(MAX_X_HIST / 40)
bin_width = MAX_X_HIST / NUM_BINS

avg_pcb_hist = np.zeros(shape=(NUM_BINS,)) 
avg_pcb_bins = []
for file in [pcb1,pcb2,pcb3]:
    tmp_elements = list(counters[file].elements())
    tmp_hist, tmp_bin = np.histogram(tmp_elements, range=(0, MAX_X_HIST), bins=NUM_BINS)
    avg_pcb_bins = tmp_bin                 # save bins for plotting
    tmp_dens = tmp_hist / (sum(tmp_hist)*bin_width)
    avg_pcb_hist += tmp_dens 
avg_pcb_hist /= 3

avg_ont_hist = np.zeros(shape=(NUM_BINS,)) 
avg_ont_bins = []
for file in [ont1,ont2,ont3]:
    tmp_elements = list(counters[file].elements())
    tmp_hist, tmp_bin = np.histogram(tmp_elements, range=(0, MAX_X_HIST), bins=NUM_BINS)
    avg_ont_bins = tmp_bin                 # save bins for plotting
    tmp_dens = tmp_hist / (sum(tmp_hist)*bin_width)
    avg_ont_hist += tmp_dens 
avg_ont_hist /= 3


plt.clf()
plt.bar(avg_ont_bins[:-1], avg_ont_hist , width=np.diff(avg_ont_bins),
        align='edge', alpha=0.75, color='steelblue', label = "avg_ont")
plt.bar(avg_pcb_bins[:-1], avg_pcb_hist , width=np.diff(avg_pcb_bins),
        align='edge', alpha=0.50, color='deeppink', label = "avg_pcb")

plt.plot(xs, ont_avg_ys, color='steelblue')
plt.plot(xs, pcb_avg_ys, color='deeppink')
plt.xlabel('Read length', fontsize=16)
plt.ylabel('Read density', fontsize=16)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.xlim(0, READ_LENGTH_MAX)
plt.ylim(0, 0.0016)
plt.title('Non weighted Histogram of Read Length', fontsize=18)
plt.legend(fontsize=14)

plt.savefig("test_avg.png", bbox_inches='tight')


from matplotlib import gridspec

# create grid for different subplots
fig = plt.figure()
spec = gridspec.GridSpec(ncols=1, nrows=2, hspace=0.05, height_ratios=[1, 3])
ax1 = fig.add_subplot(spec[0]) 
ax2 = fig.add_subplot(spec[1]) 


# plot the same data on both axes
ax1.bar(avg_ont_bins[:-1], avg_ont_hist, width=np.diff(avg_ont_bins),
        align='edge', alpha=0.75, color='dodgerblue')
ax2.bar(avg_ont_bins[:-1], avg_ont_hist, width=np.diff(avg_ont_bins),
        align='edge', alpha=0.75, color='dodgerblue', label = "ONT")
ax1.bar(avg_pcb_bins[:-1], avg_pcb_hist, width=np.diff(avg_pcb_bins),
        align='edge', alpha=0.50, color='tomato')
ax2.bar(avg_pcb_bins[:-1], avg_pcb_hist, width=np.diff(avg_pcb_bins),
        align='edge', alpha=0.50, color='tomato', label = "PacBio")
ax1.bar(avg_sr_bins[:-1], avg_sr_hist, width=np.diff(avg_sr_bins),
        align='edge', alpha=0.50, color='slategrey')
ax2.bar(avg_sr_bins[:-1], avg_sr_hist, width=np.diff(avg_sr_bins),
        align='edge', alpha=0.50, color='slategrey', label = "Illumina")

ax1.plot(xs, ont_avg_ys, color='dodgerblue')
ax2.plot(xs, ont_avg_ys, color='dodgerblue')
ax1.plot(xs, pcb_avg_ys, color='tomato') 
ax2.plot(xs, pcb_avg_ys, color='tomato') 
ax1.plot(xs, sr_avg_ys, color='slategrey') 
ax2.plot(xs, sr_avg_ys, color='slategrey', linewidth=0.2)

ax1.set_ylim(0.015, 0.025)             # outliers
ax2.set_ylim(0, .0016)                 # most of the data
ax1.set_xlim(0, READ_LENGTH_MAX)
ax2.set_xlim(0, READ_LENGTH_MAX)

ax2_yticks = np.arange(0.000,0.0018,0.0002).tolist()
ax1_yticks = np.arange(0.015,0.030,0.005).tolist()
ax1.set_yticks(ax1_yticks)
ax2.set_yticks(ax2_yticks)
ax1.set_yticklabels(["","0.020", "0.025"]) 
# ax2.set_yticklabels([""] + list(map(str,ax2_yticks[1:])))

# hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False, top = False, labelsize=17)  # don't put tick labels or ticks at the top
ax2.tick_params(labeltop=False, labelsize=17)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()


# reduce whitespace
plt.xlabel('Read length', fontsize=23)
plt.ylabel('Read density', fontsize=23)  
ax1.set_title('Short and long-read length distribution', fontsize=18)
ax2.legend(fontsize=16)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

# Move y-axis more vertically
ax2_ylabel = ax2.get_ylabel()
ax2.set_ylabel(ax2_ylabel, y=0.65)
