#!/usr/bin/env python

##        DESCRIPTION: This analysis integrates the magnitude of enhancer transcription (GRO-seq),
##        enrichment of marks associated with enhancers (H3K4me1 and H3K27ac ChIP-seq), TF mRNA
##        expression levels (RNA-seq), and TF motif p-values (MEME and TOMTOM) to determine
##        key breast cancer subtype-specific TFs.
##
##        Author: Venkat S. Malladi

import Bio.motifs
import pandas as pd
import numpy as np
import csv
import re
import string
from sklearn import preprocessing
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
from scipy.stats import pearsonr, spearmanr, cumfreq


# Order of Cell Lines
reorder = ['76NF2V','MCF10A','MCF7','ZR751','MB361','UACC812','SKBR3','AU565','HCC1954','MB468','HCC1937','MB231','MB436']


### Load and Process the ChIP-seq Data

# Load the matrix of Input data
enhancers_universe_Input= pd.DataFrame.from_csv("MatrixAnalysis/5650_sspsunp_memecoordinates_memeid_Input_normreadcountsIn14Celllines_updated.txt", sep="\t", header=0)

# Filter for these columns
Input_columns = ['status', 'MCF7_Input_normreadcounts', 'MB231_Input_normreadcounts', 'MB361_Input_normreadcounts', 'MDA76NF2V_Input_normreadcounts', 'UACC812_Input_normreadcounts', 'HCC1954_Input_normreadcounts', 'AU565_Input_normreadcounts', 'MCF10A_Input_normreadcounts', 'SKBR3_Input_normreadcounts', 'MB436_Input_normreadcounts', 'MB468_Input_normreadcounts', 'HCC1937_Input_normreadcounts', 'ZR751_Input_normreadcounts']
Input_index = enhancers_universe_Input.ID.values
Input_tmp = pd.DataFrame(enhancers_universe_Input, columns=Input_columns )
Input_values = Input_tmp.set_index(Input_index)

# Filter for SSP and SUP
Input_ssp = Input_values[Input_values['status'].str.contains("SSP")]
Input_sunp = Input_values[Input_values['status'].str.contains("SUNP")]
Input_columns = ['MCF7_Input_normreadcounts', 'MB231_Input_normreadcounts', 'MB361_Input_normreadcounts', 'MDA76NF2V_Input_normreadcounts', 'UACC812_Input_normreadcounts', 'HCC1954_Input_normreadcounts', 'AU565_Input_normreadcounts', 'MCF10A_Input_normreadcounts', 'SKBR3_Input_normreadcounts', 'MB436_Input_normreadcounts', 'MB468_Input_normreadcounts', 'HCC1937_Input_normreadcounts', 'ZR751_Input_normreadcounts']
filter_Input_values = pd.concat([Input_ssp,Input_sunp])
only_Input_values = pd.DataFrame(filter_Input_values, columns=Input_columns)


# Rename columns and reorder
only_Input_values.columns = ['MCF7', 'MB231', 'MB361', '76NF2V', 'UACC812', 'HCC1954', 'AU565', 'MCF10A', 'SKBR3', 'MB436', 'MB468', 'HCC1937', 'ZR751']
only_Input_values = only_Input_values[reorder]
x = only_Input_values.stack()
y = filter(lambda a: a != 0, x)
Input_factor = min(y)
Input_values_std_robust = only_Input_values + Input_factor


# Load the matrix of H3K4Me1
enhancers_universe_H3K4me1 = pd.DataFrame.from_csv("MatrixAnalysis/5650_sspsunp_memecoordinates_memeid_H3K4me1_normreadcountsIn14Celllines.txt", sep="\t", header=0)

# Filter for these columns
H3K4me1_columns = ['status', 'MCF7_normreadcounts', 'MB231_normreadcounts', 'MB361_normreadcounts', 'MDA76NF2V_normreadcounts', 'UACC812_normreadcounts', 'HCC1954_normreadcounts', 'AU565_normreadcounts', 'MCF10A_normreadcounts', 'SKBR3_normreadcounts', 'MB436_normreadcounts', 'MB468_normreadcounts', 'HCC1937_normreadcounts', 'ZR751_normreadcounts']
H3K4me1_index = enhancers_universe_H3K4me1.ID.values
H3K4me1_tmp = pd.DataFrame(enhancers_universe_H3K4me1, columns=H3K4me1_columns )
H3K4me1_values = H3K4me1_tmp.set_index(H3K4me1_index)

# Filter for SSP and SUP
H3K4me1_ssp = H3K4me1_values[H3K4me1_values['status'].str.contains("SSP")]
H3K4me1_sunp = H3K4me1_values[H3K4me1_values['status'].str.contains("SUNP")]
H3K4me1_columns = ['MCF7_normreadcounts', 'MB231_normreadcounts', 'MB361_normreadcounts', 'MDA76NF2V_normreadcounts', 'UACC812_normreadcounts', 'HCC1954_normreadcounts', 'AU565_normreadcounts', 'MCF10A_normreadcounts', 'SKBR3_normreadcounts', 'MB436_normreadcounts', 'MB468_normreadcounts', 'HCC1937_normreadcounts', 'ZR751_normreadcounts']
filter_H3K4me1_values = pd.concat([H3K4me1_ssp,H3K4me1_sunp])
only_H3K4me1_values = pd.DataFrame(filter_H3K4me1_values, columns=H3K4me1_columns)

# Rename columns and reorder
only_H3K4me1_values.columns = ['MCF7', 'MB231', 'MB361', '76NF2V', 'UACC812', 'HCC1954', 'AU565', 'MCF10A', 'SKBR3', 'MB436', 'MB468', 'HCC1937', 'ZR751']
only_H3K4me1_values = only_H3K4me1_values[reorder]
x = only_H3K4me1_values.stack()
y = filter(lambda a: a != 0, x)
H3K4me1_factor = min(y)
H3K4me1_values_std_robust = only_H3K4me1_values + H3K4me1_factor


#Load the matrix of H3K27ac
enhancers_universe_H3K27ac = pd.DataFrame.from_csv("MatrixAnalysis/5650_sspsunp_memecoordinates_memeid_H3K27ac_normreadcountsIn14Celllines_updated.txt", sep="\t", header=0)

# Filter for these columns
H3K27ac_columns = ['status', 'MCF7_normreadcounts', 'MB231_normreadcounts', 'MB361_normreadcounts', 'MDA76NF2V_normreadcounts', 'UACC812_normreadcounts', 'HCC1954_normreadcounts', 'AU565_normreadcounts', 'MCF10A_normreadcounts', 'SKBR3_normreadcounts', 'MB436_normreadcounts', 'MB468_normreadcounts', 'HCC1937_normreadcounts', 'ZR751_normreadcounts']
H3K27ac_index = enhancers_universe_H3K27ac.ID.values
H3K27ac_tmp = pd.DataFrame(enhancers_universe_H3K27ac, columns=H3K27ac_columns )
H3K27ac_values = H3K27ac_tmp.set_index(H3K27ac_index)

# Filter for SSP and SUP
H3K27ac_ssp = H3K27ac_values[H3K27ac_values['status'].str.contains("SSP")]
H3K27ac_sunp = H3K27ac_values[H3K27ac_values['status'].str.contains("SUNP")]
H3K27ac_columns = ['MCF7_normreadcounts', 'MB231_normreadcounts', 'MB361_normreadcounts', 'MDA76NF2V_normreadcounts', 'UACC812_normreadcounts', 'HCC1954_normreadcounts', 'AU565_normreadcounts', 'MCF10A_normreadcounts', 'SKBR3_normreadcounts', 'MB436_normreadcounts', 'MB468_normreadcounts', 'HCC1937_normreadcounts', 'ZR751_normreadcounts']
filter_H3K27ac_values = pd.concat([H3K27ac_ssp,H3K27ac_sunp])
only_H3K27ac_values = pd.DataFrame(filter_H3K27ac_values, columns=H3K27ac_columns)

# Rename columns and reorder
only_H3K27ac_values.columns = ['MCF7', 'MB231', 'MB361', '76NF2V', 'UACC812', 'HCC1954', 'AU565', 'MCF10A', 'SKBR3', 'MB436', 'MB468', 'HCC1937', 'ZR751']
only_H3K27ac_values = only_H3K27ac_values[reorder]
x = only_H3K27ac_values.stack()
y = filter(lambda a: a != 0, x)
H3K27ac_factor = min(y)
H3K27ac_values_std_robust = only_H3K27ac_values + H3K27ac_factor


#Divide Histone Marks by Input
H3K4me1_values_std_input = H3K4me1_values_std_robust.divide(Input_values_std_robust)
H3K27ac_values_std_input = H3K27ac_values_std_robust.divide(Input_values_std_robust)

# Scale from 0-1
# H3K4me1
scaler = preprocessing.MinMaxScaler()
H3K4me1_values_std_robust_transform = H3K4me1_values_std_input.T
norm = scaler.fit_transform(H3K4me1_values_std_robust_transform.values)
H3K4me1_scaled = pd.DataFrame(data=norm.T, columns=list(H3K4me1_values_std_robust.columns.values), index = H3K4me1_values_std_robust.index )

# H3k27ac
scaler = preprocessing.MinMaxScaler()
H3K27ac_values_std_robust_transform = H3K27ac_values_std_input.T
norm = scaler.fit_transform(H3K27ac_values_std_robust_transform.values)
H3K27ac_scaled = pd.DataFrame(data=norm.T, columns=list(H3K27ac_values_std_robust.columns.values), index = H3K27ac_values_std_robust.index )

### Load and Process the GRO-seq Enhancer data

# Load the matix of location and RPKM values
enhancers_universe = pd.DataFrame.from_csv("MatrixAnalysis/5650ssp_sunp_universe_RPKM_withIDtoGetRPKMforMEMEop.txt", sep="\t", header=0)

# Filter for these columns
enhancers_columns = ['Universal_ID', 'MCF7_sumRPKMval', 'MB231_sumRPKMval', 'MB361_sumRPKMval', 'MDA76NF2V_sumRPKMval', 'UACC812_sumRPKMval', 'HCC1954_sumRPKMval', 'AU565_sumRPKMval', 'MCF10A_sumRPKMval', 'SKBR3_sumRPKMval', 'MB436_sumRPKMval', 'MB468_sumRPKMval', 'HCC1937_sumRPKMval', 'ZR751_sumRPKMval']
rpkm_values = pd.DataFrame(enhancers_universe, columns=enhancers_columns)

# Filter for SSP and SUP
rpkm_ssp = rpkm_values[rpkm_values['Universal_ID'].str.contains("ssp_")]
rpkm_sunp = rpkm_values[rpkm_values['Universal_ID'].str.contains("sunp_")]
rpkm_columns = ['MCF7_sumRPKMval', 'MB231_sumRPKMval', 'MB361_sumRPKMval', 'MDA76NF2V_sumRPKMval', 'UACC812_sumRPKMval', 'HCC1954_sumRPKMval', 'AU565_sumRPKMval', 'MCF10A_sumRPKMval', 'SKBR3_sumRPKMval', 'MB436_sumRPKMval', 'MB468_sumRPKMval', 'HCC1937_sumRPKMval', 'ZR751_sumRPKMval']
filter_rpkm_values = pd.concat([rpkm_ssp,rpkm_sunp])
only_rpkm_values = pd.DataFrame(filter_rpkm_values, columns=rpkm_columns)

# Log2 scale RPKM
# get mininum value to non-zero value to scale by
x = only_rpkm_values.stack()
y = filter(lambda a: a != 0, x)
rpkm_factor = min(y) # min is 0.0011924989999999999
force_zero = np.log2(0.0005)
only_rpkm_values_factor = only_rpkm_values + rpkm_factor
rpkm_values_std = only_rpkm_values.apply(np.log2).replace(-np.inf,force_zero)
scaler = preprocessing.RobustScaler()
norm = scaler.fit_transform(rpkm_values_std.values)
rpkm_values_std_robust = pd.DataFrame(data=norm, columns=list(rpkm_values_std.columns.values), index = rpkm_values_std.index )

rpkm_values_std_robust.columns = ['MCF7', 'MB231', 'MB361', '76NF2V', 'UACC812', 'HCC1954', 'AU565', 'MCF10A', 'SKBR3', 'MB436', 'MB468', 'HCC1937', 'ZR751']
rpkm_values_std_robust = rpkm_values_std_robust[reorder]

# Scale from 0-1
scaler = preprocessing.MinMaxScaler()
rpkm_values_std_robust_transform = rpkm_values_std_robust.T
norm = scaler.fit_transform(rpkm_values_std_robust_transform.values)
rpkm_scaled = pd.DataFrame(data=norm.T, columns=list(rpkm_values_std_robust.columns.values), index = rpkm_values_std_robust.index )


### Parse MEME and TOMTOM Motif data

# Loop through meme output
meme_cell_dict = {
"MEME_op_H2_AU565_1kb_zoops": "au565_tomtomop.txt",
"MEME_op_H2_HCC1954_1kb_zoops": "hcc1954_tomtomop.txt",
"MEME_op_H2_SKBR3_1kb_zoops": "skbr3_tomtomop.txt",
"MEME_op_LA_MCF7_1kb_zoops":  "mcf7_tomtomop.txt",
"MEME_op_LA_ZR751_1kb_zoops": "zr751_tomtomop.txt",
"MEME_op_LB_MB361_1kb_zoops": "mb361_tomtomop.txt",
"MEME_op_LB_UACC812_1kb_zoops": "uacc812_tomtomop.txt",
"MEME_op_NM_MCF10A_1kb_zoops": "mcf10A_tomtomop.txt",
"MEME_op_NM_MDA76NF2V_1kb_zoops": "76nf2v_tomtomop.txt",
"MEME_op_TB_HCC1937_1kb_zoops":  "hcc1937_tomtomop.txt",
"MEME_op_TB_MB468_1kb_zoops":  "mb468_tomtomop.txt",
"MEME_op_TC_MB231_1kb_zoops":   "mb231_tomtomop.txt",
"MEME_op_TC_MB436_1kb_zoops": "mb436_tomtomop.txt"
}

# Read Target ID to Motif into dictionary
motif_id_dict = {}
with open("NEW_motifSearch_onExpandedUniv_of_eRNA/motif_Ids_name.txt", "rb") as data:
    motif_ids = csv.DictReader(data, delimiter="\t")
    for line in motif_ids:
        motif_id_dict[line['ID']] = line['NAME']

meme_tomtom = pd.DataFrame()
for meme,tom in meme_cell_dict.iteritems():
    # load meme output
    meme_file = 'NEW_motifSearch_onExpandedUniv_of_eRNA/cellline_specific_motifSearch/MEME_op/%s/meme.txt' % (meme)
    record = Bio.motifs.parse(open(meme_file), 'meme')
    # Loop through all motifs and make dataframe
    meme_positions = pd.DataFrame()
    for motif in record:
        name = motif.name.split(" ")[1]
        ones = [1] * len(motif.instances)
        names = []
        for instance in motif.instances:
            names.append(instance.sequence_name)
        new = pd.DataFrame({name: ones},index = names)
        temp = pd.concat([meme_positions, new], axis=1).fillna(0)
        meme_positions = temp
    # Read tomtom file
    tomtom_file = "NEW_motifSearch_onExpandedUniv_of_eRNA/cellline_specific_motifSearch/tomtom_output/%s" % (tom)
    tomtom_dict = {}
    with open(tomtom_file, "rb") as data:
        tomtom = csv.DictReader(data, delimiter="\t")
        for line in tomtom:
            target = line['Target ID']
            motif = line['#Query ID']
            pval = float(line['p-value'])
            tfs = motif_id_dict[target].upper()
            motif_pvalue = { motif:  [pval]}
            tf_list = tfs.split("::")
            for tf in tf_list:
                # Reduce split form splice to single value [ID]_#
                single_isoform = tf.split("_")[0]
                if single_isoform in tomtom_dict.keys():
                    if motif in tomtom_dict[single_isoform].keys():
                        tomtom_dict[single_isoform][motif].append(pval)
                    else:
                        tomtom_dict[single_isoform].update(motif_pvalue)
                else:
                    tomtom_dict[single_isoform] = motif_pvalue
    # Make dataframe
    tomtom_motif = pd.DataFrame()
    for key,motif in tomtom_dict.iteritems():
        pvalue_dict = {}
        # Loop through motifs to see if length greater than 1, if so do pvalue scaling
        for m,p in motif.iteritems():
            if len(p) > 1:
                stouffer_statistic, stouffer_pval = scipy.stats.combine_pvalues(p,method = 'stouffer', weights = None)
                pvalue_dict[m] = stouffer_pval
            else:
                pvalue_dict[m] = p[0]
        pvalues = np.array(pvalue_dict.values())
        new = pd.DataFrame({key: pvalues},index = pvalue_dict.keys())
        temp = pd.concat([tomtom_motif, new], axis=1).fillna(0).sort_index(level=int)
        tomtom_motif = temp
    # Reorder
    tomtom_motif_reorder = tomtom_motif.reindex( list(meme_positions.columns.values))
    # dot product
    meme_tomtom_cell = meme_positions.dot(tomtom_motif_reorder)
    # Scale and add
    scaler = preprocessing.MinMaxScaler()
    meme_tomtom_cell_transform = meme_tomtom_cell.T
    norm = scaler.fit_transform(meme_tomtom_cell_transform.values) # norm across enhancers for each enhancer
    meme_tomtom_cell_std = pd.DataFrame(data=norm.T, columns=list(meme_tomtom_cell.columns.values), index = meme_tomtom_cell.index )
    # Add to previous data
    temp = meme_tomtom.add(meme_tomtom_cell_std, fill_value=0).fillna(0).sort_index(level=int)
    meme_tomtom = temp

# Transform meme tom_tom
motif_enhancers = meme_tomtom.T

# Rename column headers
motif_enhancers.rename(columns=lambda x: x.split('-')[0], inplace=True)
motif_enhancers.rename(columns=lambda x: x.replace(':', "_"), inplace=True)

# Standardize to range 0-1
scaler = preprocessing.MinMaxScaler()
motif_enhancers_transform = motif_enhancers.T
norm = scaler.fit_transform(motif_enhancers_transform.values) # norm across enhancers for each enhancer
motif_enhancers_scaled = pd.DataFrame(data=norm.T, columns=list(motif_enhancers.columns.values), index = motif_enhancers.index)



### Load and Parse FPKM data from RNA-seq

# Grab TF FPKM levels
fpkm = pd.read_table("RNA-seq-analysis/genes.fpkm_tracking")
fpkm = fpkm.set_index(['gene_short_name'])

# Get only TF's in JASPAR
all_motifs = list(motif_enhancers.index)
fpkm_tfs = list(fpkm.index)
for i in range(0,len(fpkm_tfs)):
    tf = fpkm_tfs[i]
    tfs = tf.split(',')
    if len(tfs) == 1:
        fpkm_tfs[i] = tfs[0]
    else:
        for t in tfs:
            if t in all_motifs:
                fpkm_tfs[i] = t

fpkm = fpkm.set_index([fpkm_tfs])
tf_fpkm = fpkm.loc[fpkm.index.isin(all_motifs)]

# Get subset of only cell line FPKM calues
headers = list(tf_fpkm.columns.values)
subset = []
for value in headers:
    if re.search('FPKM',value):
        subset.append(value)

tf_cell_lines = tf_fpkm[subset]

# For Fusion 'EWSR1-FLI' take the lowest FPKM and add that to the tf_cell_lines
hetero_dimer_motifs = []
hetero_dimer = {}
for motif in all_motifs:
    if re.search("-[a-zA-Z]",motif):
        tfs = motif.split('-')
        tf_fpkm_hd = fpkm.loc[fpkm.index.isin(tfs)]
        tf_fpkm_hd_cell_lines = tf_fpkm_hd[subset]
        hd_fpkm = tf_fpkm_hd_cell_lines.min(axis=0).to_frame()
        hd_fpkm_transform = hd_fpkm.T
        hd_fpkm_transform.name = 'gene_short_name'
        hd_fpkm_transform.index = [motif]
        temp = pd.concat([tf_cell_lines, hd_fpkm_transform], axis=0)
        tf_cell_lines = temp


# Rename headers for cell lines
headers = list(tf_cell_lines.columns.values)
new_headers = []
for h in headers:
     new_headers.append(h.split('_')[1])

# Note 5 TFs not represented ['TCFE2A', 'RAR', 'ZFP423', 'RXR', 'TCFCP2L1']

tf_cell_lines.columns = new_headers

# Log2 scale FPKM
x = tf_cell_lines.stack()
y = filter(lambda a: a != 0, x)
tf_factor = min(y) # min is 5.2535600000000006e-05
force_zero = np.log2(0.0000005)
tf_cell_lines_std = tf_cell_lines.apply(np.log2).replace(-np.inf,force_zero)
scaler = preprocessing.RobustScaler()
norm = scaler.fit_transform(tf_cell_lines_std.values)
tf_cell_lines_std_robust = pd.DataFrame(data=norm, columns=list(tf_cell_lines_std.columns.values), index = tf_cell_lines_std.index )

# Scale from 0-1
scaler = preprocessing.MinMaxScaler()
tf_cell_lines_std_robust_transform = tf_cell_lines_std_robust.T
norm = scaler.fit_transform(tf_cell_lines_std_robust_transform.values)
tf_scaled_tmp = pd.DataFrame(data=norm.T, columns=list(tf_cell_lines_std_robust.columns.values), index = tf_cell_lines_std_robust.index )

# Binarize (.4 cutoff for intial values)
threshold_1q = .4
scaler = preprocessing.Binarizer(threshold=threshold_1q)
norm = scaler.fit_transform(tf_cell_lines.values)
tf_scaled_binarize = pd.DataFrame(data=norm, columns=list(tf_cell_lines.columns.values), index = tf_cell_lines.index )
tf_scaled = tf_scaled_tmp.multiply(tf_scaled_binarize)


### Start Integration Calculations (TFSEE)

# 0. Filteration step
needed_rows = [row for row in rpkm_scaled.index if row in list(motif_enhancers_scaled.columns.values)]
rpkm_robust_filtered= rpkm_scaled.loc[needed_rows]
H3K27ac_robust_filtered= H3K27ac_scaled.loc[needed_rows]
H3K4me1_robust_filtered= H3K4me1_scaled.loc[needed_rows]


# 1. add H3K27ac and H3K4me1 signal
H3K27ac_H3K4me1 = H3K4me1_robust_filtered.add(H3K27ac_robust_filtered)

# 2. Add H3K27ac by RPKM
rpkm_H3K27ac_H3K4me1 = H3K27ac_H3K4me1.add(rpkm_robust_filtered)

# 3. Make Score Matrix
## Enhancers RPKM x Motif Enhancers
motif_cell_line = motif_enhancers_scaled.dot(rpkm_H3K27ac_H3K4me1)
needed_rows = [row for row in motif_cell_line.index if row in list(tf_scaled.index)]
motif_cell_line_filtered_tfs = motif_cell_line.loc[needed_rows]
motif_cell_line_filtered_tfs.columns = ['MCF7', 'MB231', 'MB361', '76NF2V', 'UACC812', 'HCC1954', 'AU565', 'MCF10A', 'SKBR3', 'MB436', 'MB468', 'HCC1937', 'ZR751']
motif_cell_line_filtered_tfs = motif_cell_line_filtered_tfs[reorder]

# reindex
tf_scaled_ordered = tf_scaled.reindex(list(motif_cell_line_filtered_tfs.index))
tf_scaled_ordered = tf_scaled_ordered[reorder]

# 4. .multiply() to to Element-by-element multiplication Score Enhancers by TF
cell_tf_values = motif_cell_line_filtered_tfs.multiply(tf_scaled_ordered)
cell_tf_values.columns = ['76NF2V','MCF10A','MCF7','ZR751','MB361','UACC812','SKBR3','AU565','HCC1954','MB468','HCC1937','MB231','MB436']
cell_tf_values_colors = ["#D0D0D0","#D0D0D0","#B9CFED","#B9CFED", "#EDD3D2", "#EDD3D2", "#E6EFD7", "#E6EFD7", "#E6EFD7", "#DDD6E8", "#DDD6E8", "#C0B2D1", "#C0B2D1"]

# 5. Z-score Standardize for each cell line to see important TF's
scaler = preprocessing.StandardScaler()
norm = scaler.fit_transform(cell_tf_values.values)
cell_tf_values_std = pd.DataFrame(data=norm, columns=list(cell_tf_values.columns.values), index = cell_tf_values.index )


# Seaborn settings
sns.axes_style({'image.cmap': u'Blacks','lines.linewidth': 100.0})

# Cluster Heatmap
sns.set_context("paper")
hmap = sns.clustermap(cell_tf_values_std,xticklabels=True, yticklabels=False, cmap="RdBu_r", method = "complete", metric = "euclidean", figsize=(20, 20), col_colors=sns.color_palette(cell_tf_values_colors))
plt.setp(hmap.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
plt.savefig('final_full_cluster_heatmap.png')

labels = [item.get_text() for item in hmap.ax_heatmap.yaxis.get_majorticklabels()]
labels.reverse()
with  open("final_full_cluster_heatmap.csv", 'wb') as csv_file:
    wr = csv.writer(csv_file,dialect='excel',quoting=csv.QUOTE_ALL)
    for tf in labels:
        wr.writerow([tf,])


# 6. Reorder based on clustering
reorder_clustering = cell_tf_values_std.columns.values[hmap.dendrogram_col.reordered_ind]
cell_tf_values_std_ordered = cell_tf_values_std[reorder_clustering]
reindex_cluserting = cell_tf_values_std.index.values[hmap.dendrogram_row.reordered_ind]
cell_tf_values_std_ordered = cell_tf_values_std_ordered.reindex(reindex_cluserting)
cell_tf_values_std_ordered.to_csv("final_full_cluster_z_score.csv", encoding='utf-8')

# 7. Make means
cell_tf_values_std_luminal_her = cell_tf_values_std_ordered.ix[:,'MCF7':'MB361']
cell_tf_values_std_normal_tnbc = cell_tf_values_std_ordered.ix[:,'MB231':'MB468']

lum_her_mean = cell_tf_values_std_luminal_her.mean(axis=1)
normal_tnbc_mean = cell_tf_values_std_normal_tnbc.mean(axis=1)
nt_vs_lh_mean = normal_tnbc_mean - lum_her_mean
lh_vs_nt_mean = lum_her_mean - normal_tnbc_mean
cell_tf_values_std_ordered_means = pd.concat([cell_tf_values_std_ordered,lum_her_mean,normal_tnbc_mean,nt_vs_lh_mean,lh_vs_nt_mean],axis=1)
colmns = list(cell_tf_values_std_ordered.columns.values)
colmns.append("Luminal/Her2")
colmns.append ("TNBC/Normal")
colmns.append ("Diff-TNBC")
colmns.append ("Diff-LH")
cell_tf_values_std_ordered_means.columns = colmns

# 8. Susbet and order based on top clade
cell_tf_values_std_ordered_means_top = cell_tf_values_std_ordered_means.ix["TP63":"NFIL3",:]
cell_tf_values_std_sort = cell_tf_values_std_ordered_means_top.sort_values(['Diff-TNBC' ], ascending=[0])
hmap_sorted = sns.heatmap(cell_tf_values_std_sort.ix[:,'MB231':'MB361'],xticklabels=True, yticklabels=True, cmap="RdBu_r")
plt.yticks(rotation=0)
plt.show()
plt.savefig('triple_negative_cluster_heatmap.png')

cell_tf_values_std_ordered_means_botom = cell_tf_values_std_ordered_means.ix["ESR1":"SOX3",:]
cell_tf_values_std_sort = cell_tf_values_std_ordered_means_botom.sort_values(['Diff-LH'], ascending=[0])
hmap_sorted = sns.heatmap(cell_tf_values_std_sort.ix[:,'MB231':'MB361'],xticklabels=True, yticklabels=True, cmap="RdBu_r")
plt.yticks(rotation=0)
plt.show()
plt.savefig('luminal_her2_cluster_heatmap.png')
