import os
from snakemake.utils import min_version

min_version("7.8")

version: "0.1.0"

scatter_threads = config.pop("scatter_threads", 8)
input_ccs = config.pop("ccs", None)
env = config.pop("env", "envs/env.yml")
ref = config.pop("ref", None)
gmm_model = config.pop("gmm_model", False)
gff = config.pop("gff", False)
process_first_n = config.pop("process_first_n", None)
# output wildcards
output_types = ["m6a", "cpg", "nuc", "msp"]
aligned = ["aligned", "unaligned"]
samples = list(config.keys())

# make sure there is an existing index
if input_ccs is not None:
    assert os.path.exists(f"{input_ccs}.pbi"), f"pbi for {input_ccs} does not exist"

# make sure all subreads have an index
for sample, subreads in config.items():
    assert os.path.exists(f"{subreads}.pbi"), f"pbi for {subreads} does not exist"


# Two chunks per GB in the subreads
default_n_chunks = min(
    [
        int(2 * os.path.getsize(subreads) / 1024**3) + 1
        for sample, subreads in config.items()
    ]
)


scattergather:
    chunks=default_n_chunks,


include: "rules/common.smk"
include: "rules/ccs.smk"
include: "rules/qc.smk"
include: "rules/ml.smk"


n_chunks = len(scatter.chunks("temp/{scatteritem}.fake"))


wildcard_constraints:
    sm="|".join(samples),
    data="|".join(output_types),
    aligned="|".join(aligned),


rule all:
    input:
        expand("results/{sm}/unaligned.fiberseq.bam", sm=samples),
        expand("results/{sm}/unaligned.fiberseq.bam.pbi", sm=samples),
        expand("results/{sm}/qc/qc_msp_lengths.pdf", sm=samples),
        expand("results/{sm}/qc/qc_nuc_lengths.pdf", sm=samples),
        expand("results/{sm}/qc/qc_number_nucs_per_read.pdf", sm=samples),
        expand("results/{sm}/qc/qc_readlength_per_nuc.pdf", sm=samples),
        expand("results/{sm}/qc/qc_readlengths.pdf", sm=samples),
        expand("results/{sm}/qc/qc_readquality.pdf", sm=samples),
        expand("results/{sm}/qc/qc_m6a_per_read.pdf", sm=samples),
        expand("results/{sm}/qc/qc_stats.txt", sm=samples),
        expand("results/{sm}/qc/qc.html", sm=samples),
        expand("results/{sm}/qc/overview.html", sm=samples),
        align_results(samples),
        expand(
            "results/{sm}/{aligned}.{data}.bed.gz",
            aligned=aligned,
            sm=samples,
            data=output_types,
        ),
        expand(
            "results/{sm}/aligned.{data}.bed.bb",
            sm=samples,
            data=output_types,
        ),


rule actc:
    input:
        ccs=get_ccs_bam,
        pbi=get_ccs_pbi,
        subreads=get_subreads,
    output:
        bam=temp("temp/{sm}/actc.{scatteritem}.bam"),
        fasta=temp("temp/{sm}/actc.{scatteritem}.fasta"),
    threads: scatter_threads
    conda:
        env
    log:
        "logs/{sm}/actc/{scatteritem}.log",
    benchmark:
        "benchmarks/{sm}/actc/{scatteritem}.tbl"
    priority: 10
    shell:
        """
        actc -j {threads} {input.subreads} {input.ccs} {output.bam} 2> {log}
        """


rule index:
    input:
        bam=rules.actc.output.bam,
    output:
        pbi=temp(f"{rules.actc.output.bam}.pbi"),
    threads: 1
    conda:
        env
    log:
        "logs/{sm}/index/{scatteritem}.log",
    benchmark:
        "benchmarks/{sm}/index/{scatteritem}.tbl"
    priority: 20
    shell:
        """
        pbindex {input.bam} &> {log}
        """


rule index_ccs_fasta:
    input:
        fasta=rules.actc.output.fasta,
    output:
        fai=temp(f"{rules.actc.output.fasta}.fai"),
    threads: 1
    conda:
        env
    log:
        "logs/{sm}/index_ccs_fasta/{scatteritem}.log",
    benchmark:
        "benchmarks/{sm}/index_ccs_fasta/{scatteritem}.tbl"
    priority: 20
    shell:
        """
        samtools faidx {input.fasta} &> {log}
        """


rule ipdSummary:
    input:
        ccs_fasta=rules.actc.output.fasta,
        fai=rules.index_ccs_fasta.output.fai,
        actc=rules.actc.output.bam,
        pbi=rules.index.output.pbi,
    output:
        csv=temp("temp/{sm}/ipdSummary.{scatteritem}.csv"),
        gff=temp("temp/{sm}/ipdSummary.{scatteritem}.gff"),
    threads: config.get("ipd-threads", 16)
    conda:
        env
    resources:
        time=120,
        mem_mb=64 * 1024,
        disk_mb=16 * 1024,
    log:
        "logs/{sm}/ipdSummary/{scatteritem}.log",
    benchmark:
        "benchmarks/{sm}/ipdSummary/{scatteritem}.tbl"
    priority: 10
    shell:
        """
        ipdSummary \
            --reference {input.ccs_fasta} \
            --pvalue 0.001 \
            --numWorkers {threads} \
            --quiet --identify m6A \
            --csv {output.csv} \
            --gff {output.gff} \
            {input.actc} &> {log}
        """


rule primrose:
    input:
        bam=get_ccs_bam,
        pbi=get_ccs_pbi,
    output:
        bam=temp("temp/{sm}/primrose.{scatteritem}.bam"),
        pbi=temp("temp/{sm}/primrose.{scatteritem}.bam.pbi"),
    threads: scatter_threads
    conda:
        env
    log:
        "logs/{sm}/primrose/{scatteritem}.log",
    benchmark:
        "benchmarks/{sm}/primrose/{scatteritem}.tbl"
    priority: 100
    shell:
        """
        primrose --min-passes 3 -j {threads} \
             --keep-kinetics \
            {input.bam} {output.bam} \
            &> {log}
        """


rule train_gmm:
    input:
        bam=f"temp/{{sm}}/primrose.1-of-{n_chunks}.bam",
        csv=f"temp/{{sm}}/ipdSummary.1-of-{n_chunks}.csv",
    output:
        model="results/{sm}/gmm_model.pkl",
    conda:
        env
    log:
        "logs/{sm}/train_gmm/train.log",
    params:
        gmm=workflow.source_path("scripts/push_m6a_to_bam.py"),
    benchmark:
        "benchmarks/{sm}/train_gmm/train.tbl"
    resources:
        disk_mb=16 * 1024,
        mem_mb=16 * 1024,
        time=200,
    threads: 4
    priority: 60
    shell:
        """
        python {params.gmm} -v --threads {threads} \
            {input.csv} {input.bam} \
            --train -o {output.model} \
            2> {log}
        """


rule gmm:
    input:
        get_gmm_model,  # can be empty if we are training per fiber
        ccs=rules.primrose.output.bam,
        pbi=rules.primrose.output.pbi,
        csv=rules.ipdSummary.output.csv,
    output:
        bam=temp("temp/{sm}/gmm.{scatteritem}.bam"),
    threads: 4
    resources:
        mem_mb=16 * 1024,
    conda:
        env
    log:
        "logs/{sm}/gmm/{scatteritem}.log",
    params:
        gmm=workflow.source_path("scripts/push_m6a_to_bam.py"),
        model=lambda wc: " --min-prediction-value 0.999999 --model "
        + get_gmm_model(wc)
        if gmm_model
        else "",
    benchmark:
        "benchmarks/{sm}/gmm/{scatteritem}.tbl"
    priority: 1000  # Run this as fast as possible so we can delete the csv from idpSummary.
    shell:
        """
        python {params.gmm} -v {params.model} --threads {threads} {input.csv} {input.ccs} > {output.bam} 2> {log}
        """


rule train_hmm:
    input:
        bam=f"temp/{{sm}}/gmm.1-of-{n_chunks}.bam",
    output:
        model=temp("temp/{sm}/hmm_model.json"),
    conda:
        env
    log:
        "logs/{sm}/train_hmm/train.log",
    params:
        nuc=workflow.source_path("scripts/add_nucleosomes.py"),
    benchmark:
        "benchmarks/{sm}/train_hmm/train.tbl"
    resources:
        disk_mb=16 * 1024,
        mem_mb=16 * 1024,
        time=200,
    threads: 4
    priority: 2000
    shell:
        """
        python {params.nuc} --threads {threads} {input.bam} {output.model} 2> {log}
        """


rule nucleosome:
    input:
        bam=rules.gmm.output.bam,
        model=rules.train_hmm.output.model,
    output:
        bam=temp("temp/{sm}/nuc.{scatteritem}.bam"),
    conda:
        env
    log:
        "logs/{sm}/nucleosome/{scatteritem}.log",
    params:
        nuc=workflow.source_path("scripts/add_nucleosomes.py"),
    benchmark:
        "benchmarks/{sm}/nucleosome/{scatteritem}.tbl"
    threads: 4
    resources:
        disk_mb=16 * 1024,
    priority: 70
    shell:
        """
        python {params.nuc} -m {input.model} --threads {threads} {input.bam} {output.bam} 2> {log}
        """


rule merge:
    input:
        bam=get_nucleosome_bam,
    output:
        bam="results/{sm}/unaligned.fiberseq.bam",
    conda:
        env
    log:
        "logs/{sm}/merge/samtools.cat.log",
    resources:
        disk_mb=8000,
        time=120,
    threads: 4
    benchmark:
        "benchmarks/{sm}/merge/samtools.cat.tbl"
    priority: 100
    shell:
        """
        samtools cat -@ {threads} -o {output.bam} {input.bam} 2> {log}
        """


rule index_merge:
    input:
        bam=rules.merge.output.bam,
    output:
        pbi=f"{rules.merge.output.bam}.pbi",
    conda:
        env
    log:
        "logs/{sm}/index_merge/index.tbl",
    benchmark:
        "benchmarks/{sm}/index_merge/index.tbl"
    threads: 1
    priority: 100
    shell:
        """
        pbindex {input.bam} &> {log}
        """


rule align:
    input:
        bam=rules.merge.output.bam,
        ref=ref,
    output:
        bam="results/{sm}/aligned.fiberseq.bam",
        bai="results/{sm}/aligned.fiberseq.bam.bai",
    conda:
        env
    log:
        "logs/{sm}/align/align.log",
    resources:
        disk_mb=8000,
        time=240,
        mem_mb=32 * 1024,
    threads: 40
    benchmark:
        "benchmarks/{sm}/align/align.tbl"
    priority: 200
    shell:
        """
        pbmm2 align \
            -j {threads} \
            --preset CCS --sort \
            --sort-memory 2G \
            --log-level INFO \
            --unmapped \
            {input.ref} {input.bam} {output.bam} \
        2> {log}
        """


rule make_beds:
    input:
        bam="results/{sm}/{aligned}.fiberseq.bam",
    output:
        cpg=temp("temp/{sm}/{aligned}.cpg.bed"),
        msp=temp("temp/{sm}/{aligned}.msp.bed"),
        m6a=temp("temp/{sm}/{aligned}.m6a.bed"),
        nuc=temp("temp/{sm}/{aligned}.nuc.bed"),
    conda:
        env
    log:
        "logs/{sm}/make_beds/{aligned}_extract.bed.log",
    resources:
        disk_mb=8000,
        time=240,
    threads: 8
    benchmark:
        "benchmarks/{sm}/make_beds/{aligned}_extract_bed.tbl"
    params:
        aligned=lambda wc: "-r" if wc.aligned == "aligned" else "",
    priority: 300
    shell:
        """
        ft --threads {threads} extract \
            {params.aligned} {input.bam} \
            --cpg {output.cpg} --msp {output.msp} --m6a {output.m6a} --nuc {output.nuc} \
        2> {log}
        """


rule compress_bed:
    input:
        bed="temp/{sm}/{aligned}.{data}.bed",
    output:
        bed="results/{sm}/{aligned}.{data}.bed.gz",
    conda:
        env
    log:
        "logs/{sm}/compress_bed/{aligned}_{data}.log",
    threads: 4
    benchmark:
        "benchmarks/{sm}/compress_bed/{aligned}_{data}.tbl"
    priority: 300
    shell:
        """
        cat {input.bed} | bgzip -@ {threads} > {output.bed} 2> {log}
        """


rule bigbed:
    input:
        fai=f"{ref}.fai",
        bed="temp/{sm}/aligned.{data}.bed",
    output:
        bb="results/{sm}/aligned.{data}.bed.bb",
        bed=temp("temp/{sm}/aligned.{data}.bed.pre.bb"),
    conda:
        env
    log:
        "logs/{sm}/bigbed/{data}.log",
    benchmark:
        "benchmarks/{sm}/bigbed/{data}.tbl"
    priority: 300
    shell:
        """
        sort -k1,1 -k2,2n {input.bed} > {output.bed} 2> {log}
        bedToBigBed -allow1bpOverlap {output.bed} {input.fai} {output.bb} 2>> {log}
        """
