#/usr/bin/env/python

"""
TODO script to polish

"""

import argparse
from itertools import permutations


# global var defining rules to apply in case of ties
TIE_BREAKS = {"Promoters,Enhancers":"Promoters", "Promoters,Primed_enhancers":"Promoters",
              "Promoters,Enhancers,Primed_enhancers": "Promoters",
              "Enhancers,Primed_enhancers": "Enhancers"}


def load_active(input_file):
    res = set()
    with open(input_file, 'r') as infile:
        for line in infile:
            line = line.strip().split('\t')
            reg = '\t'.join(line[:3])
            res.add(reg)
    return res


#load prom, enh, primed sets regions with overlap as dict
def map_overlaps_active(actives, label, over_reg):
    res = {}
    multiple_overlaps = {}
    keys = set(over_reg.keys())
    for species in actives:
        for reg in actives[species]:
            if reg in keys:
                for coord_in_over in over_reg[reg]:
                    if coord_in_over not in res:
                        res[coord_in_over] = [set([species]), reg]
                    else:
                        res[coord_in_over][0].add(species)
                        res[coord_in_over].append(reg)


                # if coord_in_over == '1\t21076342\t21076716':
                #     print(over_reg[reg])

                if len(over_reg[reg]) > 1:
                    multiple_overlaps[coord_in_over] = len(over_reg[reg])

    return res, multiple_overlaps


def load_over_reg(over_file):
    """
    """
    over_reg = {}
    with open(over_file, 'r') as infile:
        for line in infile:
            ch, st, end, _, _, st_o, end_o, _, = line.strip().split('\t')
            coord = '\t'.join((ch, st, end))
            if int(st_o) != -1:
                over_reg[coord] = over_reg.get(coord, [])
                coord_in_over = '\t'.join((ch, st_o, end_o))
                over_reg[coord].append(coord_in_over)

    return over_reg


def load_overlaps(input_file):

    """
    """
    over = {}
    with open(input_file, 'r') as infile:
        for line in infile:
            line  = line.strip().split('\t')
            ch, st, end, sets = line
            sets = set(sets.split(',')) #parse result of merge -o collapse
            sets = (",").join(sets)
            coord = '\t'.join((ch, st, end))
            over[coord] = sets
    return over


def decision(score_dict, decision_type, deuce):
    """
    """
    max_score = max(score_dict.values())
    keys = {k for k in score_dict if score_dict[k] == max_score}

    winner = ''
    if len(keys) == 1:
        rule = decision_type
        (winner,) = keys

    else:
        for comb in list(permutations(keys)):
            if ','.join(list(comb)) in TIE_BREAKS:
                rule = deuce
                winner = TIE_BREAKS[','.join(list(comb))]
                break

    return winner, rule

def majority_vote(delements, doverlap, dmultiple, output_file):

    """
    tmp_multi useless now, we can just keep something for the stats.
    """

    with open(output_file, 'w') as out:
        for coord in doverlap:
            tmp_count = {}
            tmp_multi = {}
            for elem in doverlap[coord].split(','):
                count = len(delements[elem][coord][0])
                tmp_count[elem] = count
                if elem in dmultiple and coord in dmultiple[elem]:
                    tmp_multi[elem] = dmultiple[elem][coord]
            winner, rule = decision(tmp_count, "vote (majority)", "vote_tie (higher_level)")


            #overlap region (ch, start, end) winner reg element (type start-end_start-end) other reg elements type start-end_start-end)
            winner_regs = ['-'.join(i.split('\t')[1:3]) for i in delements[winner][coord][1:]]
            winner_coord = '_'.join(winner_regs)
            to_write = f"{coord}\t{winner}\t{winner_coord}"
            for elem in doverlap[coord].split(','):
                if elem != winner:
                    tmp_regs = ['-'.join(i.split('\t')[1:3]) for i in delements[elem][coord][1:]]
                    tmp_coord = '_'.join(tmp_regs)
                    to_write += f"\t{elem}\t{tmp_coord}"
            out.write(to_write+'\t'+rule+'\n')


if __name__ == '__main__':

    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-i', '--infiles', nargs='+')
    PARSER.add_argument('-oi', '--overlap_infiles', nargs='+')
    PARSER.add_argument('-over', '--overlaps')
    PARSER.add_argument('-o', '--output')
    PARSER.add_argument('-names', '--labels', nargs='+')
    PARSER.add_argument('-sp', '--species', nargs='+')

    ARGS = vars(PARSER.parse_args())

    OVER = load_overlaps(ARGS["overlaps"])

    i = 0
    ACTIVES = {}
    for LAB in ARGS["labels"]: #careful thre this requires input in correct order...
        for SP in ARGS["species"]:
            ACTIVES[LAB] = ACTIVES.get(LAB, {})
            ACTIVES[LAB][SP] = load_active(ARGS['infiles'][i])
            i+=1

    ELEM, MULTI = {}, {}
    for INFILE, LABEL in zip(ARGS["overlap_infiles"], ARGS["labels"]):
        TMP = load_over_reg(INFILE)
        ELEM[LABEL], MULTI[LABEL] = map_overlaps_active(ACTIVES[LABEL], LABEL, TMP)

    majority_vote(ELEM, OVER, MULTI, ARGS["output"])
