# Module to create landscape files from the bedready files.

# %%
import argparse
from pathlib import Path
import re
import numpy as np
import pandas as pd
import plotly.express as px

# %%
def cmdline_args():
    p = argparse.ArgumentParser(
        description="""
        Compares two bedready files to make TE landscape plots.
        """,
        usage='python /path/to/script.py <args (6 required)>',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    p.add_argument('sample', type=str, help='path to sample bedready file')
    p.add_argument('sample_size', type=int, help='sample genome size')
    p.add_argument('control', type=str, help='path to control bedready file')
    p.add_argument('control_size', type=int, help='control genome size')
    p.add_argument('in_dir', type=str, help='input directory')
    p.add_argument('out_dir', type=str, help='output directory')
    return p.parse_args()


# %%
# parse arguments
args = cmdline_args()

IN_DIR = Path(args.in_dir)
assert IN_DIR.is_dir()

OUT_DIR = Path(args.out_dir)
assert OUT_DIR.is_dir()

SAMPLE_PATH = IN_DIR.joinpath(args.sample)
assert SAMPLE_PATH.is_file()
SAMPLE_NAME = re.sub(".bedready.txt", "", SAMPLE_PATH.name) 
SAMPLE_SIZE = args.sample_size # 957_485_262

CONTROL_PATH = IN_DIR.joinpath(args.control)
assert CONTROL_PATH.is_file()
CONTROL_NAME = re.sub(".bedready.txt", "", CONTROL_PATH.name)
CONTROL_SIZE = args.control_size #880_428_986

# # %%
# SAMPLE_PATH = SAMPLE_DIR.joinpath('mayZeb.bedready.txt')
# assert SAMPLE_PATH.is_file()
# SAMPLE_SIZE = 957_485_262
# SAMPLE_NAME = re.sub(".bedready.txt", "", SAMPLE_PATH.name) 

# # %%
# CONTROL_PATH = SAMPLE_DIR.joinpath('astCal.bedready.txt')
# assert CONTROL_PATH.is_file()
# CONTROL_SIZE = 880_428_986
# CONTROL_NAME = re.sub(".bedready.txt", "", CONTROL_PATH.name) 

# %% 
# get file containing the color mapping
REPO_DIR = Path(__file__).resolve().parent.parent
color_dict = pd.read_csv(REPO_DIR.joinpath("metadata/color_codes.csv"),
    header=None, skiprows=2).set_index(0).to_dict()[1]

# %%
def read_bedready_file(path_to_file):
    col_names = ['chr', 'start', 'end', 'repeat_full', 'perc_div',
        'complement', 'thick_start', 'thick_end', 'color', 
        'block_count', 'block_sizes', 'block_starts',
        'id', 'len', 'repeat_family']
    dtype = {'chr': str, 'id': str, 'color': str, 
        'block_sizes': str, 'block_starts': str}
    return pd.read_table(path_to_file, names=col_names, index_col=False, dtype=dtype)

def format_df_for_plotting(df, genome_size):
    # repeat_full -> repeat_hier1 -> repeat_hier2 -> repeat_hier3
    # e.g. DNA/TcMar-Mariner-3 ->  DNA/TcMar-Mariner -> DNA/TcMar -> DNA
    df['repeat_hier1'] = [re.sub("-.*$", "", x) for x in df['repeat_full']]
    df['repeat_hier2'] = [re.sub("-.*$", "", x) for x in df['repeat_hier1']]
    df['repeat_hier3'] = [re.sub("/.*$", "", x) for x in df['repeat_hier2']]

    # attempts to match the repeat's name to the color dictionary
    # we go less and less granular until we find something
    df['repeat'] = np.nan
    for repeat_level in ['repeat_full', 'repeat_hier1', 'repeat_hier2', 'repeat_hier3']:
        to_update_bool = ~df['repeat'].isin(color_dict.keys())
        df['repeat'].loc[to_update_bool] = df[repeat_level].loc[to_update_bool]

    # convert length to percentage of genome length    
    df['percent_genome'] = df['len'] / genome_size * 100

    # group into bins based on perc_div
    df['age_bounds'] = pd.cut(df['perc_div'], bins=np.linspace(0, 50, num=51), right=False)

    # get midpoint of age
    # df['age_midpoint'] = [round((x.left + x.right)/2, 1) for x in df['age_bounds']]
    df['age_midpoint'] = df['age_bounds'].apply(lambda x: x.mid).astype('float')

    return df

def filter_unwanted_entries(df):
    # filter out satellites, structural rna and repeats
    # removes entries where perc_div is below 0
    unwanted = ['Low_complexity', 'Simple_repeat', 'rRNA', 'Satellite', 'Structural_RNA']
    df_filtered = df.query(
        '~repeat_full.isin(@unwanted) & ~repeat.isin(@unwanted) & perc_div > 0', 
        engine='python')
    return df_filtered

# %%
df = read_bedready_file(SAMPLE_PATH)
df = format_df_for_plotting(df, SAMPLE_SIZE)
df = filter_unwanted_entries(df)

df_control = read_bedready_file(CONTROL_PATH)
df_control = format_df_for_plotting(df_control, CONTROL_SIZE)
df_control = filter_unwanted_entries(df_control)

# %%
# if still NA, raise an error
assert np.all(~df['repeat'].isna()), 'missing repeats in SAMPLE'
assert np.all(~df_control['repeat'].isna()), 'missing repeats in CONTROL'

# can use this comment to check what is not present
# df['repeat'].loc(df['repeat'].isna())
# list(filter(lambda x: x not in color_dict.keys(), df['repeat_hier3'].unique()))

# %% 
fig = px.histogram(df, 
    x='age_midpoint', y='percent_genome', color='repeat',
    nbins=50, range_x=[0, 50],
    color_discrete_map=color_dict,
    category_orders={'repeat': list(color_dict.keys())[::-1]}
)
fig.update_layout(
    legend={'traceorder': 'reversed'},
    template='simple_white',
    title=f'TE landscape: {SAMPLE_NAME} (genome size: {SAMPLE_SIZE:,})',
    xaxis_title="divergence from ancestral concensus (%)",
    yaxis_title="percentage of genome size",
    legend_title="repeat",
    yaxis_range=[0,4]
)
fig.write_html(f"{OUT_DIR}/landscape_{SAMPLE_NAME}.html")
# fig.show()

# %%
fig = px.histogram(df, 
    x='age_midpoint', y='percent_genome', color='repeat_hier3',
    nbins=50, range_x=[0, 50],
    template='simple_white',
    category_orders={
        'repeat_hier3': ['Unknown', 'Other', 'DNA', 'LINE', 'LTR', 'Retroposon', 'RC', 'SINE'][::-1]
    },
    color_discrete_map={
        'Unknown'    :'#C4C4C4',
        'Other'      :'#4D4D4D',
        'DNA'        :'#FFBAA9',
        'LINE'       :'#98ABE6',
        'LTR'        :'#65C465',
        'Retroposon' :'#FF9500',
        'RC'         :'#FF70FF',
        'SINE'       :'#D2ADF7'
    }
)
fig.update_layout(
    legend={'traceorder': 'reversed'},
    title=f'TE class overview: {SAMPLE_NAME} (genome size: {SAMPLE_SIZE:,})',
    xaxis_title="divergence from ancestral concensus (%)",
    yaxis_title="percentage of genome size",
    legend_title="repeat class",
    yaxis_range=[0,4]
)
fig.write_html(f"{OUT_DIR}/overview_{SAMPLE_NAME}.html")
# fig.show()

# %%
# # I'm not convinced this is needed, given we can just fit a linear model
# df_trace_bt = pd.DataFrame(data=None, index=pd.Index(np.linspace(0.5, 49.5, 50), name='age_midpoint'))
# sample_fraction = 0.9
# for i in range(0, 100):
#     a = (
#         df[['age_midpoint', 'percent_genome']]
#             .sample(frac=sample_fraction)
#             .groupby('age_midpoint')
#             .agg('sum')
#     )

#     b = (
#         df_control[['age_midpoint', 'percent_genome']]
#             .sample(frac=sample_fraction)
#             .groupby('age_midpoint')
#             .agg('sum')
#     )
#     df_trace_bt[f'ratio_{i}'] = a['percent_genome'] / b['percent_genome']
# df_trace_bt = df_trace_bt.reset_index()

# fig = px.scatter(df_trace_bt.query('age_midpoint > 2 & age_midpoint < 35'),
#     x='age_midpoint', y=[f'ratio_{i}' for i in range(0,100)], trendline="ols")
# fig.update_layout(yaxis_range=[0.5, 1.5], xaxis_range=[0, 40])
# fig.show()

# %%
a = df[['age_midpoint', 'percent_genome']].groupby('age_midpoint').agg(sample=('percent_genome', 'sum'))
b = df_control[['age_midpoint', 'percent_genome']].groupby('age_midpoint').agg(control=('percent_genome', 'sum'))
df_trace = a.join(b).reset_index()
df_trace['ratio'] = df_trace['sample'] / df_trace['control']
df_trace['ratio_log2'] = np.log2(df_trace['ratio'])

# %%
fig = px.scatter(df_trace, x='age_midpoint', y=['sample', 'control'], template='seaborn')
fig.update_traces(mode='lines+markers')
fig.update_layout(
    title=f'total elements: {SAMPLE_NAME} ({SAMPLE_SIZE:,}bp) vs {CONTROL_NAME} ({CONTROL_SIZE:,}bp)',
    xaxis_title="divergence from ancestral concensus (%)",
    yaxis_title="total counts of detected TEs"
)
fig.write_image(f"{OUT_DIR}/trace_{SAMPLE_NAME}.pdf")

# %%
# TO BE REMOVED (UNLESS I CAN FIGURE OUT HOW TO PUT CI)
# fig = px.scatter(df_trace.query('age_midpoint > 2 & age_midpoint < 35'),
#     x='age_midpoint', y='ratio', trendline="ols")
# fig.update_layout(yaxis_range=[1/1.5, 1.5/1], xaxis_range=[0, 40])
# fig.show()


# %%
df_trace_subset = df_trace.query('age_midpoint > 2 & age_midpoint < 30')
import statsmodels.api as sm
x = sm.add_constant(df_trace_subset['age_midpoint'].to_numpy().reshape((-1, 1)))
y = df_trace_subset['ratio'].to_numpy()

model = sm.OLS(y, x)
results = model.fit()

print('coefficient of determination:', results.rsquared)
print('adjusted coefficient of determination:', results.rsquared_adj)
print('regression coefficients:', results.params)
print(results.summary())

# %%
scaling_factor = results.predict(sm.add_constant(np.array([0.5,1,1.5]).reshape((-1, 1))))
print('predicted scaling factor for 0.5 and 1.5:', scaling_factor)

# %%
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme()
g = sns.regplot(x='age_midpoint', y='ratio',
        data=df_trace.query('age_midpoint > 2 & age_midpoint < 30'))
g.set(xlim=(0, 40), ylim=(1/1.5, 1.5/1))
g.set_title(f'scale factor: {SAMPLE_NAME}, y = {results.params[0]:.2f} + {results.params[1]:.5f}x, R^2={results.rsquared:.2f}, F={results.fvalue:.1f}')
plt.savefig(f"{OUT_DIR}/scalefactor_{SAMPLE_NAME}.pdf")

# %%
df_summary = (df
    .query('age_midpoint < 2')
    .groupby('repeat_hier3')
    .agg(sample=('percent_genome', 'sum'))
).join(df
    .query('age_midpoint < 2')
    .groupby('repeat_hier3')
    .agg(sample_scaled=('percent_genome', 'sum')) / scaling_factor[1]
).join(df_control
    .query('age_midpoint < 2')
    .groupby('repeat_hier3')
    .agg(control=('percent_genome', 'sum'))
)

df_summary.index = df_summary.index.rename('repeat_class')
df_summary.loc['Total']= df_summary.sum(axis=0)
df_summary.to_csv(f"{OUT_DIR}/counts_{SAMPLE_NAME}.csv", sep=',', index=True, header=True)

# %%
