import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import sqlite3

conn = sqlite3.connect('splicegraph.sql')
conn.execute('pragma foreign_keys=ON')

def sr_gene_exons(gene_id):
    
    query = conn.execute('''
                        SELECT gene_id, start, end, annotated_start, annotated_end, annotated 
                        FROM exon 
                        WHERE gene_id=?
                        ''', (gene_id,))
    while True:
        fetch = query.fetchmany(100)
        if not fetch:
            break
        for x in fetch:
            yield dict(zip(('gene_id', 'start', 'end', 'annotated_start', 'annotated_end', 'annotated'), x))

           
def adjusted_psi(all_junc_reads):
    nj = len(all_junc_reads)
    adj_psi = []

    for junc_i, junc_reads in enumerate(all_junc_reads):
        a = junc_reads + (1/nj)
        b = sum(all_junc_reads[:junc_i] + all_junc_reads[junc_i+1:]) + ((nj-1)/nj) * (nj-1)
        adjusted_psi = a / (a+b)
        adj_psi.append(adjusted_psi)
    
    return adj_psi


def extract_psi_from_lr(voila_file):
    
    with open(voila_file, 'rb') as f:
        lr_data = pickle.load(f)

    gene_id_list = list(lr_data.keys())

    cols = ['gene_id','exon_coordinate', 'exon_number', 
        'junction_coordinate', 'length' ,'event',
        'junction_name','reads', 'adjusted_psi']
    
    df = pd.DataFrame(columns = ['gene_id','exon_start','exon_end',
                             'exon_number', 'junc_start','junc_end', 
                             'event', 'junction_name', 'reads', 'adjusted_psi'])
    
    for gene_id in tqdm(gene_id_list):
    
        junc_pairs_from_sql = [(x['start'], x['end'],) for x in sr_gene_exons(gene_id) if x['start'] != -1 and x['end'] != -1 and x['start'] != x['end'] and x['annotated'] ==1]
        
        ## Get Ubfi from
        t_dict = lr_data[gene_id]['transcripts']
        # Extract junctions and junction_reads 
        junctions_with_reads = {(junction, read) for transcript in t_dict for junction, read in zip(transcript['junctions'], transcript['junction_reads'])}

        # Extract intron_retention and intron_retention_reads
        intron_retentions_with_reads = {(retention, read) for transcript in t_dict for retention, read in zip(transcript['intron_retention'], transcript['intron_retention_reads'])}

        # Convert sets of tuples to dictionaries
        junctions_with_reads_dict = dict(junctions_with_reads)
        intron_retentions_with_reads_dict = dict(intron_retentions_with_reads)

        # Get arrays of coordinates only
        junctions_array = np.array([i[0] for i in list(junctions_with_reads)])
        intron_retention_array = np.array([i[0] for i in list(intron_retentions_with_reads)])


        x = 0
        for exon in junc_pairs_from_sql:
            x = x+1

            ## Get source junctions
            if len(junctions_array) != 0:
                check_right = (exon[0], exon[1] + 1)
                coords = junctions_array[np.where((junctions_array[:,0] >= check_right[0]) & 
                                                  (junctions_array[:,0] <= check_right[1]) &
                                                  (junctions_array[:,1] > check_right[1]))]
                if len(coords)!= 0:
                    coords_list = coords.tolist()
                    for i in coords_list:
                        reads = junctions_with_reads_dict[tuple(i)]
                        df.loc[len(df)] = [gene_id] + list(exon) + [x] + i + ['junction'] + ['source'] + [reads] + [0]

            ## Get source introns
            if len(intron_retention_array) != 0:
                coords = intron_retention_array[np.where((intron_retention_array[:,0] >= check_right[0]) & 
                                                          (intron_retention_array[:,0] <= check_right[1]) &
                                                          (intron_retention_array[:,1] > check_right[1]))]
                if len(coords)!=0:
                    coords_list = coords.tolist()
                    for i in coords_list:
                        reads = intron_retentions_with_reads_dict[tuple(i)]
                        df.loc[len(df)] = [gene_id] + list(exon) + [x] + i + ['intron'] + ['source'] + [reads] + [0]



            ## Get target junction
            if len(junctions_array) != 0:
                check_left = (exon[0]-1, exon[1])
                coords = junctions_array[np.where((junctions_array[:,1] <= check_left[1]) & 
                                                  (junctions_array[:,1] >= check_left[0]) &
                                                  (junctions_array[:,0] < check_left[0]))]
                if len(coords) != 0:
                    coords_list = coords.tolist()
                    for i in coords_list:
                        reads = junctions_with_reads_dict[tuple(i)]
                        df.loc[len(df)] = [gene_id] + list(exon) + [x] + i + ['junction'] + ['target'] + [reads] + [0]

            ## Get target intron
            if len(intron_retention_array) != 0:
                coords = intron_retention_array[np.where((intron_retention_array[:,1] <= check_left[1]) & 
                                                          (intron_retention_array[:,1] >= check_left[0]) &
                                                          (intron_retention_array[:,0] < check_left[0]))]
                if len(coords)!=0:
                    coords_list = coords.tolist()
                    for i in coords_list:
                        reads = intron_retentions_with_reads_dict[tuple(i)]
                        df.loc[len(df)] = [gene_id] + list(exon) + [x] + i + ['intron'] + ['target'] + [reads] + [0]

    df['exon_coordinate'] = df.exon_start.map(str) + '-' + df.exon_end.map(str)
    df['junction_coordinate'] = df.junc_start.map(str) + '-' + df.junc_end.map(str)
    df['length'] = df.junc_end - df.junc_start + 1

    df_results = df[cols]
    df_results['adjusted_psi'] = df_results.groupby(['gene_id','exon_number','junction_name']).reads.transform(lambda x: adjusted_psi(x.to_list())).explode()
    df_results['adjusted_psi'] = df_results.adjusted_psi.round(3)
    

