# Module to create landscape files from the bedready files.
# Usage: python /path/to/script.py -h

# %%
import argparse
from pathlib import Path
import re
import numpy as np
import pandas as pd
import plotly.express as px

# %%
def cmdline_args():
    p = argparse.ArgumentParser(
        description="""
        Generates TE landscape plot from a bedready formatted file.
        """,
        usage='python plot_repmask_landscape.py sample.bedready.txt <genome_size> out_dir/',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    p.add_argument('sample', type=str, help='path to sample bedready file')
    p.add_argument('sample_size', type=int, help='sample genome size')
    p.add_argument('out_dir', type=str, help='output directory (default: current)', default='./')
    return p.parse_args()

# %%
# parse arguments
args = cmdline_args()

OUT_DIR = Path(args.out_dir)
assert OUT_DIR.is_dir()

SAMPLE_PATH = Path(args.sample)
assert SAMPLE_PATH.name.endswith(".bedready.txt")
assert SAMPLE_PATH.is_file()
SAMPLE_NAME = re.sub(".bedready.txt", "", SAMPLE_PATH.name) 
SAMPLE_SIZE = args.sample_size 

# %% 
# get file containing the color mapping
color_dict = pd.read_csv(Path(__file__).resolve().parent.joinpath("TE_color.csv"),
    header=None, skiprows=2).set_index(0).to_dict()[1]
color_dict_broad = pd.read_csv(Path(__file__).resolve().parent.joinpath("TE_color_broad.csv"),
    header=None, skiprows=2).set_index(0).to_dict()[1]   

# %%
def read_bedready_file(path_to_file):
    dtype = {'chr': str, 'id': str, 'color': str, 
        'block_sizes': str, 'block_starts': str}
    return pd.read_table(path_to_file, index_col=False, dtype=dtype)

def format_df_for_plotting(df, genome_size):
    # convert length to percentage of genome length    
    df['perc_seq'] = df['len'] / genome_size * 100

    # group into bins based on perc_div
    df['age_bounds'] = pd.cut(df['perc_div'], bins=np.linspace(0, 50, num=51), right=False)

    # get midpoint of age
    df['age_midpoint'] = df['age_bounds'].apply(lambda x: x.mid).astype('float')

    return df

# %%
df = read_bedready_file(SAMPLE_PATH)
df = format_df_for_plotting(df, SAMPLE_SIZE)

# removes entries where perc_div is below 0
nrow_before = df.shape[0]
perc_seq_before = df['perc_seq'].sum()
df = df.query('perc_div >= 0', engine='python')
print(f"""
removed {nrow_before-df.shape[0]} entries where perc_div < 0({round(perc_seq_before - df['perc_seq'].sum(), 2)}% sequence)""")

# filter out satellites, structural rna and repeats
unwanted = [
    'Low_complexity', 'Simple_repeat', 'rRNA', 'Satellite', 'Structural_RNA', 'tRNA', 'snRNA'
]
nrow_before = df.shape[0]
perc_seq_before = df['perc_seq'].sum()
df = df.query('~repeat_class.isin(@unwanted) & ~repeat_family.isin(@unwanted)', engine='python')
print(f"""
removed {nrow_before-df.shape[0]} entries for low complexity, simple repeats, satellites and non-transposon RNA regions ({round(perc_seq_before - df['perc_seq'].sum(), 2)}% sequence)
""")

# filter through those that are in the categories
bool_in_colordict = df['repeat_class'].isin(color_dict.keys())
bool_in_colordictbroad = df['repeat_class_broad'].isin(color_dict_broad.keys())
tmp_dc = {
    'a': sum(~bool_in_colordict | ~bool_in_colordictbroad),
    'b': sum(df.loc[~bool_in_colordict | ~bool_in_colordictbroad, 'perc_seq']),
    'c': pd.unique(df.loc[~bool_in_colordict | ~bool_in_colordictbroad, 'repeat_class']),
    'd': pd.unique(df.loc[~bool_in_colordict | ~bool_in_colordictbroad, 'repeat_class_broad'])
}

df = df[bool_in_colordict & bool_in_colordictbroad]
if tmp_dc['a'] > 0:
    print(f"""
WARNING: there are {tmp_dc['a']} previously undiscovered TE entries ({round(tmp_dc['b'], 5)}% sequence)
please report this to the developer, quote the following...

repeat name: {', '.join(tmp_dc['c'])}    
repeat category: {', '.join(tmp_dc['d'])}""")

# print message about what's remaining
print(f"""{'*'*50}\n
number of entries remaining: {df.shape[0]} ({round(df['perc_seq'].sum(), 2)}% sequence)
""")

# %% plot TE landscape using the most granular category
fig = px.histogram(df, 
    x='age_midpoint', y='perc_seq', color='repeat_class', 
    range_x=[0, 40],
    color_discrete_map=color_dict,
    category_orders={
        'repeat_class': list(color_dict.keys())[::-1]
    }
)
fig.update_layout(
    legend={'traceorder': 'reversed'},
    template='simple_white',
    title=f'TE landscape: {SAMPLE_NAME} (genome size: {SAMPLE_SIZE:,})',
    xaxis_title="divergence from ancestral concensus (%)",
    yaxis_title="percentage sequence",
    legend_title="repeat",
    font = dict(family='Helvetica')
)
fig.write_html(f"{OUT_DIR}/{SAMPLE_NAME}_landscape.html")

# %% plot TE landscape using the broadest category
fig = px.histogram(df, 
    x='age_midpoint', y='perc_seq', color='repeat_class_broad',
    range_x=[0, 40],
    template='simple_white',
    color_discrete_map=color_dict_broad,
    category_orders={
        'repeat_class_broad': list(color_dict_broad.keys())[::-1]
    }
)

fig.update_layout(
    legend={'traceorder': 'reversed'},
    title=f'TE class overview: {SAMPLE_NAME} (genome size: {SAMPLE_SIZE:,})',
    xaxis_title="divergence from ancestral concensus (%)",
    yaxis_title="percentage of genome size",
    legend_title="repeat class",
    font = dict(family='Helvetica')
)
fig.write_html(f"{OUT_DIR}/{SAMPLE_NAME}_landscape_broad.html")

# %% plot treemap to the repmask family level, area: percent_sequence, color: age
df_agg = df.groupby('repeat_class_broad').agg({'perc_seq': lambda x: str(round(sum(x), 1)) }).reset_index()
df_agg['repeat_class_broad_agg'] = df_agg['repeat_class_broad'] + ' (' + df_agg['perc_seq'] + '%)'
df_agg = df_agg.drop('perc_seq', axis=1)
df = pd.merge(df, df_agg, on='repeat_class_broad')

fig = px.treemap(df, 
    path=[px.Constant("Transposons"), 'repeat_class_broad_agg', 'repeat_class', 'repmask_id'],
    color='perc_div', values='perc_seq',
    color_continuous_scale='RdBu',
    range_color=[0,32]
)
fig.update_layout(
    margin = dict(t=50, l=25, r=25, b=25),
    font = dict(family='Helvetica', size=16),
    legend_title="percent divergence"
)

fig.write_html(f"{OUT_DIR}/{SAMPLE_NAME}_TE_treemap.html")
fig.write_image(f"{OUT_DIR}/{SAMPLE_NAME}_TE_treemap.svg", width=1000, height=700)
