# %%
import pandas as pd
import natsort as ns
import numpy as np

# %%
# import dataframe of segment information
df_segments = pd.read_csv(snakemake.input.segments)
df_segments['segment'] = pd.Categorical(
    df_segments['segment'], ordered=True, categories=ns.natsorted(df_segments['segment'].unique())
)

# read coverage file 
df_coverage_raw = (pd.read_table(snakemake.input.coverage)
    .set_index('segment')
)

# get the number of the non-reference samples
n_samples = df_coverage_raw.shape[1] - 1

# %%
# cap coverage values to a maximum of 1
df_coverage_raw = np.minimum(df_coverage_raw, 1)

# %%
# multiply the raw coverage dataframe to convert the values into number of bases
df_coverage_raw_nbases = (df_coverage_raw
    .multiply(df_segments[['segment', 'length']].set_index('segment').to_dict()['length'], axis='index')
    .reset_index()
)

# %%
# now we iterate through each non-reference species,
# and "correct" the coverage values of backbone flexible segments to a value of 1
# if their source and sink nodes have n_bases_covered above a predefined threshold
df_coverage = df_coverage_raw.copy()
cov_threshold_nbases = 100

for j in range(1, n_samples+1):
    bb_flex_segments = (df_segments
        .query("core_bool==False & rank==0")
        .loc[:, ['segment', 'source', 'sink']]
        .merge(
            df_coverage_raw_nbases[['segment', str(j)]].rename(columns={'segment': 'source', str(j): 'source_cov'}),
            how = 'left', on = 'source')
        .merge(
            df_coverage_raw_nbases[['segment', str(j)]].rename(columns={'segment': 'sink', str(j): 'sink_cov'}),
            how = 'left', on = 'sink')
        .query("source_cov > @cov_threshold_nbases & sink_cov > @cov_threshold_nbases")
        .loc[:, 'segment']
    )
    df_coverage.loc[bb_flex_segments, str(j)] = 1

# drop the first column, which is the backbone
df_coverage = df_coverage.iloc[:, 1:]
df_coverage

# %%
max_cov_values = df_coverage.apply(max, axis=1).rename("cov_cuml") # cumulative coverage
min_cov_values = df_coverage.apply(min, axis=1).rename("cov_ubiq") # ubiquitous / universal coverage

# convert to a coverage-weighted length
df_cov_agg = (df_segments
    .loc[:, ['segment','length']]
    .merge(pd.concat([max_cov_values, min_cov_values], axis=1).reset_index())
    .assign(
        len_cuml = lambda dataframe: dataframe["cov_cuml"] * dataframe["length"],
        len_ubiq = lambda dataframe: dataframe["cov_ubiq"] * dataframe["length"]
    )
)

# combine to the original segment dataframe
df_segments_w_cov = df_segments.merge(df_cov_agg)

# %%
# estimate cumulative coverage related statistics
bb_size = sum(df_segments_w_cov.query("rank==0")['length'])

bb_cov = round(sum(df_segments_w_cov.query("rank==0")['len_cuml']), 2)
bb_cov_perc = round(bb_cov / bb_size * 100, 3)

covered_linear_segments = df_segments_w_cov.query("core_bool==True & rank==0 & len_cuml > 100")["segment"]
segments_in_covered_bubble = (
    df_segments_w_cov.query("core_bool==False & source in @covered_linear_segments & sink in @covered_linear_segments")
)

bb_extra_corr = segments_in_covered_bubble.query("rank!=0")["len_cuml"].sum()
bb_extra_perc_corr = round(bb_extra_corr / bb_cov * 100, 3)
bb_extra_perc_corr

# %%
# estimate ubiquitous coverage related stats (mostly pangenome related)
bb_ubiq = round(sum(df_segments_w_cov.query("rank==0")['len_ubiq']), 2)
bb_ubiq_perc = round(bb_ubiq / bb_size * 100, 3)

# core: linear segments which are covered in all species
core_size = round(df_segments_w_cov.query("core_bool==True & rank==0")["len_ubiq"].sum(), 2)

# flex: segments where the source and sinks are part of covered linear segments
ubiq_covered_linear_segments = df_segments_w_cov.query("core_bool==True & rank==0 & len_ubiq > 100")["segment"]
segments_in_ubiq_covered_bubble = (
    df_segments_w_cov.query("core_bool==False & source in @ubiq_covered_linear_segments & sink in @ubiq_covered_linear_segments")
)
flex_size = round(segments_in_ubiq_covered_bubble["len_ubiq"].sum(), 2)
flex_size_bb = round(segments_in_ubiq_covered_bubble.query("rank==0")["len_ubiq"].sum(), 2)
flex_size_nonref = round(segments_in_ubiq_covered_bubble.query("rank!=0")["len_ubiq"].sum(), 2)

# total pangenome size and component percentages
pangenome_size = core_size + flex_size
core_perc = round(core_size / pangenome_size * 100, 3)
flex_perc = round(flex_size / pangenome_size * 100, 3)

# estimated conservation of backbone sequence in all species, based on ubiquitously covered regions
bb_conserv_perc = round(core_size / bb_ubiq * 100, 3)

# %%
header_to_write = (
    "bb_cov,bb_cov_perc,bb_extra_perc_corr,bb_ubiq,bb_ubiq_perc,bb_conserv_perc," + 
    "core_size,flex_size,pangenome_size,core_perc,flex_perc"
)
items_to_write = [
    bb_cov, bb_cov_perc, bb_extra_perc_corr, bb_ubiq, bb_ubiq_perc, bb_conserv_perc,
    core_size, flex_size, pangenome_size, core_perc, flex_perc
]

with open(snakemake.output[0], 'w') as o:
    o.write(header_to_write + '\n')
    o.write(','.join([str(x) for x in items_to_write]) + '\n')

# %%