from snakemake import stats
from os.path import join as opj
import os

tmpdir = os.environ.get("TMPDIR")
#if workflow.run_local:
#    workflow.shadow_prefix = os.environ.get("TMPDIR")

onstart:
    shell("mkdir -p jobs_exceptions/out_of_memory")
    shell("mkdir -p jobs_exceptions/out_of_time")

# Define out_dir in config file
out_dir = config["out_dir"]

rule all:
    input:
        expand(opj(out_dir, "results/trimmed/{sample}_R1.trimmed.fastq.gz"), sample=config["samples"]),
        expand(opj(out_dir, "results/mapped/{sample}.sorted.bam"), sample=config["samples"]),
        expand(opj(out_dir, "results/marked_duplicates/{sample}.marked.bam"), sample=config["samples"]),
        expand(opj(out_dir,"results/filtered/{sample}.filtered.sorted.bam"), sample=config["samples"]),
        expand(opj(out_dir,"results/filtered/{sample}.filtered.sorted.bam.bai"), sample=config["samples"]),
        expand(opj(out_dir,"results/filtered/{sample}.filtered.CRAM"), sample=config["samples"]),
        expand(opj(out_dir, "results/filtered_ecco/{sample}.filtered.sorted.bam"), sample=config["samples"]),
        expand(opj(out_dir, "results/filtered_ecco/{sample}.filtered.CRAM"), sample=config["samples"]),
        expand(opj(out_dir, "results/fragmentation/{sample}_insert_size_metrics.txt"), sample=config["samples"]),
        expand(opj(out_dir, "results/fragmentation_ecco/{sample}_insert_size_metrics.txt"), sample=config["samples"]),
        expand(opj(out_dir, "results/readDepth/{samples}.bin{binSize}.wig"), samples=config["samples"], binSize=str(config["binSize"])),
        expand(opj(out_dir,"results/stats_ecco/{samples}.filtered.sorted.bam.stats_len.txt"), samples=config["samples"]),


rule concat_fastq_gz:
    input:
        R1=expand("{input_path}/{{sample}}_{lane}_R1_001.fastq.gz", lane=config["lanes"], input_path = config['input_path']),
        R2=expand("{input_path}/{{sample}}_{lane}_R2_001.fastq.gz", lane=config["lanes"], input_path = config['input_path']),
    output:
        fastq_R1=temp(opj (out_dir,"results/tmpfastq/{sample}_R1.fastq.gz")),
        fastq_R2=temp(opj (out_dir,"results/tmpfastq/{sample}_R2.fastq.gz")),
    resources:
        mem_mb=4000, 
        cpus=1,
        runtime='4h',
    benchmark: opj(out_dir,"results/smk-benchmark/concat-fastq-gz_{sample}.txt"),
    threads: 1
    shell:
        """
        cat {input.R1} > '{output.fastq_R1}';
        cat {input.R2} > '{output.fastq_R2}';
        """

rule trimming:
    input:
        R1=rules.concat_fastq_gz.output.fastq_R1,
        R2=rules.concat_fastq_gz.output.fastq_R2,
    output:
        R1=opj(out_dir,"results/trimmed/{sample}_R1.trimmed.fastq.gz"),
        R2=opj(out_dir,"results/trimmed/{sample}_R2.trimmed.fastq.gz"),
    params:
        adapters="data/bbduk_adapters.fa"
    benchmark: opj(out_dir,"results/smk-benchmark/trimming_{sample}.txt")
    resources:
        mem_mb=80000, 
        cpus=8,
        runtime='4h',
    threads: 8
    log:
        stats=opj(out_dir,"results/logs/{sample}.trimming.stats.txt"),
        Log=opj(out_dir,"results/logs/{sample}.trimming.log.txt")
    conda: "envs/env.yaml"
    shell:
        """
        bbduk.sh -Xmx64G -eoom in1={input.R1} in2={input.R2} out1={output.R1} out2={output.R2} ref={params.adapters} ktrim=r k=23 mink=11 hdist=1 stats={log.stats} t=7 2> {log.Log}
        """

rule mapping:
    input:
        opj(out_dir,"results/trimmed/{sample}_R1.trimmed.fastq.gz"),
        opj(out_dir,"results/trimmed/{sample}_R2.trimmed.fastq.gz")
    output:
        unsort = temp(opj(out_dir,"results/mapped/{sample}.bam")),
        sorted= opj(out_dir,"results/mapped/{sample}.sorted.bam"),

    params:
        genome = config['genome'],
    log:
        opj(out_dir,"results/logs/{sample}.mapping.log.txt")
    threads: 8
    benchmark: opj(out_dir,"results/smk-benchmark/mapping_{sample}.txt")
    resources:
        mem_mb=64000, 
        cpus=8,
        runtime='24h',
    conda: "envs/env.yaml"
    shell:
        """
        bwa mem -t {threads} {params.genome} {input}  > {output.unsort} 2> {log};
        samtools sort -@ {threads} -o {output.sorted} {output.unsort} 2> {log};
        """

rule mark_duplicates:
    input:
        opj(out_dir,"results/mapped/{sample}.sorted.bam")
    output:
        bam=opj(out_dir,"results/marked_duplicates/{sample}.marked.bam"),
        mtrx=opj(out_dir,"results/marked_duplicates/{sample}.marked_metrics.txt")
    params:
        TMP_DIR=tmpdir
    benchmark: opj(out_dir,"results/smk-benchmark/markDuplicates_{sample}.txt"),
    resources:
        runtime='24h',
        disk_mb=500000,
        mem_mb=144000,
        max_file_handles=10,
        cpus=1,
    log:
        opj(out_dir,"results/logs/{sample}.mark_duplicates.log.txt"),
    threads: 1
    conda: "envs/env.yaml"
    shell:
        """
        picard -Xmx16g -Djava.io.tmpdir={params.TMP_DIR} MarkDuplicates INPUT={input} OUTPUT={output.bam} METRICS_FILE={output.mtrx} REMOVE_DUPLICATES=false MAX_FILE_HANDLES={resources.max_file_handles} CREATE_INDEX=false COMPRESSION_LEVEL=9 TMP_DIR={params.TMP_DIR} &> {log}
        """

rule filter_and_index:
    input:
        opj(out_dir,"results/marked_duplicates/{sample}.marked.bam"),
    output:
        bam=temp(opj(out_dir,"results/filtered/{sample}.filtered.bam")),
        sorted_bam = opj(out_dir,"results/filtered/{sample}.filtered.sorted.bam"),
        CRAM=opj(out_dir,"results/filtered/{sample}.filtered.CRAM"),
        bai=opj(out_dir,"results/filtered/{sample}.filtered.sorted.bam.bai"),
        stats=opj(out_dir,"results/stats/{sample}.filtered.sorted.bam.stats"),
    params:
        genome = config['genome'],
        insert_size = 1200
    conda: "envs/env.yaml"
    benchmark: opj(out_dir,"results/smk-benchmark/filterIndexBam_{sample}.txt")
    threads: 8
    resources:
        cpus=8,
        runtime='8h',
        mem_mb=64000,
    log:
        view=opj(out_dir,"results/logs/{sample}.filter_and_index.view.log.txt"),
        cram=opj(out_dir,"results/logs/{sample}.filter_and_index.cram.log.txt"),
        sort=opj(out_dir,"results/logs/{sample}.filter_and_index.sort.log.txt"),
        index=opj(out_dir,"results/logs/{sample}.filter_and_index.index.log.txt"),
        stats=opj(out_dir,"results/logs/{sample}.filter_and_index.stats.log.txt"),

    shell:
        """
        samtools view {input} -h -q 5 -F 4 -F 256 -F 1024 -F 2048 -@ {threads} -b > {output.bam} 2> {log.view};
        samtools view {input} -h -q 5 -F 4 -F 256 -F 1024 -F 2048 -@ {threads} -T {params.genome} -C > {output.CRAM} 2> {log.cram};
        samtools sort {output.bam} > {output.sorted_bam} 2> {log.sort};
        samtools index -@ {threads} {output.sorted_bam} > {output.bai} 2> {log.index};
        samtools stats -@ {threads} -i {params.insert_size} {output.sorted_bam} > {output.stats} 2> {log.stats};
        """
        # -h: include header
        # -q 5: only reads with quality >5
        # -F 4: not unmapped
        # -F 256: no secondary alignment; thus primary alignment
        # -F 1024: no PCR duplicate

rule picard_InsertMatrix:
    input:
        bam=rules.filter_and_index.output.sorted_bam,
    output:
        length=opj(out_dir,"results/fragmentation/{sample}_insert_size_metrics.txt"),
        pdf=opj(out_dir,"results/fragmentation/{sample}_insert_size.pdf"),
    params:
        TMP_DIR=opj(out_dir,"results/tmpdir"),
    resources:
        mem_mb=8000,
        cpus=8,
        runtime='4h',
    benchmark: opj(out_dir,"results/smk-benchmark/picardInsertMatrix_{sample}.txt"),
    threads: 8
    log:
        Log=opj(out_dir,"results/logs/{sample}.picardFragment.log.txt"),
    conda: "envs/env.yaml"
    shell:
        """
        picard -Xmx8g CollectInsertSizeMetrics INPUT={input.bam} OUTPUT={output.length} Histogram_FILE={output.pdf} M=0.001 &> {log}
        """

rule bbmerge_error_correction:
    input:
        R1=opj(out_dir,"results/trimmed/{sample}_R1.trimmed.fastq.gz"),
        R2=opj(out_dir,"results/trimmed/{sample}_R2.trimmed.fastq.gz")
    output:
        R1=opj(out_dir,"results/trimmed_ecco_fastq/{sample}_R1.trimmed.bbmergeecco.fastq.gz"),
        R2=opj(out_dir,"results/trimmed_ecco_fastq/{sample}_R2.trimmed.bbmergeecco.fastq.gz")

    log:
        opj(out_dir,"results/logs/{sample}.bbmerge.log.txt")
    threads: 8
    benchmark: opj(out_dir,"results/smk-benchmark/mapping_{sample}.txt")
    resources:
        mem_mb=64000,
        cpus=8,
        runtime='24h',
    conda: "envs/env.yaml"
    shell:
        """
        bbmerge-auto.sh in={input.R1} in2={input.R2} out={output.R1} out2={output.R2} ecco mix ratiomode 2> {log}
        """

rule mapping_ecco:
    input:
        R1=opj(out_dir,"results/trimmed_ecco_fastq/{sample}_R1.trimmed.bbmergeecco.fastq.gz"),
        R2=opj(out_dir,"results/trimmed_ecco_fastq/{sample}_R2.trimmed.bbmergeecco.fastq.gz")
    output:
        unsort = temp(opj(out_dir,"results/mapped_ecco/{sample}.bam")),
        sorted= opj(out_dir,"results/mapped_ecco/{sample}.sorted.bam"),

    params:
        genome = config['genome'],
    log:
        opj(out_dir,"results/logs/{sample}.mapping_ecco.log.txt")
    threads: 8
    benchmark: opj(out_dir,"results/smk-benchmark/mapping_ecco_{sample}.txt")
    resources:
        mem_mb=64000,
        cpus=8,
        runtime='24h',
    conda: "envs/env.yaml"
    shell:
        """
        bwa mem -t {threads} {params.genome} {input}  > {output.unsort} 2> {log};
        samtools sort -@ {threads} -o {output.sorted} {output.unsort} 2> {log};
        """

rule mark_duplicates_ecco:
    input:
        sorted= opj(out_dir,"results/mapped_ecco/{sample}.sorted.bam"),
    output:
        bam=opj(out_dir,"results/marked_duplicates_ecco/{sample}.marked.bam"),
        mtrx=opj(out_dir,"results/marked_duplicates_ecco/{sample}.marked_metrics.txt")
    params:
        TMP_DIR=tmpdir
    benchmark: opj(out_dir,"results/smk-benchmark/markDuplicates_ecco_{sample}.txt")
    resources:
        runtime='24h',
        disk_mb=500000,
        mem_mb=100000,
        max_file_handles=10,
        cpus=1,
    log:
        opj(out_dir,"results/logs/{sample}.mark_duplicates_ecco.log.txt")
    threads: 1
    conda: "envs/env.yaml"
    shell:
        """
        picard -Xmx16g  -Djava.io.tmpdir={params.TMP_DIR} MarkDuplicates INPUT={input.sorted} OUTPUT={output.bam} METRICS_FILE={output.mtrx} REMOVE_DUPLICATES=false MAX_FILE_HANDLES={resources.max_file_handles} CREATE_INDEX=false COMPRESSION_LEVEL=9 TMP_DIR={params.TMP_DIR} &> {log}
        """

rule filter_and_index_ecco:
    input:
        opj(out_dir,"results/marked_duplicates_ecco/{sample}.marked.bam")
    output:
        bam=opj(out_dir,"results/filtered_ecco/{sample}.filtered.bam"),
        sorted_bam = opj(out_dir,"results/filtered_ecco/{sample}.filtered.sorted.bam"),
        CRAM=opj(out_dir,"results/filtered_ecco/{sample}.filtered.CRAM"),
        bai=opj(out_dir,"results/filtered_ecco/{sample}.filtered.sorted.bam.bai"),
        stats=opj(out_dir,"results/stats_ecco/{sample}.filtered.sorted.bam.stats"),
        FL=opj(out_dir,"results/stats_ecco/{sample}.filtered.sorted.bam.stats_len.txt"),
    params:
        genome = config['genome'],
        insert_size = 1200
    log:
        opj(out_dir,"results/logs/{sample}.ecco.filter_and_index.log.txt")
    benchmark: opj(out_dir,"results/smk-benchmark/filterIndexBam_ecco_{sample}.txt")
    threads: 8
    resources:
        cpus=8,
        runtime='4h',
        mem_mb=8000,
    conda: "envs/env.yaml"
    shell:
        """
        samtools view {input} -h -q 5 -F 4 -F 256 -F 1024 -F 2048 -@ {threads} -b > {output.bam};
        samtools view {input} -h -q 5 -F 4 -F 256 -F 1024 -F 2048 -@ {threads} -T {params.genome} -C > {output.CRAM};
        samtools sort {output.bam} > {output.sorted_bam};
        samtools index -@ {threads} {output.sorted_bam} > {output.bai};
        samtools stats -@ {threads} -i {params.insert_size} {output.sorted_bam} > {output.stats};
        grep ^IS {output.stats} | cut -f 2- > {output.FL}
        """

rule picard_InsertMatrix_ecco:
    input:
        bam=rules.filter_and_index_ecco.output.sorted_bam,
    output:
        length=opj(out_dir,"results/fragmentation_ecco/{sample}_insert_size_metrics.txt"),
        pdf=opj(out_dir,"results/fragmentation_ecco/{sample}_insert_size.pdf")
    params:
        TMP_DIR=opj(out_dir,"results/tmpdir")
    resources:
        mem_mb=8000,
        cpus=8,
        runtime='4h',
    benchmark: opj(out_dir,"results/smk-benchmark/picardInsertMatrix_ecco_{sample}.txt")
    threads: 8
    log:
        Log=opj(out_dir,"results/logs/{sample}.ecco.picardFragment.log.txt")
    conda: "envs/env.yaml"
    shell:
        """
        picard -Xmx8g CollectInsertSizeMetrics INPUT={input.bam} OUTPUT={output.length} Histogram_FILE={output.pdf} M=0.001 &> {log}
        """

rule read_counter:
    input: opj(out_dir,"results/filtered/{samples}.filtered.sorted.bam")
    output: opj(out_dir,"results/readDepth/{samples}.bin{binSize}.wig")
    params:
            readCounter=config["readCounterScript"],
            binSize=config["binSize"],
            qual="20",
            chrs=config["chrs"],
    resources: mem_mb=4000, cpus=1,runtime='30m',
    conda: "envs/hmmcopy.yaml"
    log: opj(out_dir,"logs/readDepth/{samples}.bin{binSize}.log")
    shell:
            "{params.readCounter} {input} -c {params.chrs} -w {params.binSize} -q {params.qual} > {output} 2> {log}"

onsuccess:
    print("Workflow finished, no error")
onerror:
    print("An error occurred")
