"""
Make the per-species heatmap of element overlap across tissues.
First, you should run the intervene.smk snakemake to obtain overlap values.
"""
import sys

import argparse

import matplotlib as mpl
mpl.use('Agg')

import matplotlib.pyplot as plt

import seaborn as sns
import pandas as pd
import numpy as np

sns.set(style="whitegrid", font_scale=1, font='Arial')

PARSER = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)

#Required
PARSER.add_argument('-i', '--input', required=True)

PARSER.add_argument('-r', '--reg_labels', required=True)

PARSER.add_argument('-t', '--t_labels', required=True)

PARSER.add_argument('-o', '--output', required=False, default='', help='prefix for output files')

ARGS = vars(PARSER.parse_args())

REG = ARGS["reg_labels"].split(',')
TISS = ARGS["t_labels"].split(',')

def get_index(tissue1, tissue2, set1, set2):

    if TISS.index(tissue1) == 0:
        return REG.index(set1), REG.index(set2)

    else:
        return REG.index(set2), REG.index(set1)


def load_intervene_overlaps(input_file):

    doverlap = {}
    dnorm = {}

    index = {}

    mat = np.zeros((2, 2))

    with open(input_file, 'r') as infile:

        for k, line in enumerate(infile):

            if k != 2:
                continue

            sets = line.strip().split(',')

            for s in sets:
                s = s.split('&')
                if len(s) == 2:
                    set1, set2 = s
                    set2, number = set2.split('=')

                    number = int(number)

                    if number != 0:
                        tissue1 = set1.split('_')[-1]
                        tissue2 = set2.split('_')[-1]

                        if tissue2 == tissue1:
                            "Elements shouldn't overlap between sets of the same tissue (see the nmrdmr pipeline)."
                            sys.exit(1)

                        set1 = ' '.join(set1.split('_')[:-1])
                        set2 = ' '.join(set2.split('_')[:-1])


                        (i, j) = get_index(tissue1, tissue2, set1, set2)
                        mat[(i, j)] = number

                for x in s:
                    _, number = s[-1].split("=")
                    if "=" in x:
                        x = x.split("=")[0]
                    key = ' '.join(x.split("_"))
                    dnorm[key] = dnorm.get(key, 0) 
                    dnorm[key] += int(number)

    df = pd.DataFrame(mat, columns=REG, index=REG)
    return df, dnorm, mat



def norm_df(mat, dnorm):

    new_mat = np.zeros((2, 2))
    annot_mat = np.empty([2, 2], dtype='<U20')
    for i in range(2):
        for j in range(2):
            reg1 = REG[i]
            reg2 = REG[j]
            norm = min(dnorm[reg1+' '+TISS[0]], dnorm[reg2+' '+TISS[1]])
            new_mat[i, j] = mat[i, j]/norm
            annot_mat[i, j] = f"{str(mat[i, j])[:-2]}\n({round(mat[i, j]/norm, 2)})"

    new_d = pd.DataFrame(new_mat, columns=REG, index=REG)
    return new_d, annot_mat

    
INFILE = ARGS["input"]
OUTFIG = ARGS["output"]

df, dnorm, mat = load_intervene_overlaps(INFILE)
normed_df, annot = norm_df(mat, dnorm)

g = sns.heatmap(normed_df, annot=annot, annot_kws={"size": 14}, square=True, cmap="Greys", fmt='', vmin=0, vmax=1)
g.set_yticklabels(g.get_yticklabels(), rotation = 0)
plt.tick_params(axis='both', which='major', labelsize=14, labelbottom = False, bottom=False, top = False, labeltop=True)
cbar = g.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
plt.ylabel(TISS[1])
plt.xlabel(TISS[0])
plt.tight_layout()
plt.savefig(OUTFIG, dpi=300)
# plt.show()
plt.close("all")