# %%
import re
import pandas as pd
import natsort as ns
import numpy as np

# get names of the samples
sample_list = snakemake.config["graphs"][snakemake.wildcards.graph]
bb_sample = sample_list[0]

# %%
# import dataframe of segment information
df_segments = pd.read_csv(snakemake.input.segments)
df_segments['segment'] = pd.Categorical(
    df_segments['segment'], ordered=True, categories=ns.natsorted(df_segments['segment'].unique())
)

# %%
# growth dataframe to track changes in pangenome size
df_growth = pd.DataFrame(
    index = pd.Index("core_size,core_count,flex_size_bb,flex_count_bb,flex_size_nonref,flex_count_nonref,n_bubble".split(',')),
    columns = sample_list
)

# fill in the first column when only the backbone is considered
df_growth.loc[:, bb_sample] = 0
df_growth.loc['core_size', bb_sample]  = df_segments.query("species==@bb_sample")['length'].sum()
df_growth.loc['core_count', bb_sample] = df_segments.query("species==@bb_sample").shape[0]
df_growth

# %%
# read variant files of all species
df_variants = pd.read_table(snakemake.input.variants)
df_variants = (df_variants
    .set_index(pd.RangeIndex(1, df_variants.shape[0]+1, name='bubble_id'))
)

# %%
# for each row, identify the sample where each bubble is first formed
df_variants['first_occurence'] = None
for sample in sample_list[1:]:
    idx_to_update = (
        (df_variants["first_occurence"].isnull()) &        # undiscovered bubble
        (df_variants[sample] != df_variants[bb_sample]) &  # different from backbone
        (df_variants[sample] != '.')                       # not uncalled
    )
    df_variants.loc[idx_to_update, 'first_occurence'] = sample

# drop backbone column 
df_variants = df_variants.drop(bb_sample, axis=1)

# %%
# iterates through the remaining species in the order of their addition to the graph
# and updates the growth dataframe accordingly
df_variants['explored'] = False
for rank, current_sample in enumerate(sample_list[1: ]):

    # update Boolean counter for explored bubbles
    df_variants.loc[df_variants['first_occurence']==current_sample, 'explored'] = True
    explored_bubbles = [str(x) for x in df_variants.query("explored==True").index]

    # count the number of explored bubbles, factoring the newly added ones
    df_growth.loc['n_bubble', current_sample] = sum(df_variants['explored'] == True)

    # get various types flexible segments inside explored bubbles
    tmp_df = df_segments.query("core_bool==False & bubble_id in @explored_bubbles")
    tmp_df_bb = df_segments.query("core_bool==False & bubble_id in @explored_bubbles & rank==0")
    tmp_df_nonref = df_segments.query("core_bool==False & bubble_id in @explored_bubbles & rank!=0 & rank<=@rank+1")

    # bits of the backbone that are in bubbles
    df_growth.loc['flex_size_bb',  current_sample] = tmp_df_bb['length'].sum()
    df_growth.loc['flex_count_bb', current_sample] = tmp_df_bb.shape[0]

    # new segments contributed by the current sample (DON"T NEED THIS)
    df_growth.loc['flex_size_nonref',  current_sample] = tmp_df_nonref['length'].sum()
    df_growth.loc['flex_count_nonref', current_sample] = tmp_df_nonref.shape[0]

    # calculate core genome + bb flex segments not yet explored
    tmp_df1 = df_segments.query("core_bool==False & bubble_id not in @explored_bubbles & rank==0")
    tmp_df2 = df_segments.query("core_bool==True & rank==0")
    df_growth.loc['core_size',  current_sample] = tmp_df1['length'].sum() + tmp_df2['length'].sum() 
    df_growth.loc['core_count', current_sample] = tmp_df1.shape[0] + tmp_df2.shape[0]

    # remove all intermediate dataframes
    del(tmp_df, tmp_df_bb, tmp_df_nonref, tmp_df1, tmp_df2)

# %%
df_out = df_growth.transpose()
df_out.index = pd.RangeIndex(0, df_out.shape[0], name = "sample")

# calculate useful statistics
df_out['bb_size'] = df_out['core_size'] + df_out['flex_size_bb']
df_out['flex_size'] =  df_out['flex_size_bb'] +  df_out['flex_size_nonref']
df_out['total_size'] = df_out['bb_size'] + df_out['flex_size_nonref']

df_out['core_percent'] = (df_out['core_size'] / df_out['total_size'] * 100).astype('float').round(3)
df_out['flex_percent'] = (df_out['flex_size'] / df_out['total_size'] * 100).astype('float').round(3)
df_out['bb_extra_perc'] = (df_out['flex_size_nonref'] / df_out['total_size'] * 100).astype('float').round(3)

df_out

# %%
df_out.to_csv(snakemake.output[0])

# %%