# %%
import pandas as pd

min_variant_size = snakemake.wildcards.L
species1 = snakemake.wildcards.species1
species2 = snakemake.wildcards.species2
species1_index = snakemake.input.indexes[0]
species2_index = snakemake.input.indexes[1]
path_to_df_segment = snakemake.input.segments

# %% 
# import dataframe of segment information
df_segments = pd.read_csv(path_to_df_segment,
    dtype = {'chr_name': str, 'bubble_id': str}
)

# import chr names of the species
df_index_species1 = (
    pd.read_table(species1_index, usecols=[0], names=['chr_name'])
      .assign(segment_species = species1)
)
df_index_species2 = (
    pd.read_table(species2_index, usecols=[0], names=['chr_name'])
      .assign(segment_species = species2)
)

df_index = pd.concat([df_index_species1, df_index_species2])

# %%
# join chr names onto the main data frame
df_segments_joined = df_segments.merge(df_index)

# quick check that there are no segments not from the two specified species
assert ~df_segments_joined['segment_species'].isnull().any(), 'There are segments not from species 1 or 2!!!'

# %%
# 
df_flex = (df_segments_joined
    .query("core_bool==False")
    .loc[:, ['segment', 'length', 'segment_species', 'bubble_id']]
)

# %%
df_bubble_sizes = (df_flex
    .groupby(["bubble_id", "segment_species"])
    .agg({'length': 'sum'})
    .reset_index(drop=False)
    .pivot(index='bubble_id', columns='segment_species', values='length')
    .fillna(0)
)

# %%
net_sequence_diff = df_bubble_sizes[species2] - df_bubble_sizes[species1]
n_events = net_sequence_diff.shape[0]

# count number of various events
n_ins = sum(net_sequence_diff > 0)
n_del = sum(net_sequence_diff < 0)
n_sub = sum(net_sequence_diff == 0)

# count total sequence length of various events
n_bases_ins = sum(net_sequence_diff[net_sequence_diff > 0])
n_bases_del = sum(net_sequence_diff[net_sequence_diff < 0])

# %%
items_to_write = [
    species1, species2, min_variant_size, 
    n_events, n_ins, n_del, n_sub,
    n_bases_ins, n_bases_del
]
with open(snakemake.output[0], 'w') as o:
    o.write(','.join([str(x) for x in items_to_write]) + '\n')

# %%
