import argparse
import sqlite3
import pickle
from tqdm import tqdm
import sys, os
from csv import DictWriter
import csv
import pandas as pd

parser = argparse.ArgumentParser(description='Obtain raw counts of junctions of overlapping existance between long and short reads. '
                                             'This tool does not take read counts or quantifications into account, only existing or '
                                             'non-existing junctions.')
parser.add_argument('--lr-voila-file', type=str, required=True,
                    help='Path to the .lr.voila processed long reads file')
parser.add_argument('--splicegraph', type=str, required=True,
                    help='Path to the short read MAJIQ splicegraph file (.sql)')
parser.add_argument('--output', type=str, required=True,
                    help='Path to the the tsv file to write counts')
# parser.add_argument('--output2', type=str, required=True,
#                     help='Path to the the lr tsv file to write counts')
parser.add_argument('--fuzziness5', type=int, default=0,
                        help='5 prime fuzziness of long-read sequencing (number of basepairs)')
parser.add_argument('--fuzziness3', type=int, default=0,
                    help='3 prime fuzziness of long-read sequencing (number of basepairs)')
parser.add_argument('--only-junctions', action='store_true',
                    help='Only include junctions in the analysis, not introns')


args = parser.parse_args()

if args.fuzziness3 and args.fuzziness5:
    print("Error: you may only specify one of --fuzziness3 or --fuzziness5")
    sys.exit(1)


def compare_fuzzyjunc(lr_junc, sr_junc, strand):
    if args.fuzziness3:
        if strand == "+":
            j_range = sr_junc[1] - args.fuzziness3, sr_junc[1] + args.fuzziness3
            if lr_junc[0] == sr_junc[0] and (lr_junc[1] >= j_range[0] and lr_junc[1] <= j_range[1]):
                return True
        else:
            j_range = sr_junc[0] - args.fuzziness3, sr_junc[0] + args.fuzziness3
            if lr_junc[1] == sr_junc[1] and (lr_junc[0] >= j_range[0] and lr_junc[0] <= j_range[1]):
                return True
    if args.fuzziness5:
        if strand == "+":
            j_range = sr_junc[0] - args.fuzziness5, sr_junc[0] + args.fuzziness5
            if lr_junc[1] == sr_junc[1] and (lr_junc[0] >= j_range[0] and lr_junc[0] <= j_range[1]):
                return True
        else:
            j_range = sr_junc[1] - args.fuzziness5, sr_junc[1] + args.fuzziness5
            if lr_junc[0] == sr_junc[0] and (lr_junc[1] >= j_range[0] and lr_junc[1] <= j_range[1]):
                return True
def adjust_lr_juncs_fuzziness(lr_juncs, sr_juncs_denovo, sr_juncs_annotation, strand):
    """
    For each lr_junc, check if it looks close enough to first any denovo sr juncs, and then any annotation sr juncs.
    If so, adjust the lr junc
    """
    

    new_lr_juncs = []
    for lr_junc in lr_juncs:
        if lr_junc in sr_juncs_denovo or lr_junc in sr_juncs_annotation:
            new_lr_juncs.append(lr_junc)
        else:
            for sr_junc in sr_juncs_denovo:
                if compare_fuzzyjunc(lr_junc, sr_junc, strand) is True:
                    new_lr_juncs.append(sr_junc)
                    break
            else:
                for sr_junc in sr_juncs_annotation:
                    if compare_fuzzyjunc(lr_junc, sr_junc, strand) is True:
                        new_lr_juncs.append(sr_junc)
                        break
                else:
                    new_lr_juncs.append(lr_junc)
    assert len(new_lr_juncs) == len(lr_juncs)
    return new_lr_juncs


sg = sqlite3.connect(args.splicegraph)
sg.execute('pragma foreign_keys=ON')

genes_counts = {}
lr_genes_junctions = {}

print("Gathering LR junctions...")
with open(args.lr_voila_file, 'rb') as f:
    lr = pickle.load(f)
for gene in lr.keys():
    lr_genes_junctions[gene] = set()
    for transcript in lr[gene]['transcripts']:
        for junc in transcript['junctions']:
            lr_genes_junctions[gene].add(junc)
        if not args.only_junctions:
            for junc in transcript['intron_retention']:
                lr_genes_junctions[gene].add(junc)
print("...done")

print("Gathering SR genes...")
sr_genes = []
for gene in sg.execute("SELECT id from gene ORDER BY id").fetchall():
    sr_genes.append(gene[0])
print("...done")


'majiq_lr_annot'
print("Resolving SR junctions...")
#totals = {'TTT': 0, 'TTF': 0, 'TFT': 0, 'TFF': 0, 'FTT': 0, 'FTF': 0, 'FFT': 0}

with open(args.output, 'w') as fw:
    writer = DictWriter(fw, fieldnames=['gene_id', 'All', 'Both de novo', 'MAJIQ & Annotation', 'MAJIQ de novo', 'LR & Annotaion', 'LR de novo', 'Annotation'], delimiter='\t')
    writer.writeheader()

    # Open the file for FTF junctions
    # with open(args.output2, 'w', newline='') as ftf_file:
    #     ftf_writer = csv.writer(ftf_file, delimiter='\t')
    #     ftf_writer.writerow(['start', 'end'])  # Write the header for FTF file


    for gene_id in tqdm(sr_genes):

        counts = {'gene_id': gene_id, 'All': 0, 'Both de novo': 0, 'MAJIQ & Annotation': 0, 'MAJIQ de novo': 0, 'LR & Annotaion': 0, 'LR de novo': 0, 'Annotation': 0}

        if args.fuzziness3 or args.fuzziness5:

            strand = sg.execute(f"SELECT strand FROM gene WHERE id='{gene_id}'").fetchone()[0]

            fuzzy_annot = []
            for junc in sg.execute(f"SELECT start, end FROM junction WHERE gene_id='{gene_id}' AND annotated=1").fetchall():
                fuzzy_annot.append(junc)

            if not args.only_junctions:
                for junc in sg.execute(f"SELECT start, end FROM intron_retention WHERE gene_id='{gene_id}' AND annotated=1").fetchall():
                    fuzzy_annot.append(junc)

            fuzzy_denovo = []
            for junc in sg.execute(f"SELECT start, end FROM junction WHERE gene_id='{gene_id}' AND annotated=0").fetchall():
                fuzzy_denovo.append(junc)

            if not args.only_junctions:
                for junc in sg.execute(f"SELECT start, end FROM intron_retention WHERE gene_id='{gene_id}' AND annotated=0").fetchall():
                    fuzzy_denovo.append(junc)

            # if any(x not in lr_genes_junctions.get(gene_id, []) for x in adjust_lr_juncs_fuzziness(lr_genes_junctions.get(gene_id, []), fuzzy_denovo, fuzzy_annot, strand)):
            #     print(gene_id)
            #     print(lr_genes_junctions.get(gene_id, []))
            #     print(adjust_lr_juncs_fuzziness(lr_genes_junctions.get(gene_id, []), fuzzy_denovo, fuzzy_annot, strand))
            #     for x in adjust_lr_juncs_fuzziness(lr_genes_junctions.get(gene_id, []), fuzzy_denovo, fuzzy_annot, strand):
            #         if x not in lr_genes_junctions.get(gene_id, []):
            #             print(x)
            #     assert False

            lr_genes_junctions[gene_id] = adjust_lr_juncs_fuzziness(lr_genes_junctions.get(gene_id, []), fuzzy_denovo, fuzzy_annot, strand)

        # SR annotation only
        for junc in sg.execute(f"SELECT start, end FROM junction WHERE gene_id='{gene_id}' AND annotated=1 AND has_reads=0").fetchall():
            if junc in lr_genes_junctions.get(gene_id, []):
                counts['LR & Annotaion'] += 1
                lr_genes_junctions[gene_id].remove(junc)
            else:
                counts['Annotation'] += 1

        if not args.only_junctions:
            for junc in sg.execute(f"SELECT start, end FROM intron_retention WHERE gene_id='{gene_id}' AND annotated=1 AND has_reads=0").fetchall():
                if junc in lr_genes_junctions.get(gene_id, []):
                    counts['LR & Annotaion'] += 1
                    lr_genes_junctions[gene_id].remove(junc)
                else:
                    counts['Annotation'] += 1

        # SR annotation + majiq
        for junc in sg.execute(f"SELECT start, end FROM junction WHERE gene_id='{gene_id}' AND annotated=1 AND has_reads=1").fetchall():
            if junc in lr_genes_junctions.get(gene_id, []):
                counts['All'] += 1
                lr_genes_junctions[gene_id].remove(junc)
            else:
                counts['MAJIQ & Annotation'] += 1

        if not args.only_junctions:
            for junc in sg.execute(f"SELECT start, end FROM intron_retention WHERE gene_id='{gene_id}' AND annotated=1 AND has_reads=1").fetchall():
                if junc in lr_genes_junctions.get(gene_id, []):
                    counts['All'] += 1
                    lr_genes_junctions[gene_id].remove(junc)
                else:
                    counts['MAJIQ & Annotation'] += 1

        # SR denovo
        for junc in sg.execute(f"SELECT start, end FROM junction WHERE gene_id='{gene_id}' AND annotated=0 AND has_reads=1").fetchall():
            if junc in lr_genes_junctions.get(gene_id, []):
                counts['Both de novo'] += 1
                lr_genes_junctions[gene_id].remove(junc)
            else:
                counts['MAJIQ de novo'] += 1

        if not args.only_junctions:
            for junc in sg.execute(f"SELECT start, end FROM intron_retention WHERE gene_id='{gene_id}' AND annotated=0 AND has_reads=1").fetchall():
                if junc in lr_genes_junctions.get(gene_id, []):
                    counts['Both de novo'] += 1
                    lr_genes_junctions[gene_id].remove(junc)
                else:
                    counts['MAJIQ de novo'] += 1

        for junc in lr_genes_junctions.get(gene_id, []):
            # print("FTF: ",junc)
            # ftf_writer.writerow(junc)
            counts['LR de novo'] += 1

        writer.writerow(counts)

        # for key in totals.keys():
        #     totals[key] += counts[key]

print("...done")

output_file = args.output
df = pd.read_csv(output_file, sep='\t')

# Sum the columns excluding "FFT"
sum_df = df.drop(columns=['gene_id', 'Annotation']).sum()

# Create a DataFrame for the sum
sum_df = sum_df.reset_index()
sum_df.columns = ['source', 'total count']

# Generate the new output file name
output_total_file = output_file.replace('.tsv', '_total.tsv')

# Write the results to a new TSV file
sum_df.to_csv(output_total_file, sep='\t', index=False)
