#!/usr/bin/env python
import pandas as pd
import numpy as np
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import pdist
import matplotlib
font = {'size'   : 8}
matplotlib.rc('font', **font)

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
from sklearn.preprocessing import scale,minmax_scale
import operator
from sys import argv
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

Hierarchically cluster sites and sort sites in a descending manner
according to mean signal, respecting cluster structure.

""")

##################################################
# required args:
parser.add_argument("-i" "--mats",  nargs='+',
                    help="""required, input data, e.g. logFC expression
/path/to/expression.txt
""", dest="mats", action='store', required=True)

parser.add_argument("-o", "--outplots",  nargs='+',
                    help="""

list of plot names to which to save plot.

""", dest="outplots", action='store', required=True)

##################################################
# optional args:

parser.add_argument("--titles",  nargs='+',
                    help="""optional, title of mats
""", action='store', default=False)

parser.add_argument("--vmin", 
                    help="""
vmin of matshow.

""", dest="vmin", action='store', type=float)

parser.add_argument("--vmax", 
                    help="""
vmax of matshow.

""", dest="vmax", action='store', type=float)

parser.add_argument("--cmap", nargs="+",
                    help="""

""", default=['OB'])

parser.add_argument("--colorbar", 
                    help="""
If indicated, plot colorbar next beside figure.

""", action='store_true')

parser.add_argument("--clusters", 
                    help="""
If indicated, "cut" the dendrogram that defines
the hierarchical clustering at the specified
number of clusters. By default, no clusters returned.

""", type=int, default=1)

parser.add_argument("--scale", 
                    help="""
If indicated, standardize each row (mean=0, stdev=1).

""", action='store_true')

parser.add_argument("--minmax_scale", 
                    help="""
If indicated, scale each row to have min=0, max=1.

""", action='store_true')

parser.add_argument("--indices", nargs='+', 
                    help="""

A list of files to split the matrix into sub-matrices.

""", dest="indices", action='store')

parser.add_argument("--ascending", 
                    help="""
To sort matrix in reverse order.

""", dest='ascending', action='store_true')

parser.add_argument("--figsize_y", 
                    help="""
height of figure

""", type=float, default=3.)

parser.add_argument("--figsize_x", 
                    help="""
width of figure

""", type=float, default=2.)

parser.add_argument("--weights", nargs='+', 
                    help="""
weight applied to each input matrix for linkage calculation

""", dest="weights", action='store')

parser.add_argument("--metric", default="ward",
                    help="""
[default=ward], also, single|complete|average|weighted|centroid|median

""")

parser.add_argument("-m", "--motifs",  nargs='+',
                    help="""

One or more paths to motif files with 1,0 for 
whether the site has or doesn't have the motif.

""", dest="motifs", action='store')
parser.add_argument("--motifs_width",  type=int, default=1,
                    help="""

The width of the representation of each motif file
as a factor of the width of an individual heatmap column.
""", dest="motifs_width", action='store')
parser.add_argument("--motif_colors",  nargs='+',
                    help="""

The colors of the motif columns in same order
as the --motifs argument.
""", dest="motif_colors", action='store')
parser.add_argument("--save_clustered_indices_to_file", 
                    help="""
""")


######################################################

args = parser.parse_args()

######################################################

# black = (0,0,0)
# orange = (230/255.0,159/255.0,0)
# skyblue = (86/255.0,180/255.0,233/255.0)
# bluishgreen = (0,158/255.0,115/255.0)
# yellow = (240/255.0,228/255.0,66/255.0)
# blue = (0,114/255.0,178/255.0)
# vermillion = (213/255.0,94/255.0,0)
# reddishpurple = (204/255.0,121/255.0,167/255.0)
# color_convert= {'black':black,
#                 'orange':orange,
#                 'skyblue':skyblue,
#                 'bluishgreen':bluishgreen,
#                 'yellow':yellow,
#                 'blue':blue,
#                 'vermillion':vermillion,
#                 'reddishpurple':reddishpurple}

# def color_converter(color):
#     if color in color_convert:
#         return color_convert[color]
#     else:
#         return color
    
# # #######################

if len(args.mats) == 1:
    mat = pd.read_csv(args.mats[0], sep="\t", index_col=0)
    multi_mat = False
    orig_index = list(mat.index)
else:
    mats = [pd.read_csv(mat, sep="\t", index_col=0) for mat in args.mats]
    orig_index = list(mats[0].index)
#     assert(all([True if len(mat.values[np.isnan(mat.values)]) == 0 else False 
    for mat,title in zip(mats,args.titles):
        if len(mat.values[np.isnan(mat.values)]) > 0:
            print "%s contains null values"%(title)
            
    if args.weights:
        weights = np.array([float(weight) for weight in args.weights])
    else:
        weights = np.array([1.] * len(mats))
    weights /= weights.sum()
    multi_mat = True


######################################################
# set up routines
######################################################

def recurse_to_leaf(k,N,Z,weight_dict):
    if k < N:
        return k
    else:
        idx1, idx2 = Z[int(k-N),:][:2]
        k1 = recurse_to_leaf(int(idx1),N,Z,weight_dict)
        k2 = recurse_to_leaf(int(idx2),N,Z,weight_dict)
        if weight_dict[idx1] > weight_dict[idx2]:
            return (k1, k2)
        else:
            return (k2, k1)

def flatten(x):
    # from http://stackoverflow.com/questions/2158395/flatten-an-irregular-list-of-lists-in-python
    result = []
    for el in x:
        if hasattr(el, "__iter__") and not isinstance(el, basestring):
            result.extend(flatten(el))
        else:
            result.append(el)
    return result

def linkage_sort_by_mean(Z, X):
    N = len(X)
    weight_dict = {i:X[i,:].mean() for i in range(len(X))}
    size_dict = {i:1 for i in range(len(X))}
    for i in range(len(Z)):
        k1 = Z[i,0]
        k2 = Z[i,1]
        cluster_weight = ((weight_dict[k1] * size_dict[k1]) + \
                          (weight_dict[k2] * size_dict[k2])) / (size_dict[k1] + size_dict[k2])
        weight_dict[N+i] = cluster_weight
        size_dict[N+i] = size_dict[k1] + size_dict[k2]
        
    keys = recurse_to_leaf((N-1)*2, N, Z, weight_dict)
    indices = flatten(keys)
    
    return indices

def pdist_wrapper(X, scale_bool=False, minmax_scale_bool=False):
    if scale_bool:
        X_scaled = scale(X, axis=1)
    elif minmax_scale_bool:
        X_scaled = minmax_scale(X, axis=1)
    else:
        X_scaled = X.copy()
    
    return pdist(X_scaled)

######################################################
# hierarchically cluster
######################################################

if not args.indices:
    sorted_indices_overall = []
    if not multi_mat:
        # convert to numpy array
        X = np.array(mat.values)
        # compute Euclidean distance
        D_all = pdist_wrapper(X, scale_bool=args.scale, minmax_scale_bool=args.minmax_scale)
        
        print "Distance computed."
        # create arrays for both linkage and visualization
        # which are the same here
        X_weighted = X
        X_viz = X.copy()
    else:
        widths = []
        X_viz = np.array([])
        X_weighted = np.array([])
        
        # process each matrix sequentially
        for i, mat in enumerate(mats):
            # convert to numpy array
            X = np.array(mat.values)
            widths.append(X.shape[1])
            # compute Euclidean distance
            D = pdist_wrapper(X, scale_bool=args.scale, minmax_scale_bool=args.minmax_scale)

            print "Distance computed for index list %s"%(i)
            # keep a weighted distance array
            if i == 0:
                D_all = D * weights[i]
            else:
                D_all += ( D * weights[i] )
            # keep visualization array
            # distinct from linkage array
            # the latter of which is weighted
            if len(X_viz) == 0:
                X_viz = X.copy()
                X_weighted = X * weights[i]
            else:
                X_viz = np.hstack([X_viz, X])
                X_weighted = np.hstack([X_weighted, X * weights[i]])
             
    # cluster
    Z = sch.linkage(D_all, method=args.metric)
    print "Linkage computed"
    # sort descending
    sorted_indices = linkage_sort_by_mean(Z, X_weighted)[::-1]
    print "Linkage sorted"
    # sort ascending?
    if args.ascending:
        sorted_indices = sorted_indices[::-1]
    # visualization array is unweighted but sorted in same
    # manner as weighted array
    X_clustered = X_viz[sorted_indices,:]
    
    sorted_indices_overall.extend(sorted_indices)
    # if motif files were given, then read
    # and sort in same manner as mapping matrices
    if args.motifs:
        all_motifs = []
        for motif in args.motifs:
            with open(motif, "r") as f:
                motif = np.array([int(float(line.strip())) for line in f])
                
            all_motifs.append(motif[sorted_indices])
            
        all_motifs = np.vstack(all_motifs).T
else:
    from collections import defaultdict
    sorted_indices_overall = defaultdict(list)
    if not multi_mat:
        X_list = []
        y = []
        all_motifs = []
        for i,indices in enumerate(args.indices):
            with open(indices, "r") as f:
                indices = np.array([line.strip() for line in f])
            
            indices = np.array(sorted(set(indices) & set(mat.index)))
            y.append(len(indices))
            X = np.array(mat.ix[indices])
            D = pdist_wrapper(X, scale_bool=args.scale, minmax_scale_bool=args.minmax_scale)

            print "Distance computed for index list %s"%(i)
            Z = sch.linkage(D, method=args.metric)
            print "Linkage computed for index list %s"%(i)
            indices_sorted_by_linkage = linkage_sort_by_mean(Z, X)[::-1]
            print "Linkage sorted for index list %s"%(i)
            sorted_indices = indices[indices_sorted_by_linkage]
            if args.ascending:
                sorted_indices = sorted_indices[::-1]
            X_list.append(mat.ix[sorted_indices])
            sorted_indices_overall[i].extend(sorted_indices)
            
            if args.motifs:
                indices_set = set(indices)
                bool_array = np.array([ix in indices_set for ix in mat.index])
                motifs = []
                for motif in args.motifs:
                    with open(motif, "r") as f:
                        motif = np.array([int(float(line.strip())) for line in f])
                        
                    motifs.append(motif[bool_array][indices_sorted_by_linkage])
                    
                motifs = np.vstack(motifs).T
                all_motifs.append(motifs)
    else:
        y = []
        widths = []
        all_motifs = []
        X_list_of_lists = []
        for j, indices in enumerate(args.indices):
            with open(indices, "r") as f:
                indices = np.array([line.strip() for line in f])
            
            y.append(len(indices))
            Ds = []
            X_list = []
            X_weighted_list = []
            for i, mat in enumerate(mats):
                indices = np.array(sorted(set(indices) & set(mat.index)))
                print(mat.ix[indices].index)
                X = np.array(mat.ix[indices])                
                X_list.append(X)
                D = pdist_wrapper(X, scale_bool=args.scale, minmax_scale_bool=args.minmax_scale)
                Ds.append(D * weights[i])
                X_weighted_list.append(X * weights[i])
                if j == 0:
                    widths.append(X.shape[1])
            
            D = np.dstack(Ds).mean(axis=2)[0]
            X_weighted = np.hstack(X_weighted_list)
            
            print "Distance computed for index list %s"%(j)
            Z = sch.linkage(D, method=args.metric)
            print "Linkage computed for index list %s"%(j)
            indices_sorted_by_linkage = linkage_sort_by_mean(Z, X_weighted)[::-1]
            sorted_indices = indices[indices_sorted_by_linkage]
            print "Linkage sorted for index list %s"%(j)
            if args.ascending:
                sorted_indices = sorted_indices[::-1]
            
            X_list = [X[indices_sorted_by_linkage,:] for X in X_list]
            X_list_of_lists.append(X_list)
            
            sorted_indices_overall[i].extend(sorted_indices)
            if args.motifs:
                indices_set = set(indices)
                bool_array = np.array([ix in indices_set for ix in mat.index])
                motifs = []
                for motif in args.motifs:
                    with open(motif, "r") as f:
                        motif = np.array([int(float(line.strip())) for line in f])
                        
                    motifs.append(motif[bool_array][indices_sorted_by_linkage])
                    
                motifs = np.vstack(motifs).T
                all_motifs.append(motifs)
    
    X_clustered = [np.array(X) for X in X_list]
    ys = np.cumsum(y[::-1])

######################################################
# plot correlation matrix
######################################################

##########
# define mapping cmap
##########

if len(args.cmap) == 1:
    cmap_names = [args.cmap[0]] * len(args.mats)
else:
    assert(len(args.cmap) == len(args.mats))
    cmap_names = args.cmap

cmaps = []
for cmap_name in cmap_names:
    if cmap_name == "OB":
        from palettable.cubehelix import Cubehelix
        import matplotlib.colors as mcolors
        negative = Cubehelix.make(start=0.25, rotation=-0.25, sat=1.5, reverse=True, max_light=0.8)
        positive = Cubehelix.make(start=0.75, rotation=0.25, sat=1.5, max_light=0.8)
        negative_cols = negative.mpl_colormap(np.linspace(0., 1, 128))
        positive_cols = positive.mpl_colormap(np.linspace(0, 1, 128))
        # combine them and build a new colormap
        cmap = mcolors.LinearSegmentedColormap.from_list('OB', np.vstack((negative_cols, positive_cols)))
    elif cmap_name == "YP":
        from palettable.cubehelix import Cubehelix
        import matplotlib.colors as mcolors
        negative = Cubehelix.make(start=0.75, rotation=-0.25, sat=1.5, reverse=True, max_light=0.8)
        positive = Cubehelix.make(start=1, rotation=0.25, sat=1.5, max_light=0.8)
        negative_cols = negative.mpl_colormap(np.linspace(0., 1, 128))
        positive_cols = positive.mpl_colormap(np.linspace(0, 1, 128))
        # combine them and build a new colormap
        cmap = mcolors.LinearSegmentedColormap.from_list('YP', np.vstack((negative_cols, positive_cols)))
    elif cmap_name == "cubehelix":
        from palettable.cubehelix import jim_special_16
        cmap = jim_special_16.mpl_colormap
    else:
        # plug in other cmaps here
        cmap = cmap_name
    cmaps.append(cmap)

if not args.vmin:
    vmin = -max(abs(np.percentile(np.vstack(X_clustered).flatten(),5)), np.percentile(np.vstack(X_clustered).flatten(),95))
else:
    vmin = args.vmin
if not args.vmax:
    vmax = max(abs(np.percentile(np.vstack(X_clustered).flatten(),5)), np.percentile(np.vstack(X_clustered).flatten(),95))
else:
    vmax = args.vmax

##########
# expand args.fig_size_x by xfactor,
# essentially stretching plot to accommodate colorbar
##########

if not args.colorbar:
    xfactor = 1.0
else:
    xfactor = 1.2

##########
# define arguments to gridspec
##########

# add a column for each motif file
if args.motifs:
    ncols = 1 + len(args.motifs)
else:
    ncols = 1
    width_ratios = [1]

if args.indices:
    nrows = len(args.indices)
    height_ratios = np.array(y) / float(sum(y))
#     height_ratios = np.array(y[::-1]) / float(sum(y))
else:
    nrows = 1
    height_ratios = [1]

if multi_mat:
    ncols += (len(args.mats) - 1)
    try:
        width_ratios = widths + len(args.motifs)*[args.motifs_width]
    except TypeError:
        width_ratios = widths
else:
    if args.motifs:
        try:
            width_ratios = [X_clustered.shape[1]] + len(args.motifs)*[args.motifs_width]
        except AttributeError:
            width_ratios = [X_clustered[0].shape[1]] + len(args.motifs)*[args.motifs_width]
    else:
        width_ratios = [1]

##########
# define plotting functions
##########

def plot_mapping_ax(array, ax, cmap):
    im = ax.matshow(array, 
                    aspect='auto',
                    origin='lower', 
                    cmap=cmap, 
                    vmin=vmin, 
                    vmax=vmax)
    
    ax.set_xticks([])
    ax.xaxis.set_ticks_position('none')
    ax.set_yticks([])
    ax.yaxis.set_ticks_position('none')
    return im

def plot_motif_ax(array, ax, color=None):
    if "," not in color:
        if color:
            motif_cmap, norm = mcolors.from_levels_and_colors([0., 1., 1.5], 
                                                        ['white', color_converter(color)])
        else:
            motif_cmap, norm = mcolors.from_levels_and_colors([0., 1., 1.5], 
                                                    ['black', 'white'])
    elif len(color.split(",")) == 2:
        motif_cmap, norm = mcolors.from_levels_and_colors([-1.5, -0.5, 0.5, 1.5], 
                                                    [color_converter(color.split(",")[0]), 'white', color_converter(color.split(",")[1])])
    elif len(color.split(",")) == 3:
        motif_cmap, norm = mcolors.from_levels_and_colors([-1.5, -0.5, 0.5, 1.5], 
                                                    [color_converter(color.split(",")[0]), color_converter(color.split(",")[1]), color_converter(color.split(",")[2])])
        
    motif_im = ax.imshow(array, 
              aspect='auto',
              origin='lower',
              cmap=motif_cmap,
              interpolation="none",
              norm=norm)
    ax.set_xticks([])
    ax.xaxis.set_ticks_position('none')
    ax.set_yticks([])
    ax.yaxis.set_ticks_position('none')
    return motif_im

##########
# create figurelsj\
##########

fig = plt.figure(figsize=(args.figsize_x * xfactor,
                          args.figsize_y))

print nrows, ncols
gs = gridspec.GridSpec(nrows, 
                       ncols, 
                       width_ratios=width_ratios, 
                       height_ratios=height_ratios, 
                       wspace=0.1, 
                       hspace=0.1)

# if args.motifs and args.indices:
#     for i in range(len(args.indices)):
#         ax = plt.subplot(gs[i,0])
#         im = plot_mapping_ax(X_clustered[i], ax)
#         for j in range(len(args.motifs)):
#             ax = plt.subplot(gs[i,j+1])
#             if args.motif_colors:
#                 motif_im = plot_motif_ax(np.vstack(all_motifs[i][:,j].astype('float')), ax, color=args.motif_colors[j])
#             else:
#                 motif_im = plot_motif_ax(np.vstack(all_motifs[i][:,j].astype('float')), ax)

print width_ratios, height_ratios
print len(X_clustered)
print X_clustered[0].shape
print gs

cmap_name_to_im = {}
if args.motifs and args.indices:
    if not multi_mat:
        for i in range(len(args.indices)):
            for j in range(1 + len(args.motifs)):
                ax = plt.subplot(gs[i,j])
                if args.titles and i==0:
                    ax.set_title(args.titles[j])
                if j == 0:
                    im = plot_mapping_ax(X_list[i], ax, cmaps[0])
                    cmap_name_to_im[cmap_names[0]] = im
#                     im = plot_mapping_ax(X_list[i], ax, vmin=vmin[j], vmax=vmax[j])
                else:
                    if args.motif_colors:
                        motif_im = plot_motif_ax(np.vstack(all_motifs[i][:,j-1].astype('float')), ax, color=args.motif_colors[j-1])
                    else:
                        motif_im = plot_motif_ax(np.vstack(all_motifs[i][:,j-1].astype('float')), ax)
    else:
        for i in range(len(args.indices)):
            for j in range(len(X_list_of_lists[i]) + len(args.motifs)):
                ax = plt.subplot(gs[i,j])
                if args.titles and i==0:
                    ax.set_title(args.titles[j])
                if j < len(X_list_of_lists[i]):
                    im = plot_mapping_ax(X_list_of_lists[i][j], ax, cmaps[i])
                    cmap_name_to_im[cmap_names[i]] = im
#                     im = plot_mapping_ax(X_list_of_lists[i][j], ax, vmin=vmin[j], vmax=vmax[j])
                else:
                    if args.motif_colors:
                        motif_im = plot_motif_ax(np.vstack(all_motifs[i][:,j-len(X_list_of_lists[i]) ].astype('float')), ax, color=args.motif_colors[j-len(X_list_of_lists[i])])
                    else:
                        motif_im = plot_motif_ax(np.vstack(all_motifs[i][:,j-len(X_list_of_lists[i]) ].astype('float')), ax)
elif args.motifs and not args.indices:
    if not multi_mat:
        ax = plt.subplot(gs)
        if args.titles:
            ax.set_title(args.titles[0])
        im = plot_mapping_ax(X_clustered, ax, cmaps[0])
        cmap_name_to_im[cmap_names[0]] = im
#         im = plot_mapping_ax(X_clustered, ax, vmin=vmin[j], vmax=vmax[j])
        for j in range(len(args.motifs)):
            ax = plt.subplot(gs[j+1])
            if args.motif_colors:
                motif_im = plot_motif_ax(np.vstack(all_motifs[:,j].astype('float')), ax, color=args.motif_colors[j])
            else:
                motif_im = plot_motif_ax(np.vstack(all_motifs[:,j].astype('float')), ax)
    else:
        rolling_width=0
        for i,width in enumerate(widths):
            ax = plt.subplot(gs[i])
            if args.titles:
                ax.set_title(args.titles[i])
            im = plot_mapping_ax(X_clustered[:,rolling_width:(rolling_width+width)], ax, cmaps[i])
            cmap_name_to_im[cmap_names[i]] = im
#             im = plot_mapping_ax(X_clustered[:,rolling_width:(rolling_width+width)], ax, vmin=vmin[0], vmax=vmax[0])
            rolling_width+=width
            
        for j in range(len(args.motifs)):
            ax = plt.subplot(gs[j+i+1])
            if args.motif_colors:
                motif_im = plot_motif_ax(np.vstack(all_motifs[:,j].astype('float')), ax, color=args.motif_colors[j])
            else:
                motif_im = plot_motif_ax(np.vstack(all_motifs[:,j].astype('float')), ax)

elif not args.motifs and args.indices:
    if not multi_mat:
        for i in range(len(args.indices)):
            ax = plt.subplot(gs[i])
            if args.titles:
                ax.set_title(args.titles[i])
            im = plot_mapping_ax(X_clustered[i], ax, cmaps[i])
            cmap_name_to_im[cmap_names[i]] = im
#             im = plot_mapping_ax(X_clustered[i], ax, vmin=vmin[0], vmax=vmax[0])
    else:
        for i in range(len(args.indices)):
            for j in range(len(X_list_of_lists[i])):
                ax = plt.subplot(gs[i,j])
                if args.titles and i==0:
                    ax.set_title(args.titles[j])
                im = plot_mapping_ax(X_list_of_lists[i][j], ax, cmaps[i])
                cmap_name_to_im[cmap_names[i]] = im
#                 im = plot_mapping_ax(X_list_of_lists[i][j], ax, vmin=vmin[j], vmax=vmax[j])        
else:
    if not multi_mat:
#         ax = plt.subplots()
        ax = fig.add_subplot(111)
        if args.titles:
            ax.set_title(args.titles[0])
        im = plot_mapping_ax(X_clustered, ax, cmaps[0])
        cmap_name_to_im[cmap_names[0]] = im
#         im = plot_mapping_ax(X_clustered, ax, vmin=vmin[0], vmax=vmax[0])
    else:
        rolling_width=0
        for i,width in enumerate(widths):
            ax = plt.subplot(gs[i])
            if args.titles:
                ax.set_title(args.titles[i])
            im = plot_mapping_ax(X_clustered[:,rolling_width:(rolling_width+width)], ax, cmaps[i])
            cmap_name_to_im[cmap_names[i]] = im
#             im = plot_mapping_ax(X_clustered[:,rolling_width:(rolling_width+width)], ax, vmin=vmin[i], vmax=vmax[i])
            rolling_width+=width

if args.colorbar: 
    fig.subplots_adjust(right=0.85, bottom=0.15)
    if not args.motifs:
        left, width = 0.86, 0.05
        bottom = 0.15
        n_cmaps = len(set(cmap_names))
        height = 0.75 / n_cmaps
        for i, cmap_name in enumerate(sorted(set(cmap_names))):
            bottom = 0.15 + i * height
            cbar_ax = fig.add_axes([left, bottom, width, height])
            cb = fig.colorbar(cmap_name_to_im[cmap_name], cax=cbar_ax)
            cb.ax.tick_params(labelsize=10)
#         cbar_ax = fig.add_axes([0.86, 0.15, 0.05, 0.75])
#         cb = fig.colorbar(im, cax=cbar_ax)
#         cb.ax.tick_params(labelsize=10)
    else:
        cbar_ax1 = fig.add_axes([0.86, 0.53, 0.05, 0.37])
        cbar_ax2 = fig.add_axes([0.86, 0.15, 0.05, 0.37])
        cb1 = fig.colorbar(im, cax=cbar_ax1)
        cb1.ax.tick_params(labelsize=10)
        cb2 = fig.colorbar(motif_im, cax=cbar_ax2)
        cb2.ax.tick_params(labelsize=10)

# not yet implemented with args.indices supplied
if args.clusters > 1 and not args.indices:
    import matplotlib.patches as mpatches
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(fig.axes[0])
    cluster_bar_ax = divider.append_axes("left", size="5%", pad=0.05)
    D_clustered = pdist_wrapper(X_clustered, scale_bool=args.scale, minmax_scale_bool=args.minmax_scale)
    Z_clustered = sch.linkage(D_clustered, method=args.metric)
    clust = sch.fcluster(Z_clustered, args.clusters, criterion="maxclust")
#     clust = sch.fcluster(Z_clustered, args.clusters, criterion="maxclust")[::-1]
    cmap = matplotlib.cm.get_cmap('rainbow')
    bar_colors = [cmap(i) for i in np.arange(0,1,1/float(args.clusters + 1))[1:]]
    rolling_height = 0
    _, idx = np.unique(clust, return_index=True)
    for i,cluster in enumerate(clust[np.sort(idx)]):
        height = len(clust[clust == cluster]) / float(len(clust))
        cluster_bar_ax.add_patch(mpatches.Rectangle(
                xy=(0, rolling_height), 
                width=1, 
                height=rolling_height+height,
                facecolor=bar_colors[i]))
        rolling_height+=height
    
    cluster_bar_ax.xaxis.set_ticks_position('none') 
    cluster_bar_ax.yaxis.set_ticks_position('none') 
    cluster_bar_ax.set_xticks([]) 
    cluster_bar_ax.set_yticks([])


for outplot in args.outplots:
    plt.savefig(outplot)

if args.save_clustered_indices_to_file:
    with open(args.save_clustered_indices_to_file, "w") as f:
        if isinstance(sorted_indices_overall, dict):
            for k in sorted(sorted_indices_overall):
                for v in sorted_indices_overall[k][::-1]:
                    idx = orig_index[v]
                    f.write("%s\t%s\n"%(k,idx))
        elif not isinstance(sorted_indices_overall, dict) and args.clusters > 1:
            i=0; last_cluster=""; clust_renumbered=[]
            for cluster in clust[::-1]:
                if cluster != last_cluster:
                    i+=1
                clust_renumbered.append(i)
                last_cluster = cluster
                
            f.write("\n".join(["cluster_%s\t%s"%(cluster ,orig_index[idx]) for cluster, idx in zip(clust_renumbered, sorted_indices_overall[::-1])]) + "\n")            
        else:
            f.write("\n".join([str(orig_index[idx]) for idx in sorted_indices_overall[::-1]]) + "\n")