import sys

import argparse

from collections import OrderedDict

def load_chr(input_file):
    d = {}
    with open(input_file, 'r') as infile:
        for line in infile:
            chrom, size = line.strip().split("\t")
            d[chrom] = int(size)
    return d


def load_genes(input_file):
    d = OrderedDict()
    with open(input_file, 'r') as infile:
        for line in infile:
            chrom, tss, _, gene, strand = line.strip().split("\t")
            tss = int(tss)
            d[gene] = chrom, tss, strand
    return d


def basal_domains(genes, chr_sizes):
    domains = {}
    for gene in genes:
        chrom, tss, strand = genes[gene]
        assert strand in ["1", "-1"]
        if strand == '1':
            domain_start = max(tss-5000, 0) #5 kb upstream
            domain_end = min(tss+1000, chr_sizes[chrom]) #1kb upstream
        else:
            domain_end = min(tss+5000, chr_sizes[chrom])
            domain_start = max(tss-1000, 0)

        domains[chrom] = domains.get(chrom, OrderedDict())
        domains[chrom][(domain_start, domain_end)] = domains[chrom].get((domain_start, domain_end), [])
        domains[chrom][(domain_start, domain_end)].append(gene)
    return domains


def basal_extension(domains, chr_sizes):
    extended_domains = {}
    for chrom in domains.keys():
        extended_domains[chrom] = {}
        intervals = list(domains[chrom].keys())
        for i, interval in enumerate(intervals):
            start, end = interval

            # print(chrom, interval)

            next_i = i + 1
            prev_i = i - 1

            if prev_i > 0:
                prev_end = intervals[prev_i][1]
                if prev_end < start:
                    start = max(prev_end, start - 1000000) #max extension 1Mb

            else:
                start = max(start - 1000000, 0) #max extension 1Mb

            if next_i < len(intervals):
                next_start = intervals[next_i][0]
                if next_start > end:
                    end = min(end + 1000000, next_start) #max extension 1Mb

            else:
                end = min(end + 1000000, chr_sizes[chrom]) #max extension 1Mb
            genes = domains[chrom][interval]
            extended_domains[chrom][(start, end)] = extended_domains.get((start, end), [])
            # print(start, end)
            for gene in genes:
                extended_domains[chrom][(start, end)].append(gene)
    return extended_domains

            

def write_out(domains, output_file):
    with open(output_file, 'w') as out:
        for chrom in domains:
            for dom in domains[chrom]:
                for gene in domains[chrom][dom]:
                    out.write('\t'.join([chrom, str(dom[0]), str(dom[1]), gene])+'\n')

if __name__ == '__main__':
    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('--genes', type=str, required=False, default="tss_longest_coding_transcript_hgla_sorted.bed",
                        help="5-columns .bed file with genes TSS (chr, tss, tss+1, gene, strand), WARNING:: the code assumes the bed is sorted !!!")

    PARSER.add_argument('--chr', type=str, required=False, default="../../eve_analysis/repeats/data/Hgla_chrSizes.txt", help="file with 2 col: chr_name and chr_size")

    PARSER.add_argument('--output', type=str, required=True)

    ARGS = vars(PARSER.parse_args())


    sys.stderr.write("Assigning regulatory domains to genes, using GREAT default association: basal 5kb upstream, 1kb downstream + extension.\n")
    sys.stderr.write("WARNING: the script assumes the input bed with tss coord is sorted, please run bedtools sort first. (Ignore the warning if you did).\n")

    CHR_SIZE = load_chr(ARGS["chr"])
    GENES = load_genes(ARGS["genes"])

    BASAL = basal_domains(GENES, CHR_SIZE)

    REG_DOMAINS = basal_extension(BASAL, CHR_SIZE)

    write_out(REG_DOMAINS, ARGS["output"])

    sys.stderr.write(f"DONE! output saved in {ARGS['output']}.\n")


