"""
Snakemake workflow module of the nmdrdmr pipeline.

Step 6 : extract and normalize read coverage in mappable regulatory elements.
"""


import glob
import pandas as pd

BAM_DIR = config.get("bam_folder", "../data")
if BAM_DIR[-1] == '/':
    BAM_DIR = BAM_DIR[:-1]

def get_bam(folder, lib):
    path = f"{folder}/{lib}/*.bam"
    bam = glob.glob(f"{folder}/{lib}/*.bam")

    if not bam or len(bam) > 1:
        print(f"Warning: Cannot find .bam file expected in {folder}/{lib}/")
    else:
        bam = bam.pop() #Fix this (not pretty)
        bam = bam.replace('//', '/')

    return bam



if config.get('bam_download', ''):

    rule download_bams:
        """
        Download bam files
        """
        input: config['bam_download']
        output: protected(touch(READ_COVERAGE + '/.bamDownload_ok_tmp'))
        params: bamdir = BAM_DIR
        shell: "mkdir -p {params.bamdir}/ && wget -nc -i {input} -P {params.bamdir}/ && unzip '{params.bamdir}/*.zip' -d {params.bamdir}/ && rm {params.bamdir}/*.zip"
 
    #alternative code, for other input type (file with two columns: lib, url) wget each line into the corresponding lib folder
    #"bash scripts/download_bams.sh {input} {params.bamdir}"

    rule organize_bams:
        input: READ_COVERAGE + '/.bamDownload_ok_tmp'
        output: touch(READ_COVERAGE  + '/.bamDownload_ok')
        params: bamdir = BAM_DIR
        shell: 'bash scripts/move_bams.sh {params.bamdir}'

else:

    rule pass_download:
        output: touch(READ_COVERAGE + '/.bamDownload_ok')

if not config.get("sorted_bam", False):

    rule sort_bams:
        """
        Sort input bams to use BAMscale cov to extract read coverage.
        """
        input: c = READ_COVERAGE + '/.bamDownload_ok'
        output:
            bam = READ_COVERAGE + "/bams/{library}.sorted.bam",
        conda: "envs/sambamba.yaml"
        params: bam = lambda wildcards: get_bam(BAM_DIR, wildcards.library)
        threads: 1
        shell:
            "sambamba sort {params.bam} -t {threads} --memory-limit=1G -o {output.bam} --tmpdir tmp && "
            "sambamba index -t {threads} {output.bam}"
else:

    rule pass_sort_bams:
        """
        Do not sort bams if user specifies in config that bams are sorted already.
        """
        input: bam = lambda wildcards: get_bam(BAM_DIR, wildcards.library), c = READ_COVERAGE + '/.bamDownload_ok'
        output: bam = READ_COVERAGE  + "/bams/{library}.sorted.bam"
        threads: 1
        shell:
            "ln -s {input.bam} {output.bam}"


rule index_bam_for_coverage:
    """
    Index bams to use BAMscale cov to extract read coverage.
    """
    input:
        bam = READ_COVERAGE  + "/bams/{library}.sorted.bam",
    output: READ_COVERAGE + "/bams/{library}_index.txt"
    conda: "envs/samtools.yaml"
    shell:
        "samtools idxstats {input} | cut -f 1-2 | sed -e '$ d' > {output}"


rule prepare_bed_for_coverage:
    """
    Sort regulatory regions bed files using the bam index, to use BAMscale cov.
    """
    input:
        bed = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.ok.bed",
        index = READ_COVERAGE + "/bams/{library}_index.txt"
    output: temp(READ_COVERAGE + "/{reg_unit}/{tissue}/{sps}_sorted_for_{library}.bed")
    conda: "envs/bedtools_py.yaml"
    shell:
        "bedtools sort -faidx {input.index} -i {input.bed} > {output}"


rule reads_coverage:
    """
    Extract read coverage in regulatory elements.
    """
    input:
        bed = READ_COVERAGE + "/{reg_unit}/{tissue}/{sps}_sorted_for_{library}.bed",
        bam =  READ_COVERAGE  + "/bams/{library}.sorted.bam",
        index = READ_COVERAGE + "/bams/{library}_index.txt",
        grey = GREYLIST + '/results/grey_{sps}_{tissue}.bed'
    output:
        READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{library}.FPKM_normalized_coverages.tsv",
        READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{library}.raw_coverages.tsv"
    conda: "envs/bamcov.yaml"
    threads: 5
    params: odir = READ_COVERAGE
    shell:
        "BAMscale cov -q 13 --bed {input.bed} -o '{params.odir}/bamscale_cov/{wildcards.reg_unit}/{wildcards.tissue}' "
        "--bam {input.bam} -t {threads} -n '{wildcards.sps}.{wildcards.library}' --bedsubtract {input.grey}"


rule input_reads_coverage:
    """
    Extract input read coverage in regulatory elements.
    """
    input:
        bed = READ_COVERAGE + "/{reg_unit}/{tissue}/{sps}_sorted_for_{control_lib}.bed",
        bam_input = READ_COVERAGE  + "/bams/{control_lib}.sorted.bam",
        index = READ_COVERAGE  + "/bams/{control_lib}_index.txt",
        grey = GREYLIST + '/results/grey_{sps}_{tissue}.bed'
    output:
        READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{control_lib}.input_for_{library}.FPKM_normalized_coverages.tsv",
        READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{control_lib}.input_for_{library}.raw_coverages.tsv"
    conda: "envs/bamcov.yaml"
    params: odir = READ_COVERAGE
    threads: 4
    shell:
        "BAMscale cov -q 13 --bed {input.bed} -o '{params.odir}/bamscale_cov/{wildcards.reg_unit}/{wildcards.tissue}' "
        "--bam {input.bam_input} -t {threads} -n '{wildcards.sps}.{wildcards.control_lib}.input_for_{wildcards.library}' --bedsubtract {input.grey}"


rule format_reads_coverage:
    """
    Transform BAMscale cov output (tsv) to .bed.
    """
    input: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.FPKM_normalized_coverages.tsv"
    output: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.FPKM_normalized_coverages.bed"
    shell:
        "sed 's/[:-]/\t/g' {input} | tail -n +2 > {output}"


rule format_reads_coverage2:
    """
    Put elements IDs back in BAMscale cov FPKM output.
    """
    input:
        rcov = READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.FPKM_normalized_coverages.bed",
        ref = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.ok.bed"
    output: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.FPKM_normalized_coverages_withIDs_0.85.bed"
    conda: "envs/bedtools_py.yaml"
    shell:
        "bedtools intersect -f 1 -r -a {input.rcov} -b {input.ref} -wb | cut -f 1,2,3,4,8 > {output}"



def get_cov_per_elem_per_tissue(r, t, species, ctrl=False):

    inp = []
    reg_marks_dict = {"Promoters": "H3K27ac", "Enhancers": "H3K27ac"}

    for s in species:
        rep = [(samples[i][lib], i) for i in samples if sp in samples[i] and samples[i][ti] == t\
               and samples[i][ma] == reg_marks_dict[r] and samples[i][sp] == s]

        for i, name in rep:

            if not ctrl:

                tmp1 = READ_COVERAGE + f"/bamscale_cov/{r}/{t}/{s}.{i}.FPKM_normalized_coverages_withIDs_0.85.bed"

                inp.append(tmp1)

            else:

                ctrl_lib = samples[name]["inputlib"]

                tmp2 = READ_COVERAGE + f"/bamscale_cov/{r}/{t}/{s}.{ctrl_lib}.input_for_{i}.FPKM_normalized_coverages_withIDs_0.85.bed"

                inp.append(tmp2)

    return inp                    


rule read_coverage_table_signal:
    """
    Aggregate read coverage data (signal and input) from all species and replicated for a given
    Regulatory_element - Tissue combination.
    """
    input: signal = lambda wildcards: get_cov_per_elem_per_tissue(wildcards.reg_unit, wildcards.tissue, species),
           control = lambda wildcards: get_cov_per_elem_per_tissue(wildcards.reg_unit, wildcards.tissue, species, ctrl=True)
    output:
        table1 = READ_COVERAGE + '/tables_for_eve/raw/{reg_unit}_{tissue}_fpkm_signal.csv',
        table2 = READ_COVERAGE + '/tables_for_eve/raw/{reg_unit}_{tissue}_fpkm_input.csv',
        plot1 = report(READ_COVERAGE + '/tables_for_eve/raw/{reg_unit}_{tissue}_fpkm_signal.svg', category="5. read density"),
        plot2 = report(READ_COVERAGE + '/tables_for_eve/raw/{reg_unit}_{tissue}_fpkm_input.svg', category="5. read density")
    params: prefix = lambda wildcards: READ_COVERAGE + '/tables_for_eve/raw/' +wildcards.reg_unit +'_' + wildcards.tissue +'_'
    conda: 'envs/bedtools_py.yaml'
    shell: "python scripts/fpkm_table.py -s {input.signal} -c {input.control} -p {params.prefix}"


rule qnormalize_read_density_for_eve:
    """
    Normalize read coverage across species and output a table which can be used for phylogenetic
    modelling under the eve model.
    Normalization is quantile_normalization of log2(signal_fpkm/input_fpkm)
    """
    input:
        signal = READ_COVERAGE + '/tables_for_eve/raw/{reg_unit}_{tissue}_fpkm_signal.csv',
        control = READ_COVERAGE + '/tables_for_eve/raw/{reg_unit}_{tissue}_fpkm_input.csv'
    output: READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_fpkm_normalized.csv'
    conda: "envs/quantnorm.yaml"
    shell: "Rscript scripts/quantile_normalize_read_cov.R -s {input.signal} -c {input.control} -o {output}"



#################### ALTERNATIVE NORMALIZATION, NOT USED BY DEFAULT ####################

#TODO: add an option in config to use this 'diffbind-like' normalization instead.

rule format_raw_reads_coverage:
    """
    Transform BAMscale cov raw count output (tsv) to .bed.
    """
    input: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.raw_coverages.tsv"
    output: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.raw_coverages.bed"
    shell:
        "sed 's/[:-]/\t/g' {input} | tail -n +2 > {output}"


rule format_raw_reads_coverage2:
    """
    Put elements IDs back in BAMscale cov raw count output.
    """
    input:
        rcov = READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.raw_coverages.bed",
        ref = MAPPABLE_REG +"/{reg_unit}/{tissue}/{sps}.strict.ok.bed"
    output: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{tag}.raw_coverages_withIDs_0.85.bed"
    conda: "envs/bedtools_py.yaml"
    shell:
        "bedtools intersect -f 1 -r -a {input.rcov} -b {input.ref} -wb | cut -f 1,2,3,4,8 > {output}"


rule normalize_raw_coverage:
    """
    Normalize raw count coverage with input library (subtract input).
    This is not the normalization used by default (the rule will not be run).
    It could be added as an option in the config file.
    """
    input: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{library}.raw_coverages_withIDs_0.85.bed",
           READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{control_lib}.input_for_{library}.raw_coverages_withIDs_0.85.bed"

    output: READ_COVERAGE + "/bamscale_cov/{reg_unit}/{tissue}/{sps}.{library}.minus.{control_lib}.raw.bed"
    run:
        import pandas as pd
        import numpy as np

        df1 = pd.read_csv(input[0], sep = '\t', names=["ch", "st", "end", "count_s", "id"])
        df2 = pd.read_csv(input[1], sep = '\t', names=["ch", "st", "end", "count_i", "id"])
        df2.drop(["ch", "st", "end"], inplace=True, axis=1)
        df = df1.merge(df2, on='id')
        df["normed"] = df["count_s"] - df["count_i"]
        df["normed"][ df["normed"] < 0] = 0
        df.drop(["count_s", "count_i"], inplace=True, axis=1)
        df.to_csv(output[0], sep='\t', header=False, index=False)

