configfile: "config.yaml"

import itertools
import random

wildcard_constraints:
    L="\d+",
    species_order="[a-z]+",
    dir="\w+"

species_list = config['species'].keys()

# random seeds, each used to generate permutations by species in the for loop
random.seed(322644)
random_seeds = random.sample(range(10000, 99999), 8)

# generate a dictionary containing a predefined number permutations for each backbone
permutations_by_backbone = {}
for idx, backbone in enumerate(species_list):
    other_genomes = list(species_list)
    del other_genomes[idx]
    all_permutations = list(map(lambda x: [backbone] + list(x), itertools.permutations(other_genomes)))
    random.seed(random_seeds[idx])
    permutations_by_backbone[backbone] = random.sample(all_permutations, config['n_permutations'])

# change the 0th calliptera entry to the default ordering
permutations_by_backbone['c'][0] = list(species_list)

rule all:
    input:
        expand("output/sweep_species_order/{backbone}_p{perm}_L50.{filetype}",
            backbone = species_list,
            perm = range(1, config['n_permutations'] + 1),
            filetype = 'gfa,order,stat,bubble,segments,out,bmk'.split(',')
        ),
        expand("output/sweep_variant_size/c_p1_L{L}.{filetype}",
            L = config['variant_sizes'],
            filetype = 'gfa,order,stat,bubble,segments,out,bmk'.split(',')
        ),
        "results/sweep_variant_size.summary", "results/sweep_variant_size.benchmark",
        "results/sweep_species_order.summary", "results/sweep_species_order.benchmark"

def get_species_order(wildcards):
    species_order = ''.join(permutations_by_backbone[wildcards.backbone][int(wildcards.perm) - 1])
    return species_order

def retrieve_genomes(wildcards):
    species_order = get_species_order(wildcards)
    genome_list = [f'dir/genome_malawi/{config["species"][x]}.renamed.fa' for x in species_order]
    return genome_list

rule buildgraph_species_order:
    input:
        retrieve_genomes
    log:
        "log/sweep_species_order/{backbone}_p{perm}_L{L}.log"
    params:
        species_order = get_species_order
    output:
        order = "output/sweep_species_order/{backbone}_p{perm}_L{L}.order",
        gfa = "output/sweep_species_order/{backbone}_p{perm}_L{L}.gfa"
    benchmark:
        "benchmark/sweep_species_order/{backbone}_p{perm}_L{L}.bmk"
    threads: 24
    resources: mem_mb=64000
    shell:
        "printf {params.species_order} > {output.order}; "
        "(minigraph -xggs -c -t{threads} -L{wildcards.L} {input} > {output.gfa}) 2> {log}"

rule buildgraph_variant_size:
    input:
        retrieve_genomes
    log:
        "log/sweep_variant_size/{backbone}_p{perm}_L{L}.log"
    params:
        species_order = get_species_order
    output:
        order = "output/sweep_variant_size/{backbone}_p{perm}_L{L}.order",
        gfa = "output/sweep_variant_size/{backbone}_p{perm}_L{L}.gfa"
    benchmark:
        repeat("benchmark/sweep_variant_size/{backbone}_p{perm}_L{L}.bmk", 3)
    threads: 24
    resources: mem_mb=64000
    shell:
        "printf {params.species_order} > {output.order}; "
        "(minigraph -xggs -c -t{threads} -L{wildcards.L} {input} > {output.gfa}) 2> {log}"

rule gfatools_bubble:
    input:
        "output/{dir}/{backbone}_p{perm}_L{L}.gfa"
    output:
        temp("output/{dir}/{backbone}_p{perm}_L{L}.bubble.tmp")
    shell:
        "gfatools bubble {input} > {output}"

rule get_segments_by_bubble:
    input:
        "output/{dir}/{backbone}_p{perm}_L{L}.bubble.tmp"
    output:
        "output/{dir}/{backbone}_p{perm}_L{L}.bubble"
    shell:
        "cut -f12 {input} > {output}"

rule gfatools_stat:
    input:
        "output/{dir}/{backbone}_p{perm}_L{L}.gfa"
    output:
        "output/{dir}/{backbone}_p{perm}_L{L}.stat"
    shell:
        "gfatools stat {input} > {output}"

rule get_graph_segments:
    input:
        "output/{dir}/{backbone}_p{perm}_L{L}.gfa"
    output:
        "output/{dir}/{backbone}_p{perm}_L{L}.segments"
    shell:
        "paste " 
        "<(grep '^S' {input} | cut -f2) "
        "<(grep '^S' {input} | cut -f4 | cut -d: -f3) "
        "<(grep '^S' {input} | cut -f7 | cut -d: -f3) > {output}"

rule parse_graph_stats:
    input:
        order    = "output/{dir}/{backbone}_p{perm}_L{L}.order",
        stat     = "output/{dir}/{backbone}_p{perm}_L{L}.stat",
        segments = "output/{dir}/{backbone}_p{perm}_L{L}.segments",
        bubble   = "output/{dir}/{backbone}_p{perm}_L{L}.bubble",
        bmk      = "benchmark/{dir}/{backbone}_p{perm}_L{L}.bmk"
    output:
        out  = "output/{dir}/{backbone}_p{perm}_L{L}.out",
        bmk  = "output/{dir}/{backbone}_p{perm}_L{L}.bmk"
    script:
        "script/parse_graph_stats.py"

rule collate_variant_size_runs:
    input:
        stat = expand("output/sweep_variant_size/c_p1_L{L}.out", L = config['variant_sizes']),
        bmk  = expand("output/sweep_variant_size/c_p1_L{L}.bmk", L = config['variant_sizes']) 
    output:
        stat = "results/sweep_variant_size.summary",
        bmk  = "results/sweep_variant_size.benchmark"
    shell:
        "printf species_order,min_variant_size,"
        "n_segments,n_edges,n_bubbles,bb_size,graph_size,"
        "linear_size,linear_count,variable_size,variable_count,"
        "mean,std,q0,q25,q50,q75,q100'\n' > {output.stat}; "
        "cat {input.stat} >> {output.stat}; "
        "printf species_order,min_variant_size,iteration,cpu_time,wallclock_time,max_pss'\n' > {output.bmk}; "
        "cat {input.bmk} | grep -v '^species_order' >> {output.bmk}"

rule collate_species_order_runs:
    input:
        stat = expand("output/sweep_species_order/{backbone}_p{perm}_L50.out", 
                      backbone = species_list, perm = range(1, config['n_permutations'] + 1)),
        bmk  = expand("output/sweep_species_order/{backbone}_p{perm}_L50.bmk",
                      backbone = species_list, perm = range(1, config['n_permutations'] + 1))
    output:
        stat = "results/sweep_species_order.summary",
        bmk  = "results/sweep_species_order.benchmark"
    shell:
        "printf species_order,min_variant_size,"
        "n_segments,n_edges,n_bubbles,bb_size,graph_size,"
        "linear_size,linear_count,variable_size,variable_count,"
        "mean,std,q0,q25,q50,q75,q100'\n' > {output.stat}; "
        "cat {input.stat} >> {output.stat}; "
        "printf species_order,min_variant_size,iteration,cpu_time,wallclock_time,max_pss'\n' > {output.bmk}; "
        "cat {input.bmk} | grep -v '^species_order' >> {output.bmk}"
