configfile: "config.yaml"

import functools
import os
import os.path

from scripts.common import detect_reads, dump_dict

#Config parameters
IN = config["IN"]
SPADES = config.get("SPADES", config["LOCAL_DIR"])
SPADES_REASSEMBLY = config.get("SPADES_REASSEMBLY", SPADES)
BIN = config.get("BIN", os.path.join(SPADES, "build/release/bin"))
SCRIPTS = config.get("SCRIPTS", os.path.join(SPADES, "src/projects/mts/scripts"))
SOFT = config["SOFT"]
K = int(config.get("K", 55))
SMALL_K = int(config.get("SMALL_K", 21))
MIN_CONTIG_LENGTH = int(config.get("MIN_CONTIG_LENGTH", 2000))
THREADS = config.get("THREADS", 16)
BINNER = config.get("BINNER", "canopy")

#Path to saves of necessary assembly stage
SAVES = "K{0}/saves/01_before_repeat_resolution/graph_pack".format(K)

#Autodetect samples and their reads
SAMPLE_DIRS = set(glob_wildcards(os.path.join(IN, "{sample,sample\d+}"))[0])
SAMPLE_COUNT = len(SAMPLE_DIRS)
SAMPLES = list()
for i in range(1, SAMPLE_COUNT + 1):
    sample_name = "sample" + str(i)
    if sample_name not in SAMPLE_DIRS:
        raise WorkflowError("Samples must be consecutive; missing " + sample_name)
    SAMPLES.append(sample_name)

SAMPLE_READS = dict(map(lambda sample: (sample, detect_reads(os.path.join(IN, sample))), SAMPLES))

def sample_reads(dir, wildcards):
    return SAMPLE_READS[wildcards.sample][dir]

left_reads  = functools.partial(sample_reads, 0)
right_reads = functools.partial(sample_reads, 1)

onstart:
    try:
        os.mkdir("tmp")
    except:
        pass
    print("Detected", SAMPLE_COUNT, "samples in", IN)

# ---- Main pipeline -----------------------------------------------------------

rule all:
    input:   dynamic("reassembly/{cag}.fasta")
    message: "Dataset of {SAMPLE_COUNT} samples from {IN} has been processed."

rule assemble:
    input:   left=left_reads, right=right_reads
    output:  "assembly/{sample,sample\d+}.fasta"
    params:  "assembly/{sample}"
    log:     "assembly/{sample}.log"
    threads: THREADS
    message: "Assembling {wildcards.sample} with SPAdes"
    shell:   "{SPADES}/spades.py --meta -m 400 -t {threads} -1 {input.left} -2 {input.right}"
             " -o {params} >{log} 2>&1 && "
             "cp {params}/scaffolds.fasta {output}"

rule assemble_all:
    input:   expand("assembly/{sample}.fasta", sample=SAMPLES)
    message: "Assembled all samples"

rule descriptions:
    output:  expand("profile/{sample}.desc", sample=SAMPLES)
    message: "Generating sample descriptions"
    run:
        for sample in SAMPLES:
            with open("profile/{}.desc".format(sample), "w") as out:
                wildcards.sample = sample
                print(left_reads(wildcards),  file=out)
                print(right_reads(wildcards), file=out)

rule kmc:
    input:   "profile/{sample}.desc"
    output:  temp("tmp/{sample}.kmc_pre"), temp("tmp/{sample}.kmc_suf")
    params:  min_mult=2, tmp="tmp/{sample}_kmc", out="tmp/{sample}"
    log:     "profile/kmc_{sample}.log"
    threads: THREADS
    message: "Running kmc for {wildcards.sample}"
    shell:   "mkdir {params.tmp} && "
             "{SOFT}/kmc -k{SMALL_K} -t{threads} -ci{params.min_mult} -cs65535"
             " @{input} {params.out} {params.tmp} >{log} 2>&1 && "
             "rm -rf {params.tmp}"

rule multiplicities:
    input:   expand("tmp/{sample}.kmc_pre", sample=SAMPLES), expand("tmp/{sample}.kmc_suf", sample=SAMPLES)
    output:  "profile/kmers.mpl"
    params:  kmc_files=" ".join(expand("tmp/{sample}", sample=SAMPLES)), out="profile/kmers"
    log:     "profile/kmers.log"
    message: "Gathering {SMALL_K}-mer multiplicities from all samples"
    shell:   "{BIN}/kmer_multiplicity_counter -k {SMALL_K} -o {params.out} --sample 3"
             " -f {params.kmc_files} >{log} 2>&1 && "
             "rm tmp/*.sorted"

rule profile:
    input:   contigs="assembly/{sample}.fasta", mpl="profile/kmers.mpl"
    output:  "profile/{sample}.id", "profile/{sample}.mpl", "assembly/{sample}_splits.fasta"
    log:     "profile/{sample}.log"
    message: "Counting contig abundancies for {wildcards.sample}"
    shell:   "{BIN}/contig_abundance_counter -k {SMALL_K} -w tmp -c {input.contigs}"
             " -n {SAMPLE_COUNT} -m profile/kmers -o profile/{wildcards.sample}"
             " -f assembly/{wildcards.sample}_splits.fasta"
             " -l {MIN_CONTIG_LENGTH} >{log} 2>&1"

#Clustering with canopy
rule binning_pre:
    input:   expand("profile/{sample}.id", sample=SAMPLES)
    output:  "binning/{binner}/profiles.in"
    message: "Preparing input for {wildcards.binner}"
    shell:   "{SCRIPTS}/make_input.py -t {wildcards.binner} -d profile -o {output} {SAMPLES}"

#Default clustering with CONCOCT
rule canopy:
    input:   "binning/canopy/profiles.in"
    output:  out="binning/canopy/binning.out", prof="binning/canopy/bins.prof"
    threads: THREADS
    message: "Running canopy clustering"
    shell:   "{SOFT}/cc.bin -n {threads} -i {input} -o {output.out} -c {output.prof} >binning/canopy/canopy.log 2>&1"

#Alternative clustering with CONCOCT
rule combine_splits:
    input:   expand("assembly/{sample}_splits.fasta", sample=SAMPLES)
    output:  "assembly/samples_splits.fasta"
    message: "Combine splitted contigs"
    shell:   "{SCRIPTS}/combine_contigs.py -r {input} > {output}"

rule concoct:
    input:   contigs=rules.combine_splits.output[0], profiles="binning/concoct/profiles.in"
    output:  out="binning/concoct/clustering_gt1000.csv"
    params:  "binning/concoct"
    message: "Running CONCOCT clustering"
    shell:   "mkdir -p {params} && "
             "set +u; source activate concoct_env; set -u && "
             "concoct --composition_file {input.contigs} --coverage_file {input.profiles} -b {params}"

binning_inputs = {"canopy": rules.canopy.output.out, "concoct": rules.concoct.output.out}

rule binning_post:
    input:   binning_inputs[BINNER]
    output:  expand("annotation/{sample}.ann", sample=SAMPLES)
    message: "Preparing raw annotations"
    shell:   "{SCRIPTS}/parse_output.py -t {BINNER} -o annotation {input}"

#Post-clustering pipeline
rule read_binning:
    input:   contigs="assembly/{sample}.fasta", ann="annotation/{sample}.ann",
             left=left_reads, right=right_reads
    output:  "propagation/{sample}_edges.ann"
    params:  saves=os.path.join("assembly/{sample}/", SAVES),
             splits="assembly/{sample}_splits.fasta",
             out="propagation/{sample}_edges"
    log:     "binning/{sample}.log"
    message: "Propagating annotation & binning reads for {wildcards.sample}"
    shell:   "{BIN}/prop_binning -k {K} -s {params.saves} -c {input.contigs}"
             " -a {input.ann} -f {params.splits} -l {input.left} -r {input.right}"
             " -o binning -n {wildcards.sample} -d {params.out} >{log} 2>&1"

#TODO: bin profiles for CONCOCT
rule choose_samples:
    input:   binned=expand("propagation/{sample}_edges.ann", sample=SAMPLES),
             prof=rules.canopy.output.prof
    output:  dynamic("binning/{cag}/left.fastq"),
             dynamic("binning/{cag}/right.fastq")
    log:     "binning/choose_samples.log"
    message: "Choosing samples for all CAGs"
    shell:   "{SCRIPTS}/choose_samples.py {input.prof} binning/ >{log} 2>&1"

rule reassembly_config:
    input:   "binning/{cag}/left.fastq"
    output:  "reassembly/{cag}.yaml"
    message: "Generated config file for reassembly of {wildcards.cag}"
    run:
        with open(output[0], "w") as outfile:
            conf = {"k": SMALL_K, "sample_cnt": SAMPLE_COUNT,
                    "kmer_mult": str(rules.multiplicities.params.out),
                    "bin": wildcards.cag, "bin_prof": str(rules.canopy.output.prof)}
            dump_dict(conf, outfile)

rule reassemble:
    input:   left="binning/{cag}/left.fastq", right="binning/{cag}/right.fastq",
             config="reassembly/{cag}.yaml"
    output:  "reassembly/{cag}.fasta"
    params:  "reassembly/reassembly_{cag}"
    log:     "reassembly/reassembly_{cag}.log"
    threads: THREADS
    message: "Reassembling reads for {wildcards.cag}"
    shell:   "{SPADES_REASSEMBLY}/spades.py --meta -t {threads}"
             " --pe1-1 {input.left} --pe1-2 {input.right} --pe1-ff"
             " -o {params} --series-analysis {input.config} >{log} 2>&1 && "
             "cp {params}/scaffolds.fasta {output}"
