# %%
import re
import pandas as pd
import numpy as np
from more_itertools import roundrobin
import argparse
from pathlib import Path

def cmdline_args():
    p = argparse.ArgumentParser(
        description="""
        Extract a list of nodes and a list of (chromosome) paths taken by a sample through a graph.
        The path is inferred based on variant calls from `minigraph -xasm --call`.
        
        `sample_node_coverages` is a tab-delimited file of five columns:
        segment, length, rank, chr_name, coverage (columns 2,4,7,5,8 of GAF file)

        `graph_nodes_by_bubble` can be retrieved as the 12th column of the output of gfatools bubble on the GFA.

        """,
        usage='python3 extract_genome_path.py sample.variants sample_node_coverages graph_nodes_by_bubble SampleName',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    p.add_argument('sample', type=str, help='path to the sample variants')
    p.add_argument('sample_node_coverages', type=str, help='tab-delimited file of node ids and coverages for sample')
    p.add_argument('graph_nodes_by_bubble', type=str, help='single column file ')
    p.add_argument('sample_name', type=str, help='prefix added to paths')
    p.add_argument('-o', '--output_dir', type=str, default="./")
    return p.parse_args() 

###

# parse arguments
args = cmdline_args()
SAMPLE_NAME = args.sample_name
SAMPLE_PATH = Path(args.sample)
SAMPLE_NODE_COVERAGES_PATH = Path(args.sample_node_coverages)
GRAPH_NODES_BY_BUBBLE_PATH = Path(args.graph_nodes_by_bubble)
OUTPUT_DIR = Path(args.output_dir)

###

assert SAMPLE_PATH.is_file()
assert SAMPLE_NODE_COVERAGES_PATH.is_file()
assert GRAPH_NODES_BY_BUBBLE_PATH.is_file()
assert OUTPUT_DIR.is_dir()

# %%
# read in the dataframe about the coverage for each node for the sample
df_segments_lookup = pd.read_table(SAMPLE_NODE_COVERAGES_PATH, 
    names = ['node', 'len', 'rank', 'chr', 'coverage'], 
    dtype = {'node': str, 'len': int, 'rank': int, 'chr': str, 'coverage': float}, 
    index_col = False
)

# extract a list of flex segments by taking the middle segments in graph bubbles
# then, add this information to the previous dataframe
flex_segments = []
with open(GRAPH_NODES_BY_BUBBLE_PATH) as file_object:
    line = file_object.readline().strip('\n')
    while line:
        for seg in line.split(',')[1:-1]:
            flex_segments.append(seg)
        line = file_object.readline()
df_segments_lookup['core_bool'] = True
df_segments_lookup.loc[df_segments_lookup.query("node in @flex_segments").index, 'core_bool'] = False
df_segments_lookup.loc[df_segments_lookup['node'].isin(flex_segments), 'core_bool'] = False

# %%
# read in dataframe about the path taken by the sample at each bubble
df_sample = pd.read_table(SAMPLE_PATH,
    names = ['chr', 'start', 'end', 'source', 'sink', 'query'],
    dtype = {'chr': str},
    index_col = False
)

# split query column into separate fields (information about the segments taken by the query)
# query path is the most important: >s32>s33<s50 (it's the allele in the flexible part of bubble)
new_cols = ['query_path', 'query_len', 'query_strand', 'query_chr', 'query_start', 'query_end']
df_sample[new_cols] = df_sample['query'].str.split(':', expand=True)

# concatenate the source and sink to get the bubble name
df_sample['path'] = [str(x) + str(y) for x, y in zip(df_sample['source'], df_sample['sink'])]

# %%
# retrieve the paths of each observed chromosome (those with bubbles)
def retrieve_pseudoref_nodes_for_chromosome(chr_id):
    """
    Retrieve a dataframe for the specified chromosome, where each row represents a node in the
    personalised pseudogenome.

    Columns: chr, core, node_with_strand, node
    """

    # subfunction to put commas between ">s30" or "<s30"-like strings 
    def parse_path(path_string):
        if path_string == '*':
            return '*'
        elif path_string == '.':
            return '.'
        else:
            return ','.join([x[1:] + x[0] for x in re.findall("[><]s\d+", path_string)])
    
    df_subset = df_sample.query("chr==@chr_id")

    # note: this is only the core and flex nodes in bubbles!
    core_nodes = pd.concat([df_subset['source'], df_subset['sink'].iloc[[-1]]]).apply(parse_path).to_list()
    flexible_nodes = df_subset['query_path'].apply(parse_path).to_list()

    # join them interweavingly into a path
    genome_path = ','.join(roundrobin(core_nodes, flexible_nodes))

    # regex substitution to clean up the syntax of the path
    genome_path = re.sub('>', '+', genome_path)
    genome_path = re.sub('<', '-', genome_path)

    df_chr = pd.DataFrame(
        data = {
            'chr': chr_id,
            'node_w_strand': genome_path.split(',')
        }
    )

    return(df_chr[['chr', 'node_w_strand']])

# create a dataframe of all the nodes present in bubbles
observed_chromosomes = df_sample['chr'].unique()
df_nodes_in_bubbles = pd.concat(map(retrieve_pseudoref_nodes_for_chromosome, observed_chromosomes), axis=0)
df_nodes_in_bubbles['chr_species'] = [f"{SAMPLE_NAME}_{x}" for x in df_nodes_in_bubbles['chr']]


# %%
# join nodes by chromosome to form paths
df_paths_in_bubbles = (df_nodes_in_bubbles
    .groupby('chr').agg({'node_w_strand': lambda x: ','.join(x)})
    .reset_index()
    .rename(columns={'node_w_strand': 'path'})
)

# then, split chromosome with gaps of no coverage (represented by ".") into separate rows
# provide an index for each portion of the chromosome
df_paths_in_bubbles = (df_paths_in_bubbles
    .assign(path = lambda dataframe: dataframe['path'].map(lambda x: x.split('.')) )
    .explode('path')
    .assign(path = lambda dataframe: dataframe['path'].map(lambda x: re.sub("^,|,$", "", x)) )
    .reset_index(drop=True)
)

# %%
# add backbone chromosomes that are not part of bubbles
# these should be single segment entries
extra_paths = (df_segments_lookup
    .query("rank==0")
    .query("~chr.isin(@observed_chromosomes)")
    .assign(path = lambda dataframe: dataframe['node'].map(lambda x: f'{x}+'))
    .loc[:, ['chr', 'path'] ]
)

# %%
# combine the two dataframes above to form the complete set of nodes
df_paths_agg = pd.concat([df_paths_in_bubbles, extra_paths]).reset_index(drop=True)
df_paths_agg

# %%
# add coverage information to single segment paths
# otherwise coverage is set to 1
def add_coverage_to_single_segment_paths(df_sample):
    df_sample['node'] = [x[:-1] for x in df_sample['path'] ]
    df_sample['coverage'] = df_sample['node'].map(df_segments_lookup.set_index('node')['coverage'].to_dict() )
    df_sample['coverage'] = [1 if np.isnan(x) else x for x in df_sample['coverage'] ]
    return df_sample[['chr', 'path', 'coverage']]

df_paths_agg = add_coverage_to_single_segment_paths(df_paths_agg)
df_paths_agg

# %%
# drop single segment nodes with no coverage
df_paths_agg_filt = df_paths_agg.query("coverage > 0.5").copy()

# add index for each "scaffold" grouped by chromosome
df_paths_agg_filt['path_index'] = df_paths_agg_filt.groupby('chr').cumcount() + 1
df_paths_agg_filt['chr_species_and_path_index'] = (
    SAMPLE_NAME + '_' + df_paths_agg_filt['chr'] + '-' + df_paths_agg_filt['path_index'].astype(str)
)

# sort oaths by chromosome and path index
df_paths_agg_filt = df_paths_agg_filt.sort_values(by=['chr', 'path_index'])
df_paths_agg_filt

# %%
# write path of pseudorefence in the format for GFA files
# zero length paths * are removed
(df_paths_agg_filt
    .assign(record_type='P', overlap='*',
            path = lambda dataframe: dataframe['path'].map(lambda x: re.sub("\*,", "", x)) )
    .loc[:, ['record_type', 'chr', 'chr_species_and_path_index', 'path', 'overlap']]
    .to_csv(OUTPUT_DIR.joinpath(f"{SAMPLE_NAME}.pseudoref.paths"), sep='\t', header=None, index=None)
)

# %%
######################################################################
# RECORDING NODES WITH COVERAGE

# split the paths with coverage into their respective nodes
df_nodes_filt = (df_paths_agg_filt
    .assign(path = lambda dataframe: [re.sub("\+|\-", "", x) for x in dataframe['path']] )
    .assign(node = lambda dataframe: [x.split(',') for x in dataframe['path']] )
    .explode('node')
    .loc[:, ['chr_species_and_path_index', 'node'] ]
    .assign(chr = lambda dataframe: dataframe['chr_species_and_path_index'])
)

# add information about the node length and whether it is in core
df_nodes_filt = df_nodes_filt.merge(df_segments_lookup[['node', 'len', 'core_bool']], how='left')  # add length and whether in core
df_nodes_filt.loc[df_nodes_filt['node'] == '*', 'len'] = 0
df_nodes_filt.loc[df_nodes_filt['node'] == '*', 'core_bool'] = False

df_nodes_filt

# %%
def calculate_pseudoref_nodes_coord_by_chromosome(df_in):
    """
    Calculates the start and end coordinates for each node grouped by chromosome.

    NOTE THAT INPUT MUST BE SORTED IN THE ORDER THE NODES APPEAR!
    Input columns required: chr, node, len
    """

    def calc_start_and_end(chr_id):
        df_subset = df_in.query("chr==@chr_id").copy().reset_index(drop=True)
        df_subset['len'] = df_subset['len'].astype(int)
        df_subset['end'] = df_subset['len'].cumsum()
        df_subset['start'] = pd.concat([pd.Series([0]), df_subset.iloc[:-1]['end']], axis=0).reset_index(drop=True)
        return(df_subset)
        
    # retrieve unique chromsome identifiers
    list_of_chr = df_in['chr'].unique()
    
    # create a dataframe of all the nodes present in the pseudoreference
    df_out = pd.concat(map(calc_start_and_end, list_of_chr), axis=0).reset_index(drop=True)

    return(df_out[['chr', 'start', 'end', 'node', 'len', 'core_bool']])

df_nodes_filt = calculate_pseudoref_nodes_coord_by_chromosome(df_nodes_filt)
df_nodes_filt


# %%
# write the coordinates of the nodes in the pseudoreference in BED format
# for all nodes
(df_nodes_filt[['chr', 'start', 'end', 'node', 'len', 'core_bool']]
    .to_csv(OUTPUT_DIR.joinpath(f"{SAMPLE_NAME}.nodes"), sep='\t', index=None)
)

# %%
# core nodes
(df_nodes_filt.query('core_bool==True')[['chr', 'start', 'end', 'node', 'len']]
    .to_csv(OUTPUT_DIR.joinpath(f"{SAMPLE_NAME}.nodes.core.bed"), sep='\t', header=None, index=None)
)

# %%
# flexible nodes
(df_nodes_filt.query('core_bool==False')[['chr', 'start', 'end', 'node', 'len']]
    .to_csv(OUTPUT_DIR.joinpath(f"{SAMPLE_NAME}.nodes.flex.bed"), sep='\t', header=None, index=None)
)

# %%
