#Snakefile 

import glob
import os

# Configuration ----------------------------------------------------------------

#Run options
run_freemuxlet = config["run_freemuxlet"]
run_demuxlet = config["run_demuxlet"]
run_demuxlet_all = config["run_demuxlet_all"]
run_vireo = config["run_vireo"]
    
#Scripts
ken_scripts = config["ken_scripts"]

#individuals
humans = config["humans"]
nind = config["nind"]

#genome_index
HumanGenomeIndex = config["human_genome_index"]

#cellranger executable
cellranger = config["cellranger"]

#demuxlet script
human_vcf = config["human_vcf"]
all_vcf = config["all_vcf"]

# vireo input
reference_snps = config["reference_snps"]

pd = config["proj_dir"]
fastq_dir = config["fastq_dir"]
fastq_dirs = fastq_dir + config["add_fastq_dir"]

#make sure the project directory actually exists
#assert os.path.exists(pd), "Project directory exists"

# Directory to send log files. Needs to be created manually since it
# is not a file created by a Snakemake rule.
dir_log = config["dir_log"]
if not os.path.isdir(dir_log):
    os.mkdir(dir_log)

if not os.path.isdir("benchmarks"):
    os.mkdir("benchmarks")
    
samples = set(glob_wildcards(fastq_dir + "{samples}_S{order}_L{lane}_R1_{chunk}.fastq.gz").samples)

myoutput = list()

if run_freemuxlet == True:
    myoutput.append(ancient(expand("{sample}_human/demuxlet.best.gz", sample = samples)))
if run_demuxlet == True:
    myoutput.append(ancient(expand("{sample}_human/freemuxlet.clust1.samples.gz", sample = samples)))
if run_demuxlet_all == True:
    myoutput.append(ancient(expand("{sample}_human/demuxlet.all.best.gz", sample = samples)))
if run_vireo == True:
    myoutput.append(ancient(expand("{sample}_human/cellsnp/cellSNP.base.vcf.gz", sample = samples)))
    myoutput.append(ancient(expand("{sample}_human/vireo/donor_ids.tsv", sample = samples)))
    
rule all:
    input:
        ancient(expand("{sample}_human/outs/filtered_feature_bc_matrix.h5", sample = samples)),
        ancient(expand("{sample}_human/outs/possorted_genome_bam.bam", sample = samples)),
        ancient(expand("{sample}_human/possorted_genome_bam_filtered.bam.bai", sample = samples)),
        #ancient(expand("{sample}_human/saturation.tsv", sample = samples)),
#        myoutput,
#        expand("{sample}_human/ebQTL_QC_v3.nb.html",sample = samples),
        expand("{sample}_human/ebQTL_QC_v4.nb.html",sample = samples),
        expand("{sample}_human/velocyto/{sample}_human.loom",sample = samples)
        
    
#############################################################################

# Alignment

#############################################################################

rule align_human:
    output:
        matrix = "{sample}_human/outs/filtered_feature_bc_matrix.h5",
        barcodes = "{sample}_human/outs/filtered_feature_bc_matrix/barcodes.tsv.gz",
        bam = "{sample}_human/outs/possorted_genome_bam.bam"
    params:
        fastq = fastq_dirs,
        id = "{sample}_human",
        pipestance = "{sample}_human/_invocation",
        sample = "{sample}", 
        genomeidx = HumanGenomeIndex
    log:
        err = dir_log + "{sample}_rule_align_human.err",
        out = dir_log + "{sample}_rule_align_human.out"
    benchmark:
        "benchmarks/cellranger.{sample}.txt"
    shell:
        """
	[ ! -f {params.pipestance} ] && rm -r {params.id}
        {cellranger} count --id={params.id} \
                 --transcriptome={params.genomeidx} \
                 --fastqs={params.fastq} \
                 --sample={params.sample} \
                 --include-introns true \
                 --localcores=24 > {log.out} 2> {log.err}
        touch {output.matrix}
        touch {output.bam}
        touch {output.barcodes}
        """


#############################################################################

# Demuxlet/freemuxlet only options

#############################################################################

# first I want to subset the VCFs and filter the bams to make things more 
# efficient computationally 

# subset vcfs
rule subset_vcf_for_popscle:
    output:
        sorted = "human.popscle.vcf",
        unsorted = "human.popscle.unsorted.vcf",
        bed = "human.popscle.vcf.bed",
        full = "all.popscle.vcf.bed"
    log:
        dir_log + "subset_vcf_human.err"
    params:
        compressed = "human.vcf.gz"
    benchmark:
        "benchmarks/filter_vcf_demuxlet.txt"
    shell:
        """
        source {ken_scripts}/popscle_helper_tools/filter_vcf_file_for_popscle.sh
        subset_samples_from_vcf {humans} {human_vcf} \
        | filter_out_mutations_missing_genotype_for_one_or_more_samples \
        | filter_out_mutations_homozygous_reference_in_all_samples \
        | filter_out_mutations_homozygous_in_all_samples \
        > {output.unsorted} 2> {log}
        bcftools sort -Oz {output.unsorted} -o {params.compressed}
        gunzip {params.compressed}
        bedtools merge -i {output.sorted} > {output.bed}
        bedtools merge -i {human_vcf} > {output.full}
        """

# filter the bam file to run popscle faster
rule human_filtered_bam:
    input:
         inbam = ancient("{sample}_human/outs/possorted_genome_bam.bam"),
         barcodes = ancient("{sample}_human/outs/filtered_feature_bc_matrix/barcodes.tsv.gz")
    output:
         bam = "{sample}_human/possorted_genome_bam_filtered.bam",
         bai = "{sample}_human/possorted_genome_bam_filtered.bam.bai"
    params:
         barcodes = "{sample}_human/barcodes.txt"
    log:
         dir_log + "{sample}_human_filtered_bam.err"
    benchmark:
        "benchmarks/filter_bam.{sample}.txt"
    shell:
         """
         gunzip < {input.barcodes} > {params.barcodes}
         bash {ken_scripts}/samtools_filter_barcodes.sh {params.barcodes} {output.bam} {input.inbam}
         """

# code for running demuxlet on a subset of the VCF

rule make_small_pileup:
    input:
        inbam = "{sample}_human/possorted_genome_bam_filtered.bam",
        inbai = "{sample}_human/possorted_genome_bam_filtered.bam.bai",
        inbed = "human.popscle.vcf.bed",
        invcf = "human.popscle.vcf"
    output:
        outbam = "{sample}_human/pileup.bam",
        outbai = "{sample}_human/pileup.bam.bai",
        outcel = "{sample}_human/pileup.cel.gz"
    params:
        pileup = "{sample}_human/pileup",
        outdemux = "{sample}_human/demuxlet",
        demux = "{sample}_human/demuxlet.best"
    log:
        err = dir_log + "{sample}_make_small_pileup.err",
        out = dir_log + "{sample}_make_small_pileup.out"
    benchmark:
        "benchmarks/pileup.{sample}.txt"
    shell:
        """
        /project2/gilad/kenneth/software/samtools-1.10/samtools view -@ 2 -L {input.inbed} -o {output.outbam} {input.inbam}
        /project2/gilad/kenneth/software/samtools-1.10/samtools index -@ 2 {output.outbam}
        /project2/gilad/kenneth/software/popscle/bin/popscle dsc-pileup --sam {output.outbam} --vcf {input.invcf} --out {params.pileup} > {log.out} 2> {log.err}
        """
        
rule run_demuxlet_human:
    input:
        inbam = "{sample}_human/pileup.bam",
        inbai = "{sample}_human/pileup.bam.bai",
        incel = "{sample}_human/pileup.cel.gz",
        invcf = "human.popscle.vcf"
    output:
        demux = "{sample}_human/demuxlet.best.gz"
    params:
        pileup = "{sample}_human/pileup",
        outdemux = "{sample}_human/demuxlet",
        demux = "{sample}_human/demuxlet.best"
    log:
        err = dir_log + "{sample}_run_demuxlet_human.err",
        out = dir_log + "{sample}_run_demuxlet_human.out"
    benchmark:
        "benchmarks/demuxlet.{sample}.txt"
    shell:
        """
        /project2/gilad/kenneth/software/popscle/bin/popscle demuxlet --field GT --plp {params.pileup} --vcf {input.invcf} --out {params.outdemux} > {log.out} 2> {log.err}
        gzip {params.demux}
        """

# code for running demuxlet on the entire VCF

rule make_big_pileup:
    input:
        inbam = "{sample}_human/outs/possorted_genome_bam.bam",
        inbai = "{sample}_human/outs/possorted_genome_bam.bam.bai"
    output:
        outbam = "{sample}_human/pileup.all.bam",
        outbai = "{sample}_human/pileup.all.bam.bai",
        outcel = "{sample}_human/pileup.all.cel.gz"
    params:
        pileup = "{sample}_human/pileup.all"
    log:
        err = dir_log + "{sample}_make_big_pileup.err",
        out = dir_log + "{sample}_make_big_pileup.out"
    benchmark:
        "benchmarks/big_pileup.{sample}.txt"
    shell:
        """
        /project2/gilad/kenneth/software/popscle/bin/popscle dsc-pileup --sam {output.outbam} --vcf {all_vcf} --out {params.pileup} > {log.out} 2> {log.err}
        """
               
rule run_demuxlet_all_human:
    input:
        inbam = "{sample}_human/pileup.all.bam",
        inbai = "{sample}_human/pileup.all.bam.bai"
    output:
        demux = "{sample}_human/demuxlet.all.best.gz"
    params:
        pileup = "{sample}_human/pileup.all",
        outdemux = "{sample}_human/demuxlet.all",
        demux = "{sample}_human/demuxlet.all.best"
    log:
        err = dir_log + "{sample}_demuxlet_all_human.err",
        out = dir_log + "{sample}_demuxlet_all_human.out"
    benchmark:
        "benchmarks/demuxlet_all.{sample}.txt"
    shell:
        """
        /project2/gilad/kenneth/software/popscle/bin/popscle demuxlet --field GT --plp {params.pileup} --vcf {all_vcf} --out {params.outdemux} > {log.out} 2> {log.err}
        gzip {params.demux}
        """
        
# Code for running freemuxlet 

rule run_freemuxlet_human:
    input:
        inbam = "{sample}_human/pileup.all.bam",
        inbai = "{sample}_human/pileup.all.bam.bai",
        incel = "{sample}_human/pileup.all.cel.gz"
    output:
        freemux = "{sample}_human/freemuxlet.clust1.samples.gz"
    params:
        pileup = "{sample}_human/pileup.all",
        outfreemux = "{sample}_human/freemuxlet",
        freemux = "{sample}_human/freemuxlet.best"
    log:
        err = dir_log + "{sample}_freemuxlet_human.err",
        out = dir_log + "{sample}_freemuxlet_human.out"
    benchmark:
        "benchmarks/freemuxlet.{sample}.txt"
    shell:
        """
        /project2/gilad/kenneth/software/popscle/bin/popscle freemuxlet --plp {params.pileup} --out {params.outfreemux} --nsample {nind} > {log.out} 2> {log.err}
        """
 
#############################################################################

# Vireo only options

#############################################################################

# subset vcfs
rule subset_vcf_for_vireo:
    output:
        sorted = "human.vireo.vcf",
        unsorted = "human.vireo.unsorted.vcf",
        bed = "human.vireo.vcf.bed",
        full = "all.vireo.vcf.bed"
    log:
        err = dir_log + "subset_vcf_vireo.err"
    params:
        compressed = "human.vireo.vcf.gz"
    benchmark:
        "benchmarks/subset_vcf_vireo.txt"
    shell:
        """
        source {ken_scripts}/popscle_helper_tools/filter_vcf_file_for_popscle.sh
        subset_samples_from_vcf {humans} {reference_snps} \
        | filter_out_mutations_missing_genotype_for_one_or_more_samples \
        | filter_out_mutations_homozygous_reference_in_all_samples \
        | filter_out_mutations_homozygous_in_all_samples \
        > {output.unsorted} 2> {log}
        bcftools sort -Oz {output.unsorted} -o {params.compressed}
        gunzip {params.compressed}
        bedtools merge -i {output.sorted} > {output.bed}
        bedtools merge -i {reference_snps} > {output.full}
        """

        
rule run_cellsnp:
    input:
         vcf = ancient("human.vireo.vcf"),
         bam = ancient("{sample}_human/outs/possorted_genome_bam.bam"),
         barcodes = ancient("{sample}_human/outs/filtered_feature_bc_matrix/barcodes.tsv.gz")
    output:
        "{sample}_human/cellsnp/cellSNP.base.vcf.gz"
    params:
         barcodes = "{sample}_human/barcodes.txt",
         outdir = "{sample}_human/cellsnp/"
    log:
        err = dir_log + "{sample}_cellsnp.err",
        out = dir_log + "{sample}_cellsnp.out"
    conda:
        "envs/vireo.yaml"
    benchmark:
        "benchmarks/cellsnp.{sample}.txt"
    shell:
        """
        gunzip < {input.barcodes} > {params.barcodes}
        cellsnp-lite -s {input.bam} -b {params.barcodes} -O {params.outdir} -R {input.vcf} -p 8 --minMAF 0.1 --minCOUNT 20 --gzip > {log.out} 2> {log.err}
        """
        
rule run_vireo:
    input:
    	 vcf = ancient("human.vireo.vcf"),
         cellsnp = ancient("{sample}_human/cellsnp/cellSNP.base.vcf.gz")
    output:
         "{sample}_human/vireo/donor_ids.tsv"
    params:
         vireo = "{sample}_human/vireo/",
         cellsnp = "{sample}_human/cellsnp/"
    log:
        err = dir_log + "{sample}_vireo.err",
        out = dir_log + "{sample}_vireo.out"
    conda:
        "envs/vireo.yaml"
    benchmark:
        "benchmarks/vireo.{sample}.txt"
    shell:
        """
        vireo -c {params.cellsnp} -d {input.vcf} -o {params.vireo} -t GT --nproc=4 -N {nind} --noPlot > {log.out} 2> {log.err}
        """

if run_demuxlet == True:
    myoutput.append(ancient(expand("{sample}_human/demuxlet.best.gz", sample = samples)))
if run_freemuxlet == True:
    myoutput.append(ancient(expand("{sample}_human/freemuxlet.clust1.samples.gz", sample = samples)))
if run_demuxlet_all == True:
    myoutput.append(ancient(expand("{sample}_human/demuxlet.all.best.gz", sample = samples)))
if run_vireo == True:
    myoutput.append(ancient(expand("{sample}_human/cellsnp/cellSNP.base.vcf.gz", sample = samples)))
    myoutput.append(ancient(expand("{sample}_human/vireo/donor_ids.tsv", sample = samples)))
    
    
def report_inputs(wildcards):
    inputs = list()
    inputs.append(wildcards.sample + "_human/outs/filtered_feature_bc_matrix.h5")
    if run_demuxlet:
        ancient(inputs.append(wildcards.sample + "_human/demuxlet.best.gz"))
    if run_freemuxlet:
        ancient(inputs.append(wildcards.sample + "_human/freemuxlet.clust1.samples.gz"))
    if run_demuxlet_all:
        ancient(inputs.append(wildcards.sample + "_human/demuxlet.all.best.gz"))
    if run_vireo:
        ancient(inputs.append(wildcards.sample + "_human/vireo/donor_ids.tsv"))
    return inputs
    
        
rule render_report_seurat3:
    input:
         report_inputs
    output:
         "{sample}_human/ebQTL_QC_v3.nb.html"
    log:
        err = dir_log + "{sample}_rule_report_v3.err",
        out = dir_log + "{sample}_rule_report_v3.out"
    conda:
        "envs/seurat3.yaml"
    params:
        dir = pd + "{sample}_human/"
    benchmark:
        "benchmarks/render_report_seurat3.{sample}.txt"
    shell:
        """
        Rscript -e "rmarkdown::render('{workflow.basedir}/scripts/ebQTL_QC_v3.Rmd', 
          params=list(sample            = '{params.dir}',
                      add_freemuxlet    = '{run_freemuxlet}',
                      add_demuxlet      = '{run_demuxlet}',
                      add_demuxlet_all  = '{run_demuxlet_all}',
                      add_vireo         = '{run_vireo}'
                      ), 
          output_dir='{params.dir}')" > {log.out} 2> {log.err}
        """
        
rule render_report_seurat4:
    input:
         report_inputs
    output:
         "{sample}_human/ebQTL_QC_v4.nb.html"
    log:
        err = dir_log + "{sample}_rule_report_v4.err",
        out = dir_log + "{sample}_rule_report_v4.out"
    conda:
        "envs/seurat4.yaml"
    params:
        dir = pd + "{sample}_human/"
    benchmark:
        "benchmarks/render_report_seurat4.{sample}.txt"
    shell:
        """
        Rscript -e "rmarkdown::render('{workflow.basedir}/scripts/ebQTL_QC_v4.Rmd', 
          params=list(sample            = '{params.dir}',
                      add_freemuxlet    = '{run_freemuxlet}',
                      add_demuxlet      = '{run_demuxlet}',
                      add_demuxlet_all  = '{run_demuxlet_all}',
                      add_vireo         = '{run_vireo}'
                      ), 
          output_dir='{params.dir}')" > {log.out} 2> {log.err}
        """
        
rule velocyto:
    input:
         matrix = ancient("{sample}_human/outs/filtered_feature_bc_matrix.h5"),
    output:
         "{sample}_human/velocyto/{sample}_human.loom"
    log:
        err = dir_log + "{sample}_velocyto.err",
        out = dir_log + "{sample}_velocyto.out"
    conda:
        "envs/velocyto.yaml"
    params:
        dir = pd + "{sample}_human/",
        index = HumanGenomeIndex + "/genes/genes.gtf"
    benchmark:
        "benchmarks/velocyto.{sample}.txt"
    shell:
        """
        velocyto run10x --dtype uint32 -@ 12 {params.dir} {params.index} > {log.out} 2> {log.err}
        """
  

        
