"""
Snakemake workflow module of the nmdrdmr pipeline.

Step 4 : Use majority votes to solve cases where elements are defined as different types across
species.
"""


rule reg_type_to_bed4:
    """
    Convert to bed4, as required for liftover.
    """
    input: "{regionfile}"
    output: temp("{regionfile}.bed4")
    shell:"""
        if ! awk -F '\\t' 'NF!=4{{exit 1}}' {input}; 
        then awk 'BEGIN {{OFS="\\t"}} {{print $1, $2, $3, NR}}' {input} > {output}; 
        else cp {input} {output}; fi
    """

def get_regions_per_regtype(regmap_folder, reg_folder, allsp, REF, reg, t):
    regm = []
    for sp in allsp:
        if sp != REF:
            regm.append(regmap_folder+'/'+reg+'-'+sp+'-'+t+'_mappedTo_'+REF+'.filtered.bed.bed4')
        else:
            regm.append(reg_folder+'/'+reg+'/'+reg+'-'+sp+'-'+t+'.bed.bed4')
    return regm

def get_signal_per_species1(regmap_folder, reg_folder, sp, REF, reg, t):
        return reg_folder+'/'+reg+'/'+reg+'-'+sp+'-'+t+'.bed'


rule merge_final_regions:
    """
    Get all mapped regions after the first liftover to the reference coordinate system and merge
    all overlapping regions of the same regulatory type into a single element.
    """
    input:
        lambda wildcards: get_regions_per_regtype(LIFTOVER, REGMAP_DIR, species, REF,
                                                  wildcards.reg_unit, wildcards.tissue)
    output: temp(LIFTOVER  + "/{reg_unit}/{tissue}/all_regions.bed")
    conda: "envs/bedtools_py.yaml"
    shell:"""
      cat {input} | bedtools sort | bedtools merge | awk 'BEGIN {{OFS="\\t"}} {{print $1, $2, $3, NR}}' > {output}
    """


rule inverse_mapping_tmp:
    """
    Liftover back to each species, so that we can already exclude regions that mapped only in one
    direction. We'll also use the results to homogenize regulatory types
    (i.e. save changes of regulatory assignment also in each species coordinates system)
    """
    input:
        chain = CHAIN,
        regions = LIFTOVER  + "/{reg_unit}/{tissue}/all_regions.bed"
    output:
        regions = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map_tmp/{sps}.bed"),
        unMapped = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map_tmp/{sps}.unMapped")
    params: minMatch = config.get("minMatch", 0.001), multiple = config.get('liftover_flag', '-multiple')
    conda: "envs/ucsc_liftover.yaml"
    shell:
        "liftOver {input.regions} {input.chain} {output.regions} {output.unMapped} "
        "-minMatch={params.minMatch} {params.multiple}"


rule filter_inverse_by_coverage_tmp:
    """
    Lenient filter to retain liftovered regions if the size of the converted regions >= 30% of the original region.
    In case of multiple matches, the largest region is selected.
    """
    input: outregions = LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map_tmp/{sps}.bed",
           inregions = LIFTOVER + "/{reg_unit}/{tissue}/all_regions.bed"
    output: regions = temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_inv_map_tmp/{sps}.filtered.bed")
    conda: "envs/bedtools_py.yaml"
    params: m = INV[config.get('liftover_flag', '-multiple')], c = config.get("liftover_coverage_cutoff", 0.30)
    shell:
        "python scripts/filter_by_coverage.py -q {input.inregions} -m {input.outregions} -c {params.c} {params.m} "
        "-o {output} && "
        "bedtools sort -i {output} > {output}_temp && mv {output}_temp {output}"


def get_regions_per_species(sp, ref, reg, ti):
    if sp == ref:
        return LIFTOVER  + f"/{reg}/{ti}/all_regions.bed"
    else:
        return LIFTOVER + f"/{reg}/{ti}/sets_inv_map_tmp/{sp}.filtered.bed",


rule all_active_regions_per_species:
    """
    Use bedtools intersect to cross first-pass "leninent" mappable regions with peaks previoulsy
    defined in each species.
    """
    input:
        all_regions = lambda wildcards: get_regions_per_species(wildcards.sps, REF,
                                                                  wildcards.reg_unit,
                                                                  wildcards.tissue),
        species_signal = lambda wildcards: get_signal_per_species1(LIFTOVER,
                                                                  REGMAP_DIR,
                                                                  wildcards.sps, REF,
                                                                  wildcards.reg_unit,
                                                                  wildcards.tissue)
    output: temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_tmp.bed")
    conda: "envs/bedtools_py.yaml"
    shell:  
        "bedtools intersect -wa -wb -a {input.all_regions} -b {input.species_signal} > {output}"

rule all_active_regions_per_species_to_ref:
    """
    Convert mappable peaks to the reference coordinate system.
    """
    input:
        sp_regions = LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_tmp.bed",
        mapping = LIFTOVER + "/{reg_unit}/{tissue}/all_regions.bed"
    output: temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions.bed")
    conda: "envs/bedtools_py.yaml"
    shell:  
        "python scripts/active_to_ref.py -i {input.sp_regions} -m {input.mapping} | bedtools sort > {output}"

rule clean_regions:
    """
    We used lenient lifotver criteria and some regions won't map reciprocally so we'll end up in some cases
    having regions corresponding to 0 peak, --> to remove.
    """
    input: mapping = LIFTOVER + "/{reg_unit}/{tissue}/all_regions.bed",
           active = lambda wildcards: expand(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions.bed",
                            tissue=wildcards.tissue, reg_unit=wildcards.reg_unit, sps=species),
    output: out = LIFTOVER + "/{reg_unit}/{tissue}/all_regions.clean.bed", tmp = LIFTOVER + "/{reg_unit}/{tissue}/actives_all_tmp.bed"
    conda: "envs/bedtools_py.yaml"
    shell: "cat {input.active} | bedtools sort > {output.tmp} && "
           "bedtools intersect -a {input.mapping} -b {output.tmp} -F 1 -f 1 -wa | bedtools sort | uniq > {output.out}"

rule extract_overlaps:
    """
    Identify regulatory regions identified as different regulatory types across species.
    Find all overlaps across the 3 sets in the reference coordinate system :
        A-B, B-C, A-C and A-B-C (A=Prom, B=Enh, C=Primed). 
    """
    input: lambda wildcards: expand(LIFTOVER + "/{reg_unit}/{tissue}/all_regions.clean.bed", tissue=wildcards.tissue, reg_unit=REG_TYPES)
    output:
        out = temp(LIFTOVER + "/{tissue}_reg-elements_overlaps.bed"),
        tmp1 = temp(LIFTOVER + "/{tissue}_reg-elements_overlaps_tmp.bed"),
        tmp2 = temp(LIFTOVER + "/{tissue}_reg-elements_overlaps_tmp2.bed")
    conda: "envs/bedtools_py.yaml"
    params: labels = REG_TYPES
    shell: "bedtools multiinter -i {input} -names {params.labels} | awk '$4>1' | bedtools merge -c 5 -o collapse > {output.tmp1}; " #all intersection between the 3 files with reg type labels
           "cat {output.tmp1} {input} | bedtools sort | bedtools merge  > {output.tmp2}; " #merge all overlapping regions
           "bedtools intersect -a {output.tmp2} -b {output.tmp1} -wb -wa | bedtools merge -c 7 -o collapse > {output.out}" #propagate reg-types labels



rule full_overlap_summary:
    """
    Report original regions for overlapping data (to make a complete summary of overlaps).
    """
    input:
        overlap =  LIFTOVER + "/{tissue}_reg-elements_overlaps.bed",
        elem = LIFTOVER + "/{reg_unit}/{tissue}/all_regions.clean.bed"
    output: temp(LIFTOVER + "/{reg_unit}/{tissue}/overlaps.bed")
    conda: "envs/bedtools_py.yaml"
    shell: "bedtools intersect -a {input.elem} -b {input.overlap} -loj > {output}"



rule solve_overlaps_majority_vote:
    """
    Use a majority vote to decide element type
    """
    input:
        overlaps = LIFTOVER + "/{tissue}_reg-elements_overlaps.bed",
        actives = lambda wildcards:\
                  expand(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions.bed",
                         tissue=wildcards.tissue, reg_unit=REG_TYPES, sps=species),
        full_overlap_summary = lambda wildcards: expand(LIFTOVER + "/{reg_unit}/{tissue}/overlaps.bed", tissue=wildcards.tissue, reg_unit=REG_TYPES)
    output: temp(LIFTOVER + "/{tissue}_vote_results_overlaps.bed")
    params: labels = REG_TYPES, species=species
    conda: "envs/bedtools_py.yaml"
    shell: "python scripts/resolve_overlap_majority_rule.py -over {input.overlaps} -i {input.actives} -oi {input.full_overlap_summary} -o {output} -names {params.labels} -sp {params.species}"


rule update_active_region_type:
    """
    Store individual regions defined as active in >= 1 species, with a reassigned element type.
    """
    input:
        vote = LIFTOVER + "/{tissue}_vote_results_overlaps.bed",
        ori = lambda wildcards:\
              expand(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions.bed",
                     tissue=wildcards.tissue, reg_unit=REG_TYPES, sps=wildcards.sps)
    output: temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_ftmp.bed")
    params: labels = REG_TYPES, reg = lambda wildcards: wildcards.reg_unit
    conda: "envs/bedtools_py.yaml"
    shell: """
    python scripts/update_region_type.py -i {input.ori} -v {input.vote} -e {params.reg} -l {params.labels} | bedtools sort | awk 'BEGIN {{OFS="\\t"}} {{print $1, $2, $3, NR}}' > {output}
    """


rule update_region_type:
    """
    Extract and merge new promoters, enhancers, primed_enhancers + write temporary IDs (Promoter_Heart_X etc)
    """
    input:
        vote = LIFTOVER + "/{tissue}_vote_results_overlaps.bed",
        ori = LIFTOVER + "/{reg_unit}/{tissue}/all_regions.clean.bed"
    output: temp(LIFTOVER + "/{reg_unit}/{tissue}/all_regions_after_solve.bed")
    params: reg = lambda wildcards: wildcards.reg_unit, reg_name = lambda wildcards: wildcards.reg_unit.replace("_", '-')
    conda: "envs/bedtools_py.yaml"
    shell: """
    python scripts/update_region_type.py -i {input.ori} -v {input.vote} -e {params.reg} --regions_only | bedtools sort | bedtools merge | awk 'BEGIN {{OFS="\\t"}} {{print $1, $2, $3, "{params.reg_name}_{wildcards.tissue}_"NR}}'> {output}
    """


rule add_ids_to_sp_element:
    """
    Add temporary IDs to regulatory elements.
    """
    input:
        sp = LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_ftmp.bed",
        summary = LIFTOVER + "/{reg_unit}/{tissue}/all_regions_after_solve.bed"
    output: temp(LIFTOVER + "/{reg_unit}/{tissue}/sets_all/{sps}.active_regions_after_solve.bed")
    conda: "envs/bedtools_py.yaml"
    shell: "bedtools intersect -wa -wb -a {input.summary} -b {input.sp} > {output}"