"""
Snakemake workflow module of the nmdrdmr pipeline.

Step 5 : liftover pass 2, get the set of strcitly mappable elements (across all species).
"""

######################################################################################
############## CREATE A 'LENIENT' MAPPABLE REGION TABLE BETWEEN SPECIES ##############
######################################################################################

rule inverse_mapping:
    """
    Now, we have a list of elements in the reference coordinates, where we have solved regulatory
    types for ambiguous cases.
    Here, we map these elements back to each individual species coordinates using liftover.
    This is necessary because regions may have changed slightly when we merged regions which were
    prevously assiged to different regulatory types.
    """
    input:
        chain = CHAIN,
        regions = LIFTOVER  + "/{reg_unit}/{tissue}/all_regions_after_solve.bed"
    output:
        regions = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map/{sps}.bed"),
        unMapped = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map/{sps}.unMapped")
    params: minMatch = config.get("minMatch", 0.001), multiple = config.get('liftover_flag', '-multiple')
    conda: "envs/ucsc_liftover.yaml"
    shell:
        "liftOver {input.regions} {input.chain} {output.regions} {output.unMapped} "
        "-minMatch={params.minMatch} {params.multiple}"


rule filter_inverse_by_coverage:
    """
    Lenient filter to retain liftovered regions if the size of the converted regions >= 30% of the original region.
    In case of multiple matches, the largest region is selected.
    """
    input: outregions = LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map/{sps}.bed",
           inregions = LIFTOVER + "/{reg_unit}/{tissue}/all_regions_after_solve.bed"
    output: regions = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map/{sps}.filtered.bed")
    conda: "envs/bedtools_py.yaml"
    params: cutoff = config.get("liftover_coverage_cutoff", 0.30), m = INV[config.get('liftover_flag', '-multiple')]
    shell:
        "python scripts/filter_by_coverage.py -q {input.inregions} -m {input.outregions} {params.m} "
        "-c {params.cutoff} -o {output} && "
        "bedtools sort -i {output} > {output}_temp && mv {output}_temp {output}"


def aggregate_inv_map(wildcards):
    return expand(LIFTOVER  + "/{reg_unit}/{tissue}/sets_inv_map/{sps}.filtered.bed",
                  tissue=wildcards.tissue, reg_unit=wildcards.reg_unit,
                  sps=[i for i in species if i !=REF])

rule mappable_regions:
    """
    Extract regions that could be mapped back and forth in each species, using a lenient liftover
    criteria.
    """
    input:
        invmaps = aggregate_inv_map,
        regions = LIFTOVER  + "/{reg_unit}/{tissue}/all_regions_after_solve.bed"
    output:
        filtered = temp(LIFTOVER + "/{reg_unit}/{tissue}/all_regions_filtered.bed"),
        summary = temp(LIFTOVER + "/{reg_unit}/{tissue}/all_regions_summary.bed")
    params: labels = [REF] + [i for i in species if i !=REF]
    conda: "envs/bedtools_py.yaml"
    shell:
        "python scripts/summarize_mappable_regions.py -i {input.invmaps} -o {output.filtered} -r "
        "{input.regions} -s {output.summary} -l {params.labels}"


rule make_summary_file:
    """
    Make one table per tissue, summarizing the 'lenient' mappable regions.
    """
    input:
        active = lambda wildcards:\
                 expand(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_after_solve.bed",
                        tissue=wildcards.tissue, reg_unit=wildcards.reg_unit,
                        sps=[REF] + [i for i in species if i !=REF]),
        map_summary = LIFTOVER + "/{reg_unit}/{tissue}/all_regions_summary.bed"
    output: LIFTOVER + "/{reg_unit}/{tissue}/summary_first_pass.bed"
    conda: "envs/bedtools_py.yaml"
    params:
        labels = [REF] + [i for i in species if i !=REF],
    shell:
        "python scripts/make_res_summary.py -r {input.active} -s {input.map_summary} -o {output} "
        "-l {params.labels}"

rule mappable_active_regions_per_species:
    """
    For each regions in the set of 'lenient' mappable regions, extract the one active in each
    species (i.e. with a peak).
    """
    input:
        mappable_regions = LIFTOVER + "/{reg_unit}/{tissue}/summary_first_pass.bed",
        species_signal = LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_after_solve.bed"
    output:
            map_active = LIFTOVER + "/{reg_unit}/{tissue}/sets_mappable_lenient/{sps}.active_regions.bed",
            t = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_mappable_{sps}_tmp.txt")
    conda: "envs/bedtools_py.yaml"
    shell:
        "grep -v UnMapped {input.mappable_regions} | cut -f 4 > {output.t}; grep -Fwf {output.t} {input.species_signal} > {output.map_active}"


######################################################################################
##############  CREATE A STRICT MAPPABLE REGION TABLE BETWEEN SPECIES  ###############
######################################################################################

def species_code(spname):
    return spname[0] + spname.split('_')[-1][0:3]

rule extract_sort_mappable_regions_per_species:
    """
    Prepare data for strict liftover: extract and sort 'lenient' mappable regions in each sp. coord.
    """
    input: LIFTOVER + "/{reg_unit}/{tissue}/summary_first_pass.bed"
    output: temp(expand(MAPPABLE_REG + "/{{reg_unit}}/{{tissue}}/{sps}.bed", sps=species))
    params: output_dir = lambda wildcards: MAPPABLE_REG + f"/{wildcards.reg_unit}/{wildcards.tissue}/"
    conda: "envs/bedtools_py.yaml"
    shell:
        "grep -v UnMapped {input} | python scripts/extract_mappable_regions.py -o {params.output_dir}"


rule liftover_for_strict_map:
    """
    First strict liftover for reciprocal mapping definition: each sp. to reference (mouse).
    """
    input:
        chain = LIFTOVER+"/lastZ/"+REFASSMBL+".v.{sps}_lastz_net.all.ok.chain",
        regions = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.bed"
    output:
        regions = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/lift.{sps}.bed"),
        unMapped = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/lift.{sps}.unMapped")
    params: minMatch = 0.001, multiple = config.get('liftover_flag', '-multiple')
    conda: "envs/ucsc_liftover.yaml"
    shell:
        "liftOver {input.regions} {input.chain} {output.regions} {output.unMapped} "
        "-minMatch={params.minMatch} {params.multiple}"

rule check_liftovered_regions:
    """
    Check the first liftover pass: must remap to the original mouse region defined in the lenient table.
    (i.e. check that mapping is reciprocal.)
    """
    input:
        lift =  MAPPABLE_REG + "/{reg_unit}/{tissue}/lift.{sps}.bed",
        ori = LIFTOVER + "/{reg_unit}/{tissue}/summary_first_pass.bed"
    output:
        out = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/lift.ok.{sps}.bed"),
        tmp = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/lift.ok.{sps}.tmp.bed")
    params:  m = INV[config.get('liftover_flag', '-multiple')]
    conda: "envs/bedtools_py.yaml"
    shell:
        "python scripts/check_rec_mapping.py -l {input.lift} -i {input.ori} -o {output.tmp} {params.m} && "
        "bedtools sort -i {output.tmp} > {output.out}"

rule intersect_regions_for_strict_map:
    """
    Intersect regions from all species with a reciprocal mapping to the original mouse.
    """
    input:
        inf = expand(MAPPABLE_REG + "/{{reg_unit}}/{{tissue}}/lift.ok.{sps}.bed", sps=[i for i in species if i!=REF]),
        ori = LIFTOVER + "/{reg_unit}/{tissue}/summary_first_pass.bed"
    output:
        out = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/intersect_mappable_regions.bed"),
        tmp = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/tmp_intersect_mappable_regions.bed"),
        ori = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/tmp_summary.bed")
    conda: "envs/bedtools_py.yaml"
    shell:
        "grep -v UnMapped {input.ori} | cut -f 1,2,3,4 | tail -n +2 > {output.ori}; "
        "bedtools multiinter -i {input.inf} | awk '$4 == 3' | cut -f 1,2,3 > {output.tmp}; " #find overlaping in all
        "bedtools intersect -a {output.tmp} -b {output.ori}  -wb -wa > {output.out}" #add elements IDs

rule find_biggest_intersection:
    """
    Identify the largest (sub)-region that remap to ref. (mouse) from all other species.
    """
    input: MAPPABLE_REG + "/{reg_unit}/{tissue}/intersect_mappable_regions.bed"
    output: temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/strictly_mappable_regions_set.bed")
    conda: "envs/bedtools_py.yaml"
    shell:
        "python scripts/get_strictly_mappable.py -i {input} -o {output}"

rule inverse_liftover_regions_for_strict_map:
    """
    Second strict liftover pass, from ref to other species, for the largest subregions with a reciprocal mapping in all.
    """
    input:
        chain = CHAIN,
        regions =  MAPPABLE_REG + "/{reg_unit}/{tissue}/strictly_mappable_regions_set.bed"
    output:
        regions = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/liftover.strict.reg.{sps}.bed"),
        unMapped = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/liftover.{sps}.unMapped")
    params: minMatch = 0.30, multiple = config.get('liftover_flag', '')
    conda: "envs/ucsc_liftover.yaml"
    shell:"""
        liftOver {input.regions} {input.chain} {output.regions} {output.unMapped} -minMatch={params.minMatch} {params.multiple};\
        """

def get_index(r, t, species):
    inp = []
    reg_marks_dict = {"Primed_enhancers":"H3K4me1", "Promoters": "H3K27ac", "Enhancers": "H3K27ac"}

    for s in species:
        rep = [samples[i][lib] for i in samples if sp in samples[i] and samples[i][ti] == t\
               and samples[i][ma] == reg_marks_dict[r] and samples[i][sp] == s][0]

        inp.append(READ_COVERAGE + "/bams/"+rep+"_index.txt")
    return inp

rule filter_strictly_mappable_regions_per_species:
    """
    Final filter after the two strict liftover: check that mouse subregions are mapped to regions of similar sizes in each sp. (>85%)?
    """
    input:
        a = MAPPABLE_REG + "/{reg_unit}/{tissue}/strictly_mappable_regions_set.bed",
        b = expand(MAPPABLE_REG + "/{{reg_unit}}/{{tissue}}/liftover.strict.reg.{sps}.bed", sps=[i for i in species if i!=REF]),
        c = LIFTOVER + "/{reg_unit}/{tissue}/summary_first_pass.bed"
    output: temp(expand(MAPPABLE_REG + "/{{reg_unit}}/{{tissue}}/{sps}.strict.bed", sps=[REF] + [i for i in species if i!=REF]))
    params: labels = ','.join([i for i in species if i!=REF]), cutoff = config.get("strict_liftover_cutoff", 0.85)
    conda: "envs/bedtools_py.yaml"
    shell:
        "python scripts/get_strictly_mappable_per_sp.py -s {input.a} -l {params.labels} "
        "-i {input.b} -ref {input.c} -o {output} -c {params.cutoff}"



######################################################################################
##############  REMOVE GREYLISTED ELEMENTS AND ONE-TO-MANY REGIONS  ##################
######################################################################################


rule extract_redundant:
    """
    After the liftovers, we can still have one-to-many ortho regions (one in ref == many in others)
    """
    input: MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.bed"
    output: temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.redundant.txt")
    conda: "envs/bedtools_py.yaml"
    shell: "bedtools intersect -a {input} -b {input} | cut -f 4 | sort | uniq -d > {output}"

rule all_redundant:
    """
    Concatenate all redundant regions (i.e. one-to-many in any species).
    """
    input: expand(MAPPABLE_REG + "/{{reg_unit}}/{{tissue}}/{sps}.redundant.txt", sps=species)
    output: temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/redundant.txt")
    shell: "cat {input} | sort | uniq > {output}"

rule greylisted_reg_elements_per_sp:
    """
    Cross mappable regions with greylists defined for each species.
    """
    input:
        grey = GREYLIST + '/results/grey_{sps}_{tissue}.bed',
        reg = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.bed"
    output: GREYLIST + '/results/greylisted_{reg_unit}_{sps}_{tissue}.txt'
    conda: "envs/bedtools_py.yaml"
    shell: "bedtools intersect -a {input.reg} -b {input.grey} | cut -f 4 > {output}"

rule all_greylisted_reg_elements:
    """
    Extract all elements greylisted in at least one species.
    """
    input: expand(GREYLIST + '/results/greylisted_{{reg_unit}}_{sps}_{{tissue}}.txt', sps=species)
    output: GREYLIST + '/results/greylisted_{reg_unit}_{tissue}_all.txt'
    shell: "cat {input} | sort | uniq > {output}"


rule remove_greylisted_and_redundant:
    """
    Remove greylisted and redundant elements fron the set of mappable regions.
    """
    input:
        grey = GREYLIST + '/results/greylisted_{reg_unit}_{tissue}_all.txt',
        red = MAPPABLE_REG + "/{reg_unit}/{tissue}/redundant.txt",
        reg = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.bed"
    output:
        out = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.ok.bed",
        tmp = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.to_remove.txt")
    shell:
        "cat {input.red} {input.grey} | sort | uniq > {output.tmp}; "
        "grep -vFwf {output.tmp} {input.reg} > {output.out}"


rule prepare_active_regions:
    """
    Prepare an updated bed file of peaks in each species.
    """
    input: LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_after_solve.bed"
    output: temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.bed")
    shell:"""
        awk 'BEGIN {{OFS="\\t"}} {{print $5, $6, $7, $4}}' {input} > {output}
    """

rule cross_with_active_regions:
    """
    Cross mappable regions with peaks, to obtain files of mappable active elements in each species.
    """
    input:
        active = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.bed", 
        regions =  MAPPABLE_REG + "/{reg_unit}/{tissue}/"+REF+".strict.ok.bed"
    output: MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.ok.bed"
    conda: "envs/bedtools_py.yaml"
    shell:
        "bedtools intersect -a {input.regions} -b {input.active} -wa > {output}"


rule cross_with_active_regions_sp_coord:
    """
    Build files of mappable active elements in each species, in their respective coord. system.
    """
    input:
        active = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.ok.bed", 
        regions =  MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.strict.ok.bed"
    output:
        out = MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.sp_coord.ok.bed",
        tmp = temp(MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.sp_coord_tmp")
    conda: "envs/bedtools_py.yaml"
    shell:
        "cut -f 4 {input.active} > {output.tmp} && "
        "grep -Fwf {output.tmp} {input.regions} > {output.out}"