# %%
import re
import pandas as pd
import natsort as ns

min_variant_size = snakemake.wildcards.L
species1 = snakemake.wildcards.species1
species2 = snakemake.wildcards.species2
cov_threshold = 0.1

# %% 
# import dataframe of segment information
df_segments = pd.read_csv(snakemake.input.segments)

# %%
# based on: https://www.vipinajayakumar.com/parsing-text-with-python/#step-3-define-regular-expressions
rx_dict = {
    'n_segments': re.compile(r'Number of segments: (?P<n_segments>\d*)\n'),
    'n_edges': re.compile(r'Number of links: (?P<n_edges>\d*)\n'),
    'total_pangenome_size': re.compile(r'Total segment length: (?P<total_pangenome_size>\d*)\n'),
    'backbone_size': re.compile(r'Sum of rank-0 segment lengths: (?P<backbone_size>\d*)\n')
}

def _parse_line(line):
    """
    Do a regex search against all defined regexes and
    return the key and match result of the first matching regex
    
    """
    for key, rx in rx_dict.items():
        match = rx.search(line)
        if match:
            return key, match
    return None, None

with open(snakemake.input.stat) as file_object:
    line = file_object.readline()
    while line:
        key, match = _parse_line(line)
        
        if key == 'n_segments':
            n_segments = int(match.group('n_segments'))
        
        if key == 'n_edges':
            n_edges = int(match.group('n_edges'))
        
        if key == 'total_pangenome_size':
            graph_size = int(match.group('total_pangenome_size'))

        if key == 'backbone_size':
            bb_size = int(match.group('backbone_size'))

        line = file_object.readline()


# %%
# using the existing information, we can calculate most statistics that we want
n_bubbles = sum(1 for line in open(snakemake.input.bubble))
mean_segment_length = round(df_segments['length'].mean(), 3)

# linear segments
linear_size = df_segments.query('core_bool == True')['length'].sum() 
linear_count = df_segments.query('core_bool == True').shape[0]

# variable segments located within bubbles
variable_size = df_segments.query('core_bool == False')['length'].sum()
variable_count = n_segments - linear_count
variable_size_bb = df_segments.query('core_bool == False & rank==0')['length'].sum()
variable_count_bb = df_segments.query('core_bool == False & rank==0').shape[0]
variable_size_nonref = df_segments.query('core_bool == False & rank!=0')['length'].sum() 
variable_count_nonref = df_segments.query('core_bool == False & rank!=0').shape[0]

# %%
# get the core and flex parts of the backbone with sufficient coverage
df_linear_covered  = df_segments.query("core_bool == True & coverage > @cov_threshold")
df_variable_bb_covered = df_segments.query('rank == 0 & bubble_covered == True')

# add this to a dataframe of the backbone's flexible part with sufficient coverage
df_bb_covered = pd.concat([df_linear_covered, df_variable_bb_covered]).sort_values('segment')

# calculate the percentage of backbone covered by the query
bb_cov = df_bb_covered['weighted_length'].sum()
bb_cov_perc = round(bb_cov / bb_size * 100, 3)

# how much extra nonref sequence relative to backbone
bb_extra_perc = round((variable_size_nonref) / bb_size * 100, 3)
bb_extra_perc_corr = round((variable_size_nonref) / bb_cov * 100, 3)

# calculate the percentage of backbone detectable in the query, corrected by coverage
bb_detect_perc = round((bb_cov - variable_size_bb) / bb_cov * 100, 3)

# %%
# estimate pangenome 
# core: refers to the detectable backbone sequence in all other species
core_size = df_linear_covered['weighted_length'].sum()
flex_size = df_segments.query('core_bool == False & bubble_covered == True')['weighted_length'].sum()
pangenome_size = core_size + flex_size
core_perc = round(core_size / pangenome_size * 100, 3)
flex_perc = round(flex_size / pangenome_size * 100, 3)

# %%
# write results
items_to_write = [
    species1, species2, min_variant_size, 
    n_segments, n_edges, n_bubbles, mean_segment_length, 
    graph_size, linear_size, linear_count, variable_size, variable_count,
    variable_size_bb, variable_count_bb, variable_size_nonref, variable_count_nonref,
    bb_size, bb_cov, bb_cov_perc, bb_extra_perc, bb_extra_perc_corr, bb_detect_perc,
    core_size, flex_size, pangenome_size, core_perc, flex_perc
]
with open(snakemake.output[0], 'w') as o:
    o.write(','.join([str(x) for x in items_to_write]) + '\n')

# %%
