
#GLOBAL VARIABLE DEFINITION FOR OUTPUT PATHS
DATADIR = config.get("peak_data_folder", "data/peaks")
CONSENSUS_DIR = 'out_'+config.get('jobname', 'nmrdmr')+'/consensus_peaks'
REGMAP_DIR = 'out_'+config.get('jobname', 'nmrdmr')+'/regulatory_elements'
PLOTS_DIR = 'out_'+config.get('jobname', 'nmrdmr')+'/plots'
LIFTOVER = 'out_'+config.get('jobname', 'nmrdmr')+'/liftovers'
MAPPABLE_REG = 'out_'+config.get('jobname', 'nmrdmr')+'/mappable_regulatory_elements'
READ_COVERAGE = 'out_'+config.get('jobname', 'nmrdmr')+'/coverage'
GREYLIST = 'out_'+config.get('jobname', 'nmrdmr')+'/greylists'


#parse config params
SAMPLESHEET = config["samplesheet"]
TARGET = config.get("peak_file_column", "Peak file")
SAMPLES_COL = config["colnames"]
COLUMNS = list(SAMPLES_COL.values())
IGNORE_REP = config.get("ignore_replicate", None)
COMMENT_TAGS = config.get("comment_tags", [])
REF = config["mapping_ref_sp"]

#default is to predict promoters, enhancers and primed enhancers.
#it can be changed in the config file 
REG_TYPES = config.get("reg_types", "Promoters,Enhancers,Primed_enhancers").split(',')

#if no chain file is provided the default is to download mafs from ensembl and generate chains
CHAIN = config.get("chain_files", LIFTOVER+"/lastZ/"+config["assemblies"][REF]+".v.{sps}_lastz_net.all.chain")


def parse_samplesheet(samplesheet, columns, target="Peak file", input_dir="", comment_tags=[],
                      ignore=None):

    """
    Parses a summary samplesheet.

    TODO: simplify input samplesheet format to simplify parsing here
    """

    i = 0
    res = {}
    with open(samplesheet, 'r') as infile:
        for line in infile:
            skip = False
            for comment_tag in comment_tags:
                if line.startswith(comment_tag):
                    skip = True
                    break
            if skip:
                continue

            line = line.strip().split('\t')
            if i == 0:
                header = line
                i = 1

            else:
                if len(line) >= len(header):
                    line = line[:len(header)]
                line_dict = dict(zip(header, line))
                tmp = parse_line(line_dict, columns, target, input_dir=input_dir, ignore=ignore)

                if tmp:
                    res.update(tmp)
    return res



def parse_line(line_dict, columns, target, input_dir='', ignore=None):

    """
    Parses a line of the input summary table

    FIXME: here I extract matching input control library from the peak file name, an ideal summary
    table file should contain a column for this.
    """

    entry = {}

    if target not in line_dict or not line_dict[target]:
        if line_dict[columns[-1]] and line_dict[columns[-2]]:
            entry = {line_dict[columns[-2]]: {columns[-1] : line_dict[columns[-1]]}}
            return entry
        return None

    if ignore:
        for key in ignore:
            for val in ignore[key]:
                if key in line_dict and line_dict[key]==val:
                    return None

    filename = input_dir + '/' + os.path.splitext(os.path.basename(line_dict[target]))[0]

    entry[filename] = {}

    for key in columns:
        if key in line_dict:
            entry[filename][key] = line_dict[key]
        else:
            return None

    #FIXME function was generalized before this part (see docstring)
    #This assumes that input lib id follows the lib id in the peak file name (sep by '_') as in nmrdmr data :/
    try:
        input_lib = filename.split(entry[filename][columns[-2]])[1].split('_')[1]
        entry[filename]["inputlib"] = input_lib
    except:
        print("Unable to load the id of the input lib, the read coverage module will likely crash.")
    return entry


### WORKFLOW TARGETS
samples = parse_samplesheet(SAMPLESHEET, COLUMNS, target=TARGET, input_dir=DATADIR, ignore=IGNORE_REP,
                            comment_tags=COMMENT_TAGS)

ti, sp, ma, lib, libsize = (SAMPLES_COL[i] for i in ["tissue", "species", "mark", "lib", "libsize"])
species = {samples[i][sp] for i in samples if sp in samples[i]}
tissues = {samples[i][ti] for i in samples if sp in samples[i]}
marks = {samples[i][ma] for i in samples if sp in samples[i]}
libraries = {samples[i][lib] for i in samples if sp in samples[i]}

wildcard_constraints:
    species='[A-Za-z_]+',
    sps='[A-Za-z]+_[A-Za-z]+',
    s='[A-Za-z_]+',
    tissue='[A-Za-z]+',
    mark='[A-Za-z_0-9]+',
    library='[A-Za-z0-9]+',
    reg_unit='[A-Za-z_]+',
    reg='[A-Za-z_]+'


PLOTS = expand(PLOTS_DIR + "/mappable_regulatory_elements/{reg_unit}/{tissue}/Intervene_upset.svg",
               reg_unit=REG_TYPES, tissue=tissues)

for i in ["mappable_regulatory_elements", "regmap"]:
    PLOTS += expand(PLOTS_DIR + "/"+i+"/barplot_nb_reg-{species}.svg", species=species)
    PLOTS += expand(PLOTS_DIR + "/"+i+"/boxplot_size_reg-{species}.svg", species=species)
    PLOTS += expand(PLOTS_DIR + "/"+i+"/tss_dist-{species}.svg", species=species)


i = "consensus_peaks"

PLOTS += expand(PLOTS_DIR + "/"+i+"/barplot_nb_reg-{species}-{mark}.svg", species=species,
                mark=marks)

PLOTS += expand(PLOTS_DIR + "/"+i+"/boxplot_size_reg-{species}-{mark}.svg", species=species,
                mark=marks)

PLOTS += expand(PLOTS_DIR + "/"+i+"/tss_dist-{species}-{mark}.svg", species=species,
                mark=marks)

PLOTS += expand(PLOTS_DIR + "/solve_reg_type/{tissue}_overlap_numbers.svg", tissue=tissues)

PLOTS += expand(PLOTS_DIR + "/solve_reg_type/{tissue}_overlap_solver.svg", tissue=tissues)

if len(tissues) == 2:
    PLOTS += expand(PLOTS_DIR + "/{sps}_before_mapping/tissue_overlap.svg", sps=species)
    PLOTS += expand(PLOTS_DIR + "/{sps}_after_mapping/tissue_overlap.svg", sps=species)


#Final output
rule Target:
    input: 
        PLOTS,
        expand(MAPPABLE_REG + "/{reg_unit}/{tissue}/{sps}.active.ok.bed", tissue=tissues, sps=species, reg_unit=REG_TYPES),
        expand(GREYLIST + '/stats/stats_{sps}_{tissue}.txt', sps=species, tissue=tissues),
        expand(READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_clustering.pdf', reg_unit=["Promoters", "Enhancers"], tissue=tissues),
        expand(READ_COVERAGE + '/tables_for_eve/qnorm_for_eve/{reg_unit}_{tissue}_fpkm_normalized.svg', reg_unit=["Promoters", "Enhancers"], tissue=tissues)


#include the modules
include: "module_consensus_peaks.smk"
include: "module_regulatory_elements.smk"
include: "module_liftover_to_ref.smk"
include: "module_solve_element_overlaps.smk"
include: "module_mappable_regions.smk"
include: "module_read_coverage.smk"
include: "module_quality_control_plots.smk"
include: "module_greylists.smk"