# this script adds more information to the segment dataframe extacted form GAF
# adding information about bubble of origin,
# as well as the source and sink nodes for flexible segments

# %%
import re
import pandas as pd
import natsort as ns

# %%
# import dataframe of segment information
df_segments = pd.read_table(snakemake.input.segments,
    names = ['segment', 'length', 'rank', 'chr_name']
)

# naturally sort segment column
df_segments['segment'] = pd.Categorical(
    df_segments['segment'], ordered=True, categories=ns.natsorted(df_segments['segment'].unique())
)

# extract species from which each segment originates
df_segments['species'] = [re.search(r"^([^_]*)_", x).group(1) for x in df_segments["chr_name"]]

# %%
# extract a list of flex segments by taking the middle segments of the bubbles
# then, add this information to the main dataframe
flex_segments = []
with open(snakemake.input.bubble) as file_object:
    line = file_object.readline().strip('\n')
    while line:
        for seg in line.split(',')[1:-1]:
            flex_segments.append(seg)
        line = file_object.readline()
df_segments['core_bool'] = True
df_segments.loc[df_segments.query("segment in @flex_segments").index, 'core_bool'] = False

# %%
# add bubble identifier for the segments
# core segments will often have two bubbles they are associated with 
df_bubble = (pd.read_table(snakemake.input.bubble, names=["segment"])
    .assign(bubble_id = lambda dataframe: [str(x) for x in range(1, dataframe.shape[0] + 1)],
            segment = lambda dataframe: [x.split(',') for x in dataframe['segment']] )
    .explode("segment")
    .reset_index(drop=True)
    .groupby('segment')
    .agg({'bubble_id': lambda x: ",".join(x)})
    .reset_index()
)
df_segments = pd.merge(df_segments, df_bubble, how='left')

# %%
# read bubble file again to retrieve the source and sink nodes for the flexible segments
def get_source_and_sink_nodes_for_flex_segments():
    list_flexlists = []
    list_sources = []
    list_sinks = []
    with open(snakemake.input.bubble) as file_object:
        line = file_object.readline()
        while line:
            current_segments = line.strip('\n').split(',')
            list_flexlists.append(current_segments[1:-1])
            list_sources.append(current_segments[0])
            list_sinks.append(current_segments[-1])
            line = file_object.readline()
    tmp_df = (
        pd.DataFrame({'segment': list_flexlists, 'source': list_sources, 'sink': list_sinks})
          .explode("segment")
          .drop_duplicates()
    )
    return(tmp_df)

df_segments = pd.merge(df_segments, get_source_and_sink_nodes_for_flex_segments(), how='left')

# %%
df_segments.to_csv(snakemake.output[0], index=False)

# %%
