import glob


def get_length(chromsfile, chrom):
    with open(chromsfile, 'r') as infile:
        for line in infile:
            line = line.strip().split('\t')
            if line[0] == chrom:
                lg = int(line[1])
                break
    return lg

def get_all_inputs_for_sp(s, t):
    a = {samples[i]["inputlib"] for i in samples if "inputlib" in samples[i] and sp in samples[i] and samples[i][sp] == s and samples[i][ti] == t} #and samples[i][ma].lower()=="h3k27ac"
    return expand(GREYLIST + "/chipseq-greylist/{library}.sorted-grey.bed", library=a)

def get_all_stats(s, t):
    a = {samples[i]["inputlib"] for i in samples if "inputlib" in samples[i] and sp in samples[i] and samples[i][sp] == s and samples[i][ti] == t} #and samples[i][ma].lower()=="h3k27ac"
    return expand(GREYLIST + "/stats/stats_{library}.txt", library=a)

def get_all_inputs_for_sp_except(lib):
    s = [samples[i][sp] for i in samples if "inputlib" in samples[i] and samples[i]["inputlib"] == lib][0]
    t = [samples[i][ti] for i in samples if "inputlib" in samples[i] and samples[i]["inputlib"] == lib][0]
    a = {samples[i]["inputlib"] for i in samples if "inputlib" in samples[i] and samples[i]["inputlib"] != lib and sp in samples[i] and samples[i][sp] == s and samples[i][ti] == t} #and samples[i][ma].lower()=="h3k27ac"
    return expand(GREYLIST + "/chipseq-greylist/{library}.sorted-grey.bed", library=a)


rule greylist_from_each_replicate:
    """
    Run chipseq-greylist on each input lib.
    """
    input: READ_COVERAGE + "/bams/{library}.sorted.bam"
    output: GREYLIST + "/chipseq-greylist/{library}.sorted-grey.bed", GREYLIST + "/chipseq-greylist/{library}.sorted-greystats.csv", GREYLIST + "/chipseq-greylist/{library}.sorted-greydepth.tsv"
    params: outdir = GREYLIST + "/chipseq-greylist"
    log: GREYLIST + "/chipseq-greylist/{library}_chipseq-greylist.log"
    conda: "envs/greylist.yaml"
    shell: "chipseq-greylist {input} --outdir {params.outdir} --bootstraps 500 &> {log}"


rule greylist_per_tissue_per_sp:
    """
    Extract a list of regions greylisted in >= 1 input lib. by chipseq-greylist, for each species.
    """
    input: bed = lambda wildcards: get_all_inputs_for_sp(wildcards.species, wildcards.tissue)
    output: GREYLIST + '/results/grey_{species}_{tissue}.bed'
    conda: "envs/bedtools_py.yaml"
    shell: "cat {input.bed} | bedtools sort | bedtools merge > {output}"


rule greylist_stats:
    """
    Compute stats on the number of greylisted regions per lib.
    """
    input:
        bed = lambda wildcards: get_all_inputs_for_sp_except(wildcards.library),
        ref = GREYLIST + "/chipseq-greylist/{library}.sorted-grey.bed"
    output: temp(GREYLIST + '/stats/stats_{library}.txt')
    conda: "envs/bedtools_py.yaml"
    shell:
        "bedtools intersect -v -a {input.ref} -b {input.bed} | wc | tr -s ' ' | cut -f 2 -d ' ' > {output} && "
        "wc {input.ref} | tr -s ' ' | cut -f 2 -d ' ' >> {output}"


rule greylist_stats_all:
    """
    Summarize greylisted regions per species:
        for each input lib., print the nb of unique (i.e only greylited in this lib) and total greylisted regions.
    """
    input: stats = lambda wildcards: get_all_stats(wildcards.species, wildcards.tissue)
    output: out = GREYLIST + '/stats/stats_{species}_{tissue}.txt'
    run:
        with open(output.out, 'w') as out:
            for input_file in input.stats:
                with open(input_file, 'r') as infile:
                    out.write(f"input_library\tunique_greylisted_regions\tall_greylisted_regions\n")
                    for i, line in enumerate(infile):
                        if i == 0:
                            uniq = line.strip()
                        if i == 1:
                            tot = line.strip()
                        elif i > 2:
                            sys.exit(100)
                    out.write(f"{input_file}\t{uniq}\t{tot}\t{float(uniq)/float(tot)}\n")
