"""

Source for original WASP pipeline: https://github.com/bmvdgeijn/WASP.git (WASP/mapping/Snakefile)
Source for modified WASP scripts accounting for indels: https://git.embl.de/rabinowi/wasp_indels/-/tree/indel/mapping
"""


import os
import sys
import re
import yaml
import glob
import pandas as pd
from snakemake.utils import min_version
import glob
# Enforce a minimum Snakemake version
min_version("5.20")


# shell.prefix("module load SciPy-bundle; module load pysam/0.15.1-foss-2018b-Python-3.7.0")

onstart:
    print("##########################################\n")
    print("# STARTING PIPELINE\n")
    print("##########‰################################\n")
    print ("Running ChIP-seq pre-processing workflow for the following samples:\n " + ' \n '.join(map(str, SAMPLES)))

onsuccess:
    print("##########################################\n")
    print("# PIPELINE ENDED SUCCESSFULLY \n")
    print("##########################################\n")

ruleDisplayMessage = """\n\n####### START RULE EXECUTION ##########\n"""


################################################################################
######################## Dependencies ##########################################
################################################################################

"""
The following environment was created to run the pipelines

mamba create -c conda-forge -c bioconda -n wasp_env snakemake python intervaltree pysam

versions:
    - python=3.8.6
    - pysam=0.16.0.1
    - intervaltree=3.1.0
    - snakemake=5.32.0

Versions of common bioinformatics tools are specified in config.yml file

"""

################################################################################
######################## FUNCTIONS #############################################
################################################################################


################################################################################
######################## READ DATA #############################################
################################################################################

### FOLDERS ON TIER1 ###
PROJECT_DIR = config["global"]["projectdir"]
DATA_DIR  = PROJECT_DIR + "/" + config["project_structure"]["datadir"]
LOG_DIR = PROJECT_DIR + "/" + config["project_structure"]["logdir"]
EXEC_DIR = PROJECT_DIR + "/" + config["project_structure"]["utils"]

### FOLDERS ON SCRATCH ###
SCRATCH_DIR = config["global"]["scratchdir"]
FASTQ_DIR = SCRATCH_DIR + "/PREPROCESS_FASTQ/3.SEQTK_TRIMFQ" # processed fastq files
# MAPPABILITY_FILTER_DIR = SCRATCH_DIR + "/MAPPABILITY_FILTER"
# AS_COUNTS_DIR = SCRATCH_DIR + "/AS_COUNTS"

### SOFTWARE ###
#WASP = config["tools"]["by_path"]["wasp"]
WASP_indels = config["tools"]["by_path"]["wasp_indels"] # additional scripts modified by Adam
SAMTOOLS = config["tools"]["by_path"]["samtools"]
BWA = config["tools"]["by_path"]["bwa"]
PICARD = config["tools"]["by_path"]["picard"]
JEv2 = config["tools"]["by_path"]["je"]["v2"]

### DATA ###

# Cleaned fastq fiels after trimming and adaptor removal (previous step)
FASTQ = os.listdir(FASTQ_DIR)

# remove Dsim files
dsim_files = glob.glob1(FASTQ_DIR,  "*Dsim*")
FASTQ = [x for x in FASTQ if x not in dsim_files]

# sample names in format {antibody}.{mother}.{father}.{timePoint}.{replicate}
SAMPLES = [f.split("-trimmed-")[0] for f in FASTQ]
READ_PAIRS = ['pair1', 'pair2']

# SAMPLES = ["input.1012.vgn_714_1"] # temporary
# SAMPLES = ["bin.68.vgn_399_1", "bin.68.vgn_399_2", "bin.1012.vgn_399_1", "bin.1012.vgn_399_2"]
#SAMPLES = ["bin.1012.399_399_1", "bin.1012.399_399_2"]

### Genome ###
genome_fasta = config["data"]["genome"]["dm6"]["fasta"]
chrom_sizes = config["data"]["genome"]["dm6"]["chrom_sizes"]

### VCF files ###
vcf_dir = config["data"]["VCF"]["vcf_dir"]
# stringent or lenient VCF set - use 'lenient' in mappability pipeline, and 'stringent for allelic counts'
vcf_set_mf = config["pipelines"]["general"]["VCF_set"]["mappability_filter"]
vcf_set_ac = config["pipelines"]["general"]["VCF_set"]["as_counts"]
# VCF file for the selected set
VCF_mappability_filter = config["data"]["VCF"]["vcf_dir"] + "/" + config["data"]["VCF"][vcf_set_mf]
# VCF_allele_counts = config["data"]["phased_VCF"][vcf_set_ac]["with_genotype_likelihoods"]

#vcf_dir = config["data"]["VCF"]["vcf_dir"] + "/"
# chromosomes with info in VCF
#chrom_list = config["pipelines"]["2_mappability_filter"]["vcf_params"]["chrom_list"].split(" ")

################################################################################
######################## MAIN ##################################################
################################################################################



# Generate a filtered BAM file with duplicates removed
rule all:
    input:
        # expand(SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{sample}.variant_log.txt", sample = SAMPLES),
        # expand(SCRATCH_DIR + "/MAPPABILITY_FILTER/2_map2/{sample}.filter_log.txt", sample = SAMPLES),
        # expand(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{sample}.filtered.rmdup.bam", sample = SAMPLES),
        # DATA_DIR + "/phased_VCF/haplotypes.vcf.gz",
        # expand(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{sample}.filtered.rmdup_picard.bam", sample = SAMPLES[0]),
        # expand(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{sample}.filtered.rmdup_withUMI.bam", sample = SAMPLES[0]),
        # expand(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{sample}.filtered.rmdup_withUMI_nonrand.bam", sample = SAMPLES[0])
        expand(DATA_DIR + "/ChIPseq/AS_COUNTS/with_indels/1_as_counts_by_variant/{sample}.variant_counts.txt.gz", sample = SAMPLES)

################################################################################
#############  Mapping and Mappability filter  #################################
################################################################################

# Align initial fastq reads to the genome. -M used for campatibility with downstream code (duplicate removal)
rule first_alignment:
    input:
        read1_fastq = FASTQ_DIR + "/{ab}.{tp}.{individual}-trimmed-pair1_seqtk.fastq",
        read2_fastq = FASTQ_DIR + "/{ab}.{tp}.{individual}-trimmed-pair2_seqtk.fastq"
    output:
        temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}.initial.sam")
    params:
        fasta = genome_fasta,
        bwa_index = "/g/furlong/genome/D.melanogaster/Dm6/indexes/bwa/dm6.UCSC.noMask.fa",
        rg = "@RG\\tID:1\\tSM:{ab}.{tp}.{individual}\\tPL:ILLUMINA\\tLB:1"
    conda: "bwa_env.yml"
    shell:
        'bwa mem -M -R {params.rg:q} {params.bwa_index} {input} > {output}'

# Convert initial alignment to coordinate sorted BAM file
rule sort_first_alignment:
    input:
        rules.first_alignment.output
    output:
        bam = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}.initial.bam"),
        bai = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}.initial.bai")
    conda: "picard_env.yml"
    shell:
        'picard SortSam I={input} O={output.bam} SO=coordinate CREATE_INDEX=true'


# Place alignments not overlapping variants in a BAM file and extract reads
# overlapping variants, and their allele flipped versions, to a fastq file.
rule filter_first_alignment:
    input:
        bam = rules.sort_first_alignment.output.bam,
        bai = rules.sort_first_alignment.output.bai,
        vcf = VCF_mappability_filter
    output:
        bam = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}.no_variants.bam"),
        fastq = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}.allele_flipped.fq.gz"),
        log = SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}.first_alignment_log.txt"
    params:
        prefix = SCRATCH_DIR + "/MAPPABILITY_FILTER/1_map1/{ab}.{tp}.{individual}",
        mapq = config['pipelines']['2_mappability_filter']['software_params']['minMAPQ'],
        max_vars = config['pipelines']['2_mappability_filter']['software_params']['max_vars'],
        max_seqs = config['pipelines']['2_mappability_filter']['software_params']['max_seqs']
    conda: "environment.yml"
    shell:
        """
        python {WASP_indels}/mapping/generate_variant_reads.py \
                     --min_mapq {params.mapq} \
                     --max_vars {params.max_vars} \
                     --max_seqs {params.max_seqs} \
                     --bam {input.bam} \
                     --vcf {input.vcf} \
                     --out_prefix {params.prefix}
        """

# Realign reads along with their allele flipped versions
rule second_alignment:
    input: rules.filter_first_alignment.output.fastq
    output: temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/2_map2/{ab}.{tp}.{individual}.remap.sam")
    params:
        fasta = genome_fasta,
        bwa_index = "/g/furlong/genome/D.melanogaster/Dm6/indexes/bwa/dm6.UCSC.noMask.fa",
        rg = "@RG\\tID:1\\tSM:{ab}.{tp}.{individual}\\tPL:ILLUMINA\\tLB:1"
    conda: "bwa_env.yml"
    shell:
        'bwa mem -M -p -R {params.rg:q} {params.bwa_index} {input} > {output}'

# Sort second alignment by read name
rule sort_second_alignment:
    input: rules.second_alignment.output
    output: temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/2_map2/{ab}.{tp}.{individual}.remap.bam")
    conda: "picard_env.yml"
    shell:
        'picard SortSam I={input} O={output} SO=queryname CREATE_INDEX=true'

# Extract reads for which the original and allele flipped version align
# to the same location
rule filter_second_alignment:
    input: rules.sort_second_alignment.output
    output:
        bam = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/2_map2/{ab}.{tp}.{individual}.consistent.bam"),
        log = SCRATCH_DIR + "/MAPPABILITY_FILTER/2_map2/{ab}.{tp}.{individual}.second_alignment_log.txt"
    params:
        prefix = SCRATCH_DIR + "/MAPPABILITY_FILTER/2_map2/{ab}.{tp}.{individual}",
        mapq = config['pipelines']['2_mappability_filter']['software_params']['minMAPQ']
    conda: "environment.yml"
    shell:
        """
        python {WASP_indels}/mapping/filter_remapped_reads.py --paired_end  \
               --min_mapq {params.mapq} {input} {params.prefix}
        """

# Merge all passed alignments
rule merge_alignments:
    input:
        bam1 = rules.filter_first_alignment.output.bam,
        bam2 = rules.filter_second_alignment.output.bam
    output: temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.bam"),
    conda: "picard_env.yml"
    shell:
        'picard GatherBamFiles I={input.bam1} I={input.bam2} O={output}'

# Sort passed alignments
rule sort_merged_bam:
    input: rules.merge_alignments.output
    output:
        bam = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.sorted.bam"),
        bai = temp(SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.sorted.bai")
    conda: "picard_env.yml"
    shell:
        'picard SortSam I={input} O={output.bam} SO=coordinate CREATE_INDEX=true'


################################################################################
#############  Remove duplicates  #########################################
################################################################################

# Try three different approaches - for benchmarking purposes only

# 1 .Remove duplicate alignments randomly so as not to favour a specific allele - with Picard tool
# rule randomly_remove_duplicates_picard:
#     input:
#         bam = rules.sort_merged_bam.output.bam,
#         bai = rules.sort_merged_bam.output.bai
#     output:
#         bam = SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.rmdup_picard.bam",
#         bai = SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.rmdup_picard.bai",
#         log = LOG_DIR + "/mappability_filter/3_combined/{ab}.{tp}.{individual}.filtered.rmdup_picard.log"
#     shell:
#         """
#         {PICARD} MarkDuplicates I={input.bam} M={output.log} O={output.bam} \
#                 ASSUME_SORTED=true DUPLICATE_SCORING_STRATEGY=RANDOM \
#                 REMOVE_DUPLICATES=true CREATE_INDEX=true
#         """

# 2. Remove duplicate alignments randomly + account for UMI information - with Je tool
# MM=1 - allow 1 mismatch to consider UMIs the same; SLOTS=-1 SLOTS=-2 - expect two UMIs at the ends of the read; read_name_regex is tailored to JE clip output for 2 barcodes
rule randomly_remove_duplicates:
    input:
        bam = rules.sort_merged_bam.output.bam,
        bai = rules.sort_merged_bam.output.bai
    output:
        bam = DATA_DIR + "/ChIPseq/Alignments/with_indels/{ab}.{tp}.{individual}.filtered.rmdup_withUMI.bam",
        bai = DATA_DIR + "/ChIPseq/Alignments/with_indels/{ab}.{tp}.{individual}.filtered.rmdup_withUMI.bai",
        log = DATA_DIR + "/ChIPseq/Alignments/with_indels/{ab}.{tp}.{individual}.filtered.rmdup_withUMI.log"
    shell:
        """
        {JEv2} markdupes I={input.bam} M={output.log} O={output.bam} \
                MM=1 SLOTS=-1 SLOTS=-2 \
                READ_NAME_REGEX='(?:.*:)?([0-9]+)[^:]*:([0-9]+)[^:]*:([0-9]+)[^:]*:[^:]+:[^:]+$' \
                ASSUME_SORTED=true DUPLICATE_SCORING_STRATEGY=RANDOM \
                REMOVE_DUPLICATES=true CREATE_INDEX=true
        """

# 3. Remove duplicate (keep best scoring read) + account for UMI information - with Je tool
# rule remove_duplicates:
#     input:
#         bam = rules.sort_merged_bam.output.bam,
#         bai = rules.sort_merged_bam.output.bai
#     output:
#         bam = SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.rmdup_withUMI_nonrand.bam",
#         bai = SCRATCH_DIR + "/MAPPABILITY_FILTER/3_combined/{ab}.{tp}.{individual}.filtered.rmdup_withUMI_nonrand.bai",
#         log = LOG_DIR + "/mappability_filter/3_combined/{ab}.{tp}.{individual}.filtered.rmdup_withUMI_nonrand.log"
#     shell:
#         """
#         {JEv2} markdupes I={input.bam} M={output.log} O={output.bam} \
#                 MM=1 SLOTS=-1 SLOTS=-2 \
#                 READ_NAME_REGEX='(?:.*:)?([0-9]+)[^:]*:([0-9]+)[^:]*:([0-9]+)[^:]*:[^:]+:[^:]+$' \
#                 ASSUME_SORTED=true \
#                 REMOVE_DUPLICATES=true CREATE_INDEX=true
#         """

################################################################################
#############  Allele-specific counts  #########################################
################################################################################

# using stringent VCF set for allele counts
rule get_phased_vcf:
    """Take VCF file for parents and create phased VCF file for F1"""
    input: expand(vcf_dir + "/F1_haplotype_joint_call_GATK_{set}.vcf.gz", set = vcf_set_ac)
    output: DATA_DIR + "/Genotypes/phased_VCF/haplotypes.vcf.gz"
    message: "Generate phased VCF set"
    log: DATA_DIR + "/Genotypes/phased_VCF/haplotypes.log"
    conda: "genotypes.yml"
    threads: 1
    shell: """
           python {EXEC_DIR}/VCF/construct_phased_vcf_with_replicates_with_gl_tabix.py  \
                  -i {input} -o {output} 2> {log}
           """

rule get_phased_lenient_vcf:
    """Take VCF file for parents and create phased VCF file for F1"""
    input: expand(vcf_dir + "/F1_haplotype_joint_call_GATK_{set}.vcf.gz", set = vcf_set_mf)
    output: DATA_DIR + "/Genotypes/phased_VCF/haplotypes_lenient.vcf.gz"
    message: "Generate phased VCF set"
    log: DATA_DIR + "/Genotypes/phased_VCF/haplotypes_lenient.log"
    conda: "genotypes.yml"
    threads: 1
    shell: """
           python {EXEC_DIR}/VCF/construct_phased_vcf_with_replicates_with_gl_tabix.py  \
                  -i {input} -o {output} 2> {log}
           """

# Variant counts
rule get_counts:
    input:
        bam = rules.randomly_remove_duplicates.output.bam,
        bai = rules.randomly_remove_duplicates.output.bai,
        vcf = rules.get_phased_vcf.output
    output:
        counts = DATA_DIR + "/ChIPseq/AS_COUNTS/with_indels/1_as_counts_by_variant/{ab}.{tp}.{individual}.variant_counts.txt.gz",
        index = DATA_DIR + "/ChIPseq/AS_COUNTS/with_indels/1_as_counts_by_variant/{ab}.{tp}.{individual}.variant_counts.txt.gz.tbi"
    params:
        prefix = DATA_DIR + "/ChIPseq/AS_COUNTS/with_indels/1_as_counts_by_variant/{ab}.{tp}.{individual}",
        sample = "{individual}"
    shell:
        """
        python {WASP_indels}/mapping/get_counts.py --bam {input.bam} --vcf {input.vcf} \
            --sample {params.sample} --outprefix {params.prefix}
        """
