import argparse
from filter_by_coverage import load_bed4

def get_overlap(a, b):
    return max(0, min(a[1], b[1]) - max(a[0], b[0]))

def load_summary(input_file, select_biggest=False):
    res = {}
    with open(input_file, 'r') as f:
        for line in f:
            line_split = line.strip().split('\t')
            c, start, end, unique_id = line_split[:4]
            if not select_biggest or unique_id not in res:
                res[unique_id] = (c, start, end)
            else:
                lg = int(end) - int(start)
                _, st, e = res[unique_id] 
                lg_prev = int(e) - int(st)
                if lg > lg_prev:
                    res[unique_id] = (c, start, end)
    return res


def filter_reciprocal(lifted, ori):#on doit verifier que c'est inclu et borner
    res = {}
    for unique_id in lifted:
        ori_c, ori_st, ori_e = ori[unique_id]
        for reg in lifted[unique_id]:
            _, c, start, end = reg

            if c != ori_c:
                continue

            if not get_overlap([int(start), int(end)], [int(ori_st), int(ori_e)]):
                continue

            if int(start) < int(ori_st):
                start = ori_st

            if int(end) > int(ori_e):
                end = ori_e


            res[unique_id] = res.get(unique_id, [])
            res[unique_id].append((c, start, end, unique_id))



    return res

def write_bed(bed_dict, output_file):
    with open(output_file, 'w') as out:
        for unique_id in bed_dict:
            for reg in bed_dict[unique_id]:
                out.write('\t'.join(reg)+'\n')

if __name__ == '__main__':
    
    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-i', '--input', type=str, required=True)

    PARSER.add_argument('-l', '--lifted', type=str, required=True)

    PARSER.add_argument('-o', "--output", type=str, required=True)

    PARSER.add_argument("--no_multi", action='store_false')

    ARGS = vars(PARSER.parse_args())


    ORI = load_summary(ARGS["input"])
    LIFTED = load_bed4(ARGS["lifted"], multiple=ARGS["no_multi"])

    MYREG = filter_reciprocal(LIFTED, ORI)

    write_bed(MYREG, ARGS["output"])