
# coding: utf-8
###############################################################################
#
# Author: HeChen
#
###############################################################################
import sys

import pandas as pd
import numpy  as np

#from scipy.stats import zscore, gaussian_kde

#from sklearn.decomposition import PCA
#from MulticoreTSNE         import MulticoreTSNE as TSNE
#from sklearn.manifold      import MDS
#from umap                  import UMAP 

import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
#from tqdm import tqdm_notebook as tqdm

plt.switch_backend('Agg')

###############################################################################

if len(sys.argv) <= 1:
    print ('USAGE: python ' + sys.argv[0] + ' <Sample_Name>')
    sys.exit(0)

sample = sys.argv[1]
window = '200k'

try:
    data_df = pd.read_pickle(sample + '.data_seg.' + window + '.pkl')
except:
    print ( 'This sample / window isn\'t exist! '+ sample + '.data_seg.' + window + '.pkl' )
    sys.exit(0)

window_size = {
    '50k' :5e4,
    '100k':1e5,
    '200k':2e5,
    '500k':5e5,
    '1m'  :1e6,
    '2m'  :2e6,
}

# Load fix bin pos
ref_File = 'tilling' + window + '.bed'
ref = pd.read_csv(ref_File, sep='\t', header=None, names=['space','start','end'])

# Drop chrY, chrM # chrY will raise bug: divide by zero
ref = ref[ref['space'].str.contains('chr[\d+|X]')] 
# Set index
ref = ref.set_index(ref.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1))
# Add abs pos, rough / precise
ref['abs_pos'] = (np.arange(0,ref.shape[0])*window_size[window]).astype(np.int64)

###############################################################################

# Set chromosome index
data_df_new = data_df.set_index(
    pd.MultiIndex.from_arrays(
        pd.DataFrame(
            data_df.index.str.split(':',expand=True).tolist(),
            index=data_df.index
        )[0].reset_index().T.values,
        names=['Pos_bin','chroms'],
    )
)


# Set QC dataframe
qc_df = pd.DataFrame( 
    index = pd.MultiIndex.from_product(
        [
            data_df_new.columns.levels[0],
            np.append('ALL', ref['space'].unique())
#            ref['space'].unique()
        ], names=['Cell','chroms'] ),
    data = []
)

###############################################################################

# https://planspace.org/2013/06/21/how-to-calculate-gini-coefficient-from-raw-data-in-python/
def gini(list_of_values):
    sorted_list = sorted(list_of_values)
    height, area = 0, 0
    for value in sorted_list:
        height += value
        area += height - value / 2.
    
    fair_area = height * len(list_of_values) / 2.
    return (fair_area - area) / fair_area


###############################################################################
# Start calculate QC Metrics
for cell in tqdm(data_df_new.columns.levels[0]):
    tmp_df  = data_df_new[cell].dropna(how='all').clip(upper=10).fillna(0).copy()
    HMM_Seg = pd.Series(tmp_df['HMMseg'].unique())
    CBS_Seg = pd.Series(tmp_df['CBSseg'].unique())

    qc_df.loc[(cell, 'ALL'), 'GINI'      ] = gini(tmp_df['copy'].values)
    
    qc_df.loc[(cell, 'ALL'), 'autocorBin'] = tmp_df['copy'].autocorr()
    qc_df.loc[(cell, 'ALL'), 'meanBin'   ] = tmp_df['copy'].mean()
    qc_df.loc[(cell, 'ALL'), 'varBin'    ] = tmp_df['copy'].var()
    qc_df.loc[(cell, 'ALL'), 'MAD'       ] = tmp_df['copy'].mad()
    qc_df.loc[(cell, 'ALL'), 'MAPD'      ] = tmp_df['copy'].diff().abs().median()
    
    qc_df.loc[(cell, 'ALL'), 'HMM_Seg_n'       ] = HMM_Seg.shape[0]
    qc_df.loc[(cell, 'ALL'), 'HMM_Seg_mean'    ] = CBS_Seg[CBS_Seg<CBS_Seg.quantile(0.99)].mean()
    qc_df.loc[(cell, 'ALL'), 'HMM_Seg_var'     ] = HMM_Seg[HMM_Seg<HMM_Seg.quantile(0.99)].var()
    qc_df.loc[(cell, 'ALL'), 'HMM_residual_bs' ] = (tmp_df['copy']-tmp_df['HMMseg']).abs().median()
    qc_df.loc[(cell, 'ALL'), 'HMM_residual_bn' ] = (tmp_df['copy']-tmp_df['HMMseg'].round() ).abs().median()
    qc_df.loc[(cell, 'ALL'), 'HMM_residual_sn' ] = (HMM_Seg-HMM_Seg.round()).abs().median()
        
    qc_df.loc[(cell, 'ALL'), 'CBS_Seg_n'       ] = CBS_Seg.shape[0]
    qc_df.loc[(cell, 'ALL'), 'CBS_Seg_mean'    ] = CBS_Seg[CBS_Seg<CBS_Seg.quantile(0.99)].mean()
    qc_df.loc[(cell, 'ALL'), 'CBS_Seg_var'     ] = CBS_Seg[CBS_Seg<CBS_Seg.quantile(0.99)].var()
    qc_df.loc[(cell, 'ALL'), 'CBS_residual_bs' ] = (tmp_df['copy']-tmp_df['CBSseg']).abs().median()
    qc_df.loc[(cell, 'ALL'), 'CBS_residual_bn' ] = (tmp_df['copy']-tmp_df['CBSseg'].round() ).abs().median()
    qc_df.loc[(cell, 'ALL'), 'CBS_residual_sn' ] = (CBS_Seg-CBS_Seg.round()).abs().median()


    for chrom in ref['space'].unique():
        tmp_cdf = tmp_df.xs(level=1,key=chrom)
        HMM_Seg = pd.Series(tmp_cdf['HMMseg'].unique())
        CBS_Seg = pd.Series(tmp_cdf['CBSseg'].unique())
    
        qc_df.loc[(cell, chrom), 'GINI'      ] = gini(tmp_cdf['copy'].values)

        qc_df.loc[(cell, chrom), 'autocorBin'] = tmp_cdf['copy'].autocorr()
        qc_df.loc[(cell, chrom), 'meanBin'   ] = tmp_cdf['copy'].mean()
        qc_df.loc[(cell, chrom), 'varBin'    ] = tmp_cdf['copy'].var()
        qc_df.loc[(cell, chrom), 'MAD'       ] = tmp_cdf['copy'].mad()
        qc_df.loc[(cell, chrom), 'MAPD'      ] = tmp_cdf['copy'].diff().abs().median()

        qc_df.loc[(cell, chrom), 'HMM_Seg_n'       ] = HMM_Seg.shape[0]
        qc_df.loc[(cell, chrom), 'HMM_Seg_mean'    ] = HMM_Seg.mean()
        qc_df.loc[(cell, chrom), 'HMM_Seg_var'     ] = HMM_Seg.var()
        qc_df.loc[(cell, chrom), 'HMM_residual_bs' ] = (tmp_cdf['copy']-tmp_cdf['HMMseg']).abs().median()
        qc_df.loc[(cell, chrom), 'HMM_residual_bn' ] = (tmp_cdf['copy']-tmp_cdf['HMMseg'].round() ).abs().median()
        qc_df.loc[(cell, chrom), 'HMM_residual_sn' ] = (HMM_Seg-HMM_Seg.round()).abs().median()
#         qc_df.loc[(cell, chrom), 'HMM_Seg_autocorr'] = HMM_Seg.autocorr()

        qc_df.loc[(cell, chrom), 'CBS_Seg_n'       ] = CBS_Seg.shape[0]
        qc_df.loc[(cell, chrom), 'CBS_Seg_mean'    ] = CBS_Seg.mean()
        qc_df.loc[(cell, chrom), 'CBS_Seg_var'     ] = CBS_Seg.var()
        qc_df.loc[(cell, chrom), 'CBS_residual_bs' ] = (tmp_cdf['copy']-tmp_cdf['CBSseg']).abs().median()
        qc_df.loc[(cell, chrom), 'CBS_residual_bn' ] = (tmp_cdf['copy']-tmp_cdf['CBSseg'].round() ).abs().median()
        qc_df.loc[(cell, chrom), 'CBS_residual_sn' ] = (CBS_Seg-CBS_Seg.round()).abs().median()
#         qc_df.loc[(cell, chrom), 'CBS_Seg_autocorr'] = CBS_Seg.autocorr()

qc_df = qc_df.fillna(0)
qc_df.to_pickle( sample + '.qc.' + window + '.pkl' )
sys.exit(0)

###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################

# qc_df[['meanBin','HMM_Seg_mean','CBS_Seg_mean']] = (
#     qc_df[['meanBin','HMM_Seg_mean','CBS_Seg_mean']]
#     -
#     qc_df[['meanBin','HMM_Seg_mean','CBS_Seg_mean']].reset_index().drop('Cell',axis=1).groupby('chroms',sort=False).mean()
# ).abs()

###############################################################################

# fig = plt.figure(figsize=(8,6))
# ax = fig.add_subplot(111)
# sns.heatmap(qc_df.corr(),cmap='coolwarm', ax=ax) 
# ax.set_title('QC metrics corr.')
# plt.show()

fig = sns.clustermap(qc_df.corr(),cmap='coolwarm',figsize=(8,8)) 
# plt.show()
fig.savefig( sample + '.' + window + '.QCheatmap.png' )

###############################################################################

# Select bad chromosome to represent the cell
qc_represent = (
    qc_df
    .reset_index().drop('chroms',axis=1).groupby('Cell',sort=False)
    .transform(lambda col: col.nlargest(3).min()).set_index(qc_df.index)
    .reset_index().drop('chroms',axis=1).groupby('Cell',sort=False)
    .mean()
)

###############################################################################

# Z-score transform
qc_z = (
#     qc_df.reset_index().drop('chroms',axis=1).groupby('Cell',sort=False).max()
#     qc_df.reset_index().drop('chroms',axis=1).groupby('Cell',sort=False).quantile(0.95)
    qc_represent
    .apply(lambda col: col[(col < col.quantile(0.995))], axis=0) # Drop 0.5% outliers
    .dropna(how='any')
    .apply(zscore)                                                # Z-score normalization
    .apply(lambda col: col[(col < col.quantile(0.995))], axis=0) # Drop 0.5% outliers
    .dropna(how='any')
)

###############################################################################

# Dimension reduction
model1 = PCA(  n_components=2, random_state=0)
model2 = TSNE( n_components=2, random_state=0, method='exact', n_jobs=24)
model3 = MDS(  n_components=2, random_state=0)
model4 = UMAP( n_components=2, random_state=0)

# model1 = PCA(n_components=2)
# model2 = TSNE(n_components=2, method='exact', n_jobs=24)
# model3 = MDS(n_components=2)
# model4 = UMAP(n_components=2)

a1 = pd.DataFrame( model1.fit_transform(qc_z) )
a2 = pd.DataFrame( model2.fit_transform(qc_z) )
a3 = pd.DataFrame( model3.fit_transform(qc_z) )
a4 = pd.DataFrame( model4.fit_transform(qc_z) )

fig_data = pd.concat(
    [
        a1.rename(columns={0:'PCA1',  1:'PCA2'}),
        a2.rename(columns={0:'tSNE1', 1:'tSNE2'}),
        a3.rename(columns={0:'MDS1',  1:'MDS2'}),
        a4.rename(columns={0:'UMAP1', 1:'UMAP2'})
    ],axis=1).reset_index()

###############################################################################

# PCA components
fig = plt.figure(figsize=(8,4))
ax  = fig.add_subplot(111)
cax = ax.matshow(model1.components_, cmap='viridis')
cax.set_clim(vmin=-0.5, vmax=0.5)
fig.colorbar(cax, shrink=0.3, ticks=[-0.4, 0, 0.4], location='bottom', use_gridspec=False, pad=0.1)

ax.set_xticks(range(len(qc_df.columns)),)
ax.set_xticklabels( qc_df.columns.tolist(), rotation=35, ha='left')
ax.set_yticks([0,1])
ax.set_yticklabels(['PC1','PC2'],fontsize=10)

# plt.show()
fig.savefig( sample + '.' + window + '.components.png' )

# plt.matshow(model1.components_, cmap='viridis')
# plt.yticks([0,1],['PC1','PC2'],fontsize=10)
# plt.colorbar()
# plt.clim(-0.5,0.5)
# plt.xticks( range(len(qc_df.columns)), qc_df.columns,rotation=35, ha='left')
# plt.show()

###############################################################################

# Plot Dimension reduction
from bokeh.io       import output_notebook, show, export_png
from bokeh.layouts  import gridplot
from bokeh.models   import ColumnDataSource
from bokeh.plotting import figure

# output_notebook()
# color = ['aquamarine', 'blanchedalmond', 'burlywood', 'cadetblue', 'chocolate', 'coral', 'cornflowerblue']

source = ColumnDataSource(data=fig_data)

TOOLS = "reset,box_select,save,help"

f1 = figure(tools=TOOLS, plot_width=200, plot_height=200, title='PCA')
f2 = figure(tools=TOOLS, plot_width=200, plot_height=200, title='tSNE')
f3 = figure(tools=TOOLS, plot_width=200, plot_height=200, title='MDS')
f4 = figure(tools=TOOLS, plot_width=200, plot_height=200, title='UMAP')

f1.circle('PCA1',  'PCA2',  alpha=0.7, selection_alpha=0.5, color='cornflowerblue', selection_color="coral", source=source)
f2.circle('tSNE1', 'tSNE2', alpha=0.7, selection_alpha=0.5, color='cornflowerblue', selection_color="coral", source=source)
f3.circle('MDS1',  'MDS2',  alpha=0.7, selection_alpha=0.5, color='cornflowerblue', selection_color="coral", source=source)
f4.circle('UMAP1', 'UMAP2', alpha=0.7, selection_alpha=0.5, color='cornflowerblue', selection_color="coral", source=source)

p = gridplot([[f1, f2, f3, f4]])

#show(p)
export_png(p, filename=sample+'.'+window+'.dr4.png')

###############################################################################

# Plot PCA and filter
f = gaussian_kde(np.vstack([fig_data['PCA1'],fig_data['PCA2']]))

X, Y = np.mgrid[
    fig_data['PCA1'].min()*1.15 : fig_data['PCA1'].max()*1.15 : 100j, 
    fig_data['PCA2'].min()*1.15 : fig_data['PCA2'].max()*1.15 : 100j
]

positions = np.vstack([X.ravel(), Y.ravel()])

Z = np.reshape(f(positions), X.shape)

fig = plt.figure(figsize=(6,5))
ax  = fig.add_subplot(111)
cs  = ax.tricontour(X.ravel(), Y.ravel(), Z.ravel(), levels=50, cmap="RdBu_r", linewidths=1)

ax.plot(fig_data['PCA1'], fig_data['PCA2'], 'k.', markersize=5, alpha=0.5)
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')

ax.set_xlim( [ fig_data['PCA1'].min()*1.2, fig_data['PCA1'].max()*1.2 ] )
ax.set_ylim( [ fig_data['PCA2'].min()*1.2, fig_data['PCA2'].max()*1.2 ] )

th = cs.get_array()[1]

pick = pd.concat(
    [
        fig_data[['PCA1','PCA2']],
        pd.Series( f( np.vstack([ fig_data['PCA1'], fig_data['PCA2'] ]) ) ),
    ],axis=1
).rename(columns={0:'Density'}).query('Density>=@th')

ax.plot(pick['PCA1'], pick['PCA2'], '.', alpha=0.3, color='red')

fig.savefig( sample + '.' + window + '.pca.png' )
# plt.show()

pca_pick = set( qc_z.index[pick.index].tolist() )

###############################################################################

# Plot MDS and filter
f = gaussian_kde(np.vstack([fig_data['MDS1'],fig_data['MDS2']]))

X, Y = np.mgrid[
    fig_data['MDS1'].min()*1.15 : fig_data['MDS1'].max()*1.15 : 100j, 
    fig_data['MDS2'].min()*1.15 : fig_data['MDS2'].max()*1.15 : 100j
]

positions = np.vstack([X.ravel(), Y.ravel()])

Z = np.reshape(f(positions), X.shape)

fig = plt.figure(figsize=(6,5))
ax  = fig.add_subplot(111)
cs  = ax.tricontour(X.ravel(), Y.ravel(), Z.ravel(), levels=50, cmap="RdBu_r", linewidths=1)

ax.plot(fig_data['MDS1'], fig_data['MDS2'], 'k.', markersize=5, alpha=0.5)
ax.set_xlabel('MDS1')
ax.set_ylabel('MDS2')

ax.set_xlim( [ fig_data['MDS1'].min()*1.2, fig_data['MDS1'].max()*1.2 ] )
ax.set_ylim( [ fig_data['MDS2'].min()*1.2, fig_data['MDS2'].max()*1.2 ] )

th = cs.get_array()[1]

pick = pd.concat(
    [
        fig_data[['MDS1','MDS2']],
        pd.Series( f( np.vstack([ fig_data['MDS1'], fig_data['MDS2'] ]) ) ),
    ],axis=1
).rename(columns={0:'Density'}).query('Density>=@th')

ax.plot(pick['MDS1'], pick['MDS2'], '.', alpha=0.3, color='red')

fig.savefig( sample + '.' + window + '.mds.png' )
# plt.show()

mds_pick = set( qc_z.index[pick.index].tolist() )

###############################################################################

# Merge two result
both_pick = data_df[list(pca_pick & mds_pick)]
pd.Series(list(pca_pick & mds_pick)).to_pickle( sample + '.' + window + '.bpcell.pkl' )

try:
    fig = sns.clustermap(
        data_df.loc[
            data_df.index[data_df.xs(level=1,key='ideal', axis=1).fillna(0).mean(axis=1)>0.8]
        ].xs(level=1,key='HMMseg', axis=1).T.fillna(2),
        vmin=0, vmax=4, xticklabels='', cmap='RdBu_r', col_cluster=False,
        cbar_kws={'ticks': [0, 1, 2, 3, 4]} )

    fig.savefig( sample + '.' + window + '.raw.png' )

    fig = sns.clustermap(
        both_pick.loc[
            both_pick.index[both_pick.xs(level=1,key='ideal', axis=1).fillna(0).mean(axis=1)>0.8]
        ].xs(level=1,key='HMMseg', axis=1).T.fillna(2),
        vmin=0, vmax=4, xticklabels='', cmap='RdBu_r', col_cluster=False,
        cbar_kws={'ticks': [0, 1, 2, 3, 4]} )

    fig.savefig( sample + '.' + window + '.both_pick.png' )

    keep_cells = fig.dendrogram_row.reordered_ind[int(both_pick.xs(level=1,key='HMMseg', axis=1).shape[1]/20)-1:]
except:
    print ('Heatmap scipy bug! Try not cluster')
    tmp = both_pick.loc[
                both_pick.index[both_pick.xs(level=1,key='ideal', axis=1).fillna(0).mean(axis=1)>0.8]
            ].xs(level=1,key='HMMseg', axis=1).T.fillna(2)

    keep_cells = tmp.reset_index().reset_index().set_index('Cell_ID').loc[(tmp.mean(axis=1)*tmp.std(axis=1)).sort_values(ascending=False).index,'index'].values
    keep_cells = keep_cells[int(keep_cells.shape[0]/20)-1:]
    pass
###############################################################################

# Plot Heatmap
good_cell = both_pick.xs(level=1,key='HMMseg', axis=1).loc[
        both_pick.index[both_pick.xs(level=1,key='ideal', axis=1).fillna(0).mean(axis=1)>0.8]
    ].T.fillna(2).iloc[keep_cells]

figsize_w = 16
figsize_h = int(int(good_cell.index.shape[0]/10)/3)+1

fig = plt.figure(figsize=(figsize_w,figsize_h))
ax  = fig.add_subplot(111)

sns.heatmap(
    good_cell, vmin=-0.01, vmax=4.025, ax=ax,
    xticklabels='', yticklabels='', cmap='RdBu_r', 
    cbar_kws={'shrink':0.3, 'location':'top', "use_gridspec":False, 'ticks': [0, 1, 2, 3, 4], 'pad':0.01} )

ax.set_title(  'Sample Name: ' + sample )
ax.set_xlabel( 'Bin size: '    + str(window).replace('m','Mb') )
ax.set_ylabel( '# of Cells: '  + str(good_cell.shape[0]) ) 

# plt.show()
fig.savefig( sample + '.' + window + '.final.png',  bbox_inches='tight' )

###############################################################################

print ('# of cells input        : ' + str(qc_df.reset_index().drop('chroms',axis=1).groupby('Cell',sort=False).max().shape[0]) )
print ('# of cells pass z-score : ' + str(qc_z.shape[0]) )
print ('# of cells pass PCA     : ' + str(len(pca_pick)) )
print ('# of cells pass MDS     : ' + str(len(mds_pick)) )
print ('# of cells for NEXT     : ' + str(good_cell.index.shape[0]) )


pd.Series(good_cell.index).to_pickle( sample + '.' + window + '.pfcell.pkl' )




# ---

