configfile: "config.yaml"

import functools
import os
import os.path

from scripts.common import detect_reads, gather_refs, dump_dict

#Config parameters
IN = config["IN"]
BIN = config["BIN"]
SCRIPTS = config["SCRIPTS"]
SOFT = config["SOFT"]
try:
    QUAST_DIR = config["QUAST"]
    QUAST = os.path.join(QUAST_DIR, "quast.py")
    METAQUAST = os.path.join(QUAST_DIR, "metaquast.py")
except KeyError:
    QUAST = "quast"
    METAQUAST = "metaquast"
MIN_CONTIG_LENGTH = int(config.get("MIN_CONTIG_LENGTH", 2000))
THREADS = config.get("THREADS", 16)

#Autodetect samples and their reads
SAMPLE_DIRS = set(glob_wildcards(IN + "/{sample,sample\d+}")[0])
SAMPLE_COUNT = len(SAMPLE_DIRS)
SAMPLES = list()
for i in range(1, SAMPLE_COUNT + 1):
    sample_name = "sample" + str(i)
    if sample_name not in SAMPLE_DIRS:
        raise WorkflowError("Samples must be consecutive; missing " + sample_name)
    SAMPLES.append(sample_name)

#Autodetect bins
CAGS, = glob_wildcards("binning/{cag,CAG\d+}/left.fastq")
CAGS.sort()

REFS = dict(gather_refs(config.get("REFS", [])))
ALL_REFS = ",".join(path for path in REFS.values())

def ref_path(wildcards):
    return REFS[wildcards.ref]

onstart:
    try:
        os.mkdir("tmp")
    except:
        pass
    print("Detected", SAMPLE_COUNT, "samples in", IN)
    if CAGS:
        print("Detected good (abundant) CAGs:", " ".join(CAGS))
    if REFS:
        print("Detected references:", " ".join(REFS))

#===============================================================================
#---- Statistics section -------------------------------------------------------
#===============================================================================

#---- Single alignments for samples per reference -------------------------------
#TODO: use alignments from meta version instead
rule quast_all_samples:
    input:   ref=ref_path, contigs=expand("assembly/{sample}.fasta", sample=SAMPLES)
    output:  "stats/summary/q_{ref}.tsv"
    params:  "stats/{ref}"
    log:     "stats/{ref}/quast.log"
    message: "Aligning all samples on {wildcards.ref}"
    shell:   "{QUAST} -R {input.ref} {input.contigs} -o {params} >/dev/null 2>&1 && "
             "cp {params}/report.tsv {output}"

#---- Contigs of interest ------------------------------------------------------
rule filter_all_samples:
    input:   "stats/summary/q_{ref}.tsv"
    output:  expand("stats/{{ref}}/{sample}.cont", sample=SAMPLES)
    message: "Filtering interesting long contigs from all samples for {wildcards.ref}"
    params:  "stats/{ref}/"
    shell:   "{SCRIPTS}/filter_nucmer.py {params} {MIN_CONTIG_LENGTH} 70"

#---- GF of combined sample ----------------------------------------------------
rule combine_filtered:
    input:   contigs=expand("assembly/{sample}.fasta", sample=SAMPLES),
             filters=expand("stats/{{ref}}/{sample}.cont", sample=SAMPLES)
    output:  "stats/{ref}.fasta"
    message: "Gathering all interesting contigs for {wildcards.ref} into a single assembly"
    shell:   "{SCRIPTS}/filter_contigs.py {SAMPLE_COUNT} {output} {input.contigs} {input.filters}"

rule quast_combined:
    input:   ref=ref_path, contigs="stats/{ref}.fasta"
    output:  "stats/q_{ref}_all/report.tsv"
    params:  "stats/q_{ref}_all"
    log:     "stats/q_{ref}_all.log"
    threads: THREADS
    message: "Aligning combined sample on {wildcards.ref}"
    shell:   "{QUAST} -t {threads} -R {input.ref} {input.contigs} -o {params} >{log} 2>&1"

# Run this
rule quast_combined_all:
    input:   expand("stats/q_{ref}_all/report.tsv", ref=REFS)
    message: "Calculated QUAST metrics on all combined samples"


#---- Bins of interest ---------------------------------------------------------
rule int_bins:
    input:   "annotation/{sample}.ann", "stats/{ref}/{sample}.cont"
    output:  "stats/{ref}/{sample}.bin"
    message: "Filtering interesting bins for {wildcards.sample} aligned to {wildcards.ref}"
    shell:   "{SCRIPTS}/filter_bins.py {input} > {output}"

rule int_bins_all_samples:
    input:   expand("stats/{{ref}}/{sample}.bin", sample=SAMPLES)
    output:  "stats/{ref}/total.bin"
    message: "Gathering interesting bins for {wildcards.ref} from all samples"
    run:
        bins = set()
        for in_fn in input:
            with open(in_fn) as infile:
                for line in infile:
                    bins.add(line)
        with open(output[0], "w") as outfile:
            for bin in bins:
                print(bin, file=outfile)

# Run this
rule int_bins_all:
    input:   expand("stats/{ref}/total.bin", ref=REFS)
    message: "Gathered all interesting bins"

#---- GF per bin per reference -------------------------------------------------
#Helper formatters for determining input files from different stages
PROP = {"prelim": ("assembly/{}_splits.fasta",   "annotation/{}.ann"),
        "prop":   ("propagation/{}_edges.fasta", "propagation/{}_edges.ann")}

#TODO: split into different directories per sample
rule split_bins:
    input:   lambda w: PROP[w.prop][0].format(w.sample),
             lambda w: PROP[w.prop][1].format(w.sample)
    output:  touch("binning/{prop}/{sample}.log")
    log:     "binning/{prop}/split_{sample}.log"
    params:  "binning/{prop}"
    message: "Splitting assembly of {wildcards.sample} between {wildcards.prop} bins"
    shell:   "{SCRIPTS}/split_bins.py {input} {params} >{log}"

rule cat_binned_contigs:
    input:   expand("binning/{{prop}}/{sample}.log", sample=SAMPLES)
    output:  "binning/{prop}/{cag,CAG\d+}.fasta"
    params:  "`ls binning/{prop}/*-{cag}.fasta`"
    message: "Combine binned contigs ({wildcards.prop}) for {wildcards.cag}"
    shell:   "cat {params} > {output}"

#Two helpers for determining dependencies of QUAST targets.
#For split contigs and reassemblies, we need only corresponding FASTA.
#For combined contigs, we need to glue their split pieces first.
def stats_input(wildcards):
    if wildcards.stage == "reassembly":
        return expand("reassembly/{cag}.fasta", cag=CAGS)
    w_bin, w_prop = wildcards.stage.split("_", 2)
    if w_bin == "split":
        return expand("binning/{prop}/{sample}.log", prop=w_prop, sample=SAMPLES)
    elif w_bin == "bin":
        return expand("binning/{prop}/{cag}.fasta", prop=w_prop, cag=CAGS)

def stats_data(wildcards):
    if wildcards.stage == "reassembly":
        return "`ls reassembly/CAG*.fasta`"
    w_bin, w_prop = wildcards.stage.split("_", 2)
    masks = {"bin": "CAG*", "split": "*-CAG*"}
    return "`ls binning/{}/{}.fasta`".format(w_prop, masks[w_bin])

rule quast_stats:
    input:   stats_input
    output:  "stats/summary/gf_{stage}.tsv"
    params:  data=stats_data, out="stats/q_{stage}"
    log:     "stats/q_{stage}.log"
    threads: THREADS
    message: "Aligning {wildcards.stage} assemblies on all references"
    shell:   "{METAQUAST} -t {threads} -R {ALL_REFS} {params.data} -o {params.out} >{log} 2>&1 && "
             "cp '{params.out}/summary/TSV/Genome_fraction_(%).tsv' {output}"

# Run this AFTER 'all'
rule stats_all:
    input:   expand("stats/summary/gf_{bin}_{prop}.tsv", bin=["bin"], prop=["prelim", "prop"])
    message: "Gathered some numbers, deal with them."

#---- Reassembly statistics ----------------------------------------------------

# Run this AFTER 'reassembly_all'
rule stats_reassembly:
    input:   "stats/summary/gf_reassembly.tsv"
    output:  "stats/summary/reassembly.tsv"
    params:  "stats/q_reassembly"
    message: "Gathered bins stats"
    shell:   "{SCRIPTS}/gather_stats.py {params} > {output}"

#---- Propagator statistics ----------------------------------------------------
rule prop_stats:
    input:   prelim="annotation/{sample}.ann", prop="annotation/{sample}_edges.ann",
             contigs="assembly/{sample}.fasta", edges="assembly/{sample}_edges.fasta",
             ref=REFS.values() #, bins="{sample}/{ref}.bin"
    output:  "stats/prop_{cag}/{sample}.tsv"
    log:     "stats/prop_{cag}/{sample}.log"
    message: "Calculating propagation statistics for {wildcards.sample}"
    shell:   "{BIN}/stats -k {K} -s {wildcards.sample}/assembly/{SAVES} -r {input.ref}"
             " -c {input.contigs} -a {input.prelim} -e {input.edges} -p {input.prop}"
             " -b {wildcards.cag} -o {output} >{log}"

# Run this
rule prop_stats_all:
    input:   expand("stats/prop_{cag}/{sample}.tsv", sample=SAMPLES, cag=CAGS)
    message: "Calculated propagation statistics"

#---- PCA ----------------------------------------------------------------------
rule pca:
    input:   "profile/canopy.in", "profile/canopy.out", "{sample}.cont"
    output:  "stats/{sample}.png"
    message: "Doing some visualization"
    shell:
        "Rscript {SCRIPTS}/pca.R {input} {output}"

rule combine_contigs:
    input:   expand("stats/{ref}/{sample}.cont", sample=SAMPLES, ref=REFS)
    output:  "stats/total.cont"
    message: "Combining interesting contigs for all samples/references"
    run:
        shell("rm -f {output}")
        for sample in SAMPLES:
            for ref in REFS:
                shell("awk '{{print $0 \"\t{ref}\"}}' stats/{ref}/{sample}.cont >> {output}")

# Run this
rule pca_total:
    input:   "profile/canopy.in", "profile/canopy.out", "stats/total.cont"
    output:  "stats/summary/pca.png"
    shell:   "Rscript {SCRIPTS}/pca.R {input} {output}"
