# this script adds more information to the segment dataframe extacted form GAF
# adding information about bubble of origin, and whether that bubble is covered
# also computes a weighted length based on coverage

# %%
import pandas as pd
import natsort as ns

min_variant_size = snakemake.wildcards.L
species1 = snakemake.wildcards.species1
species2 = snakemake.wildcards.species2
cov_threshold = 0.1

# %%
# get segment, segment length, genome rank
df_segments = pd.read_table(snakemake.input.segments,
    names = ['segment', 'length', 'rank', 'chr_name', 'coverage']
)

# naturally sort segment column
df_segments['segment'] = pd.Categorical(
    df_segments['segment'], ordered=True, categories=ns.natsorted(df_segments['segment'].unique())
)

# extract a list of flex segments by taking the middle segments of the bubbles
# then, add this information to the main dataframe
flex_segments = []
with open(snakemake.input.bubble) as file_object:
    line = file_object.readline().strip('\n')
    while line:
        for seg in line.split(',')[1:-1]:
            flex_segments.append(seg)
        line = file_object.readline()
df_segments['core_bool'] = True
df_segments.loc[df_segments.query("segment in @flex_segments").index, 'core_bool'] = False

# %%
# add bubble identifier for the flexible segments (in case we want to sum up all flex segments by bubble)
# core segments will often have two bubbles they are associated with 
df_bubble = (pd.read_table(snakemake.input.bubble, names=["segment"])
    .assign(bubble_id = lambda dataframe: [str(x) for x in range(1, dataframe.shape[0] + 1)],
            segment = lambda dataframe: [x.split(',') for x in dataframe['segment']] )
    .explode("segment")
    .reset_index(drop=True)
    .groupby('segment')
    .agg({'bubble_id': lambda x: ",".join(x)})
    .reset_index()
)

df_segments = pd.merge(df_segments, df_bubble, how='left')

# %%
# add columns for min_variant_size, species1 and species2
df_segments['min_variant_size'] = min_variant_size
df_segments['species1'] = species1
df_segments['species2'] = species2

# %%
##########
# to calculate the percentage conservation of the backbone in the query species
# we can only take into account regions that are alignable by it

def get_bb_flex_segments_with_coverage_in_bubbles(cov_threshold = 0.1):
    """
    Retrieve a data frame of backbone segments with sufficient coverage that are part of bubbles.

    Each row represents a flexible segment, and its corresponding source and sink.
    The flexible segment must be part of the backbone and only gets included 
    if the coverage of the latter two is above `cov_threshold`.

    Because some bubbles may not have any backbone flexible segments in them, 
    the dataframe will not contain the complete set of source and sinks.
    To obtain this, use the `df_core` variable calculated earlier.

    """

    # read bubble file to retrieve the segments in each bubble
    list_sources = []
    list_sinks = []
    list_flexlists = []
    with open(snakemake.input.bubble) as file_object:
        line = file_object.readline()
        while line:
            current_segments = line.strip('\n').split(',')
            list_sources.append(current_segments[0])
            list_sinks.append(current_segments[-1])
            list_flexlists.append(current_segments[1:-1])
            line = file_object.readline()
    df = pd.DataFrame({'source': list_sources, 'sink': list_sinks, 'flex': list_flexlists})

    # get coverage values for source and sink
    coverage_dict = df_segments[['segment', 'coverage']].set_index('segment').to_dict()['coverage']
    df['source_cov'] = df['source'].map(coverage_dict)
    df['sink_cov'] = df['sink'].map(coverage_dict)

    # keep only those where both sink and source have coverage
    # we can't simply filter for flex segments with coverage, because the query species may not have that
    df = df.query("source_cov > @cov_threshold & sink_cov > @cov_threshold")

    # melt/split out the flexlist into individual rows
    df = df.explode('flex')

    return(df)

# get data frame of flexible segments filtered to be in covered bubbles
# add this info as a Boolean variable in main data frame
df_flex_bb_covered = get_bb_flex_segments_with_coverage_in_bubbles(cov_threshold = cov_threshold)
df_segments['bubble_covered'] = False
df_segments.loc[df_segments.query("segment in @df_flex_bb_covered.flex").index, 'bubble_covered'] = True

# add a column where coverages above 1 across all segments are rounded down
df_segments['cov_corr'] = [min(x,1) for x in df_segments['coverage']]

# for flexible regions whose sources and sinks are covered, we correct their coverage to 1
df_segments.loc[df_segments.query('bubble_covered == True & core_bool == False').index, 'cov_corr'] = 1

# add a column where the segment length is weighted by this "corrected coverage"
df_segments = df_segments.assign(weighted_length = lambda x: x.length * x.cov_corr)
df_segments['weighted_length'] = df_segments['weighted_length'].astype(int)

# %%
# write new dataframe augmented with information
cols_rearranged = "species1,species2,min_variant_size,segment,length,rank,chr_name,coverage,core_bool,bubble_id,bubble_covered,cov_corr,weighted_length"
df_segments[cols_rearranged.split(',')].to_csv(snakemake.output[0], index=False)

# %%
