# %%
import math
import pandas as pd
import numpy as np
from subfunctions import read_and_parse_BED_wao

# %%
# get iteration number if exist, otherwise default
try:
    perm = str(snakemake.wildcards.i)
except:
    perm = "default"

# %%
age_filter_default = snakemake.config["age_filter_default"]
te_classes = snakemake.config["major_te_classes"] # # major TE classes
classes_to_remove = snakemake.config["entries_to_remove"] # repeat classes to remove
size_filter = snakemake.config["size_filter"]

# age_filters = [1,5,10,25,50]
# age_filter_default = 10
# te_classes = "DNA,LINE,LTR,SINE,Retroposon,Helitron,Unknown".split(',')
# classes_to_remove = "Simple_repeat,Low_complexity,tRNA,snRNA,rRNA,Satellite".split(',')
# size_filter = 50

# %%
# read BED files of the intersections
df_bed = read_and_parse_BED_wao(snakemake.input[0], classes_to_remove, size_filter)
# df_bed = read_and_parse_BED_wao("../results/astCal-intersect–te_as_whole.bed", classes_to_remove, size_filter)
# df_bed = read_and_parse_BED_wao("../output/astCal/intersect_whole/s2.bed", classes_to_remove, size_filter)

# %%
# overlap len needs to be fixed because there is -1
# sv_len maximum will be 0
df_bed['overlap_len'] = [0 if x == -1 else x for x in df_bed['overlap_len']]
df_bed['sv_len'] = [0 if x == -1 else x for x in df_bed['sv_len']]

# calculate overlap percentages
# amount of TE fragment inside the intersect
df_bed['overlap_perc_te'] = df_bed['overlap_len'] / (df_bed['end'] - df_bed['start'])

# amount of SV inside the intersect
df_bed['overlap_perc_sv'] = df_bed['overlap_len'] / df_bed['sv_len']
df_bed['overlap_perc_sv'] = [0 if math.isnan(x) else x for x in df_bed['overlap_perc_sv']]

# %%
# REVISIT? add a filter for simple bubbles??

# %%
# calculation 1: perfect intersection between SV and TE
df_out1 = pd.DataFrame(
    {'te_class': ['All Transposons', 'All Transposons except Unknown'] + te_classes, 'perm': perm}
)

for age_filter in [age_filter_default, 50]:
    tmp_df = df_bed.query("overlap_perc_sv > 0.9 and overlap_perc_te > 0.9 and age < @age_filter")
    results1 = []

    # all 
    results1.append(tmp_df.shape[0])

    # all except unknown
    results1.append(tmp_df.query("te_class!='Unknown'").shape[0])

    # the other TE classes
    for te_class in te_classes:
        results1.append(tmp_df.query("te_class==@te_class").shape[0])

    df_out1[f'age_{age_filter}'] = results1

df_out1.to_csv(snakemake.output.perfect, index=False)
    

# %%
# calculation 2: number of SVs that are mostly TEs
df_out2 = pd.DataFrame(
    {'te_class': ['All Transposons', 'All Transposons except Unknown'], 'perm': perm}
)

def count_SVs_made_of_TEs(df):
    tmp = df.groupby('sv_name').agg(overlap_perc_sv=('overlap_perc_sv', 'sum'))
    return(sum(tmp['overlap_perc_sv'] > 0.9) )

for age_filter in [age_filter_default, 50]:
    tmp_df = df_bed.query("sv_len != 0 and age < @age_filter")
    results2 = []

    # all 
    results2.append(count_SVs_made_of_TEs(tmp_df))

    # all except unknown
    results2.append(count_SVs_made_of_TEs(tmp_df.query("te_class!='Unknown'")))

    df_out2[f'age_{age_filter}'] = results2

df_out2.to_csv(snakemake.output.madeofTE, index=False)

# %%
# calculation 3: number of young TE fragments significantly inside flexible regions
# REVISIT? do if it fits well in the text

