# %%
import re
import numpy as np
import pandas as pd
import natsort as ns

# %%
# get species list from config
species_list = snakemake.config['graphs'][snakemake.wildcards.graph]

# %%
# import dataframe of segment information
df_segments = pd.read_csv(snakemake.input.segments)
df_segments['segment'] = pd.Categorical(df_segments['segment'], 
    ordered=True, categories=ns.natsorted(df_segments['segment'].unique())
)

# %%
# import dataframe of coverage 
df_coverage_raw = (pd.read_table(snakemake.input.coverage)
    .set_index('segment')
)

# get the number of the non-reference samples
n_samples = df_coverage_raw.shape[1] - 1

# cap coverage values to a maximum of 1
df_coverage_raw = np.minimum(df_coverage_raw, 1)

# multiply the raw coverage dataframe to convert the values into number of bases
df_coverage_raw_nbases = (df_coverage_raw
    .multiply(df_segments[['segment', 'length']].set_index('segment').to_dict()['length'], axis='index')
    .reset_index()
)

# %%
# %%
# now we iterate through each non-reference species,
# and "correct" the coverage values of backbone flexible segments to a value of 1
# if their source and sink nodes have n_bases_covered above a predefined threshold
df_coverage = df_coverage_raw.copy()
cov_threshold_nbases = 100

for j in range(0, n_samples+1):
    bb_flex_segments = (df_segments
        .query("core_bool==False & rank==0")
        .loc[:, ['segment', 'source', 'sink']]
        .merge(
            df_coverage_raw_nbases[['segment', str(j)]].rename(columns={'segment': 'source', str(j): 'source_cov'}),
            how = 'left', on = 'source')
        .merge(
            df_coverage_raw_nbases[['segment', str(j)]].rename(columns={'segment': 'sink', str(j): 'sink_cov'}),
            how = 'left', on = 'sink')
        .query("source_cov > @cov_threshold_nbases & sink_cov > @cov_threshold_nbases")
        .loc[:, 'segment']
    )
    df_coverage.loc[bb_flex_segments, str(j)] = 1

# drop the first column, which is the backbone
#df_coverage = df_coverage.iloc[:, 1:]
#df_coverage

# %%
# extract rows corresponding to backbone segments
df_segments_bb = df_segments.query("rank==0")[["segment", "length"]].set_index('segment')
bb_size = df_segments_bb['length'].sum()

# multiply by segment length to calculate the covered bases
df_coverage_nbases = df_coverage.loc[df_segments_bb.index].copy()
for j in range(0, n_samples+1):
    df_coverage_nbases.loc[:, str(j)] *= df_segments_bb['length']

# %%
# then perform a column-wise sum to get coverage by sample
cov_by_sample = df_coverage_nbases.apply(sum, axis=0)

# %%
df_out = pd.DataFrame({
    'rank': cov_by_sample.index,
    'species': species_list, #species_list[1:],
    'bb_cov_bysample': cov_by_sample.astype(int),
    'bb_cov_perc_bysample': np.round(cov_by_sample / bb_size * 100, 3)}
)
df_out

# %%
df_out.to_csv(snakemake.output[0], index=False)

# %%
