#!/usr/bin/env python
# coding: utf-8
###############################################################################
#
# Author: HeChen
#
###############################################################################

# In[ ]:


import numpy  as np
import pandas as pd

from tqdm.auto import tqdm


# In[ ]:


print (pd.__version__)


# In[ ]:


window = '200k'

good_df = pd.read_pickle('good_df.' + window + '.pkl')

aidx = good_df.index[good_df.index.str.contains('chr\d+:')]
xidx = good_df.index[good_df.index.str.contains('chrX:')]
yidx = good_df.index[good_df.index.str.contains('chrY:')]


# In[ ]:


xlsx_tab = pd.read_excel('/PATH/TO/SAMPLE/INFO/info.xlsx')[['SAMPLE','AGE','GENDER']]
xlsx_tab['SAMPLE'] = xlsx_tab['SAMPLE'].str.split('-',expand=True)[0]
xlsx_tab['SAMPLE'] = xlsx_tab['SAMPLE'].str.lower()
xlsx_tab = xlsx_tab.drop_duplicates()
xlsx_tab = xlsx_tab.set_index('SAMPLE')

xlsx_tab.loc['SAMPLE1',  'AGE'] = 44
xlsx_tab.loc['SAMPLE2','AGE'] = 43
xlsx_tab.loc['SAMPLE3','AGE'] = 36
xlsx_tab.loc['SAMPLE4', 'AGE'] = 13

# xlsx_tab = xlsx_tab.reset_index()

xlsx_tab['X_default'] = xlsx_tab['GENDER'].replace({'male':1,'female':2})
xlsx_tab['Y_default'] = xlsx_tab['GENDER'].replace({'male':1,'female':0})


# In[ ]:


window_size = {
    '100k':1e5,
    '200k':2e5,
    '500k':5e5,
    '1m'  :1e6,
    '2m'  :2e6,
}

refFile  = 'tilling' + window + '.bed'
# Load fix bin pos
ref = pd.read_csv(refFile, sep='\t', header=None, names=['space','start','end'])
# Drop chrM
ref = ref[ref['space'].str.contains('chr[\d+|X|Y]')] 
# Set index
ref = ref.set_index(ref.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1))
# Add abs pos, rough / precise
ref['abs_pos'] = (np.arange(0,ref.shape[0])*window_size[window]).astype(np.int64)
# ref['abs_pos'] = ref['end'].cumsum().shift().fillna(0).astype(np.int64)
# ref.iloc[2489:2492]
chr_uidx = ref['space'].unique()


# ---

# In[ ]:


def get_shuffle_df(sample, window):
    print ('Loading ', sample, ' DataFrame ...', flush=True)
    
    shuffle_hmm_df = pd.read_pickle(sample + '.shuffle.hmm.'+ window + '.pkl')
    shuffle_cbs_df = pd.read_pickle(sample + '.shuffle.cbs.'+ window + '.pkl')

    shuffle_df = pd.merge( 
        shuffle_hmm_df.T.reset_index(level=0).T, 
        shuffle_cbs_df.T.reset_index(level=0).T,
        left_index=True, right_index=True,
        suffixes=('_hmm', '_cbs'),
    ).T.reset_index().set_index(['level_0','index']).T.sort_index(axis=1,level=0)

    shuffle_df.columns.levels[0].name = 'Cell'
    shuffle_df.columns.levels[1].name = 'Type'

    shuffle_df = shuffle_df.copy()
    del shuffle_hmm_df; del shuffle_cbs_df;
    
    return shuffle_df


# In[ ]:


def get_cnv(chr_df,normal=2):
    cnv_df = pd.DataFrame()
    
#     for g, df in chr_df.loc[ 
#         chr_df.round().query('CBSseg!=@normal & HMMseg!=@normal & mean_shuffle_cbs!=@normal & mean_shuffle_hmm!=@normal').index 
#     ].groupby('HMMseg',sort=False):
    for g, df in chr_df.loc[ chr_df.round().query('HMMseg!=@normal').index ].groupby('HMMseg',sort=False):
        if ( pd.DataFrame( df.mean().round() ).T
            .query('CBSseg!=@normal & HMMseg!=@normal & mean_shuffle_cbs!=@normal & mean_shuffle_hmm!=@normal')
#            .query('CBSseg!=@normal & HMMseg!=@normal')
#            .query('CBSseg!=@normal | HMMseg!=@normal')
#            & mean_shuffle_cbs!=@normal & mean_shuffle_hmm!=@normal')
           ).shape[0] == 0:
            continue
        
        g_chr = df.index.to_series().str.split(':',expand=True)[0].unique()
        assert g_chr.shape == (1,), 'Bug: Segment not right!!!'

        binSize = df.index.shape[0]
        if binSize < 10:
            continue
    
        if g < normal:
            cnv_tag = 'Loss'

        else:
            cnv_tag = 'Gain'

        dfcopy = df['copy']
        dfgood = df.where( dfcopy<dfcopy.quantile(0.95) ).dropna()
        disBfS = (dfgood['copy']-dfgood['HMMseg']).abs()
        
        tmp_dict = {
            'Pos'      : ':'.join(df.index[0].split(':')[0:2] + [ df.index[-1].split(':')[-1] ]),
            'binSize'  : binSize,
            'absSize'  : ref.loc[df.index,'abs_pos'].iloc[[0,-1]].diff()[-1],
            'HMMseg'   : df['HMMseg'].mean(),
            'CBSseg'   : df['CBSseg'].mean(),
            'CBSshf'   : df['mean_shuffle_cbs'].median(),
            'HMMshf'   : df['mean_shuffle_hmm'].median(),
            'Space'    : g_chr[0],
            'cnvTag'   : cnv_tag,
            'cnvMAPD'  : df['copy'].diff().abs().median(),
            'disBfS_M' : disBfS.median(),
            'disBfS_m' : disBfS[disBfS<disBfS.quantile(0.95)].mean(),
            'cnv_bM'   : dfcopy[dfcopy<dfcopy.quantile(0.95)].mean(),
            'cnv_bSTD' : dfcopy[dfcopy<dfcopy.quantile(0.95)].std(),
        }

        cnv_df = cnv_df.append(tmp_dict, ignore_index=True)

    return cnv_df


# In[ ]:


# def cell_worker(cell):
#     cell_df  = sample_df[cell].astype(np.float)
#     cell_adf = cell_df.reindex(aidx).dropna(how='any')
#     cell_xdf = cell_df.reindex(xidx).dropna(how='any')
#     cell_ydf = cell_df.reindex(yidx).dropna(how='any')

#     cell_cnv = pd.DataFrame()

#     # Call CNV
#     acnv = get_cnv(cell_adf, 2)
#     xcnv = get_cnv(cell_xdf, x_default)
#     ycnv = get_cnv(cell_ydf, y_default)

#     cell_cnv = pd.concat([acnv,xcnv,ycnv])
#     cell_cnv['Cell'] = cell

#     return cell_cnv


# --- 

# In[ ]:


cnv_summary_df = pd.DataFrame()

for sample in xlsx_tab.index:
    shuffle_df = get_shuffle_df(sample, window)
    
    sample_df = pd.concat(
        [
            good_df[shuffle_df.columns.levels[0]],
            shuffle_df,
        ], axis=1, sort=False,
    ).sort_index(axis=1,level=0).loc[good_df.index]
    
    x_default = xlsx_tab.loc[sample,'X_default']
    y_default = xlsx_tab.loc[sample,'Y_default']

    sample_cnv_df = pd.DataFrame()
    print ('Start processing cells ...', flush=True)

    for cell in tqdm( sample_df.columns.levels[0] ):
        cell_df  = sample_df[cell].astype(np.float)
        cell_adf = cell_df.reindex(aidx).dropna(how='any')
        cell_xdf = cell_df.reindex(xidx).dropna(how='any')
        cell_ydf = cell_df.reindex(yidx).dropna(how='any')

        cell_cnv = pd.DataFrame()

        # Call CNV
        acnv = get_cnv(cell_adf, 2)
        xcnv = get_cnv(cell_xdf, x_default)
        ycnv = get_cnv(cell_ydf, y_default)

        cell_cnv = pd.concat([acnv,xcnv,ycnv], sort=False, ignore_index=True)
        cell_cnv['Cell'] = cell
        sample_cnv_df = sample_cnv_df.append( cell_cnv, sort=False, ignore_index=True )
#         break

    sample_cnv_df['Sample'] = sample
    cnv_summary_df = cnv_summary_df.append(sample_cnv_df, sort=False, ignore_index=True)
    print ('', flush=True)
    print ('Finshing ', sample, ' !!!', flush=True)
    print ('', flush=True)
#     break


# ---

# In[ ]:


cnv_summary_df.to_pickle('summary_df.'+window+'.pkl')


# ---

# In[ ]:




