# %%
import re
import pandas as pd
import numpy as np
from pathlib import Path

# %%
# directories

DATA_DIR = Path.home().joinpath("code/malawi_transposon/cloud/graph_genome/malawi_haplochromines_zebra/")
PATH_TO_GRAPH_NODEINFO = DATA_DIR.joinpath("malawi_haplochromines_zebra-graph.nodeinfo")
PATH_TO_GRAPH_NODES_BY_BUBBLE= DATA_DIR.joinpath("malawi_haplochromines_zebra-variants.nodesbybubble")

assert DATA_DIR.is_dir()
assert PATH_TO_GRAPH_NODEINFO.is_file()
assert PATH_TO_GRAPH_NODES_BY_BUBBLE.is_file()

# output file
PATH_OUTPUT= DATA_DIR.joinpath("pseudoreference/pav/graph_bbsegments_coverage.csv")

# %% 
# read in the dataframe about the segments in each bubble
# melt/split out the flexlist into individual rows
list_sources = []
list_sinks = []
list_flexlists = []
with open(PATH_TO_GRAPH_NODES_BY_BUBBLE) as file_object:
    line = file_object.readline()
    while line:
        current_segments = line.strip('\n').split(',')
        list_sources.append(current_segments[0])
        list_sinks.append(current_segments[-1])
        list_flexlists.append(current_segments[1:-1])
        line = file_object.readline()
df_bubble = pd.DataFrame({'source': list_sources, 'sink': list_sinks, 'flex': list_flexlists}).explode('flex')

# %%
# read in the dataframe about information for individual segments
df_segments_lookup = (
    pd.read_table(
        PATH_TO_GRAPH_NODEINFO, 
        dtype = {'segment': str, 'chr': str, 'start': int, 'length': int, 'rank': int}, 
        index_col = False)
    .assign(end = lambda dataframe: dataframe['start'] + dataframe['length'] )
    .loc[:, ['chr', 'start', 'end', 'segment', 'length', 'rank']]
)

# add info about whether it is a core or flexible segment (THIS IS PROBABLY NOT NEEDED)
df_segments_lookup['core_bool'] = True
df_segments_lookup.loc[df_segments_lookup['segment'].isin(df_bubble.flex), 'core_bool'] = False


# %%
list_of_samples = "astCal1.2_ensembl mayZeb2.0_ensembl troMau2.0 aulStu5.0 otoArg1.0 copChr1.0 rhaChi1.0 rhaChi2.0".split()

for i, sample in enumerate(list_of_samples):
    PATH_TO_SAMPLE_COVERAGE = DATA_DIR.joinpath(f"{sample}.coverage")
    assert PATH_TO_SAMPLE_COVERAGE.is_file()
    df_coverage_sample = (
        pd.read_table(PATH_TO_SAMPLE_COVERAGE, 
            names = ['segment', 'len', 'rank', 'chr', 'coverage'], 
            dtype = {'segment': str, 'len': int, 'rank': int, 'chr': str, 'coverage': float}, 
            index_col = False)
        .loc[:, ['segment', 'coverage']]
    )

    # get coverage values for source and sink
    df_bubble_sample = df_bubble.copy()
    coverage_dict = df_coverage_sample.set_index('segment').to_dict()['coverage']
    df_bubble_sample['source_cov'] = df_bubble_sample['source'].map(coverage_dict)
    df_bubble_sample['sink_cov'] = df_bubble_sample['sink'].map(coverage_dict)

    # keep only those bubbles where both sink and source have coverage
    # we can't simply filter for flex segments with coverage, because the query species may not have that
    cov_threshold = 0.1
    df_bubble_sample_filtered = (df_bubble_sample
        .query("source_cov > @cov_threshold & sink_cov > @cov_threshold")
    )

    # correct segment coverage values for the sample if their bubble is covered
    df_coverage_sample['cov_corr'] = df_coverage_sample['coverage']
    df_coverage_sample.loc[df_coverage_sample['segment'].isin(df_bubble_sample_filtered.flex), 'cov_corr'] = 1

    # add the corrected coverage into the main dataframe
    df_segments_lookup = df_segments_lookup.merge(
        df_coverage_sample.loc[:, ['segment', 'cov_corr']].rename(columns = {'cov_corr': f"cov_{i}"}),
        how="left"
    )


# %%
# count number of samples with coverage above a predefined list of thresholds
list_of_thresholds = [50, 60, 70, 75, 80, 85, 90, 95, 99, 100]

for threshold in list_of_thresholds:
    df_segments_lookup[f'n_cov{str(threshold)}'] = (
        (df_segments_lookup.filter(regex='^cov_',axis=1) >= threshold/100).sum(axis=1)
    )

df_segments_lookup

# %%
(df_segments_lookup
    .query("rank == 0")
    .loc[:, ['segment', 'chr', 'start', 'end', 'core_bool'] + [f'n_cov{str(x)}' for x in list_of_thresholds] ]
    .to_csv(PATH_OUTPUT, index=None)
)

# %%
# write spanning coverage values by species
list_of_samples2 = 'astCal,mayZeb,troMau,aulStu,otoArg,copChr,rhaChi,rhaChi2'.split(',')
PATH_OUTPUT2 = DATA_DIR.joinpath("pseudoreference/pav/graph_bbsegments_coverage_by_species.csv")
(df_segments_lookup
    .query("rank == 0")
    .loc[:, ['segment', 'chr', 'start', 'end', 'core_bool'] + [f'cov_{str(x)}' for x in range(0, 8)] ]
    .rename(dict(zip([f'cov_{str(x)}' for x in range(0, 8)], list_of_samples2)), axis=1)
    .to_csv(PATH_OUTPUT2, index=None)
)

# %%