"""
Script to format the read coverage data for normalization and EVE analysis.
Also plots distributions.
"""
import numpy as np

import pandas as pd
import seaborn as sns

import matplotlib as mpl
mpl.use('Agg')

import matplotlib.pyplot as plt
import os

import argparse

PARSER = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)

#Required
PARSER.add_argument('-s', '--signal', required=True, nargs='+')

PARSER.add_argument('-c', '--control', required=True, nargs='+')

PARSER.add_argument('-p', '--prefix', required=False, default='', help='prefix for output files')

ARGS = vars(PARSER.parse_args())

#Load bed files with pandas
def load_normed3(input_file, sp, sp_rep, norm_len=False, invert=False):

    rec = []
    with open(input_file, 'r') as infile:
        for i, line in enumerate(infile):
            line = line.strip().split('\t')
            if norm_len:
                line[-1] = float(line[-1])/(float(line[2]) - float(line[1]))
            tmp = (line[3], line[4])
            if invert:
                tmp = (line[4], line[3])
            rec.append(tmp)

    df = pd.DataFrame.from_records(rec, columns=[sp_rep.lower(), "unique_id"])
    return df


LABELS = ["signal", "input"]

for j, inputs in enumerate([ARGS["signal"], ARGS["control"]]):
    col_to_plot = []
    master_file = pd.DataFrame()
    for myfile in sorted(inputs):

        sp_rep = os.path.basename(myfile).split('.')

        sp = sp_rep[0]

        if j == 0:
            sp_rep = '_'.join(sp_rep[:2])
        else:
            sp_rep = sp_rep[0] + '_' + sp_rep[2].split('_')[-1]

        col_to_plot.append(sp_rep.lower())
        tmp_df = load_normed3(myfile, sp, sp_rep)
        if not master_file.empty:
            master_file = pd.merge(master_file, tmp_df, on='unique_id', how='inner')
            master_file.drop_duplicates(inplace=True, subset=['unique_id'], keep=False) #check that no duplicates
            master_file.dropna(inplace=True)
        else:
            master_file = tmp_df

    d = {}
    for i in master_file.columns:
        if i in col_to_plot:
            d[i] = i[0].upper() + i.split('_')[1][:3] + '_' + i.split('_')[-1]
    master_file = master_file.rename(columns=d)

    col_to_plot2 = [d[i] for i in col_to_plot]
    df = master_file
    first_col = df.pop('unique_id')

    df.insert(0, 'unique_id', first_col)

    output_file = ARGS["prefix"] + 'fpkm_' + LABELS[j]
    df.to_csv(output_file+'.csv', sep='\t', index=False)

    plt.figure(figsize=(7, 4))
    g = sns.boxplot(data=df, showfliers=True)
    g.set_yscale("log", basey=2)
    plt.xticks(rotation=45, horizontalalignment="right")
    sns.despine()
    plt.tight_layout()
    plt.savefig(output_file+'.svg')
    # plt.show()
    plt.close('all')