import glob
import itertools

def consensus_and_rep(sample, col, cons_folder, species, mark, tissues, rt="LibraryID"):

    cons, rep = [], []
    groups = {col[1]:species, col[2]:mark}
    for t in tissues:
        cons.append(cons_folder+"/Consensus_peaks-"+species+"-"+t+"-"+mark+".bed")

        rep_tmp = [i+'.narrowPeak' for i in samples if col[1] in samples[i] and samples[i][col[1]] == groups[col[1]]\
                                                       and samples[i][col[2]] == groups[col[2]]]
        rep += rep_tmp
    rep = list(set(rep))
    return cons + rep 

def consensus_and_reg(cons_folder, reg_folder, species, marks, tissues, reg_units):

    cons, reg = [], []


    for (t, m) in itertools.product(tissues, marks):
        cons.append(cons_folder+"/Consensus_peaks-"+species+"-"+t+"-"+m+".bed")

    for (t, r) in itertools.product(tissues, reg_units):
        reg.append(reg_folder+'/'+r+'/'+r+'-'+species+'-'+t+'.bed')

    return cons + reg



rule plot_consensus_peaks:
    input:
        infiles = lambda wildcards: consensus_and_rep(samples, COLUMNS, CONSENSUS_DIR,
                                                                 wildcards.species, wildcards.mark,
                                                                 tissues),
        tss = lambda wildcards: "data/TSS/TSS.biomart."+wildcards.species+'.bed'

    output:
        bar = report(PLOTS_DIR + "/consensus_peaks/barplot_nb_reg-{species}-{mark}.svg",
                     category="1. consensus peaks"),
        bx = report(PLOTS_DIR + "/consensus_peaks/boxplot_size_reg-{species}-{mark}.svg",
                    category="1. consensus peaks"),
        tss = report(PLOTS_DIR + "/consensus_peaks/tss_dist-{species}-{mark}.svg",
                     category="1. consensus peaks")
    params:
        labels_1 = ["consensus-"+i for i in tissues],
        labels_2 = lambda wildcards, 
                   input: [samples['.'.join(infile.split('.')[:-1])].get("LibraryID", "")+\
                           '-'+samples['.'.join(infile.split('.')[:-1])]["Tissue"]\
                           for i, infile in enumerate(input.infiles[len(tissues):])],
        groups = "Tissues="+','.join(tissues)
    conda: "envs/bedtools_py.yaml"
    shell:
        "python scripts/control_plots.py -i {input.infiles} -l {params.labels_1} {params.labels_2} "
        "-t '{wildcards.species} {wildcards.mark}' -o {output.bar} {output.bx} {output.tss} "
        "-tss {input.tss} --groupby {params.groups}"



rule plot_regulatory_elements:
    input:
        infiles = lambda wildcards: consensus_and_reg(CONSENSUS_DIR, REGMAP_DIR,
                                                                     wildcards.species, marks,
                                                                     tissues, REG_TYPES),
        tss = lambda wildcards: "data/TSS/TSS.biomart."+wildcards.species+'.bed'

    output:
        bar = report(PLOTS_DIR + "/regmap/barplot_nb_reg-{species}.svg",
                     category="2. regulatory elements"),
        bx = report(PLOTS_DIR + "/regmap/boxplot_size_reg-{species}.svg",
                     category="2. regulatory elements"),
        tss = report(PLOTS_DIR + "/regmap/tss_dist-{species}.svg",
                      category="2. regulatory elements")
    params:
        labels = ["consensus-"+'-'.join(i) for i in itertools.product(tissues, marks)] +\
                 ["-".join(i) for i in itertools.product(tissues, REG_TYPES)],
        groups = "Tissues="+','.join(tissues)
    conda: "envs/bedtools_py.yaml"
    shell:
        "python scripts/control_plots.py -i {input.infiles} -l {params.labels} "
        "-t '{wildcards.species}' -o {output.bar} {output.bx} {output.tss} "
        "-tss {input.tss} --groupby {params.groups}"

def mappable_reg(reg_folder, sp, tissues, reg_units):
    reg = []
    for (t, r) in itertools.product(tissues, reg_units):
        reg.append(reg_folder+'/'+r+'/'+t+'/'+sp+'.active.sp_coord.ok.bed')
    return reg


rule plot_mappable_active_regions:
    input:
        infiles = lambda wildcards: mappable_reg(MAPPABLE_REG, wildcards.species, tissues,
                                                 REG_TYPES),
        tss = lambda wildcards: "data/TSS/TSS.biomart."+wildcards.species+'.bed'

    output:
        bar =  report(PLOTS_DIR + "/mappable_regulatory_elements/barplot_nb_reg-{species}.svg",
                      category="4. mappable elements"),
        bx =  report(PLOTS_DIR + "/mappable_regulatory_elements/boxplot_size_reg-{species}.svg",
                     category="4. mappable elements"),
        tss = report(PLOTS_DIR + "/mappable_regulatory_elements/tss_dist-{species}.svg",
                      category="4. mappable elements")
    conda: "envs/bedtools_py.yaml"
    params: labels = ["-".join(i) for i in itertools.product(tissues, REG_TYPES)],
            groups = "Tissues="+','.join(tissues)
    shell:
        "python scripts/control_plots.py -i {input.infiles} -l {params.labels} "
        "-t '{wildcards.species}' -o {output.bar} {output.bx} {output.tss} "
        "-tss {input.tss} --groupby {params.groups}"



rule plot_overlaps:
  input: LIFTOVER + "/{tissue}_reg-elements_overlaps.bed"
  output:
    nb = report(PLOTS_DIR + "/solve_reg_type/{tissue}_overlap_numbers.svg", category = "3. solve regulatory type"),
    size = report(PLOTS_DIR + "/solve_reg_type/{tissue}_overlap_size.svg", category = "3. solve regulatory type")
  conda: "envs/bedtools_py.yaml"
  shell:
      "python scripts/plot_overlaps.py -i {input} -o {output}"

rule plot_overlap_solver:
  input: LIFTOVER + "/{tissue}_vote_results_overlaps.bed"
  output: report(PLOTS_DIR + "/solve_reg_type/{tissue}_overlap_solver.svg", category = "3. solve regulatory type")
  conda: "envs/bedtools_py.yaml"
  shell:
      "python scripts/plot_overlaps_solving.py -i {input} -o {output}"


## TODO: add greylist plots (venn per sp?)

# rule format_for_intervene:
#     input: MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active_regions.ok.bed"
#     output: PLOTS_DIR + "/mappable_regulatory_elements/{reg_unit}/{tissue}/sets_for_upset/{sps}.active_regions_final.bed"
#     shell:
#         "cut -f 4 {input} > {output}"

def aggregate_signal(wildcards):
    file_list =  expand(MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.ok.bed",
                        reg_unit=wildcards.reg_unit, tissue=wildcards.tissue, sps=species)
    return file_list

COLOR = {"Enhancers": '#5ebaffff', "Promoters": '#ffa34aff', "Primed_enhancers": '#5efa59ff'}


rule upset_mappable_regions:
    input: infiles = aggregate_signal
    output:
        report(PLOTS_DIR + "/mappable_regulatory_elements/{reg_unit}/{tissue}/Intervene_upset.svg",
           category="4. mappable elements")

    params:
        labels = (",").join(species),
        dir_out = lambda wildcards: os.path.dirname(PLOTS_DIR+"/mappable_regulatory_elements/"+wildcards.reg_unit+"/"+\
                                                    wildcards.tissue+"/Intervene_upset.svg"),
        c = lambda wildcards: COLOR[wildcards.reg_unit]
    conda: "envs/intervene.yaml"
    shell:
        "intervene upset -i {input.infiles} --output {params.dir_out} --sbcolor '{params.c}' --mbcolor '{params.c}'"
        " --figtype svg --type list --order degree --showzero --save-overlaps --names {params.labels}"


rule tissue_intervene:
    input: lambda wildcards: expand(REGMAP_DIR+"/{reg_unit}/{reg_unit}-{sps}-{tissue}.bed", reg_unit=["Promoters", "Enhancers"], tissue=list(tissues)[:2], sps=wildcards.sps)
    output: fig = temp(PLOTS_DIR +"/{sps}_before_mapping/Intervene_upset.svg"), summary = temp(PLOTS_DIR+"/{sps}_before_mapping/Intervene_upset_combinations.txt")
    params:         
        labels = ','.join(['_'.join(i) for i in itertools.product(["Promoters", "Enhancers"], list(tissues)[:2])]),
        dir_out = lambda wildcards:  os.path.dirname(PLOTS_DIR+"/"+wildcards.sps+"_before_mapping/Intervene_upset.svg")
    conda: "envs/intervene.yaml"
    shell:
        "intervene upset -i {input} --output {params.dir_out} --names {params.labels}"
        " --figtype svg --ninter=20 --showzero --save-overlaps"

rule tissue_intervene_after_map:
    input: lambda wildcards: expand(MAPPABLE_REG+"/{reg_unit}/{tissue}/{sps}.active.ok.bed", reg_unit=["Promoters", "Enhancers"], tissue=list(tissues)[:2], sps=wildcards.sps)
    output: fig = temp(PLOTS_DIR +"/{sps}_after_mapping/Intervene_upset.svg"), summary = temp(PLOTS_DIR+"/{sps}_after_mapping/Intervene_upset_combinations.txt")
    params:         
        labels = ','.join(['_'.join(i) for i in itertools.product(["Promoters", "Enhancers"], list(tissues)[:2])]),
        dir_out = lambda wildcards:  os.path.dirname(PLOTS_DIR+"/"+wildcards.sps+"_after_mapping/Intervene_upset.svg")
    conda: "envs/intervene.yaml"
    shell:
        "intervene upset -i {input} --output {params.dir_out} --names {params.labels}"
        " --figtype svg --ninter=20 --showzero --save-overlaps"


rule tissue_heatmap:
    input: PLOTS_DIR+"/{sps}_before_mapping/Intervene_upset_combinations.txt"
    output: report(PLOTS_DIR + "/{sps}_before_mapping/tissue_overlap.svg", category="2. regulatory elements")
    conda: "envs/bedtools_py.yaml"
    params: reg = "Promoters,Enhancers", t = ",".join(list(tissues)[:2])
    shell: "python scripts/tissue_heatmaps.py -i {input} -o {output} -r {params.reg} -t {params.t}"


rule tissue_heatmap_after_map:
    input: PLOTS_DIR+"/{sps}_after_mapping/Intervene_upset_combinations.txt"
    output: report(PLOTS_DIR + "/{sps}_after_mapping/tissue_overlap.svg", category="4. mappable elements")
    conda: "envs/bedtools_py.yaml"
    params: reg = "Promoters,Enhancers", t = ",".join(list(tissues)[:2])
    shell: "python scripts/tissue_heatmaps.py -i {input} -o {output} -r {params.reg} -t {params.t}"


rule post_norm_dendrograms:
    input: READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_fpkm_normalized.csv'
    output: report(READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_clustering.pdf', category="5. read density")
    conda: "envs/bclust.yaml"
    shell: "Rscript scripts/dendrogram.R -i {input} -o {output}"


rule plot_normalized_coverage:
  input: READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_fpkm_normalized.csv'
  output: report(READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_fpkm_normalized.svg', category="5. read density")
  conda: "envs/bedtools_py.yaml"
  shell: "python scripts/plot_normed_data.py -i {input} -o {output}"

