#/usr/bin/env/python

"""
Filters liftover multiple match .bed output to keep single-match above a given coverage threshold.

Example
"""

import argparse

def load_bed4(input_file, multiple=False, lg_only=False):
    res = {}
    with open(input_file, 'r') as f:
        for line in f:
            line_split = line.strip().split('\t')
            if len(line_split) == 1:
                line_split = line.strip().split()
            line = line_split
            assert (len(line) == 4 and not multiple) or (len(line) == 5 and multiple),\
                    "{} is not in bed4 nor liftover multiple match format".format(input_file)

            if not multiple:
                c, start, end, unique_id = line
                lg = int(end) - int(start)
                if lg_only:
                    res[unique_id] = lg
                else:
                    res[unique_id] = [(lg, c, start, end)]

            else:
                c, start, end, unique_id, _ = line
                lg = int(end) - int(start)
                res[unique_id] = res.get(unique_id, [])
                res[unique_id].append((lg, c, start, end))

    return res

def filter_matches(query, mapped, threshold=0.5):
    res = {}
    for unique_id in query:
        if unique_id in mapped:
            best_hit = max(mapped[unique_id], key=lambda x: x[0])
            lg, c, start, end = best_hit
            if lg >= query[unique_id]*threshold:
                res[unique_id] = (c, start, end, unique_id)
    return res

def write_bed(bed_dict, output_file):
    with open(output_file, 'w') as out:
        for unique_id in bed_dict:
            out.write('\t'.join(bed_dict[unique_id])+'\n')

if __name__ == '__main__':
    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-q', '--query', type=str, required=True)

    PARSER.add_argument('-m', '--mapped', type=str, required=True)

    PARSER.add_argument('-o', "--output", type=str, required=True)

    PARSER.add_argument('-c', "--cutoff", type=float, default=0.50)

    PARSER.add_argument("--no_multi", action='store_false')

    ARGS = vars(PARSER.parse_args())

    query = load_bed4(ARGS["query"], lg_only=True)
    mapped = load_bed4(ARGS["mapped"], multiple=ARGS["no_multi"])

    single_match = filter_matches(query, mapped, ARGS["cutoff"])

    write_bed(single_match, ARGS["output"])
