#!/usr/bin/env python
# coding: utf-8
###############################################################################
#
# Author: HeChen
#
###############################################################################

# In[1]:


import os, sys

from tqdm.auto import tqdm

import pandas as pd
import numpy  as np

import seaborn as sns
import matplotlib.pyplot as plt

print ('Pandas=='+pd.__version__)


# In[2]:


f1 = 2.5
f2 = 75
f3 = 6
f4 = 0.6


# In[3]:


window = '200k'

window_size = {
    '100k':1e5,
    '200k':2e5,
    '500k':5e5,
    '1m'  :1e6,
    '2m'  :2e6,
}

refFile  = 'tilling' + window + '.bed'
# Load fix bin pos
ref = pd.read_csv(refFile, sep='\t', header=None, names=['space','start','end'])
# Drop chrM
ref = ref[ref['space'].str.contains('chr[\d+|X|Y]')] 
# Set index
ref = ref.set_index(ref.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1))
# Add abs pos, rough / precise
ref['abs_pos'] = (np.arange(0,ref.shape[0])*window_size[window]).astype(np.int64)
# ref['abs_pos'] = ref['end'].cumsum().shift().fillna(0).astype(np.int64)
# ref.iloc[2489:2492]
chr_uidx = ref['space'].unique()


# In[4]:


try:
    data_df = pd.read_pickle('data_df.' + window + '.pkl')
except:
    df_files = [ f for f in os.listdir('.') if ( '.data_seg.'+ window in f ) ]
    data_df = pd.DataFrame()
    for f in df_files:
        data_df = pd.concat( [ data_df, pd.read_pickle(f) ], axis=1, sort=False )
    data_df.to_pickle('data_df.' + window + '.pkl')


# In[5]:


try:
    mmqc = pd.read_pickle('mmqc.' + window + '.pkl')
except:
    qc_files = [ f for f in os.listdir('.') if ( '.qc.'      + window in f ) ]

    mqc_df = pd.DataFrame()
    for f in qc_files:
        mqc_df = pd.concat( [ mqc_df, pd.read_pickle(f) ], axis=0, sort=False  )

    key_f = [
        'MAPD','HMM_Seg_var','CBS_Seg_var',
        'HMM_residual_sn','CBS_residual_sn',
        'HMM_residual_bs','CBS_residual_bs',
        'meanBin','varBin',
        'HMM_Seg_n','CBS_Seg_n'
    ]

    mqc_all = mqc_df.xs(key='ALL',axis=0,level=1)[key_f]
    mqc_chr = mqc_df.reset_index().query('chroms!="ALL" and chroms!="chrX"').set_index(['Cell','chroms'])[key_f]

    # mqc_all = mqc_all.clip(upper=4)
    # mqc_chr = mqc_chr.clip(upper=4)
    mqc_all = mqc_all.clip(upper=mqc_all.apply(lambda col: np.percentile(col, 98)).to_dict(), axis=1)
    mqc_chr = mqc_chr.clip(upper=mqc_chr.apply(lambda col: np.percentile(col, 98)).to_dict(), axis=1)

    mqc_mchr = 0*mqc_all.copy()

    g_cell_mqc_chr = mqc_chr.reset_index().groupby('Cell',sort=False)

    for f in key_f:
        mqc_mchr[f] = g_cell_mqc_chr.apply(lambda df: df[f].nlargest(3).min())

    mmqc = pd.merge(
        mqc_all, mqc_mchr,
        left_index=True,right_index=True,
        suffixes=('_all','_chr')
    )

    mmqc.to_pickle('mmqc.' + window + '.pkl')


# ---

# In[6]:


# from scipy.stats import norm, lognorm, skewnorm, gaussian_kde


# In[7]:


# sns.kdeplot( mmqc['MAPD_all'], lw=2 )

# x = np.linspace(0, 1, 100)

# m,v = norm.fit(mmqc['MAPD_all'])
# sns.kdeplot( norm.rvs(m, v, size=100), label='norm_fit', lw=2 )

# a,b,c = lognorm.fit(mmqc['MAPD_all'])
# sns.kdeplot( lognorm.rvs(s=a, loc=b, scale=c, size=100), label='lognorm_fit', lw=2 )

# mean, var, skew = skewnorm.fit(mmqc['MAPD_all'])
# sns.kdeplot( skewnorm.rvs(mean, var, skew, size=100), label='skewnorm_fit', lw=2 )

# plt.xlim(0.2,0.8);
# plt.title('Fit all', fontsize=20);


# In[8]:


# sns.kdeplot( mmqc['MAPD_all'], lw=2 )

# x = np.linspace(0, 1, 100)

# m,v = norm.fit(mmqc.query('0.8>MAPD_all>0.2')['MAPD_all'])
# sns.kdeplot( norm.rvs(m, v, size=100), label='norm_fit', lw=2 )

# a,b,c = lognorm.fit(mmqc.query('0.8>MAPD_all>0.2')['MAPD_all'])
# sns.kdeplot( lognorm.rvs(s=a, loc=b, scale=c, size=100), label='lognorm_fit', lw=2 )

# mean, var, skew = skewnorm.fit(mmqc.query('0.8>MAPD_all>0.2')['MAPD_all'])
# sns.kdeplot( skewnorm.rvs(mean, var, skew, size=100), label='skewnorm_fit', lw=2 )


# plt.xlim(0.2,0.8);
# plt.title('Fit 0.2~0.8', fontsize=20);


# In[9]:


# sns.kdeplot( mmqc['MAPD_all'], lw=2 )

# x = np.linspace(0, 1, 100)

# m,v = norm.fit(mmqc.query('0.6>MAPD_all>0.2')['MAPD_all'])
# sns.kdeplot( norm.rvs(m, v, size=100), label='norm_fit', lw=2 )

# a,b,c = lognorm.fit(mmqc.query('0.6>MAPD_all>0.2')['MAPD_all'])
# sns.kdeplot( lognorm.rvs(s=a, loc=b, scale=c, size=100), label='lognorm_fit', lw=2 )

# mean, var, skew = skewnorm.fit(mmqc.query('0.6>MAPD_all>0.2')['MAPD_all'])
# sns.kdeplot( skewnorm.rvs(mean, var, skew, size=100), label='skewnorm_fit', lw=2 )


# plt.xlim(0.2,0.8);
# plt.title('Fit 0.2~0.6', fontsize=20);


# In[10]:


# sns.kdeplot( mmqc['MAPD_all'], lw=2)
# x = np.linspace(0, 1, 100)
# kde = gaussian_kde(mmqc['MAPD_all'])
# plt.plot( x, kde.evaluate(x), label='kde_fit' )

# kde_max_x = x[kde.evaluate(x).argmax()]
# ml = mmqc.query('MAPD_all<=@kde_max_x')['MAPD_all']
# fake = pd.concat([ml,(-1*ml+2*kde_max_x)])
# sns.kdeplot(fake, label='norm_fit', lw=2);
# plt.xlim(0.2,0.8)
# plt.title('Fit kde + norm', fontsize=20);


# ---

# ## 200k bin

# In[11]:


fig = plt.figure(figsize=(5,5))
ax  = fig.add_subplot(111)

mmqc.plot.scatter('MAPD_all','meanBin_all', alpha=0.2, s=2, ax=ax);
sns.kdeplot( mmqc['MAPD_all'], mmqc['meanBin_all'], gridsize=100, n_levels=15, ax=ax)

ax.vlines(0.5, 0, 5, linestyles='dashed')
ax.hlines(2.5, 0, 1, linestyles='dashed')

ax.set_xlabel(ax.get_xlabel(), fontsize=16)
ax.set_ylabel(ax.get_ylabel(), fontsize=16)
ax.set_xlim(0.1,1);
ax.set_ylim(1.8,3.2);

plt.show()


# In[12]:


check = mmqc.query('meanBin_all>=@f1').index
print ('Drop %.2f%% Cell (by bin mean)' % (100*(check.shape[0]/mmqc.shape[0])) )


# ## 2M bin

# In[13]:


# fig = plt.figure(figsize=(5,5))
# ax  = fig.add_subplot(111)

# mmqc.plot.scatter('MAPD_all','meanBin_all', alpha=0.2, s=2, ax=ax);
# sns.kdeplot( mmqc['MAPD_all'], mmqc['meanBin_all'], gridsize=100, n_levels=15, ax=ax)

# ax.vlines(0.35, 0, 5, linestyles='dashed')
# ax.hlines(2.0, 0, 1, linestyles='dashed')

# ax.set_xlabel(ax.get_xlabel(), fontsize=16)
# ax.set_ylabel(ax.get_ylabel(), fontsize=16)
# ax.set_xlim(0.1,1);
# ax.set_ylim(1.8,2.2);

# plt.show()


# In[14]:


# check = mmqc.query('meanBin_all>=2').index
# print ('Drop %.2f%% Cell (by bin mean)' % (100*(check.shape[0]/mmqc.shape[0])) )


# ---

# In[15]:


fig,axs = plt.subplots(nrows=8, figsize=(20,8), sharex=True, )

for i,cell in enumerate( np.random.choice(check, 8) ):
    fig_df = pd.concat(
        [
            data_df[cell],
            ref['abs_pos'],
        ], axis=1, sort=False
    ).dropna(how='any')

    ax = axs[i]
    ax.plot(fig_df['abs_pos'],fig_df['copy'],  '.', alpha=0.5, markersize=3, color='k')
    ax.plot(fig_df['abs_pos'],fig_df['HMMseg'],'-', alpha=0.7, lw=3,)
    ax.plot(fig_df['abs_pos'],fig_df['CBSseg'],'-', alpha=0.7, lw=3,)
    ax.set_title(cell)
    ax.set_ylim(-0.5,6.5)

plt.show()


# ## 200k bin

# In[16]:


df_next = mmqc.query('meanBin_all<@f1')
fig, axs = plt.subplots(ncols=2, figsize=(12,5))

ax = axs[0]
df_next.plot.scatter('MAPD_all','HMM_Seg_n_all', alpha=0.2, s=2, ax=ax);
sns.kdeplot( df_next['MAPD_all'], df_next['HMM_Seg_n_all'], n_levels=15, ax=ax)

ax.vlines(0.5, 0, 400, linestyles='dashed')
ax.hlines(75,  0, 1,   linestyles='dashed')

ax.set_xlabel(ax.get_xlabel(), fontsize=16)
ax.set_ylabel(ax.get_ylabel(), fontsize=16)
ax.set_xlim(0,1);
ax.set_ylim(0,200);

ax = axs[1]
sns.kdeplot( df_next['MAPD_all'], df_next['HMM_Seg_n_chr'], n_levels=15, ax=ax)

ax.vlines(0.5, 0, 30, linestyles='dashed')
ax.hlines(6,   0, 1, linestyles='dashed')

ax.set_xlabel(ax.get_xlabel(), fontsize=16)
ax.set_ylabel(ax.get_ylabel(), fontsize=16)
ax.set_xlim(0,1);
ax.set_ylim(0,15);

plt.show()


# In[17]:


check = (
    mmqc.query('meanBin_all<@f1')
    .query('(HMM_Seg_n_all >= @f2 & CBS_Seg_n_all >= @f2) | (HMM_Seg_n_chr >= @f3 & CBS_Seg_n_chr >= @f3)').index
)

print ('Drop %.2f%% Cell (by number of Segment)' % (100*(check.shape[0]/mmqc.shape[0])) )


# ---

# In[18]:


# df_next = mmqc.query('meanBin_all<2')
# fig, axs = plt.subplots(ncols=2, figsize=(12,5))

# ax = axs[0]
# df_next.plot.scatter('MAPD_all','HMM_Seg_n_all', alpha=0.2, s=2, ax=ax);
# sns.kdeplot( df_next['MAPD_all'], df_next['HMM_Seg_n_all'], n_levels=15, ax=ax)

# ax.vlines(0.35, 0, 60, linestyles='dashed')
# ax.hlines(40, 0, 1, linestyles='dashed')

# ax.set_xlabel(ax.get_xlabel(), fontsize=16)
# ax.set_ylabel(ax.get_ylabel(), fontsize=16)
# ax.set_xlim(0,0.5);
# ax.set_ylim(0,60);

# ax = axs[1]
# sns.kdeplot( df_next['MAPD_all'], df_next['HMM_Seg_n_chr'], n_levels=15, ax=ax)

# ax.vlines(0.35, 0, 3, linestyles='dashed')
# ax.hlines(3, 0, 1, linestyles='dashed')

# ax.set_xlabel(ax.get_xlabel(), fontsize=16)
# ax.set_ylabel(ax.get_ylabel(), fontsize=16)
# ax.set_xlim(0,0.5);
# ax.set_ylim(0,3);

# plt.show()


# In[19]:


# check = (
#     mmqc.query('meanBin_all<2')
#     .query('(HMM_Seg_n_all >= 40 | CBS_Seg_n_all >= 40)').index
# )

# print ('Drop %.2f%% Cell (by number of Segment)' % (100*(check.shape[0]/mmqc.shape[0])) )


# ---

# In[20]:


fig,axs = plt.subplots(nrows=8, figsize=(20,8), sharex=True, )

for i,cell in enumerate( np.random.choice(check, 8) ):
    fig_df = pd.concat(
        [
            data_df[cell],
            ref['abs_pos'],
        ], axis=1, sort=False
    ).dropna(how='any')

    ax = axs[i]
    ax.plot(fig_df['abs_pos'],fig_df['copy'],  '.', alpha=0.5, markersize=3, color='k')
    ax.plot(fig_df['abs_pos'],fig_df['HMMseg'],'-', alpha=0.7, lw=3,)
    ax.plot(fig_df['abs_pos'],fig_df['CBSseg'],'-', alpha=0.7, lw=3,)
    ax.set_title(cell)
    ax.set_ylim(-0.5,6.5)

plt.show()


# ## 200k bin

# In[21]:


df_next = (
    mmqc.query('meanBin_all<@f1')
    .query('(HMM_Seg_n_all<@f2 | CBS_Seg_n_all<@f2) & (HMM_Seg_n_chr<@f3 | CBS_Seg_n_chr<@f3)')
)

fig = plt.figure(figsize=(5,5))
ax  = fig.add_subplot(111)

df_next.plot.scatter('MAPD_all','MAPD_chr', alpha=0.2, s=2, ax=ax);

ax.vlines(0.5, 0, 1, linestyles='dashed', alpha=0.5, color='gray')
ax.hlines(0.6, 0, 1, linestyles='dashed')

ax.set_xlabel(ax.get_xlabel(), fontsize=16)
ax.set_ylabel(ax.get_ylabel(), fontsize=16)
ax.set_xlim(0.2,1);
ax.set_ylim(0.2,1);

plt.show()


# In[22]:


check = (
    mmqc.query('meanBin_all<@f1')
    .query('(HMM_Seg_n_all<@f2 | CBS_Seg_n_all<@f2) & (HMM_Seg_n_chr<@f3 | CBS_Seg_n_chr<@f3)')
    .query('MAPD_chr>=@f4')
)

print ('Drop %.2f%% Cell (by MAPD)' % (100*(check.shape[0]/mmqc.shape[0])) )


# In[23]:


check = df_next.query('MAPD_all>0.5 & MAPD_chr<@f4').index

fig,axs = plt.subplots(nrows=8, figsize=(20,8), sharex=True, )

for i,cell in enumerate( np.random.choice(check, 8) ):
    fig_df = pd.concat(
        [
            data_df[cell],
            ref['abs_pos'],
        ], axis=1, sort=False
    ).dropna(how='any')

    ax = axs[i]
    ax.plot(fig_df['abs_pos'],fig_df['copy'],  '.', alpha=0.5, markersize=3, color='k')
    ax.plot(fig_df['abs_pos'],fig_df['HMMseg'],'-', alpha=0.7, lw=3,)
    ax.plot(fig_df['abs_pos'],fig_df['CBSseg'],'-', alpha=0.7, lw=3,)
    ax.set_title(cell)
    ax.set_ylim(-0.5,6.5)

plt.show()


# In[24]:


good_idx = (
    mmqc
    .query('meanBin_all<@f1')
    .query('( (HMM_Seg_n_all<@f2) | (CBS_Seg_n_all<@f2) ) & ( (HMM_Seg_n_chr<@f3) | (CBS_Seg_n_chr<@f3) )')
    .query('MAPD_chr<@f4')
)
good_df = data_df[good_idx.index]


# In[25]:


print ('Keep %.2f%% passing filter Cell' % (100*(good_idx.shape[0]/mmqc.shape[0])) )


# ## 2M bin

# In[26]:


# df_next = (
#     mmqc.query('meanBin_all<2')
#     .query('(HMM_Seg_n_all<40 & CBS_Seg_n_all<40)')
# )

# fig = plt.figure(figsize=(5,5))
# ax  = fig.add_subplot(111)

# df_next.plot.scatter('MAPD_all','MAPD_chr', alpha=0.2, s=2, ax=ax);

# ax.vlines(0.3, 0, 1, linestyles='dashed', alpha=0.5, color='gray')
# ax.hlines(0.4, 0, 1, linestyles='dashed')

# ax.set_xlabel(ax.get_xlabel(), fontsize=16)
# ax.set_ylabel(ax.get_ylabel(), fontsize=16)
# ax.set_xlim(0,0.5);
# ax.set_ylim(0,0.5);

# plt.show()


# In[27]:


# check = (
#     mmqc.query('meanBin_all<2')
#     .query('(HMM_Seg_n_all<40 & CBS_Seg_n_all<40)')
#     .query('MAPD_chr>=0.4')
# )

# print ('Drop %.2f%% Cell (by MAPD)' % (100*(check.shape[0]/mmqc.shape[0])) )


# In[28]:


# check = df_next.query('MAPD_all>0.3 & MAPD_chr<0.4').index

# fig,axs = plt.subplots(nrows=8, figsize=(20,8), sharex=True, )

# for i,cell in enumerate( np.random.choice(check, 8) ):
#     fig_df = pd.concat(
#         [
#             data_df[cell],
#             ref['abs_pos'],
#         ], axis=1, sort=False
#     ).dropna(how='any')

#     ax = axs[i]
#     ax.plot(fig_df['abs_pos'],fig_df['copy'],  '.', alpha=0.5, markersize=3, color='k')
#     ax.plot(fig_df['abs_pos'],fig_df['HMMseg'],'-', alpha=0.7, lw=3,)
#     ax.plot(fig_df['abs_pos'],fig_df['CBSseg'],'-', alpha=0.7, lw=3,)
#     ax.set_title(cell)
#     ax.set_ylim(-0.5,6.5)

# plt.show()


# ---

# In[29]:


# good_idx = (
#     mmqc
#     .query('meanBin_all<2')
#     .query('(HMM_Seg_n_all<40) & (CBS_Seg_n_all<40)')
#     .query('MAPD_chr<0.4')
# )
# good_df = data_df[good_idx.index]


# In[30]:


# print ('Keep %.2f%% passing filter Cell' % (100*(good_idx.shape[0]/mmqc.shape[0])) )


# ---

# # Pos

# ## 200k

# In[31]:


nbin_chr = good_df.index.str.split(':',expand=True).to_frame()[0].value_counts()

aidx = good_df.index[good_df.index.str.contains('chr\d+')]
copy_mean = good_df.xs(key='copy',level=1,axis=1).loc[aidx].fillna(0).clip(upper=4).mean(axis=1)

fig = plt.figure(figsize=(6,4))
ax  = fig.add_subplot(111)

g = sns.kdeplot( copy_mean, lw=2, ax=ax )
ys = ax.get_ylim()

ax.vlines(1.2, ys[0], ys[1], linestyles='dashed')
ax.vlines(2.8, ys[0], ys[1], linestyles='dashed')

ax.set_title('Copy Mean');

plt.show()


# In[32]:


bad_pos1 = copy_mean.to_frame(name='copy_M').query('copy_M <= 1.2 or copy_M  >= 2.8').index

print ('Drop %.2f%% bins (with BadMean)' % (100*bad_pos1.shape[0]/data_df.shape[0]) )

for x,y in (
    bad_pos1.str.split(':',expand=True)
    .to_frame()[0].value_counts()
    .reindex(chr_uidx, fill_value=0)
    .iteritems()
): print ( '%s: %.1f%%' % (x, 100*y/nbin_chr[x] ), end='; ')

print ()


# ---

# In[33]:


xidx = good_df.index[good_df.index.str.contains('chrX')]

xlsx_tab = pd.read_excel('Fu2Samples.xlsx')[['SAMPLE','AGE','GENDER']]
xlsx_tab['SAMPLE'] = xlsx_tab['SAMPLE'].str.split('-',expand=True)[0]
xlsx_tab['SAMPLE'] = xlsx_tab['SAMPLE'].str.lower()
xlsx_tab = xlsx_tab.drop_duplicates()
xlsx_tab = xlsx_tab.replace({'huang':'Huang','wang':'Wang','shen':'Shen','pang':'Pang','he':'HE'})
xlsx_tab = xlsx_tab.set_index('SAMPLE')['GENDER']

cell_df = pd.Series(data=good_df.xs(key='copy',axis=1,level=1).columns)
cell_df.index = cell_df
cell_df = cell_df.str.split('_',expand=True)[0].replace(xlsx_tab.to_dict())
cell_df = cell_df.reset_index().rename(columns={0:'GENDER'})

male_copy_mean   = good_df.loc[xidx,cell_df.query('GENDER=="male"'  )['Cell_ID']].fillna(0).clip(upper=4).mean(axis=1)
female_copy_mean = good_df.loc[xidx,cell_df.query('GENDER=="female"')['Cell_ID']].fillna(0).clip(upper=4).mean(axis=1)

male_copy_mean.name   ='male'
female_copy_mean.name ='female'

copy_mean = pd.concat(
    [
        male_copy_mean,
        female_copy_mean
    ], axis=1
)

g = sns.JointGrid( data=copy_mean,x='male', y='female',space=0, xlim=(0,1.5), ylim=(0,3))
g.plot_joint(sns.scatterplot, size=3, edgecolor='', alpha=0.7,legend=None)
g.plot_marginals(sns.kdeplot, lw=2)
g.ax_joint.vlines(0.8, 0, 3,   linestyles='dashed')
g.ax_joint.hlines(1.4, 0, 1.5, linestyles='dashed')

plt.show()


# In[34]:


bad_pos2 = copy_mean.query('male<0.8 & female<1.4').index

print ('Drop %.2f%% bins (for chrX)' % (100*bad_pos2.shape[0]/xidx.shape[0]) )

print ()


# ---
# ---
# ---

# ## 2M

# In[35]:


# aidx = good_df.index[good_df.index.str.contains('chr[\d+|X]')]
# nbin_chr = good_df.index.str.split(':',expand=True).to_frame()[0].value_counts()

# aidx = good_df.index[good_df.index.str.contains('chr\d+')]
# copy_mean = good_df.xs(key='copy',level=1,axis=1).loc[aidx].fillna(0).mean(axis=1).clip(upper=4)

# fig = plt.figure(figsize=(6,4))
# ax  = fig.add_subplot(111)

# g = sns.kdeplot( copy_mean, lw=2, ax=ax )
# ys = ax.get_ylim()
# ax.vlines(1.4, ys[0], ys[1], linestyles='dashed')
# ax.vlines(2.6, ys[0], ys[1], linestyles='dashed')
# ax.set_title('Copy Mean');

# plt.show()


# In[36]:


# bad_pos = copy_mean.to_frame(name='copy_M').query('copy_M <= 1.4 or copy_M  >= 2.6').index

# print ('Drop %.2f%% bins (with BadMean)' % (100*bad_pos.shape[0]/data_df.shape[0]) )
# for x,y in (
#     bad_pos.str.split(':',expand=True)
#     .to_frame()[0].value_counts()
#     .reindex(chr_uidx, fill_value=0)
#     .iteritems()
# ): print ( '%s: %.1f%%' % (x, 100*y/nbin_chr[x] ), end='; ')

# print ()


# ---

# In[37]:


good_df = good_df.drop(bad_pos1).drop(bad_pos2).drop('ideal',axis=1,level=1)


# In[38]:


sns.clustermap( 
    good_df.xs(key='HMMseg',level=1,axis=1).T.sample(100).fillna(0), 
    cmap='RdBu_r',center=2, vmax=4, vmin=0, 
    xticklabels='', 
    col_cluster=False,
    cbar_kws={'ticks': [0, 1, 2, 3, 4]},
    figsize=(15, 5)
);

plt.show()


# In[39]:


good_df.to_pickle('good_df.' + window + '.pkl')
# good_df = pd.read_pickle('good_df.' + window + '.pkl')


# ---
# ---
# ---

# In[40]:


# good_order = (    
#     mmqc
#     .query('meanBin_all<@f1')
#     .query('(HMM_Seg_n_all<@f2 | CBS_Seg_n_all<@f2) & (HMM_Seg_n_chr<@f3 | CBS_Seg_n_chr<@f3)')
#     .query('MAPD_chr<@f4')
# ).sort_values('MAPD_all').index

# g_dict = {}
# for g,df in pd.DataFrame(good_order)['Cell'].str.split('_',expand=True).groupby(0, sort=False):
#     tmp_df = good_df.xs(axis=1, level=1, key='copy')[good_order[df.head(200).index]]
#     g_dict[g] = tmp_df[tmp_df<tmp_df.quantile(0.99)].mean(axis=1)
# g_mean_df = pd.DataFrame(g_dict)


# In[41]:


# fig, axs = plt.subplots(nrows=23, ncols=1, figsize=(20,20))

# for n, (g, df) in enumerate(ref.query('space!="chrY"').groupby('space',sort=False)):
#     sns.heatmap(
#         g_mean_df.reindex(index=df.index).dropna(how='all').T,
#         vmax=3, vmin=1, center=2,
#         xticklabels=[],
#         yticklabels=[],
#         ax=axs[n],
#         cmap='RdBu_r',
#         cbar_kws={'ticks':[1,2,3]},
#     )
#     axs[n].set_ylabel(g, rotation=0, ha='right', va='center')


# ---

# In[42]:


# aidx = good_df.index[good_df.index.str.contains('chr\d+')]
# xidx = good_df.index[good_df.index.str.contains('chrX')]

# a_norm_df = pd.DataFrame()
# b_norm_df = good_df.xs(key='copy',level=1,axis=1).copy()


# In[43]:


# for cell in tqdm(b_norm_df.columns):
#     sample = cell.split('_')[0]
#     cell_df = b_norm_df[cell].dropna()
#     norm_df = g_mean_df[sample].dropna()
#     a_cell_df = (cell_df[aidx].dropna() / norm_df[aidx].dropna() * round(norm_df[aidx].median()))[aidx]
#     x_cell_df = (cell_df[xidx].dropna() / norm_df[xidx].dropna() * round(norm_df[xidx].median()))[xidx]
#     norm_df = pd.concat([a_cell_df, x_cell_df], sort=False)
#     norm_df.name = cell
#     a_norm_df = a_norm_df.append(norm_df)


# In[ ]:





# In[ ]:




