configfile: "config.yaml"

import itertools

wildcard_constraints:
    L="\d+",
    sample1="\w+",
    sample2="\w+",
    dir="\w+"

# get variables from config file
species_list = config['species'].keys()
min_variant_size_list = config['min_variant_sizes']

# generate a list of all pairwise permutations of the genomes (n**2 - n)
species_pair_list = ['-'.join(x) for x in itertools.permutations(species_list, 2)]

rule all:
    input:
        expand("output/L{L}/{species_pair}.{filetype}", 
            L = min_variant_size_list,
            species_pair = species_pair_list,
            filetype = 'gfa,gaf,stat,bubble,segments,segments.df,out,events'.split(',')),
        "results/results.csv", "results/results_events.csv",
        expand("output/segments/segment_dataframes_L{L}.csv", L = min_variant_size_list),
        expand("results/fig/percent_conserv/percent_conserv-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/percent_sv/percent_sv-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/bubble/bubble-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/insert/insert-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/insert/insert_bases-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/delete/delete-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/delete/delete_bases-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/net_indel/net_indel-L{L}.pdf", L = min_variant_size_list),
        expand("results/fig/net_indel/net_indel_bases-L{L}.pdf", L = min_variant_size_list),
        "results/fig/percent_conserv/bb_robustness.pdf",
        "results/fig/percent_sv/bb_robustness.pdf",
        "results/fig/bubble/bb_robustness_bubble.pdf",
        "results/fig/insert/bb_robustness_ins.pdf",
        "results/fig/delete/bb_robustness_del.pdf",
        "results/bubble_agg_stats-L50.txt"


##########
# functions to get input variables

def retrieve_genomes(wildcards):
    species_pair = [wildcards.species1, wildcards.species2]
    genome_list = [f'dir/genome_malawi/{config["species"][x]}.renamed.fa' for x in species_pair]
    return genome_list

def retrieve_genome_indexes(wildcards):
    species_pair = [wildcards.species1, wildcards.species2]
    genome_index_list = [f'dir/genome_malawi/{config["species"][x]}.renamed.fa.fai' for x in species_pair]
    return genome_index_list

def get_query_genome(wildcards):
    return f'dir/genome_malawi/{config["species"][wildcards.species2]}.renamed.fa'

##########
# rules

rule build_graph:
    input:
        retrieve_genomes
    log:
        "log/build_graph/L{L}/{species1}-{species2}.log"
    output:
        "output/L{L}/{species1}-{species2}.gfa"
    threads: 8
    resources: mem_mb=40000
    shell:
        "(minigraph -xggs -c -t{threads} -L{wildcards.L} {input} > {output}) 2> {log}"

rule gfatools_bubble:
    input:
        "output/L{L}/{species1}-{species2}.gfa"
    output:
        temp("output/L{L}/{species1}-{species2}.bubble.tmp")
    shell:
        "gfatools bubble {input} > {output}"

rule get_segments_by_bubble:
    input:
        "output/L{L}/{species1}-{species2}.bubble.tmp"
    output:
        "output/L{L}/{species1}-{species2}.bubble"
    shell:
        "cut -f12 {input} > {output}"

rule gfatools_stat:
    input:
        "output/L{L}/{species1}-{species2}.gfa"
    output:
        "output/L{L}/{species1}-{species2}.stat"
    shell:
        "gfatools stat {input} > {output}"

rule align_query_to_graph:
    input:
        gfa = "output/L{L}/{species1}-{species2}.gfa",
        seq = get_query_genome
    log:
        "log/align_query_to_graph/L{L}/{species1}-{species2}.log"
    output:
        "output/L{L}/{species1}-{species2}.gaf"
    threads: 8
    resources: mem_mb=40000
    shell:
        "(minigraph -xasm -c --cov -t{threads} {input.gfa} {input.seq} > {output}) 2> {log}"

rule get_coverage_by_graph_segments:
    input:
        "output/L{L}/{species1}-{species2}.gaf"
    output:
        "output/L{L}/{species1}-{species2}.segments"
    shell:
        "paste "
        "<(grep '^S' {input} | cut -f2) "
        "<(grep '^S' {input} | cut -f4 | cut -d: -f3) "
        "<(grep '^S' {input} | cut -f7 | cut -d: -f3) "
        "<(grep '^S' {input} | cut -f5 | cut -d: -f3) "
        "<(grep '^S' {input} | cut -f8 | cut -d: -f3) > {output}"

rule add_more_info_to_segments:
    input:
        segments = "output/L{L}/{species1}-{species2}.segments",
        bubble   = "output/L{L}/{species1}-{species2}.bubble"
    output:
        "output/L{L}/{species1}-{species2}.segments.df"
    script:
        "script/add_more_info_to_segments.py"

rule collate_segment_dataframes:
    input:
        expand("output/L{{L}}/{species_pair}.segments.df", species_pair = species_pair_list)
    output:
        "output/segments/segment_dataframes_L{L}.csv"
    shell:
        "printf species1,species2,min_variant_size,"
        "segment,length,rank,chr_name,coverage,core_bool,"
        "bubble_id,bubble_covered,cov_corr,weighted_length'\n' > {output}; "
        "cat {input} | grep -v '^species1' >> {output}"

rule calculate_graph_stats:
    input:
        segments = "output/L{L}/{species1}-{species2}.segments.df",
        stat     = "output/L{L}/{species1}-{species2}.stat",
        bubble   = "output/L{L}/{species1}-{species2}.bubble"
    output:
        "output/L{L}/{species1}-{species2}.out"
    script:
        "script/calculate_graph_stats.py"

rule collate_graph_stats:
    input:
        expand("output/L{L}/{species_pair}.out", 
            L = min_variant_size_list, species_pair = species_pair_list)
    output:
        "results/results.csv"
    shell:
        "printf species1,species2,min_variant_size,"
        "n_segments,n_edges,n_bubbles,mean_segment_length,"
        "graph_size,linear_size,linear_count,variable_size,variable_count,"
        "variable_size_bb,variable_count_bb,variable_size_nonref,variable_count_nonref,"
        "bb_size,bb_cov,bb_cov_perc,bb_extra_perc,bb_extra_perc_corr,bb_detect_perc,"
        "core_size,flex_size,pangenome_size,core_perc,flex_perc'\n' > {output}; "
        "cat {input} >> {output}"

rule count_indel_events_and_bases:
    input:
        segments = "output/L{L}/{species1}-{species2}.segments.df",
        indexes = retrieve_genome_indexes
    output:
        "output/L{L}/{species1}-{species2}.events"
    script:
        "script/count_indel_events_and_bases.py"

rule collate_indel_events_and_bases:
    input:
        expand("output/L{L}/{species_pair}.events", 
            L = min_variant_size_list, species_pair = species_pair_list)
    output:
        "results/results_events.csv"
    shell:
        "printf species1,species2,min_variant_size,"
        "n_events,n_ins,n_del,n_sub,n_bases_ins,n_bases_del'\n' > {output}; "
        "cat {input} >> {output}"

rule plot_percent_conserv:
    input:
        "results/results.csv"
    output:
        heatmap = expand("results/fig/percent_conserv/percent_conserv-L{L}.pdf", L = min_variant_size_list),
        robustness = "results/fig/percent_conserv/bb_robustness.pdf"
    script:
        "script/plot_percent_conserv.R"

rule plot_percent_sv:
    input:
        "results/results.csv"
    output:
        heatmap = expand("results/fig/percent_sv/percent_sv-L{L}.pdf", L = min_variant_size_list),
        robustness = "results/fig/percent_sv/bb_robustness.pdf"
    script:
        "script/plot_percent_sv.R"

rule plot_indel_density_matrices:
    input:
        "results/results.csv", "results/results_events.csv"
    output:
        bubble = expand("results/fig/bubble/bubble-L{L}.pdf", L = min_variant_size_list),
        inserted = expand("results/fig/insert/insert-L{L}.pdf", L = min_variant_size_list),
        inserted_bases = expand("results/fig/insert/insert_bases-L{L}.pdf", L = min_variant_size_list),
        deleted = expand("results/fig/delete/delete-L{L}.pdf", L = min_variant_size_list),
        deleted_bases = expand("results/fig/delete/delete_bases-L{L}.pdf", L = min_variant_size_list),
        net_indel = expand("results/fig/net_indel/net_indel-L{L}.pdf", L = min_variant_size_list),
        net_indel_bases = expand("results/fig/net_indel/net_indel_bases-L{L}.pdf", L = min_variant_size_list),
        robustness_bubble = "results/fig/bubble/bb_robustness_bubble.pdf",
        robustness_ins = "results/fig/insert/bb_robustness_ins.pdf",
        robustness_del = "results/fig/delete/bb_robustness_del.pdf"
    script:
        "script/plot_indel_density.R"

rule plot_bubble_histograms:
    input:
        "output/segments/segment_dataframes_L{L}.csv"
    output:
        stats = "results/bubble_agg_stats-L{L}.txt",
        hist = "results/fig/bubble_agg/L{L}/bubble_hist-L{L}.pdf",
        cum_count = "results/fig/bubble_agg/L{L}/cum_count-L{L}.pdf",
        cum_len = "results/fig/bubble_agg/L{L}/cum_len-L{L}.pdf"
    script:
        "script/plot_bubble_histograms.R"

