#/usr/bin/env/python

"""
Script producing customized quality control plots for the crossregmap workflow.
"""

import argparse
import os

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy as sp

sns.set_palette("muted")
sns.set(style="ticks", font_scale=0.6)

def load_bed(input_file, peaks=False):

    """
    Loads a bed file in a python dict.
    """

    print(input_file)
    d = {}
    with open(input_file, "r") as infile:
        for i, line in enumerate(infile):
            line_split = line.strip().split('\t')
            if len(line_split) == 1:
                line_split = line.strip().split()
            line = line_split

            if not peaks:
                ch, start, stop, unique_id, score = line
                d[unique_id] = {'chr': ch, 'start': int(start), 'stop': int(stop),
                            'length': int(stop) - int(start), "coverage": float(score)}
            else:
                unique_id = line[-1]
                d[unique_id] = {'log2 peak intensity': np.log2(float(line[4])+0.1), 'log2 coverage (FPKM)': np.log2(float(line[-2])+0.1)}
    return d


def boxplot_coverage(dataset, fliers=True, xlab='', ylab='', out=None, title='', gb=None):
    
    def load_df(data, name):
        d = load_bed(data)
        rec = [(i, d[i]['coverage']) for i in d]
        return pd.DataFrame.from_records(rec, columns=["region_id", name])

    cl_dict = {}
    for _, name in dataset:
        if gb:
            for val in gb[key]:
                if val in name:
                    cl_dict[val] = cl_dict.get(val, [])
                    if name not in cl_dict[val]:
                        cl_dict[val].append(name)
                    break

    df = load_df(dataset[0][0], dataset[0][1])
    
    for (data, name) in dataset[1:]:
        df_tmp = load_df(data, name)
        df = df.merge(df_tmp, on='region_id', how='left')

    df.set_index("region_id", inplace = True)

    if gb:
        ORDER = []
        COLORS = []
        PALETTE = sns.color_palette("muted").as_hex()
        for i, val in enumerate(sorted(list(cl_dict.keys()))):
            ORDER += cl_dict[val]
            COLORS += [PALETTE[i] for _ in cl_dict[val]]

        if "mapped" not in list(gb.values())[0]:
            sns.boxplot(data=df, showfliers=fliers, order=ORDER, palette=COLORS)
        else:
            sns.boxplot(data=df, showfliers=fliers, palette=COLORS)

   
    else:
        sns.boxplot(data=df, showfliers=fliers)
    plt.xticks(rotation=45,  horizontalalignment="right")
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    plt.title(title)
    sns.despine()
    plt.tight_layout()

    if out:
        plt.savefig(out)
    else:
        plt.show()

    plt.close('all')


def scatterplot_coverage(dataset, xlab='', ylab='', out=None, title='', peaks=False):


    rec = []
    x = 'coverage'
    y = 'length'
    wrap = 4

    if peaks:
        x = 'log2 peak intensity'
        y = 'log2 coverage (FPKM)'
        wrap = 6

    for (data, name) in dataset:
        d = load_bed(data, peaks)
        rec += [(i, name, d[i][x], d[i][y]) for i in d]

    df = pd.DataFrame.from_records(rec, columns=["region_id", "library", x, y])

    df.set_index("region_id", inplace = True)

    sns.lmplot(data=df, x=x, y=y, hue="library", fit_reg=True, col="library", scatter_kws={"s": 5, "alpha":0.3},  col_wrap=wrap, height=3).fig.suptitle(title)
    for lib in df["library"].unique():
        small_df = df[ df["library"] == lib]
        print(len(small_df), lib, sp.stats.pearsonr(small_df[x], small_df[y]))
    sns.despine()
    plt.tight_layout()

    if out:
        plt.savefig(out)
    else:
        plt.show()

    plt.close('all')


if __name__ == '__main__':

    PARSER = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    PARSER.add_argument('-i', '--input', nargs='+', required=True)

    PARSER.add_argument('-l', '--labels', nargs='+', required=False)

    PARSER.add_argument('-plot_types', "--plot_types", nargs='+', required=False, default=["scatter", "box"])

    PARSER.add_argument('-o', "--output", nargs='+', required=False, default="")

    PARSER.add_argument('-t', '--title', type=str, required=False, default='')

    PARSER.add_argument('-gb', '--groupby', type=str, required=False, default='')

    PARSER.add_argument('-y', '--ylab', type=str, required=False, default='Coverage (RPKM)')
    
    PARSER.add_argument('--peaks', action='store_true')

    PARSER.add_argument('--verbose', action='store_true')

    ARGS = vars(PARSER.parse_args())

    if not ARGS["output"]:
        ARGS["output"] = ARGS["plot_types"]

    if not ARGS["labels"]:
        ARGS["labels"] = [os.path.splitext(os.path.basename(i))[0] for i in ARGS["input"]]

    assert len(ARGS["labels"])==len(ARGS["input"]), "Error : different numbers of inputs and labels"
    
    DATA = list(zip(ARGS["input"], ARGS["labels"]))
    TO_PLOT = {plot: {"out": ARGS["output"][i], "title": ARGS["title"]} for (i, plot) in enumerate(ARGS["plot_types"])}

    gb=""
    if ARGS["groupby"]:
        gb = {}
        key, values = ARGS["groupby"].split("=")
        gb[key] = values.split(',')

    if "box" in TO_PLOT:
        boxplot_coverage(DATA, fliers=False, ylab=ARGS["ylab"], out=TO_PLOT["box"]["out"], title=TO_PLOT["box"]["title"])

    if "scatter" in TO_PLOT:
        scatterplot_coverage(DATA, out=TO_PLOT["scatter"]["out"], title=TO_PLOT["scatter"]["title"], peaks=ARGS["peaks"])


    plt.close("all")