#/usr/bin/env/python

"""
TODO script to polish

"""
import sys
import argparse


def load_regions(input_file):
    res = set()
    with open(input_file, 'r') as infile:
        for line in infile:
            line = line.strip().split('\t')
            reg = '\t'.join(line[:3])
            res.add(reg)
    return res

#load reassigned prom, enh, primed sets regions with overlap as dict
def load_reassigned(input_file, reg_element, actives):
    """
    That's ugly now
    """
    to_add = set()
    to_remove = set()
    with open(input_file, 'r') as infile:
        for line in infile:
            line = line.strip().split('\t')
            ch, st_full, end_full, winner, regs = line[:5]
            regs = regs.split('_')
            for reg in regs:
                st, end = reg.split('-')
                reg = '\t'.join((ch, st, end))

                #if winner is the evaluated reg_element, we'll add it to the region set and pull active regions from other sets
                if winner == reg_element:

                    tmp_regs = line[6].split('_')
                    loser = line[5]
                    for tmp_reg in tmp_regs:
                        st, end = tmp_reg.split('-')
                        if not actives or '\t'.join((ch, st, end)) in actives[loser]:
                            to_add.add('\t'.join((ch, st, end)))
                            break

                    if len(line) > 8:
                        loser = line[7]
                        tmp_regs = line[8].split('_')
                        for tmp_reg in tmp_regs:
                            st, end = tmp_reg.split('-')
                            if not actives or '\t'.join((ch, st, end)) in actives[loser]:
                                to_add.add('\t'.join((ch, st, end)))
                                break

                elif not actives or reg in actives[winner]:

                    if reg_element == line[5]:
                        tmp_regs = line[6].split('_')
                        for tmp_reg in tmp_regs:
                            st, end = tmp_reg.split('-')
                            to_remove.add('\t'.join((ch, st, end)))

                    elif len(line) > 8 and reg_element == line[7]:
                        tmp_regs = line[8].split('_')
                        for tmp_reg in tmp_regs:
                            st, end = tmp_reg.split('-')
                            to_remove.add('\t'.join((ch, st, end)))

    return to_add, to_remove

def update_regions(regions, to_add, to_remove, out_target):

    handle = open(out_target, 'w') if out_target else sys.stdout

    reassigned = to_remove.union(to_add)

    for reg in regions:
        if reg not in reassigned:
            handle.write(reg+'\n')

    for reg in to_add:
        handle.write(reg+'\n')

    if handle is not sys.stdout:
        handle.close()


if __name__ == '__main__':

    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-i', '--infiles', nargs='+', required=True)
    PARSER.add_argument('-v', '--vote_results', required=True)
    PARSER.add_argument('-o', '--output', required=False, default=None)
    PARSER.add_argument('-e', '--reg_element', required=True)
    PARSER.add_argument('-l', '--labels', nargs='+', required=False)
    PARSER.add_argument('--regions_only', action='store_true', required=False, help='only report regions set, not activity in species')
    ARGS = vars(PARSER.parse_args())


    if ARGS["regions_only"]:
        assert len(ARGS["infiles"]) == 1, 'wrong arguments buddy'
        ORI = load_regions(ARGS["infiles"][0])
        TO_ADD, TO_REMOVE = load_reassigned(ARGS["vote_results"], ARGS["reg_element"], None)
        update_regions(ORI, TO_ADD, TO_REMOVE, ARGS["output"])

    else:
        INPUT_REG = {}
        assert len(ARGS["infiles"]) == len(ARGS["labels"]), 'wrong arguments buddy'
        for INPUT, LABEL in zip(ARGS["infiles"], ARGS["labels"]):
            INPUT_REG[LABEL] = load_regions(INPUT)
        TO_ADD, TO_REMOVE = load_reassigned(ARGS["vote_results"], ARGS["reg_element"], INPUT_REG)

        update_regions(INPUT_REG[ARGS["reg_element"]], TO_ADD, TO_REMOVE, ARGS["output"])
