import argparse
from get_strictly_mappable import write_bed4
from check_rec_mapping import load_summary, get_overlap

def load_regions_in_summary_file(input_file):
    """
    Split bed summary 
    """
    res = {}
    with open(input_file, 'r') as infile:
        for i, line in enumerate(infile):

            if "UnMapped" in line:
                continue

            if i == 0:
                header = line.strip().split('\t')
                ref_sp = header[0].replace("_chr", "")
                species = ['_'.join(sp.split('_')[:-1]) for i, sp in enumerate(header[5:]) if i % 2 == 0]
            else:
                line = line.strip().split('\t')
                unique_id = line[3]
                res[ref_sp] = res.get(ref_sp, {})
                res[ref_sp][unique_id] = line[:3]

                for i, sp in enumerate(species):
                    res[sp] = res.get(sp, {})
                    res[sp][unique_id] = line[5+i*2].split()
    return res

def is_similar_size(interval1, interval2, cutoff):
    lg1 = interval1[1] - interval1[0]
    lg2 = interval2[1] - interval2[0]
    if max(lg1, lg2)*cutoff > min(lg1, lg2):
        return False
    else:
        return True

def check_regions(regions, sumreg_sp, setreg_ref, cutoff=0.95, extend=0, si=None):
    """
    Regions should have the same size in all species (0.95 means +/-5% of reference region size)
    Regions in non-reference species should overlap with regions in the summary file.

    extend (int) --> if regions is smaller than extend, regions are extended from both ends to match it
                --> 0 means no extension ofc
    """
    for unique_id in list(regions.keys()):
        c, start, end = regions[unique_id]
        prev_c, prev_st, prev_e = sumreg_sp[unique_id]

        if c != prev_c:
            regions.pop(unique_id)
            continue

        if not get_overlap([int(start), int(end)], [int(prev_st), int(prev_e)]):
            regions.pop(unique_id)
            continue

        if int(start) < int(prev_st):
            start = ori_st

        if int(end) > int(prev_e):
            end = ori_e

        _, ori_st, ori_e = setreg_ref[unique_id]

        if not is_similar_size([int(start), int(end)], [int(ori_st), int(ori_e)], cutoff):
            regions.pop(unique_id)
            continue

        if extend:
            lg = int(end) - int(start)
            if lg < extend:
                to_gain = extend - lg
                end = min(int(end) + int(round(to_gain/2)), si[c])
                start = max(int(start) - int(round(to_gain/2)), 0)
            regions[unique_id] = (c, str(start), str(end))


if __name__ == '__main__':
    
    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-ref', '--ref', type=str, required=True)

    PARSER.add_argument('-s', '--set_of_regions', type=str, required=True)

    PARSER.add_argument('-i', '--infiles', nargs='+', required=True)

    PARSER.add_argument('-ind', '--index_files', nargs='+', required=False)

    PARSER.add_argument('-l', '--labels', type=str, required=True)

    PARSER.add_argument('-o', "--outfiles", nargs='+', required=True)

    PARSER.add_argument('-c', "--cutoff", type=float, required=False, default=0.85)

    ARGS = vars(PARSER.parse_args())

    SP = ARGS["labels"].split(',')

    assert len(ARGS["infiles"]) == len(SP)
    assert len(ARGS["outfiles"]) == len(SP) + 1


    REFREG = load_regions_in_summary_file(ARGS["ref"])

    SETREG = load_summary(ARGS["set_of_regions"])

    REF_SP = [i for i in REFREG if i not in SP][0]

    # EXTEND = 1000
    EXTEND = 0
    chrom_sizes = {}
    if EXTEND:
        for i, sp in enumerate(SP):
            with open(ARGS["index_files"][i], 'r') as indfile:
                chrom_sizes[sp] = {line.strip().split('\t')[0]:int(line.strip().split('\t')[1]) for line in indfile}
        with open(ARGS["index_files"][i+1], 'r') as indfile:
            chrom_sizes[REF_SP] = {line.strip().split('\t')[0]:int(line.strip().split('\t')[1]) for line in indfile}

    MYREG = {}
    REG_TO_KEEP = set()
    for i, infile in enumerate(ARGS["infiles"]):

        MYREG[SP[i]] = load_summary(infile, select_biggest=True)
        check_regions(MYREG[SP[i]], REFREG[SP[i]], SETREG, cutoff=ARGS["cutoff"], extend=EXTEND, si=chrom_sizes.get(SP[i], {})) #FI checked that 0.80 would keep 90% --> most are of similar size (previous regions would keep 30% of prom and 50% of others)

        if not REG_TO_KEEP:
            REG_TO_KEEP = set(MYREG[SP[i]].keys())
        else:
            REG_TO_KEEP = REG_TO_KEEP.intersection(set(MYREG[SP[i]].keys()))

    print('****', ARGS["set_of_regions"], len(REFREG[REF_SP]), len(REG_TO_KEEP))

    for i, outfile in enumerate(ARGS["outfiles"]):

        if i == 0:
            with open(ARGS["set_of_regions"], 'r') as infile, open(outfile, 'w') as out:
                for line in infile:
                    line_split = line.strip().split('\t')

                    if line_split[-1] in REG_TO_KEEP:

                        if not EXTEND:

                            out.write(line)

                        else:

                            c, start, end = line_split[:3]
                            lg = int(end) - int(start)

                            if lg < EXTEND:
                                to_gain = EXTEND - lg
                                end = min(int(end) + int(round(to_gain/2)), chrom_sizes[REF_SP][c])
                                start = max(0, int(start) - int(round(to_gain/2)))

                            line_split[:3] = [c, str(start), str(end)]
                            out.write('\t'.join(line_split)+'\n')
        else:
            write_bed4(MYREG[SP[i-1]], outfile, subset=REG_TO_KEEP)