from gtfparse import read_gtf
import pandas as pd
import pyranges as pr
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")
from tqdm.notebook import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib import pyplot
import matplotlib.ticker as mticker
import numpy as np
from scipy.interpolate import interp1d
from collections import defaultdict


class Distance3:

    def __init__(self,gtf_annotation, splice_junctions:list =[]):

        """
        Preprocessing
        """
        gtf_annotation=gtf_annotation[['gene_id','Start',"End","transcript_id", "Feature"]]
        gtf_annotation=gtf_annotation.loc[gtf_annotation["Feature"]=="exon"]
        self.exon_end_2_genes=defaultdict(list)
        self.exon_end_2_transcripts=defaultdict(list)
        self.transcript_2_exons=defaultdict(list)
        self.transcript_len=defaultdict(int)

        for index, row in gtf_annotation.iterrows():
            self.exon_end_2_genes[row['End']].append(row['gene_id'])
            self.exon_end_2_transcripts[row['End']].append(row['transcript_id'])
            self.transcript_2_exons[row['transcript_id']].append([row['Start'],row['End']])
            self.transcript_len[row['transcript_id']]+=((row['End']-row['Start']+1))
        
        for key in self.transcript_2_exons:
            self.transcript_2_exons[key] = sorted(self.transcript_2_exons[key])


    def findTrancriptIdandDist(self,splice_junction):
        """
        Input: one splice junction
        Work: Find all transcript nearest to the splice junction. Get the longest transcript. Merge all exons until splice junction
        Output: distance from UTR to splice junction
        """

        #Find all transcript nearest to the splice junction
        closest_transcipts_end=0
        if splice_junction[0] in self.exon_end_2_transcripts:
            closest_transcipts_end=splice_junction[0]
        else:
            closest_transcipts_end = min(self.exon_end_2_transcripts.keys(), key=lambda x: abs(int(x) - splice_junction[0]))
        
        transcript_lists=self.exon_end_2_transcripts[closest_transcipts_end]

        #Get the longest transcript
        max_value = 0
        longest_transcript='none'
        for key in transcript_lists:
            if max_value<self.transcript_len[key]:
                max_value=self.transcript_len[key]
                longest_transcript=key
        
        #Merge all exons until splice junction
        distance_from_3UTR=0
        for exon in self.transcript_2_exons[longest_transcript]:
            if exon[1]<=splice_junction[0]:
                distance_from_3UTR+=(exon[1]-exon[0]+1)
            else:
                continue
        
        return distance_from_3UTR


gtf_annotation = pr.read_gtf("gencode.v42.annotation.gtf", full=True).df


def calculate_distance(gtf_file):
    spliced_juncs = eval(open(gtf_file, 'r').read())
    output_file="final_ont.txt"

    # save distances in a dictionary
    all_junctions = list(spliced_juncs.keys())

    print(len(all_junctions))

    outfile = open(output_file, "w")
    outfile.write("{")
    outfile.close()

    DistanceCalculator = Distance3(gtf_annotation, splice_junctions = all_junctions)

    for i, splice_junction in enumerate(all_junctions):
        dist=DistanceCalculator.findTrancriptIdandDist(splice_junction)
        with open(output_file, "a") as outfile:
            outfile.write(str(splice_junction))
            outfile.write(": "+str(dist))
            if (i+1)!=len(all_junctions):
                outfile.write(",")
            else:
                outfile.write("}")

    outfile.close()


def generate_hist(splice_files, last_coordinate_files, bins_num):
    x1s=[]
    y1s=[]
    filepath="" 

    for i, f in enumerate(splice_files):
        last_coordinate_file = filepath+last_coordinate_files[i]
        final_file = filepath+f

        print(last_coordinate_file)
        # load last coordinate file and enter to dictionary
        with open(last_coordinate_file) as file:
            file = file.read()

        file = file[1:-1] # remove { } from beginning and end
        file = file.replace('(', '') # remove all (
        file = file.replace('):', ',') # replace all ): with a comma
        file = file.replace(' ', '') # remove all empty spaces
        file = file.split(',') # split to list

    #print(file)
    last_coordinate = {}

    for i in tqdm(range(len(file))):
        if (i+1) % 3 == 0:
            last_coordinate[f'{file[i-2]}, {file[i-1]}'] = int(file[i])

    # load final file, filter it, check if in A, and write to dictionary
    with open(final_file) as file:
        final_file = file.read()

    file = final_file[1:-1] # remove { } from beginning and end
    file = file.replace('(', '') # remove all (
    file = file.replace('):', ',') # replace all ): with a comma
    file = file.replace(' ', '') # remove all empty spaces
    file = file.split(',') # split to list

    filter_value = 10

    b_in_a = {}

    for i in tqdm(range(len(file))):
        if (i+1) % 3 == 0 and float(file[i]) > filter_value: # check if count value is greater than filter value
            pair = f'{file[i-2]}, {file[i-1]}'

            # check if pair in A
            if not last_coordinate.get(pair):
                continue
            else:
                b_in_a[pair] = last_coordinate.get(pair)


    pcb1=list(last_coordinate.values())
    pcb2=list(b_in_a.values())

    # calculate bin size by taking (max coordinate - min coordinate) divided by 10 and generate bins list
    bin_size = (max(pcb1) - min(pcb1)) / bins_num
    bins_list = [round(i * bin_size) for i in range((bins_num+1))]

    # calculate % by (A-B)/A and graph
    A, bins = np.histogram(pcb1, bins=bins_list)
    B, bins = np.histogram(pcb2, bins=bins_list)

    y1 = (A-B)/A # calculate percentage values
    x1 = bins_list[1:]

    # Calculate the quartiles and IQR
    Q1 = np.percentile(pcb1, 5)
    Q3 = np.percentile(pcb1, 90)
    IQR = Q3 - Q1

    # Calculate the lower and upper bounds
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR

    # Create a new list without outlier values
    new_pcb1 = [v for v in pcb1 if  v<=upper_bound]

    # Create a new list without outlier values
    new_pcb2 = [v for v in pcb2 if v<=upper_bound]

    bin_size = (max(new_pcb1) - min(new_pcb1)) / bins_num
    bins_list = [round(i * bin_size) for i in range(bins_num+1)]

    # calculate % by (A-B)/A and graph
    A, bins = np.histogram(new_pcb1, bins=bins_list)
    B, bins = np.histogram(new_pcb2, bins=bins_list)

    y1 = (A-B)/A # calculate percentage values
    x1 = bins_list[1:]

    #print(f, x1, y1)
    x1s.append(list(x1))
    y1s.append(list(y1))

    return x1s, y1s


def combine_plot(x1s, y1s, bins_num):
    x_values = x1s
    y_values = y1s

    # Define the range of x values for interpolation
    x_range = np.linspace(max(min(x_values)), min(max(x_values)), bins_num)# x_values[1]

    # Interpolate each line using cubic spline interpolation
    interpolated_lines = []
    for i in range(len(x_values)):
        print(x_values[i], y_values[i])
        f = interp1d(x_values[i], y_values[i], kind='cubic')
        y_interp = f(x_range)
        interpolated_lines.append(y_interp)

    # Plot the original lines and interpolated lines
    for i in range(len(x_values)):
        plt.plot(x_values[i], y_values[i], 'ro', x_range, interpolated_lines[i], 'b-', label='junction'+str(i))
    plt.legend(loc='upper right')
    plt.show()

    interpolated_lines_mean=np.mean(interpolated_lines, axis=0)
    return x_range, interpolated_lines_mean           



