import numpy as np
import utils
import seaborn as sns
import matplotlib.pyplot as plt
import colorcet
def expressionHeatMap(output, gmat_meansz):

    heatmat = gmat_meansz.T

    heatmat_reduced = heatmat.mean(level=0, axis=1)
    nrow, ncol = heatmat.shape

    # order1 
    _, rowidx, colidx_reduced = utils.diag_matrix(heatmat_reduced.values, threshold=1)
    colidx = np.hstack([4*col+np.array([0,1,2,3]) for col in colidx_reduced])

    # order2
    row_order = utils.get_order_from_hierarchy(heatmat)
    col_order_reduced = utils.get_order_from_hierarchy(heatmat_reduced.T) # by cluster; collapse slices
    col_order_heatmap = col_order_reduced

    # order3 (curated col_order_reduced)
    _, rowidx_v3_heatmap_all, _ = utils.diag_matrix_rows(heatmat_reduced.iloc[:,col_order_reduced].values)

    fig, ax = plt.subplots(figsize=(10, 15))
    sns.heatmap(
        #heatmat.iloc[rowidx_v3_heatmap_all, col_order_heatmap], 
        heatmat.iloc[[1,0,2,3,4,6,7,8,5], 0:ncol],
        vmax=2, vmin=-2, center=0, 
        xticklabels=True,
        yticklabels=True,
        cmap=colorcet.coolwarm,
        ax=ax,
        cbar_kws=dict(shrink=0.2, label='gene expression in zscore across cell types and samples'),
       )
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=10)
    ax.tick_params(axis=u'both', which=u'both',length=0)

    xtl = [tl.get_text() for tl in ax.get_xticklabels()]
    xtl_part1 = [tl.split('-')[0] for tl in xtl]
    xtl_part2 = ['-'.join(tl.split('-')[1:]) for tl in xtl]
    xtl_simp = list(utils.dedup_array_elements(xtl_part1))
    ax.set_xticklabels(xtl_simp)

    utils.savefig(fig, output)
    print(output)
    plt.show()

    cluster_orders = col_order_reduced