#/usr/bin/env/python

"""
    Script producing customized barplots plots for the crossregmap workflow.

    Example:
        $ python -m scripts.plot_overlaps_solving.py -i Heart_vote_results_overlaps.bed [-o out.svg]
"""

import argparse

import seaborn as sns

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

import pandas as pd

sns.set_palette("muted")

def load_input(input_file):

    """
    Loads overlap vote summary file into a pandas dataframe.

    Args:
        input_file (str) : path to the input file.

    Returns:
        pd.Dataframe : three columns tables with sets involved in overlap, winner and decsion rule.
    """

    over = []
    with open(input_file, 'r') as infile:
        for line in infile:
            line = line.strip().split('\t')
            winner = line[3]
            rule = line[-1]
            sets = [winner] + [line[5]]
            if len(line) > 8:
                sets += [line[7]]
            sets = sorted(sets)
            if "Promoters" in sets:
                sets = ["Promoters"] + sets[:-1]
            sets = ','.join(sets)
            over.append((sets, winner, rule))
    return pd.DataFrame.from_records(over, columns=["sets", "winner", "Decision rule"])


def mybarplot(df, output):

    """
    Plots overlap resolution as barplots.

    Args:
        df (pd.Dataframe) : overlap summary table with sets, winner and Decsion rule
        output (str) : path to output figure
    """

    order = sorted(df['sets'].unique())
    tmp2 = order[-1]
    tmp1 = order[-2]
    order[-1], order[-2] = tmp1, tmp2

    plot = sns.catplot(x="sets", hue="winner", col="Decision rule", data=df, kind="count",
                       col_wrap=2, order=order, legend_out=False,
                       col_order=sorted(df["Decision rule"].unique()))

    for i in range(len(df['Decision rule'].unique())):
        ax = plot.facet_axis(0, i)
        for patch in ax.patches:
            #I guess 'try and except' was there for 0-sized bars ??? --> yep
            try:
                ax.annotate(int(patch.get_height()), (patch.get_x() + patch.get_width() / 2.,
                            patch.get_height()), ha='center', va='center', fontsize=8, color='gray',
                            xytext=(0, 5), textcoords='offset points')
            except ValueError:
                pass

    sns.despine()

    for ax in plot.axes.flat:
        for label in ax.get_xticklabels():
            label.set_rotation(45)
            label.set_ha('right')

    plt.tight_layout()
    plt.savefig(output)
    plt.close("all")


if __name__ == '__main__':

    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-i', '--input', type=str, required=True)

    PARSER.add_argument('-o', "--output", type=str, required=False, default="out.svg")

    ARGS = vars(PARSER.parse_args())

    DATAFRAME = load_input(ARGS["input"])

    mybarplot(DATAFRAME, ARGS["output"])
