# Get read pairing
if len(config['fastq']) == 1:
    pair = False
elif len(config['fastq']) == 2:
    pair = True
else:
    raise ValueError('1 or 2 input fastq files expected')

# Generate a filtered BAM file with duplicates removed
rule all:
    input:
        config['out_prefix'] + '.filtered.rmdup.bam',
        config['out_prefix'] + '.filtered.rmdup.bai',
        config['out_prefix'] + '.variant_counts.txt.gz',
        config['out_prefix'] + '.variant_counts.txt.gz.tbi',
        config['out_prefix'] + '.first_alignment_log.txt',
        config['out_prefix'] + '.second_alignment_log.txt',
        config['out_prefix'] + '.rmdup_log.txt',
        config['out_prefix'] + '.variant_counts_log.txt'

# Align initial fastq reads to the genome
rule first_alignment:
    input:
        config['fastq']
    output:
        temp(config['out_prefix'] + '.initial.sam')
    params:
        fasta = config['fasta']
    shell:
        'bwa mem -M {params[fasta]} {input} > {output}'

# Convert initial alignment to coordinate sorted BAM file
rule sort_first_alignment:
    input:
        config['out_prefix'] + '.initial.sam'
    output:
        bam = temp(config['out_prefix'] + '.initial.bam'),
        bai = temp(config['out_prefix'] + '.initial.bai')
    shell:
        'picard SortSam -I {input} -O {output.bam} -SO coordinate '
        '--CREATE_INDEX true --USE_JDK_DEFLATER true '
        '--USE_JDK_INFLATER true'

# Place alignments not overlapping variants in a BAM file and extract reads
# overlapping variants, and their allele flipped versions, to a fastq file.
rule filter_first_alignment:
    input:
        bam = config['out_prefix'] + '.initial.bam',
        bai = config['out_prefix'] + '.initial.bai',
        vcf = config['vcf']
    output:
        bam = temp(config['out_prefix'] + '.no_variants.bam'),
        fastq = temp(config['out_prefix'] + '.allele_flipped.fq.gz'),
        log = config['out_prefix'] + '.first_alignment_log.txt'
    params:
        script = config['wasp'] + '/generate_variant_reads.py',
        prefix = config['out_prefix'],
        mapq = config['min_mapq']
    shell:
        "python {params[script]} --min_mapq {params[mapq]} --bam {input.bam} "
        "--vcf {input.vcf} --out_prefix {params[prefix]}"

# Realign reads along with their allele flipped versions
rule second_alignment:
    input:
        config['out_prefix'] + '.allele_flipped.fq.gz'
    output:
        temp(config['out_prefix'] + '.allele_flipped.sam')
    params:
        fasta = config['fasta'],
        paired = lambda wildcards, pair=pair: '-p' if pair else ''
    shell:
        'bwa mem -M {params[paired]} {params[fasta]} {input} > '
        '{output}'

# Sort second alignment by read name
rule sort_second_alignment:
    input:
        config['out_prefix'] + '.allele_flipped.sam'
    output:
        temp(config['out_prefix'] + '.allele_flipped.bam')
    shell:
        'picard SortSam -I {input} -O {output} -SO queryname '
        '--USE_JDK_DEFLATER true --USE_JDK_INFLATER true'

# Extract reads for which the original and allele flipped version align
# to the same location
rule filter_second_alignment:
    input:
        config['out_prefix'] + '.allele_flipped.bam'
    output:
        temp(config['out_prefix'] + '.consistent.bam'),
        config['out_prefix'] + '.second_alignment_log.txt',
    params:
        script = config['wasp'] + '/filter_remapped_reads.py',
        prefix = config['out_prefix'],
        mapq = config['min_mapq'],
        wobble = config['wobble']
    shell:
        'python {params[script]} --wobble {params[wobble]} '
        '--min_mapq {params[mapq]} {input} {params[prefix]}'

# Merge all passed alignments
rule merge_alignments:
    input:
        bam1 = config['out_prefix'] + '.no_variants.bam',
        bam2 = config['out_prefix'] + '.consistent.bam'
    output:
        temp(config['out_prefix'] + '.filtered.bam')
    shell:
        'picard GatherBamFiles -I {input.bam1} -I {input.bam2} -O {output}'

# Sort passed alignments
rule sort_merged_bam:
    input:
        config['out_prefix'] + '.filtered.bam',
    output:
        bam = temp(config['out_prefix'] + '.filtered.sorted.bam'),
        bai = temp(config['out_prefix'] + '.filtered.sorted.bai')
    shell:
        'picard SortSam -I {input} -O {output.bam} -SO coordinate '
        '--CREATE_INDEX true'

# Remove duplicate alignments randomly so as not to favour a specific allele
rule randomly_remove_duplicates:
    input:
        bam = config['out_prefix'] + '.filtered.sorted.bam',
        bai = config['out_prefix'] + '.filtered.sorted.bai'
    output:
        bam = config['out_prefix'] + '.filtered.rmdup.bam',
        bai = config['out_prefix'] + '.filtered.rmdup.bai',
        log = config['out_prefix'] + '.rmdup_log.txt'
    shell:
        'picard MarkDuplicates -I {input.bam} -M {output.log} -O {output.bam} '
        '--ASSUME_SORTED --DUPLICATE_SCORING_STRATEGY RANDOM '
        '--REMOVE_DUPLICATES true --REMOVE_SEQUENCING_DUPLICATES true '
        '--CREATE_INDEX true --USE_JDK_DEFLATER true --USE_JDK_INFLATER true'

# Variant counts
rule get_counts:
    input:
        bam = config['out_prefix'] + '.filtered.rmdup.bam',
        bai = config['out_prefix'] + '.filtered.rmdup.bai',
        vcf = config['vcf']
    output:
        counts = config['out_prefix'] + '.variant_counts.txt.gz',
        index = config['out_prefix'] + '.variant_counts.txt.gz.tbi',
        log = config['out_prefix'] + '.variant_counts_log.txt'
    params:
        script = config['wasp'] + '/get_counts.py',
        prefix = config['out_prefix'],
        sample = config['sample']
    shell:
        'python {params[script]} --bam {input.bam} --vcf {input.vcf} '
        '--sample {params[sample]} --outprefix {params[prefix]}'
