import os
import argparse
import pandas as pd
from gtfparse import read_gtf
from tqdm.notebook import tqdm
from collections import defaultdict
import json
import matplotlib.pyplot as plt
from matplotlib import pyplot
import matplotlib.ticker as mticker
import seaborn as sns
import numpy as np
from scipy import stats
from collections import Counter
from matplotlib import gridspec
import shelve
import sqlite3
import upsetplot
from scipy.interpolate import interp1d
from upsetplot import plot
from upsetplot import from_memberships
# warnings.filterwarnings("ignore")


def comparative_analysis(comparison_file, output_folder):
    df = pd.read_csv(comparison_file, sep='\t', engine='python')
    df = df.sum(axis=0)

    denovo = from_memberships(
            [['combination'],
              ['novel'],
              ['combination','novel'], 
              ['combination','pTSS/pTES'],
              ['novel','pTSS/pTES'],
              ['combination','novel','pTSS/pTES'],
            ],
            data=[df[11], df[12], df[13], df[14], df[15], df[16]],
        )

    novel = from_memberships(
            [["Alt 3'ss"],
            ["Alt 5'ss"],
            ['Novel exon'],
            ['Intron retention'],
            ["Alt 3'ss","Alt 5'ss"],
            ["Alt 3'ss",'Novel exon'],
            ["Alt 3'ss",'Intron retention'],
            ["Alt 5'ss",'Novel exon'],
            ["Alt 5'ss",'Intron retention'],
            ['Novel exon','Intron retention'],
            ["Alt 3'ss","Alt 5'ss",'Novel exon'],
            ["Alt 3'ss","Alt 5'ss",'Intron retention'],
            ["Alt 3'ss",'Novel exon','Intron retention'],
            ["Alt 5'ss",'Novel exon','Intron retention'],
            ["Alt 3'ss","Alt 5'ss",'Novel exon','Intron retention'],
            ],
            data=[df[18], df[19], df[20], df[21], df[22], df[23], df[24], df[25], df[26], df[27], df[28], df[29], df[30], df[31], df[32]]
        )


    # pie chart
    fig = plt.figure(figsize=(8,8))
    labels = ['MAJIQ combination', 'MAJIQ novel']
    data = [df[9], df[10]]
    total = sum(data)
    colors = ['#09B141','#99ff99']
    #sns.color_palette('pastel')[0:30]
    plt.pie(data, labels=labels, colors = colors, autopct='%.1f%%', textprops={'fontsize':18})
    plt.savefig(os.path.join(output_folder, "majiq_piechart.pdf"), format="pdf", bbox_inches="tight")
    plt.show()

    # denovo upset plot
    fig = plt.figure() #40
    plot(denovo, show_counts=True, sort_by='cardinality', element_size=45, 
        min_subset_size=0, totals_plot_elements=5, intersection_plot_elements=4, shading_color="lightgray")
    plt.suptitle('LR de novo junction', fontsize=20, y =1.05) #20
    plt.savefig(os.path.join(output_folder, "lr_denovo.pdf"), format="pdf", bbox_inches="tight")
    plt.show()

    # novel upset plot
    plot(novel, show_counts=True, sort_by='cardinality', element_size=45, 
        min_subset_size=0, totals_plot_elements=5, intersection_plot_elements=5, shading_color="lightgray")
    plt.suptitle('LR novel splice site', fontsize=25, y=1.05)
    plt.savefig(os.path.join(output_folder, "lr_novel.pdf"), format="pdf", bbox_inches="tight")
    plt.show()


def read_psi(junction_file, gtf_file):
    df = pd.read_csv(junction_file, sep='\t', engine='python', skiprows=4)
    df.rename(columns={df.columns[15]:'median_reads', df.columns[16]: 'median_psi', df.columns[17]: 'var_psi'}, inplace=True)
    df_psi = df.copy()
    df_psi['median_psi'] = pd.to_numeric(df['median_psi'], errors='coerce') # convert to numeric
    df_psi = df_psi[df_psi['median_psi'].notna()] # remove all blank psi rows
    df_psi = df_psi.groupby('junction_coord').min()
    df_psi['junction_coord'] = df_psi.index # add back junction_coord column

    df_reads = df.copy()
    df_reads.rename(columns={df_reads.columns[15]:'median_reads', df_reads.columns[16]: 'median_psi', df_reads.columns[17]: 'var_psi'}, inplace=True)
    df_reads['median_reads'] = pd.to_numeric(df['median_reads'], errors='coerce') # convert to numeric
    df_reads = df_reads[df_reads['median_reads'].notna()] # remove all blank reads rows
    df_reads = df_reads[df_reads['junction_coord'].notna()] # remove all blank junction_coord rows
    df_reads = df_reads[['gene_id','junction_coord','median_reads']]
    df_reads.drop_duplicates(keep='first', inplace=True)
    df_reads.loc[(df_reads['median_reads'] > 110), 'median_reads'] = 110 # changes median_reads value to 110 if greater than 110

    ax = df_psi['median_psi'].plot.hist(bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                                                    0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], 
                                              figsize=(10, 6))
    psi_count = []
    for p in ax.patches:
        y = int(p.get_height())
        psi_count.append(y)

    ax1 = df_reads['median_reads'].plot.hist(bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], 
                                              figsize=(10, 6), align='left')
    read_count = []
    for p in ax1.patches:
        y = int(p.get_height())
        read_count.append(y)

    # create 20 psi bin set to hold values
    for i in range(20):
        globals()[f'pbin{i+1}'] = set()

    # add junct_coord values into corresponding psi set
    for i, value in enumerate(df_psi['median_psi']):
        num1 = int(df_psi.iloc[i]['junction_coord'].split('-')[0])
        num2 = int(df_psi.iloc[i]['junction_coord'].split('-')[1])
        
        if value < 0.05:
            pbin1.add((num1, num2))
        elif value < 0.1:
            pbin2.add((num1, num2))
        elif value < 0.15:
            pbin3.add((num1, num2))
        elif value < 0.2:
            pbin4.add((num1, num2))
        elif value < 0.25:
            pbin5.add((num1, num2))
        elif value < 0.3:
            pbin6.add((num1, num2))
        elif value < 0.35:
            pbin7.add((num1, num2))
        elif value < 0.4:
            pbin8.add((num1, num2))
        elif value < 0.45:
            pbin9.add((num1, num2))
        elif value < 0.5:
            pbin10.add((num1, num2))
        elif value < 0.55:
            pbin11.add((num1, num2))
        elif value < 0.6:
            pbin12.add((num1, num2))
        elif value < 0.65:
            pbin13.add((num1, num2))
        elif value < 0.7:
            pbin14.add((num1, num2)) 
        elif value < 0.75:
            pbin15.add((num1, num2))
        elif value < 0.8:
            pbin16.add((num1, num2))
        elif value < 0.85:
            pbin17.add((num1, num2))
        elif value < 0.9:
            pbin18.add((num1, num2))
        elif value < 0.95:
            pbin19.add((num1, num2))
        elif value <= 1:
            pbin20.add((num1, num2))

    # create 10 reads bin set to hold values
    for i in range(11):
        globals()[f'rbin{i+1}'] = set()

    # add junct_coord values into corresponding reads set
    for i, value in enumerate(df_reads['median_reads']):
        num1 = int(df_reads.iloc[i]['junction_coord'].split('-')[0])
        num2 = int(df_reads.iloc[i]['junction_coord'].split('-')[1])
        
        if value < 10:
            rbin1.add((num1, num2))
        elif value < 20:
            rbin2.add((num1, num2))
        elif value < 30:
            rbin3.add((num1, num2))
        elif value < 40:
            rbin4.add((num1, num2))
        elif value < 50:
            rbin5.add((num1, num2))
        elif value < 60:
            rbin6.add((num1, num2))
        elif value < 70:
            rbin7.add((num1, num2))
        elif value < 80:
            rbin8.add((num1, num2))
        elif value < 90:
            rbin9.add((num1, num2))
        elif value < 100:
            rbin10.add((num1, num2))
        elif value < 111:
            rbin11.add((num1, num2))

    # load set text file
    with open(gtf_file) as file:
        a = file.read()

    a = a[2:-2].split('), (') # convert set to list
    psi_in_a = [0] * 20
    reads_in_a = [0] * 11

    # create list of psi bin sets 
    b_list = [pbin1, pbin2, pbin3, pbin4, pbin5, pbin6, pbin7, pbin8, pbin9, pbin10,
              pbin11, pbin12, pbin13, pbin14, pbin15, pbin16, pbin17, pbin18, pbin19, pbin20]

    # iterate through b_list
    for i, b in enumerate(b_list):
        print(f'counting number of pbin{i+1} pairs in a')
        
        # convert set to list
        b = str(b)
        b = b[2:-2].split('), (')
        
        # if numerical pair in a, increase corresponding counter
        for pair in b:
            if pair in a:
                psi_in_a[i] += 1

    # create list of reads bin sets 
    b_list = [rbin1, rbin2, rbin3, rbin4, rbin5, rbin6, rbin7, rbin8, rbin9, rbin10, rbin11]

    # iterate through b_list
    for i, b in enumerate(b_list):
        print(f'counting number of rbin{i+1} pairs in a')
        
        # convert set to list
        b = str(b)
        b = b[2:-2].split('), (')
        
        # if numerical pair in a, increase corresponding counter
        for pair in b:
            if pair in a:
                reads_in_a[i] += 1

    temp1 = []
    temp2 = []
    start = 0.025

    # start with pbin1 totals and increment for the next bin
    for i in range(20):
        for j in range(psi_count[i]):
            temp1.append(start)
        for j in range(psi_in_a[i]):
            temp2.append(start)
        start += 0.05
        
    plt.figure(figsize=(10,6))
    plt.hist(pd.Series(temp1), alpha=0.3, label='MAJIQ',color='slategrey',
            bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                  0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1])
    plt.hist(pd.Series(temp2), alpha=0.8, label='IsoQuant PacBio',color='tomato',
            bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                  0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1])


    plt.legend(loc='best', fontsize=13)
    plt.title('MAJIQ and IsoQuant PSI ($\Psi$)', fontsize=23)
    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)
    plt.xlabel("PSI ($\Psi$)", fontsize=21)
    plt.ylabel("Number of splice junctions", fontsize=21)
    plt.show;


    # psi cumulative combined histogram
    fig, ax1 = plt.subplots(figsize=(10,6))
    ax1.hist(pd.Series(temp1), alpha=0.4, label='MAJIQ', histtype='step', cumulative=-1, color='slategrey',
            bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                  0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], linewidth=1.2)
    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)

    plt.xlabel("PSI ($\Psi$)", fontsize=21)
    plt.ylabel("Number of splice junctions", fontsize=21)

    ax1.hist(pd.Series(temp2), alpha=1.0, label='LR PacBio', histtype='step', cumulative=-1, color='tomato',
            bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                  0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], linewidth=1.3)

    # ax1.hist(pd.Series(temp2), alpha=1.0, label='IsoQuant ONT', histtype='step', cumulative=-1, color='dodgerblue',
            # bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                  # 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], linewidth=1.3)
    ax1.legend(loc='upper left', fontsize=12)

    ax2 = ax1.twinx()
    ax2.hist(pd.Series(temp1), alpha=0.5, label='MAJIQ', density=True, histtype='step', cumulative=-1,
            bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 
                  0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], linewidth=1.3)

    ax1.invert_xaxis()
    #ax2.invert_xaxis()

    ax2.set_yticklabels(['0', '20', '40', '60', '80', '100'])
    plt.yticks(fontsize=18)
    plt.axvline(x = 0.2, color = 'black', label = 'axvline - full height', linestyle="dashed")
    #ax1.axhline(y = sum(pcb[4:]), color = 'black', linestyle="dashed", xmin = 0.775, xmax = 0.955)
    ax1.axhline(y = sum(psi_count[4:]), color = 'black', linestyle="dashed", xmin = 0.775, xmax = 0.955)
    ax1.axhline(y = sum(psi_in_a[4:]), color = 'black', linestyle="dashed", xmin = 0.775, xmax = 0.955)
    plt.ylabel("Cumulative probability", fontsize=21)
    plt.title('MAJIQ and LR PSI ($\Psi$)', fontsize=23)
    plt.show;


    # create three temp lists with 1 entry per bin value
    temp1 = []
    temp2 = []
    value = 5

    # start with first bin totals and increment for the next bin
    for i in range(11):
        for j in range(reads_in_a[i]):
            temp1.append(value)
        for j in range(read_count[i]):
            temp2.append(value)
        value += 10

    # histogram for temp bins
    ax = pd.Series(temp2).plot.hist(alpha=0.3, label='MAJIQ', align='left', color = 'slategrey',
            bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110])
    ax = pd.Series(temp1).plot.hist(alpha=0.8, label='LR PacBio', figsize=(10,6), align='left', color='tomato',
            bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110])


    ax.legend(loc='upper right', fontsize=13)
    ax.set_title('MAJIQ and LR reads', fontsize=23)

    ax.set_xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    ax.set_xticklabels(['0-10', '10-20', '20-30', '30-40', '40-50', '50-60', '60-70', '70-80', '80-90', '90-100', '>100'])
    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)
    plt.xlabel("Number of reads", fontsize=21)
    plt.ylabel("Number of splice junctions", fontsize=21)
    plt.show;

    temp1 = []
    temp2 = []
    value = 105 # start in reverse and subtract by 10 for each loop

    for i in range(11):
        for j in range(reads_in_a[i]):
            temp1.append(value)
        for j in range(read_count[i]):
            temp2.append(value)
        value -= 10
        

    # reads cumulative combined histogram
    fig, ax1 = plt.subplots(figsize=(10,6))
    ax1.hist(pd.Series(temp2), alpha=0.4, label='MAJIQ', histtype='step', cumulative=True, align='left', color='slategrey',
            bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], linewidth=1.3)
    ax1.hist(pd.Series(temp1), alpha=1.0, label='LR PacBio', histtype='step', cumulative=True, align='left', color='tomato',
            bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], linewidth=1.3)
    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)

    plt.xlabel("Number of reads", fontsize=21)
    plt.ylabel("Number of splice junctions", fontsize=21)

    # ax1.hist(pd.Series(temp2), alpha=1.0, label='Bambu ONT', histtype='step', cumulative=True, align='left', color='dodgerblue',
            # bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], linewidth=1.3)

    ax1.legend(loc='upper left', fontsize=12)


    ax2 = ax1.twinx()
    ax2.hist(pd.Series(temp2), alpha=0.5, label='majiq', density=True, histtype='step', cumulative=True, align='left',
            bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110])
    ax2.set_yticklabels(['0', '20', '40', '60', '80', '100'])


    ax1.set_xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    ax1.set_xticklabels(['>100', '90-100', '80-90', '70-80', '60-70', '50-60', '40-50', '30-40', '20-30', '10-20', '0-10'])
    plt.yticks(fontsize=14)

    plt.ylabel("Cumulative probability", fontsize=21)
    plt.title('MAJIQ and LR reads', fontsize=23)
    plt.show;


def lsv_plot(majiq_file, pacbio_file, ont_file):
    df = pd.read_csv(majiq_file, sep='\t', engine='python')
    df = df.sum(axis=0)
    df_lsv_id = df.copy()
    df_lsv_id['median_reads'] = pd.to_numeric(df_lsv_id['median_reads'], errors='coerce') # convert to numeric
    df_lsv_id = df_lsv_id[df_lsv_id['median_reads'].notna()] # remove all blank reads rows

    df_lsv_id = df_lsv_id[['gene_id','junction_coord','median_reads']]
    df_lsv_id.drop_duplicates(keep='first', inplace=True)

    df_lsv_id.loc[(df_lsv_id['median_reads'] > 110), 'median_reads'] = 110 # changes median_reads value to 110 if greater than 110
    ax = df_lsv_id['median_reads'].plot.hist(bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], 
                                              figsize=(10, 6), align='left')
    ax.set_title('junctions! histogram median_reads')
    ax.set_xlabel('reads numbers')
    ax.set_ylabel('# splice junctions')

    ax.set_xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    ax.set_xticklabels(['0-10', '10-20', '20-30', '30-40', '40-50', '50-60', '60-70', '70-80', '80-90', '90-100', '> 100'])
    plt.show();

    lsv_count = []
    for p in ax.patches:
        y = int(p.get_height())
        lsv_count.append(y)

    with open(pacbio_file) as file:
        file = file.read()

    file = file[1:-1] # remove { } from beginning and end
    file = file.replace('(', '') # remove all (
    file = file.replace('):', ',') # replace all ): with a comma
    file = file.replace(' ', '') # remove all empty spaces
    file = file.split(',') # split to list -> [junc1, junc2, count, ...]

    # load set text file
    with open(ont_file) as file2:
        file2 = file2.read()

    file2 = file2[1:-1] # remove { } from beginning and end
    file2 = file2.replace('(', '') # remove all (
    file2 = file2.replace('):', ',') # replace all ): with a comma
    file2 = file2.replace(' ', '') # remove all empty spaces
    file2 = file2.split(',') # split to list -> [junc1, junc2, count, ...]

    filter_value = 10 # change as needed

    a = [] # create a list to hold junc_coordinate pairs
    for i, value in enumerate(file):
        if (i+1) % 3 == 0 and int(value) > filter_value: # check if count value is greater than filter value
            a.append(file[i-2] + ', ' + file[i-1]) # append junc1, junc2 to a list
            
    lsv_id_reads_in_a = [0] * 11 # create lsv_id reads counter list

    # loop through junction_coord
    for i, junction_coord in enumerate(df_lsv_id['junction_coord']):
        num1 = junction_coord.split('-')[0]
        num2 = junction_coord.split('-')[1]
        b = num1 + ', ' + num2
        
        # check if junction_coord in a
        if b in a:
            reads_value = df_lsv_id.iloc[i]['median_reads'] # grab reads value
            
            # increment lsv_id reads counter accordingly
            if reads_value < 10:
                lsv_id_reads_in_a[0] += 1
            elif reads_value < 20:
                lsv_id_reads_in_a[1] += 1
            elif reads_value < 30:
                lsv_id_reads_in_a[2] += 1
            elif reads_value < 40:
                lsv_id_reads_in_a[3] += 1
            elif reads_value < 50:
                lsv_id_reads_in_a[4] += 1
            elif reads_value < 60:
                lsv_id_reads_in_a[5] += 1
            elif reads_value < 70:
                lsv_id_reads_in_a[6] += 1
            elif reads_value < 80:
                lsv_id_reads_in_a[7] += 1
            elif reads_value < 90:
                lsv_id_reads_in_a[8] += 1
            elif reads_value < 100:
                lsv_id_reads_in_a[9] += 1
            elif reads_value < 111:
                lsv_id_reads_in_a[10] += 1
                

    c = [] # create a list to hold junc_coordinate pairs
    for i, value in enumerate(file2):
        if (i+1) % 3 == 0 and int(value) > filter_value: # check if count value is greater than filter value
            c.append(file2[i-2] + ', ' + file2[i-1]) # append junc1, junc2 to a list
            
    lsv_id_reads_in_b = [0] * 11 # create lsv_id reads counter list

    # loop through junction_coord
    for i, junction_coord in enumerate(df_lsv_id['junction_coord']):
        num1 = junction_coord.split('-')[0]
        num2 = junction_coord.split('-')[1]
        b = num1 + ', ' + num2
        
        # check if junction_coord in a
        if b in c:
            reads_value = df_lsv_id.iloc[i]['median_reads'] # grab reads value
            
            # increment lsv_id reads counter accordingly
            if reads_value < 10:
                lsv_id_reads_in_b[0] += 1
            elif reads_value < 20:
                lsv_id_reads_in_b[1] += 1
            elif reads_value < 30:
                lsv_id_reads_in_b[2] += 1
            elif reads_value < 40:
                lsv_id_reads_in_b[3] += 1
            elif reads_value < 50:
                lsv_id_reads_in_b[4] += 1
            elif reads_value < 60:
                lsv_id_reads_in_b[5] += 1
            elif reads_value < 70:
                lsv_id_reads_in_b[6] += 1
            elif reads_value < 80:
                lsv_id_reads_in_b[7] += 1
            elif reads_value < 90:
                lsv_id_reads_in_b[8] += 1
            elif reads_value < 100:
                lsv_id_reads_in_b[9] += 1
            elif reads_value < 111:
                lsv_id_reads_in_b[10] += 1

    A = np.array(lsv_count)
    B = np.array(lsv_id_reads_in_a)
    C = np.array(lsv_id_reads_in_b)

    # (A-B)/A % graph
    ip = (A-B)/A # calculate percentage values
    io = (A-C)/A

    x_labels = ['10-20', '20-30', '30-40', '40-50', '50-60', '60-70', '70-80', '80-90', '90-100', '>100']
    bar_width = 0.42

    ip = [j*100 for j in ip[1:]]
    io = [j*100 for j in io[1:]]

    x = np.arange(len(x_labels))

    plt.figure(figsize=(10,6))
    plt.bar(x - bar_width/2, io, color = 'dodgerblue', label ='ONT',
            width=bar_width)
    plt.bar(x + bar_width/2, ip, color = 'tomato', label ='PacBio',
            width=bar_width)

    # plt.plot(x, io, '-o', color = 'steelblue', label ='ONT')
    # plt.plot(x, ip, '-o', color = 'deeppink', label ='PacBio')
    plt.legend(loc='upper right', fontsize=13)
    plt.xticks(x, x_labels, fontsize=18, rotation=30)
    plt.yticks(fontsize=18)
    plt.xlabel("Number of reads per LSV", fontsize=21)
    plt.ylabel("% of non-quantifiability", fontsize=21)
    plt.title('LR LSV non-quantifiability', fontsize=23)
    plt.show()





def lr_intron_psi_distribution(majiq_intron_file, voila_file, gene_id_list):
    
    lr_read_file = extract_psi_from_lr(voila_file, gene_id_list)
    lr_read = pd.read_csv(lr_read_file, engine='python')
    lr_read = lr_read[lr_read['event'] =='intron']
    lr_read = lr_read.groupby('junction_coordinate').min()
    lr_read['junction_coord'] = lr_read.index
    lr_read.loc[(lr_read['reads'] > 110), 'reads'] = 110

    ax = lr_read['reads'].plot.hist(bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], 
                                              figsize=(10, 6), align='left')
    
    lr_read_count = []
    for p in ax.patches:
        y = int(p.get_height())
        lr_read_count.append(y)

    for i in range(11):
        globals()[f'rbin{i+1}'] = set()

    for i, value in enumerate(lr_read['reads']):
        ir_coord = lr_read.iloc[i]['junction_coord']
        num1 = ir_coord.split('-')[0]
        num2 = ir_coord.split('-')[1]
        
        if value < 10:
            rbin1.add((num1, num2))
        elif value < 20:
            rbin2.add((num1, num2))
        elif value < 30:
            rbin3.add((num1, num2))
        elif value < 40:
            rbin4.add((num1, num2))
        elif value < 50:
            rbin5.add((num1, num2))
        elif value < 60:
            rbin6.add((num1, num2))
        elif value < 70:
            rbin7.add((num1, num2))
        elif value < 80:
            rbin8.add((num1, num2))
        elif value < 90:
            rbin9.add((num1, num2))
        elif value < 100:
            rbin10.add((num1, num2))
        elif value < 111:
            rbin11.add((num1, num2))

    
    sr_ir = pd.read_csv(majiq_intron_file, engine='python',skiprows=4, sep='\t')
    sr_ir = sr_ir.loc[sr_ir['junction_name'].isin(['C1_C2_intron', 'C2_C1_intron'])]
    sr_ir = sr_ir.drop_duplicates(subset='junction_coord')

    globals()[f'dfm_bin'] = set()
    for i, value in enumerate(sr_ir['junction_coord']):
        #dfm['junction_coord'].split('-')
        junc_start = str(value.split('-')[0])
        junc_end = str(value.split('-')[1])
        dfm_bin.add((junc_start, junc_end))

    dfm_bin = str(dfm_bin)
    dfm_bin = dfm_bin[2:-2].split('), (')

    ir_bin = [0] * 11
    b_list = [rbin1, rbin2, rbin3, rbin4, rbin5, rbin6, rbin7, rbin8, rbin9, rbin10, rbin11]
    for i,b in enumerate(b_list):
        print(f'counting number of ir_bin{i+1} pairs in dfm_bin')
        b = str(b)
        #print(b)
        b = b[2:-2].split('), (')
        for pair in b:
            if pair in dfm_bin:
                ir_bin[i] += 1
    

    temp1 = []
    temp2 = []
    value = 5

    for i in range(11):
        for j in range(lr_read[i]):
            temp1.append(value)
        for j in range(ir_bin[i]):
            temp2.append(value)
        value += 10

    fig, ax = plt.subplots(figsize=(10,6))
    # dodgerblue
    # tomato
    ax = pd.Series(temp1).plot.hist(alpha = 0.6, label='LR IR', align='left', color = 'tomato',
                                  bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110])
    ax = pd.Series(temp2).plot.hist(alpha = 0.4, label='MAJIQ IR', align='left',color = 'dimgrey',
                                  bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110])

    plt.legend(loc='upper right', fontsize=13)
    ax.set_xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    ax.set_xticklabels(['0-10', '10-20', '20-30', '30-40', '40-50', '50-60', '60-70', '70-80', '80-90', '90-100', '>100'])

    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)
    plt.xlabel("Number of reads", fontsize=21)
    plt.ylabel("Number of introns", fontsize=21)
    plt.title('IR: MAJIQ and LR reads', fontsize=23)
    plt.show;

    
    temp3 = []
    temp4 = []
    value = 105

    for i in range(11):
        for j in range(lr_read[i]):
            temp3.append(value)
        for j in range(ir_bin[i]):
            temp4.append(value)
        value -= 10

    fig, ax = plt.subplots(figsize=(10,6))
    ax = pd.Series(temp3).plot.hist(alpha = 1.0, label='LR IR', histtype='step', cumulative=True, align='left', color = 'tomato',
                                  bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], linewidth=1.3)
    ax = pd.Series(temp4).plot.hist(alpha = 0.4, label='MAJIQ IR', histtype='step', cumulative=True, align='left',color = 'dimgrey',
                                  bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110], linewidth=1.3)

    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)

    plt.xlabel("Number of reads", fontsize=21)
    plt.ylabel("Number of introns", fontsize=21)

    ax.legend(loc='upper left', fontsize=12)

    ax2 = ax.twinx()
    ax2.hist(pd.Series(temp3), alpha=0.5, label='majiq', density=True, histtype='step', cumulative=True, align='left',
            bins=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110])
    ax2.set_yticklabels(['0', '20', '40', '60', '80', '100'])

    ax.set_xticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    ax.set_xticklabels(['>100', '90-100', '80-90', '70-80', '60-70', '50-60', '40-50', '30-40', '20-30', '10-20', '0-10'])

    plt.yticks(fontsize=18)
    plt.ylabel("Cumulative probability", fontsize=18)
    plt.title('IR: MAJIQ and LR reads', fontsize=23)
    plt.show;

    lr_psi_file =  extract_psi_from_lr(voila_file, gene_id_list)
    lr_psi = pd.read_csv(lr_psi_file, engine='python')
    lr_psi = lr_psi[lr_psi['event'] =='intron']
    lr_psi = lr_psi[lr_psi['reads'] >= 10]
    lr_psi = lr_psi.groupby('junction_coordinate').min()
    lr_psi['junction_coord'] = lr_psi.index

    ax = lr_psi['adjusted_psi'].plot.hist(bins=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], figsize=(10, 6))

    lr_psi_count = []
    for p in ax.patches:
        y = int(p.get_height())
        lr_psi_count.append(y)

    for i in range(10):
        globals()[f'pbins{i+1}'] = set()

    for i, value in enumerate(lr_psi['adjusted_psi']):
        ir_coord = lr_psi.iloc[i]['junction_coord']
        ir_start = ir_coord.split('-')[0]
        ir_end = ir_coord.split('-')[1]
        if value < 0.1:
            pbins1.add((ir_start, ir_end))
        elif value < 0.2:
            pbins2.add((ir_start, ir_end))
        elif value < 0.3:
            pbins3.add((ir_start, ir_end))
        elif value < 0.4:
            pbins4.add((ir_start, ir_end))
        elif value < 0.5:
            pbins5.add((ir_start, ir_end))
        elif value < 0.6:
            pbins6.add((ir_start, ir_end))
        elif value < 0.7:
            pbins7.add((ir_start, ir_end))
        elif value < 0.8:
            pbins8.add((ir_start, ir_end))
        elif value < 0.9:
            pbins9.add((ir_start, ir_end))
        elif value <= 1.0:
            pbins10.add((ir_start, ir_end))

    sr_ir = pd.read_csv(majiq_intron_file, engine='python',skiprows=4, sep='\t')
    sr_ir = sr_ir.loc[sr_ir['junction_name'].isin(['C1_C2_intron', 'C2_C1_intron'])]
    sr_ir = sr_ir.drop_duplicates(subset='junction_coord')

    globals()[f'dfm_bin'] = set()
    for i, value in enumerate(sr_ir['junction_coord']):
        #dfm['junction_coord'].split('-')
        junc_start = str(value.split('-')[0])
        junc_end = str(value.split('-')[1])
        dfm_bin.add((junc_start, junc_end))

    dfm_bin = str(dfm_bin)
    dfm_bin = dfm_bin[2:-2].split('), (')

    ir_bin2 = [0] * 10
    b_list = [pbins1, pbins2, pbins3, pbins4, pbins5, pbins6, pbins7, pbins8, pbins9, pbins10]
    for i,b in enumerate(b_list):
        #print(b)
        print(f'counting number of ir_bin{i+1} pairs in dfm_bin')
        b = str(b)
        b = b[2:-2].split('), (')
        #print(b)
        for pair in b:
            #print(pair)
            if pair in dfm_bin:
                ir_bin2[i] += 1

    temp5 = []
    temp6 = []
    value = 0.01

    for i in range(10):
        for j in range(lr_psi_count[i]):
            temp5.append(value)
        for j in range(ir_bin2[i]):
            temp6.append(value)
        value += 0.1

    fig, ax = plt.subplots(figsize=(10,6))
    # dodgerblue
    # tomato
    ax = pd.Series(temp5).plot.hist(alpha = 0.6, label='LR IR', align='left', color = 'tomato',
                                  bins =[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
    ax = pd.Series(temp6).plot.hist(alpha = 0.4, label='MAJIQ IR', align='left',color = 'dimgrey',
                                  bins =[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])

    plt.legend(loc='upper left', fontsize=13)
    ax.set_xticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
    ax.set_xticklabels(['0-0.1','0.1-0.2', '0.2-0.3', '0.3-0.4', '0.4-0.5', '0.5-0.6', '0.6-0.7', '0.7-0.8', '0.8-0.9', '0.9-1.0'])

    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)
    plt.xlabel("PSI ($\Psi$)", fontsize=21)
    plt.ylabel("Number of introns", fontsize=21)
    plt.title('IR: MAJIQ and LR PSI ($\Psi$)', fontsize=23)
    plt.show;


    temp7 = []
    temp8 = []
    value = 1.00

    for i in range(10):
        for j in range(lr_psi_count[i]):
            temp7.append(value)
        for j in range(ir_bin2[i]):
            temp8.append(value)
        value -= 0.11

    fig, ax = plt.subplots(figsize=(10,6))
    ax = pd.Series(temp1).plot.hist(alpha = 1.0, label='LR IR', histtype='step', cumulative=True, align='left', color = 'tomato',
                                  bins =[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], linewidth=1.3)
    ax = pd.Series(temp2).plot.hist(alpha = 0.4, label='MAJIQ IR', histtype='step', cumulative=True, align='left',color = 'dimgrey',
                                  bins =[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], linewidth=1.3)

    plt.xticks(fontsize=18, rotation=30)
    plt.yticks(fontsize=18)

    plt.xlabel("PSI ($\Psi$)", fontsize=21)
    plt.ylabel("Number of introns", fontsize=21)

    ax.legend(loc='upper left', fontsize=12)

    ax2 = ax.twinx()
    ax2.hist(pd.Series(temp1), alpha=0.5, label='majiq', density=True, histtype='step', cumulative=True, align='left',
            bins=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
    ax2.set_yticklabels(['0', '20', '40', '60', '80', '100'])

    ax.set_xticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
    ax.set_xticklabels(['0.9-1.0', '0.8-0.9', '0.7-0.8', '0.6-0.7', '0.5-0.6', '0.4-0.5', '0.3-0.4', '0.2-0.3', '0.1-0.2', '0-0.1'])

    plt.yticks(fontsize=18)
    plt.ylabel("Cumulative probability", fontsize=18)
    plt.title('IR: MAJIQ and LR PSI ($\Psi$)', fontsize=23)
    plt.show;


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Run comparative analysis on a tsv file.')
    parser.add_argument('--comparison', type=str, help='Path to the comparison tsv file', required=True)
    parser.add_argument('--output-path', type=str, help='Path to the output folder', required=True)
    args = parser.parse_args()

    comparative_analysis(args.comparison, args.output)






