#!/usr/bin/env python
import pandas as pd
import numpy as np
from copy import copy
import matplotlib.pyplot as plt
from GGR import utils
from mpl_toolkits.axes_grid1 import make_axes_locatable
import argparse
from reddylab_utils.reddylab_plotting_utils import make_colormap

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script plots a dataframe of values (with index and header by default) as a heatmap.
Heatmap ranges from (color blind-friendly) orange to white to blue for high to low values,
respectively. Other colormaps can be plotted instead.

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args:

required.add_argument("--vals", help="required, file path to dataframe to plot as heatmap", required=True)
required.add_argument("--outplot", type=str, help="path to outplot", required=True)

##################################################
# optional args:

optional.add_argument("--vmax", help="optional, vmax of heatmap", type=float)
optional.add_argument("--vmin", help="optional, vmin of heatmap", type=float)

optional.add_argument("--cmap", help="optional, colormap of heatmap (default: orange-black-blue)")

optional.add_argument("--no_header", help="optional, if --vals and --pvals have no header", action="store_true")
optional.add_argument("--no_index", help="optional, if --vals and --pvals have no index", action="store_true")

optional.add_argument("--no_xticks", action="store_true")
optional.add_argument("--no_yticks", action="store_true")

optional.add_argument("--imshow", help="use imshow instead of pcolor", action="store_true")

optional.add_argument("--not_equal_aspect_ratio", help="do not enforce equal aspect ratio", action="store_true")

optional.add_argument("--figsize_x", help="optional", type=float)
optional.add_argument("--figsize_y", help="optional", type=float)

optional.add_argument("--nan_color", help="optional", default="white")

optional.add_argument("--pvals", help="optional, file path to p-value dataframe corresponding to --vals df in dimension and order")
optional.add_argument("--p_cutoff", help="optional, p-value cut off for plotting 'N.S.' (non-significant) (default: %(default)s)", 
                    type=float, default=0.001)
optional.add_argument("-x", '--xlabel', type=str, help="optional, xlabel.")

optional.add_argument("--enforce_zero_balance",
                    help="""optional, if --enforce_zero_balance,
then enforce that vmax == -vmin. For log2 odd ratios, this may be helpful
to make sure zero values are white.""", action="store_true")
        
##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

header = 'infer' if not args.no_header else None
index_col = 0 if not args.no_index else None

vals = pd.read_csv(args.vals, index_col=index_col, sep="\t", header=header)
if args.pvals:
    pvals = pd.read_csv(args.pvals, index_col=index_col, sep="\t", header=header)

# replace near infinity values
vals[vals == -np.inf] = -9999
vals[vals == np.inf] = 9999

# set up colormap
if not args.cmap or args.cmap == "OB":
    from palettable.cubehelix import Cubehelix
    import matplotlib.colors as mcolors
    negative = Cubehelix.make(start=0.25, rotation=-0.25, sat=1.5, reverse=True, max_light=0.8)
    positive = Cubehelix.make(start=0.75, rotation=0.25, sat=1.5, max_light=0.8)
    negative_cols = negative.mpl_colormap(np.linspace(0., 1, 128))
    positive_cols = positive.mpl_colormap(np.linspace(0, 1, 128))
    # combine them and build a new colormap
    cmap = mcolors.LinearSegmentedColormap.from_list('OB', np.vstack((negative_cols, positive_cols)))
elif args.cmap == "O":
    from palettable.cubehelix import Cubehelix
    cmap = Cubehelix.make(start=0.75, rotation=0.25, sat=1.5, max_light=0.8).mpl_colormap
elif args.cmap == "OW":
    from palettable.cubehelix import Cubehelix
    cmap = Cubehelix.make(start=0.75, rotation=0.25, sat=1.5, max_light=1).mpl_colormap
elif args.cmap == "B":
    from palettable.cubehelix import Cubehelix
    cmap = Cubehelix.make(start=0.25, rotation=-0.25, sat=1.5, max_light=0.8).mpl_colormap
elif args.cmap == "BW":
    from palettable.cubehelix import Cubehelix
    cmap = Cubehelix.make(start=0.25, rotation=-0.25, sat=1.5, max_light=1).mpl_colormap
elif args.cmap == "VS":
    vermillion = (213/255.,94/255.,0)
    sky_blue = (86/255.,180/255.,233/255.)
    cmap = make_colormap([sky_blue, (1,1,1), 0.5, (1,1,1), vermillion])
elif args.cmap == "cubehelix":
    from palettable.cubehelix import jim_special_16
    cmap = jim_special_16.mpl_colormap
elif args.cmap == "cubehelix_red":
    from palettable.cubehelix import red_16
    cmap = red_16.mpl_colormap
else:
    cmap = args.cmap

ytick_labels = vals.index
xtick_labels = vals.columns

# make sure heatmap can handle nan values
cmap = copy(cmap)
cmap.set_bad(color=args.nan_color, alpha = 1.)
vals = np.ma.array(np.array(vals), mask=np.isnan(np.array(vals)))

# set up colormap limits
vmax = args.vmax if args.vmax is not None else np.nanpercentile(vals.flatten(), 95)
if args.enforce_zero_balance and not args.vmin:
    vmin = -vmax
else:
    vmin = args.vmin if args.vmin is not None else np.nanpercentile(vals.flatten(), 5)

if args.figsize_x:
    x_size = args.figsize_x
else:
    x_size = vals.shape[1] if vals.shape[1] < 20 else 8
    
if args.figsize_y:
    y_size = args.figsize_y
else:
    y_size = vals.shape[0] if vals.shape[0] < 20 else 8
    
fig,ax = plt.subplots(figsize=(x_size, y_size))

# plot heatmap
if not args.imshow:
    for_cb = ax.pcolor(vals, cmap=cmap, vmin=vmin, vmax=vmax)
else:
    aspect = "auto" if args.not_equal_aspect_ratio else None
    for_cb = ax.imshow(vals, 
                       cmap=cmap, 
                       vmin=vmin, 
                       vmax=vmax, 
                       interpolation="nearest", 
                       origin='lower', 
                       aspect=aspect)

# create labels and limit plot boundaries
if not args.no_xticks:
    ax.set_xticks(np.arange(0.5, len(xtick_labels), 1))
    ax.set_xticklabels(xtick_labels, ha="right", rotation=45, fontsize=8)
else:
    ax.set_xticks([])
if not args.no_yticks:
    ax.set_yticks(np.arange(0.5, len(ytick_labels), 1))
    ax.set_yticklabels(ytick_labels, fontsize=8)
else:
    ax.set_yticks([])

ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')

if not args.imshow:
    ax.set_xlim((0, vals.shape[1]))
    ax.set_ylim((0, vals.shape[0]))

# create colorbar
divider = make_axes_locatable(ax)
cbar_ax = divider.append_axes("right", size="5%", pad=0.05)
if vmax == int(vmax):
    # if vmax is an integer, then make colorbar ticks at integer values
    ticks=np.arange(vmin, vmax+1)
else:
    # otherwise, just label vmin, vmax (and zero)
    ticks = [vmin, 0, vmax] if vmin < 0 else [vmin, vmax]

cb = fig.colorbar(for_cb, cax=cbar_ax, ticks=ticks)
cb.ax.tick_params(labelsize=8)

# enforce that heatmap is composed of squares
if not args.not_equal_aspect_ratio:
    ax.set_aspect('equal')

# label results with p-value > p_cutoff as non-significant or "N.S."
if args.pvals:
    p_array = np.array(pvals)
    for i in np.arange(p_array.shape[0]):
        for j in np.arange(p_array.shape[1]):
            if p_array[i,j] > args.p_cutoff:
                ax.annotate('N.S.', xy=(j+0.2, i+0.7), xytext=(j+0.2, i+0.7), fontsize=8)

ax.invert_yaxis()
plt.tight_layout()
plt.savefig(args.outplot)
