from tqdm import tqdm
import pandas as pd
pd.set_option('display.max_columns', None)
import ast
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter
import scipy.stats as stats 
import os
from concurrent.futures import ThreadPoolExecutor
from Bio import BiopythonWarning
from Bio import SeqIO
from Bio.Seq import Seq
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=BiopythonWarning)
    from Bio import pairwise2
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import mannwhitneyu
from matplotlib.patches import Patch
import numpy as np
import pysam

sequencing_type = "isq_pcb1_gtf.txt" ### Choose between isq_pcb1_gtf.txt OR isq_ont1_gtf.txt
group = "only_in_junctions" ### Choose between only_in_junctions OR intersected_junctions
out_suffix = "isq_pcb1_unique" ### Chosse a suffix to the output tables

if not os.path.exists(str(out_suffix)):
    os.makedirs(str(out_suffix))
    
##function calculate distance
def distance_calculation(row, all_exons, load_pbar):
    gene_ID = row[1]
    mask = all_exons['gene_id'].str.contains(gene_ID)
    matched_rows = all_exons[mask]
    #matched_rows = all_exons[all_exons['gene_id'].str.startswith(gene_ID)]

    if not matched_rows.empty:
        coverage = row[2]
        junction = row[3]
        end_jc, start_jc = row[3].split('-')

        end_jc_exons = matched_rows[matched_rows['end'] == int(end_jc)]
        ##### End position is the important to subtraction
        start_jc_exons = matched_rows[matched_rows['start'] == int(start_jc)]
        concatenated_df = pd.concat([end_jc_exons['transcript_id'], start_jc_exons['transcript_id']])
        concatenated_df = pd.DataFrame({'transcript_found': concatenated_df.drop_duplicates()})
        transcript_list = concatenated_df['transcript_found'].tolist()

        #### Identify smaller transcript by summing exons and UTRs per isoform
        smaller_transcript_df = matched_rows[matched_rows['transcript_id'].isin(transcript_list)]
        sum_per_transcript = smaller_transcript_df.groupby('transcript_id')['length'].sum().reset_index()
        smaller_transcript_ID = sum_per_transcript.sort_values(by='length').head(1)
        transcript_ID = ''.join(smaller_transcript_ID['transcript_id'].astype(str))

        #### Identify which exon of the smaller isoform the juction occurs
        exon_with_jc = end_jc_exons[end_jc_exons['transcript_id'] == transcript_ID]
        exon_with_jc = ''.join(exon_with_jc['exon_id'].astype(str))

        exons_downstream_jc = matched_rows[matched_rows['transcript_id'] == transcript_ID]

        length_df = exons_downstream_jc[exons_downstream_jc['exon_id'] >= exon_with_jc]
        length_to_3UTR = length_df['length'].sum()

        with open(f"{out_suffix}/{out_suffix}_distance.tsv", 'a') as f:
            f.write(f"{gene_ID}\t{coverage}\t{junction}\t{length_to_3UTR}\n")
        result = str(f"{gene_ID}")
        load_pbar.update(1)
        return result

#### Prepare input to calculate distances
#### Drop duplicates and junction with < 40 reads:
df = pd.read_csv('junctions.tsv', sep='\t', engine='python')
#df = pd.read_csv('junctions_missing.tsv', sep='\t', engine='python')

df_main = df.copy()
df_main = df_main[['gene_id', 'ENCFF766OAK_median_reads', 'junction_coord']]
df_main.drop_duplicates(keep='first', inplace=True)

df_main.loc[:, 'ENCFF766OAK_median_reads'] = pd.to_numeric(df_main['ENCFF766OAK_median_reads'], errors='coerce')
df_main_min40 = df_main[df_main['ENCFF766OAK_median_reads'] >= 40]

#### Create junction dataframe from ont OR pacbio files
if sequencing_type == "isq_ont1_gtf.txt":
    with open("isq_ont1_gtf.txt", "r") as file:
        data_str = file.read()
elif sequencing_type == "isq_pcb1_gtf.txt":
    with open("isq_pcb1_gtf.txt", "r") as file:
        data_str = file.read()
else:
    print('Provide either "isq_ont1_gtf.txt" OR "isq_pcb1_gtf.txt", sequencing_type not found!')

data_tuples = ast.literal_eval(data_str)
juntions_txt = pd.DataFrame(data_tuples, columns=['JC1', 'JC2'])
juntions_txt['junction_coord'] = juntions_txt['JC1'].astype(str) + '-' + juntions_txt['JC2'].astype(str)

#### Get common junctions between dataframes
junctions_txt_values = set(juntions_txt['junction_coord'])

intersected_junctions = df_main_min40[df_main_min40['junction_coord'].isin(junctions_txt_values)]
#### Get unique junctions from junctions.tsv
only_in_junctions_tsv = df_main_min40[~df_main_min40['junction_coord'].isin(junctions_txt_values)]

#### Import gtf file
complete_gene = pd.read_csv(str("gencode_v45_junc_genes.gff3"), header=None, sep="\t",
                            usecols=[0, 1, 2, 3, 4, 5, 6, 7, 8],
                            names=['chr', 'source', 'feature', 'start', 'end', 'dot', 'strand', 'dot2', 'ID'])

#### Create exons dataframe
all_exons = complete_gene.copy()
all_exons = all_exons[all_exons['feature'].isin(["exon"])]
all_exons['gene_id'] = all_exons['ID'].str.replace(r'ID=.*?gene_id=', '', regex=True)
all_exons['gene_id'] = all_exons['gene_id'].str.split(';').str[0]

all_exons['transcript_id'] = all_exons['ID'].str.replace('ID=.*Parent=', '')
all_exons['transcript_id'] = all_exons['transcript_id'].str.split(';').str[0]

all_exons['exon_id'] = all_exons['ID'].str.split(';').str[0]
all_exons['exon_id'] = all_exons['exon_id'].str.replace('ID=UTR5:', '')
all_exons['exon_id'] = all_exons['exon_id'].str.replace('ID=UTR3:', '')
all_exons['exon_id'] = all_exons['exon_id'].str.replace('ID=exon:', '')
all_exons['length'] = all_exons['end'] - all_exons['start']
# print(all_exons)

#### Create transcripts dataframe
all_transcripts = complete_gene.copy()
all_transcripts = all_transcripts[all_transcripts['feature'] == "transcript"]
all_transcripts['ID'] = all_transcripts['ID'].str.split(';').str[0]
all_transcripts['ID'] = all_transcripts['ID'].str.replace("ID=", '')
all_transcripts['length'] = all_transcripts['end'] - all_transcripts['start']
all_transcripts = all_transcripts[['ID', 'length']]


def extract_sequence(chromosome, start, end, strand):
    fasta_file = "hg38.fa"
    faidx_file = fasta_file + ".fai"

    # Create a FastaFile object from indexed FASTA file
    fasta = pysam.FastaFile(fasta_file)

    # Check if chromosome exists in the index
    if chromosome not in fasta.references:
        print(f"Chromosome {chromosome} not found in the file.")
        return ""
    # Extract sequence
    if strand == '+':
        sequence = fasta.fetch(chromosome, start-1, end).upper()
    elif strand == '-':
        sequence = fasta.fetch(chromosome, start-1, end).upper()
        sequence = reverse_complement(sequence)
    # Close the FastaFile object
    fasta.close()
    return sequence

def calculate_GC_content(sequence):
    gc_count = sequence.count('G') + sequence.count('C')
    gc_content = (gc_count / len(sequence)) * 100
    return gc_content

def calculate_GA_content(sequence):
    ga_count = sequence.count('G') + sequence.count('A')
    ga_content = (ga_count / len(sequence)) * 100
    return ga_content

def GC_GA_calc(row_it, ref_df):
    gene_id = row_it[1]
    end_jc_pc1 = int(row_it[4]) - 49
    end_jc_pc2 = row_it[4]
    start_jc_pc1 = row_it[5]
    start_jc_pc2 = int(row_it[5]) + 49
    chromosome = ref_df[ref_df['ID'] == gene_id]
    if not chromosome.empty:
        chromosome = ''.join(chromosome['chr'].astype(str))

        strand = ref_df[ref_df['ID'] == gene_id]
        strand = ''.join(strand['strand'].astype(str))

        upstream_boundery = extract_sequence(chromosome, int(end_jc_pc1), int(end_jc_pc2), str(strand))
        upstream_boundery = upstream_boundery.upper()  ### Unmask soft-masked nucleotides
        downstream_boundery = extract_sequence(chromosome, int(start_jc_pc1), int(start_jc_pc2), str(strand))
        downstream_boundery = downstream_boundery.upper()  ### Unmask soft-masked nucleotides
        junction_seq = upstream_boundery + downstream_boundery

        GC_content = calculate_GC_content(junction_seq)
        GC_content = round(GC_content, 2)

        GA_content = calculate_GA_content(junction_seq)
        GA_content = round(GA_content, 2)
        with open(f"{out_suffix}/GC_GA_{out_suffix}_content_output.tsv", 'a') as out:
            out.write(f"{gene_id}\t{chromosome}\t{strand}\t{end_jc_pc2}\t{start_jc_pc1}\t{GC_content}\t{GA_content}\n")        


complete_gene = pd.read_csv(str("gencode_v45_junc_genes.gff3"), header=None, sep="\t",
                                usecols=[0, 1, 2, 3, 4, 5, 6, 7, 8],
                                names=['chr', 'source', 'feature', 'start', 'end', 'dot', 'strand', 'dot2', 'ID'])
gene_strand = complete_gene.copy()
gene_strand = gene_strand[gene_strand['feature'] == "gene"]
gene_strand['ID'] = gene_strand['ID'].str.split(';').str[0]
gene_strand['ID'] = gene_strand['ID'].str.replace('ID=','')
gene_strand = gene_strand[['ID', 'chr', 'strand']]
print(gene_strand)

#### Drop duplicates and junction with < 40 reads:
df = pd.read_csv('junctions.tsv', sep='\t', engine='python')
df_main = df.copy()
df_main = df_main[['gene_id', 'ENCFF766OAK_median_reads', 'junction_coord']]
df_main.drop_duplicates(keep='first', inplace=True)

df_main.loc[:, 'ENCFF766OAK_median_reads'] = pd.to_numeric(df_main['ENCFF766OAK_median_reads'], errors='coerce')
df_main_min40 = df_main[df_main['ENCFF766OAK_median_reads'] >= 40]


#### Create junction dataframe from ont OR pacbio files
if sequencing_type == "isq_ont1_gtf.txt":
    with open("isq_ont1_gtf.txt", "r") as file:
        data_str = file.read()
elif sequencing_type == "isq_pcb1_gtf.txt":
    with open("isq_pcb1_gtf.txt", "r") as file:
        data_str = file.read()

#### Merge junctions coordinats in the format exon_end-exon_start
data_tuples = ast.literal_eval(data_str)
juntions_txt = pd.DataFrame(data_tuples, columns=['JC1', 'JC2'])
juntions_txt['junction_coord'] = juntions_txt['JC1'].astype(str) + '-' + juntions_txt['JC2'].astype(str)

#### Get common junctions between dataframes
junctions_txt_values = set(juntions_txt['junction_coord'])

intersected_junctions = df_main_min40[df_main_min40['junction_coord'].isin(junctions_txt_values)]

#### Get unique junctions from junctions.tsv
only_in_junctions_tsv = df_main_min40[~df_main_min40['junction_coord'].isin(junctions_txt_values)]
only_in_junctions_tsv = only_in_junctions_tsv.copy()
only_in_junctions_tsv[['end_jc', 'start_jc']] = only_in_junctions_tsv['junction_coord'].str.split('-', expand=True)

print(only_in_junctions_tsv)
intersected_junctions = intersected_junctions.copy()
intersected_junctions[['end_jc', 'start_jc']] = intersected_junctions['junction_coord'].str.split('-', expand=True)

if group == "only_in_junctions":
    print("Only in junctions.tsv")
    with tqdm(total=len(only_in_junctions_tsv)) as pbar: ##### GC/GA content for "only in junctions.tsv"
        for row in only_in_junctions_tsv.itertuples():
            
            #GC_GA_calc(row, gene_strand)
            
            pbar.update(1)
elif group == "intersected_junctions":
    print("intersected_junctions.tsv")
    with tqdm(total=len(intersected_junctions)) as pbar: ##### GC/GA content for shared junctions between junctions.tsv and "data_str"
        for row in intersected_junctions.itertuples():
            
            #GC_GA_calc(row, gene_strand)
            
            pbar.update(1)
else:
    print('You have to choose between "only_in_junctions" OR "intersected_junctions", change group variable!'); exit()


distance_df = pd.read_csv(f"{out_suffix}/{out_suffix}_distance.tsv",
                 sep = "\t", header = None)
col_names = ['gene_id', 'reads', 'junction', 'distance']
distance_df.columns = col_names

distance_df[['End', 'Start']] = distance_df['junction'].str.split('-', expand=True)
distance_df.drop('junction', axis=1, inplace=True)
distance_df['End'] = distance_df['End'].astype(int)
distance_df['Start'] = distance_df['Start'].astype(int)
print(distance_df)

GC_GAdf = pd.read_csv(f"{out_suffix}/GC_GA_{out_suffix}_content_output.tsv",
                      sep = "\t", header = None)
col_names = ['gene_id', 'chr', 'strand', 'End', 'Start', 'GC', 'GA']
GC_GAdf.columns = col_names
print(GC_GAdf)

merged_df = pd.merge(distance_df, GC_GAdf, on=['gene_id', 'End', 'Start'], how='inner')
print(merged_df)
merged_df.to_csv(f"{out_suffix}/{out_suffix}_complete.tsv",
                 header=True, sep = "\t", index= False)
if os.path.exists(f"{out_suffix}/{out_suffix}_complete.tsv"):
    print("Merged file has been created!")


def reverse_complement(sequence):
    complement_dict = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    reverse_sequence = sequence[::-1]
    reverse_complement_sequence = ''.join([complement_dict[base] for base in reverse_sequence])
    return reverse_complement_sequence


def blend_with_white(color, alpha):
    # Assuming color is in the range (0-255)
    # Convert to 0-1 range for matplotlib
    r, g, b = color
    return (r / 255 * alpha + 1 * (1 - alpha),
            g / 255 * alpha + 1 * (1 - alpha),
            b / 255 * alpha + 1 * (1 - alpha))

# Approximate the colors
dodgerblue_approx = blend_with_white((30, 144, 255), 0.7)
slategrey_approx = blend_with_white((112, 128, 144), 0.3)
tomato_approx = blend_with_white((255, 99, 71), 0.8)

def boxplots(combined_df, content_group, seq_method):
    # Approximate colors based on alpha blending with white
    dodgerblue_approx = blend_with_white((30, 144, 255), 0.7)
    slategrey_approx = blend_with_white((112, 128, 144), 0.3)
    tomato_approx = blend_with_white((255, 99, 71), 0.8)
#         'PacBio': f'#{dodgerblue_approx[0]:02x}{dodgerblue_approx[1]:02x}{dodgerblue_approx[2]:02x}',  # Convert to hex
    
    dodgerblue_hex = tuple(int(255 * x) for x in dodgerblue_approx)
    slategrey_hex = tuple(int(255 * x) for x in slategrey_approx)
    tomato_hex = tuple(int(255 * x) for x in tomato_approx)
    
    # Custom palette using approximated colors / tomato_approx
    custom_palette = {
        'PacBio': f'#{tomato_hex[0]:02x}{tomato_hex[1]:02x}{tomato_hex[2]:02x}',
        'Illumina': f'#{slategrey_hex[0]:02x}{slategrey_hex[1]:02x}{slategrey_hex[2]:02x}'
    }

    # Define the order of bins
    bin_order = ["Bin 0", "Bin 1", "Bin 2", "Bin 3", "Bin 4"]

    # Plotting with seaborn
    plt.figure(figsize=(10, 6))
    ax = sns.boxplot(data=combined_df, x='combined_bin', y=content_group, hue='dataset',
                     palette=custom_palette, order=bin_order, showfliers=False)
    
    plt.title(str(seq_method), fontsize=18)
    # plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
    plt.xlabel("Distance from 3' end", fontsize=18)
    plt.ylabel("GC content (%)", fontsize=18)
    # plt.ylabel(str(content_group), fontsize=21)
    plt.xticks(fontsize=15, rotation=30)
    plt.yticks(fontsize=15)
    #plt.legend(title='dataset')
    plt.grid(True)

    #plt.show()
    # Compute and annotate Wilcoxon test for each pair of boxplots
    for i, bin_value in enumerate(bin_order):
        if i == 0:
            i = 0.1
        elif i == 1:
            i = 0.3
        elif i == 2:
            i = 0.5
        elif i == 3:
            i = 0.7
        elif i == 4:
            i = 0.9
        pure_data = combined_df[(combined_df['combined_bin'] == bin_value) & (combined_df['dataset'] == 'Illumina')]['GA']
        overlapped_data = \
        combined_df[(combined_df['combined_bin'] == bin_value) & (combined_df['dataset'] == 'PacBio')]['GA']
        # Randomly sample overlapped_data to match the length of pure_data
        overlapped_data_sampled = overlapped_data.sample(len(pure_data), replace=True)

        # Perform mannwhitneyu signed-rank test
        statistic, p_value2 = stats.ttest_ind(pure_data, overlapped_data_sampled)
        if len(pure_data) > 1 and len(overlapped_data_sampled) > 1:
            statistic, p_value2 = stats.mannwhitneyu(pure_data, overlapped_data_sampled)
            # print(f"mannwhitneyu test for {bin_value}: statistic={statistic}, p-value={p_value}")

        # Annotate p-value above each pair of boxplots
        ax.text(i, 0.97, f'p-value = {p_value2:.2e}', horizontalalignment='center', verticalalignment='center',
                transform=ax.transAxes)

    # plt.legend(bbox_to_anchor=(1.05, 1))
    custom_legend = [
        Patch(facecolor=tomato_approx, label='PacBio+Illumina', edgecolor='k', linewidth=0.6),
        Patch(facecolor=slategrey_approx, label='Illumina only', edgecolor='k', linewidth=0.6)
    ]
    ax.legend(handles=custom_legend, loc='lower right')
    # plt.legend(loc='lower right')
    plt.tight_layout()

    #### Save the plot as a PNG file
    plt.savefig(f"{seq_method}_{content_group}.pdf", dpi = 300)


def extract_sequence(chromosome, start, end, strand):
    fasta_file = "hg38.fa"
    faidx_file = fasta_file + ".fai"

    # Create a FastaFile object from indexed FASTA file
    fasta = pysam.FastaFile(fasta_file)

    # Check if chromosome exists in the index
    if chromosome not in fasta.references:
        print(f"Chromosome {chromosome} not found in the file.")
        return ""
    # Extract sequence
    if strand == '+':
        sequence = fasta.fetch(chromosome, start-1, end).upper()
    elif strand == '-':
        sequence = fasta.fetch(chromosome, start-1, end).upper()
        sequence = reverse_complement(sequence)
    # Close the FastaFile object
    fasta.close()
    return sequence

def seq_extract(binned_data, out_f):
    bin_order = ["0.5", "1.0", "1.5", "2.0", "2.5"]
    for bin_n in bin_order:
        bin_out = bin_n.replace(" ", "")
        subset_df = binned_data[(binned_data['combined_bin'] == bin_n) & (binned_data['dataset'] == 'Pure')]
        for i, row_tup in enumerate(subset_df.itertuples(index=False)):
            chr = row_tup[5]
            strand = row_tup[6]
            end_jc1 = row_tup[3]
            start_jc1 = row_tup[3] - 49
            start_jc = row_tup[4]
            end_jc2 = row_tup[4] + 49
            if strand == "+":
                seq_jc1 = extract_sequence(chr, start_jc1, end_jc1, strand)
                seq_jc2 = extract_sequence(chr, start_jc, end_jc2, strand)
                seq_complete = str(seq_jc1 + seq_jc2)
            elif strand == "-":
                seq_jc1 = extract_sequence(chr, start_jc1, end_jc1, strand)
                seq_jc1 = reverse_complement(seq_jc1)
                seq_jc2 = extract_sequence(chr, start_jc, end_jc2, strand)
                seq_jc2 = reverse_complement(seq_jc2)
                seq_complete = str(seq_jc1 + seq_jc2)

            with open(f"{out_f}_pure_{bin_out}.fa", 'a') as f_out:
                f_out.write(f">{i}\n{seq_complete}\n")

        subset_df = binned_data[(binned_data['combined_bin'] == bin_n) & (binned_data['dataset'] == 'Overlapped')]
        for i, row_tup in enumerate(subset_df.itertuples(index=False)):
            chr = row_tup[5]
            strand = row_tup[6]
            end_jc1 = row_tup[3]
            start_jc1 = row_tup[3] - 49
            start_jc = row_tup[4]
            end_jc2 = row_tup[4] + 49
            if strand == "+":
                seq_jc1 = extract_sequence(chr, start_jc1, end_jc1, strand)
                seq_jc2 = extract_sequence(chr, start_jc, end_jc2, strand)
                seq_complete = str(seq_jc1 + seq_jc2)
            elif strand == "-":
                seq_jc1 = extract_sequence(chr, start_jc1, end_jc1, strand)
                seq_jc1 = reverse_complement(seq_jc1)
                seq_jc2 = extract_sequence(chr, start_jc, end_jc2, strand)
                seq_jc2 = reverse_complement(seq_jc2)
                seq_complete = str(seq_jc1 + seq_jc2)

            with open(f"{out_f}_overlapped_{bin_out}.fa", 'a') as f_out:
                f_out.write(f">{i}\n{seq_complete}\n")
                
def binning_equal_length(intersected_df, unique_df):
    num_bins = 5
    
    def cut_into_bins(df):
        #### Transform zero to one
        df['distance'] = df['distance'].replace(0, 1)

        #### Sort distance in ascending order
        df = df.sort_values(by='distance')
        
        #### Convert 'distance' column to float64
        df['distance'] = df['distance'].astype('float64')
        
        
        # Calculate the number of rows in each bin
        num_rows_per_bin = len(df) // 5
        
        # Create an empty list to store bin labels
        bin_labels = []

        # Initialize variables
        current_bin = 0
        bin_count = 0

        # Iterate over the dataframe and assign bin labels
        for index, row in df.iterrows():
            bin_labels.append(f'Bin {current_bin}')
            bin_count += 1
            if bin_count == num_rows_per_bin:
                bin_count = 0
                current_bin += 1
                if current_bin == 5:
                    current_bin = 4  # Ensure the last bin index is 4, not 5

        # Assign bin labels to the dataframe
        df['combined_bin'] = bin_labels

        return df
    
    ### Read data from CSV files
    result_intersected = pd.read_csv(intersected_df, sep="\t")
    result_unique = pd.read_csv(unique_df, sep="\t")

    ### Apply binning to both data frames
    result_intersected = cut_into_bins(result_intersected)
    result_unique = cut_into_bins(result_unique)

    ### Combine data frames
    combined_df = pd.concat([result_intersected.assign(dataset='PacBio'),
                             result_unique.assign(dataset='Illumina')], ignore_index=True)
    return combined_df


def binning_log_transf(intersected_df, unique_df):
    #### Import data
    result_intersected = pd.read_csv(intersected_df, sep = "\t")
    result_unique = pd.read_csv(unique_df, sep = "\t")

    result_intersected.loc[result_intersected['distance'] == 0, 'distance'] = 1
    result_unique.loc[result_unique['distance'] == 0, 'distance'] = 1

    #### Log transformation
    result_intersected['distance'] = np.log(result_intersected['distance'])
    result_unique['distance'] = np.log(result_unique['distance'])

    
    num_bins = 5
    # Add "Bin " before the numbers in the bin column
    result_intersected['combined_bin'] = pd.cut(result_intersected['distance'], bins=num_bins, labels=False)
    result_intersected['combined_bin'] = 'Bin ' + result_intersected['combined_bin'].astype(str)

    result_unique['combined_bin'] = pd.cut(result_unique['distance'], bins=num_bins, labels=False)
    result_unique['combined_bin'] = 'Bin ' + result_unique['combined_bin'].astype(str)

    #### Combine dataframes
    combined_df = pd.concat([result_intersected.assign(dataset='PacBio'),
                             result_unique.assign(dataset='Illumina')])
    combined_df.to_csv("res_log_jypyter.tsv", sep ="\t")

    #### Create a new column for x-axis
    return combined_df