#!/usr/bin/env python
# coding: utf-8
###############################################################################
#
# Author: HeChen
#
###############################################################################

# In[ ]:


import sys
import os
import gc
import signal
import warnings

from tqdm.auto       import tqdm
from itertools       import repeat
from multiprocessing import Pool

import pandas as pd
import numpy  as np

import seaborn as sns
import matplotlib.pyplot as plt

from rpy2                   import __version__ as rpy2_version
from rpy2.robjects          import pandas2ri as pd2ri
from rpy2.robjects          import r         as robj
from rpy2.robjects.packages import importr   as libraryr
from rpy2.rinterface        import RRuntimeWarning

from pybedtools import BedTool as BedTools
from pybedtools import cleanup as tmpCleanUp


# In[ ]:


print ('Pandas==' + pd.__version__, flush=True)
print ('rpy2=='   + rpy2_version,   flush=True)
# tqdm.write('Pandas==' + pd.__version__)
# tqdm.write('rpy2=='   + rpy2_version  )


# In[ ]:


if len(sys.argv) < 1:
    print ('USAGE: python ' + sys.argv[0] + ' <Sample_Name>')
    sys.exit(0)

sample = sys.argv[1]
window = '200k'


# In[ ]:


pd2ri.activate()

HMMcopy = libraryr('HMMcopy')
DNAcopy = libraryr('DNAcopy')
IRanges = libraryr('IRanges')

r2df = robj['data.frame']
r2n  = robj['as.numeric']

window_size = {
    '50k' :5e4,
    '100k':1e5,
    '200k':2e5,
    '500k':5e5,
    '1m'  :1e6,
    '2m'  :2e6,
}

EPS = np.finfo(np.float32).eps
RTIME = 30
redoN = 20

class TimeoutError(Exception):
    pass

def handler(signum, frame):
    raise TimeoutError()

ALPHA  = 1e-4
NPERM  = 10000
UNDOSD = 1
MINWD  = 5

refFile  = 'tilling' + window + '.bed'
# Load fix bin pos
ref = pd.read_csv(refFile, sep='\t', header=None, names=['space','start','end'])
# Drop chrM
ref = ref[ref['space'].str.contains('chr[\d+|X|Y]')] 
# Set index
ref = ref.set_index(ref.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1))
# Add abs pos, rough / precise
ref['abs_pos'] = (np.arange(0,ref.shape[0])*window_size[window]).astype(np.int64)
# ref['abs_pos'] = ref['end'].cumsum().shift().fillna(0).astype(np.int64)
# ref.iloc[2489:2492]
chr_uidx = ref['space'].unique()

non_idx = ref.copy()
non_idx['CHR'] = 'chr99'
non_idx['pos1'] = non_idx['abs_pos']/1000 + 1
non_idx['pos2'] = non_idx['abs_pos']/1000 + ref['abs_pos'].iloc[1]/1000
non_idx = non_idx[['CHR','pos1','pos2']].rename(columns={'CHR':'space','pos1':'start','pos2':'end'}).astype( {'start':np.int64, 'end':np.int64} )
non_idx['ID'] = non_idx.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1)


# ---

# In[ ]:


# pfcells = pd.read_pickle(sample+'.qc.' + window + '.pkl').xs(key='ALL',axis=0,level=1).query('MAPD<0.5').index
# data_df = pd.read_pickle(sample+'.data_seg.' + window + '.pkl')[pfcells]
# data_df = data_df.drop('ideal',axis=1,level=1).copy()
# del(pfcells); gc.collect();

good_df = pd.read_pickle(   'good_df.' + window + '.pkl').xs(key='copy',axis=1,level=1)
data_df = (
    pd.read_pickle(sample+'.data_seg.' + window + '.pkl').drop('ideal',axis=1,level=1)
    .reindex(index=good_df.index)
    .reindex(columns=good_df.columns, level=0)
    .dropna(how='all',axis=0)
    .dropna(how='all',axis=1)
).copy()
del good_df; gc.collect();

# ---

# In[ ]:


def shuffle_HMM(cell_df):

    signal.signal(signal.SIGALRM, handler) 
    signal.alarm(RTIME)

    try:
        # Get seg as group
        seg_group = cell_df.groupby('HMMseg',sort=False)

        # Shuffle bin based on seg, save shuffled index
        shuffle_df     = pd.concat([seg_group.get_group(seg) for seg in np.random.permutation( list(seg_group.groups.keys()) )])
        shuffle_df_idx = shuffle_df.index

        # Reset index for shuffled df, use genome index
        input_df = pd.concat(
            [
                shuffle_df.set_index(cell_df.index).reindex(non_idx.index), 
                non_idx,
            ], axis=1,
        ).reindex(cell_df.index)[['space','start','end','copy']]

        # Pandas to R
        input_rdf = IRanges.RangedData(
            ranges = IRanges.IRanges(
                start = input_df['start'], 
                end   = input_df['end'  ],
            ),
            space =     input_df['space'],
            copy  = r2n(np.log2(input_df['copy' ]/2).astype(str)),
        )

        # Run HMMcopy
        warnings.simplefilter(action='ignore', category=FutureWarning)
        normal_segments = HMMcopy.HMMsegment(input_rdf, verbose=False)
        df_copy = pd2ri.ri2py_dataframe(r2df(input_rdf)).dropna()
        df_segs = pd2ri.ri2py_dataframe(
            normal_segments[normal_segments.names.index('segs')]
            ).astype( {'start':np.int32, 'end':np.int32} )

        # Seg to bin
        obj_bin = BedTools.from_dataframe(df_copy)
        obj_seg = BedTools.from_dataframe(df_segs)

        data = (
            BedTools.intersect(self=obj_bin, b=obj_seg, wao=True)
            .to_dataframe(
                names=(df_copy.columns.tolist() + df_segs.columns.tolist()[-2:]), 
                usecols=np.arange(0,5).tolist()+[8,9]
            )
            .drop('width',axis=1)
        )

        # Add abs pos to result
        data = pd.concat(
                [
                    data.set_index(
                        data.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1)
                    ).rename(index=non_idx.reset_index().set_index('ID')['index']).reindex(ref.index),
                    ref['abs_pos'],
                ], axis=1,
            ).drop(
                ['space','start','end'],
                axis=1
            ).dropna()

        # Reset original index
        new_seg = data.set_index(shuffle_df_idx).reindex(cell_df.index)['median']

        tmpCleanUp()

        warnings.simplefilter('default')

    except TimeoutError:
        print ('TimeOut: ' + cell_df.index.name, flush=True)
        # tqdm.write('TimeOut: ', cell_df.index.name)
        new_seg = pd.Series()
    except:
        print ('Unexpected error: ' + sys.exc_info()[0], flush=True)
        # tqdm.write("Unexpected error: ", sys.exc_info()[0])
        new_seg = pd.Series()
    finally:
        signal.alarm(0)
        tmpCleanUp()
        return new_seg


# In[ ]:


def shuffle_CBS(cell_df):

    signal.signal(signal.SIGALRM, handler) 
    signal.alarm(2*RTIME)

    try:
        seg_group = cell_df.groupby('CBSseg',sort=False)

        # Shuffle bin based on seg, save shuffled index
        shuffle_df     = pd.concat([seg_group.get_group(seg) for seg in np.random.permutation( list(seg_group.groups.keys()) )])
        shuffle_df_idx = shuffle_df.index

        # Reset index for shuffled df, use genome index
        input_df = pd.concat(
            [
                shuffle_df.set_index(cell_df.index).reindex(non_idx.index), 
                non_idx,
            ], axis=1,
        ).reindex(cell_df.index)[['space','start','end','copy']]

        dChr = pd.Series([99] * input_df.shape[0])
        pos1 = pd.Series(input_df['start'].values)

        warnings.simplefilter(action='ignore', category=FutureWarning)

        # Run CBScopy
        obj_CNA = DNAcopy.smooth_CNA( DNAcopy.CNA(np.log2(input_df['copy']/2), dChr, pos1, sampleid='ID') )
        obj_seg = DNAcopy.segment(obj_CNA, alpha=ALPHA, nperm=NPERM, undo_splits='sdundo', undo_SD=UNDOSD, min_width=MINWD, verbose=0)

        pd_segs = pd2ri.ri2py(pd2ri.ri2py(obj_seg)[1])
        pd_segs['chr'] = 'chr'+pd_segs['chrom'].map(str)

        tmp_i = 0
        while True:
            if pd_segs['seg.mean'].unique().shape[0] == pd_segs.shape[0]:
                break
            else:
                if tmp_i >= 10:
                    ### TODO ###
                    break
                pd_segs['seg.mean'] += ( np.random.random(size=pd_segs.shape[0])/1e5 ).clip(EPS,1e-5)
                tmp_i += 1


        # Get Segment value for each bin by BedTools, and re-index
        obj_seg = BedTools.from_dataframe(pd_segs[['chr','loc.start','loc.end','seg.mean']])
        obj_bin = BedTools.from_dataframe(
            input_df.reset_index().iloc[:,1:4].rename(columns={'space':'Chr'}).assign(bin_log2=np.log2(input_df['copy']).values/2)
        )

        data = BedTools.intersect(self=obj_bin, b=obj_seg, wao=True).to_dataframe(
            names=( input_df.columns.tolist() + [pd_segs.columns[5]] ),
            usecols=np.arange(0,4).tolist()+[7]
        )
        data['seg.mean'] = pd.to_numeric(data['seg.mean'], errors='coerce')

        # Fill missing bins
        data['seg.mean'] = data['seg.mean'].fillna(data['seg.mean'].mean())

        data = pd.concat(
            [
                data.set_index(
                    data.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1)
                ).rename(index=non_idx.reset_index().set_index('ID')['index']).reindex(ref.index),
                ref['abs_pos'],
            ], axis=1,
        ).drop(
            ['space','start','end'],
            axis=1
        ).dropna()

        # Reset original index
        new_seg = data.set_index(shuffle_df_idx).reindex(cell_df.index)['seg.mean']

        tmpCleanUp()

        warnings.simplefilter('default')

    except TimeoutError:
        print ('TimeOut: ' + cell_df.index.name, flush=True)
        # tqdm.write('TimeOut! ', cell_df.index.name)
        new_seg = pd.Series()
    except:
        print ('Unexpected error: ' + sys.exc_info()[0], flush=True)
        # tqdm.write("Unexpected error: ", sys.exc_info()[0])
        new_seg = pd.Series()
    finally:
        signal.alarm(0)
        tmpCleanUp()
        return new_seg


# In[ ]:


print ('', flush=True)
print ('', flush=True)
print ('Starting shuffling HMMseg ...', flush=True)
print ('', flush=True)
# tqdm.write('')
# tqdm.write('Starting shuffling HMMseg ...')

shuffle_results = {}

for cell in tqdm(data_df.xs(key='copy',axis=1,level=1).columns, desc='ForCells'):
    cell_df = data_df[cell].dropna(how='any').copy()
    cell_df.index.name = cell

    pool  = Pool(min(32,redoN))

    result = pool.imap_unordered( shuffle_HMM, repeat(cell_df, redoN) )

#    result = list(
#        tqdm(
#            pool.imap_unordered( shuffle_HMM, repeat(cell_df, redoN) ),
#            total = redoN, leave = False,
#            desc  = 'ForShuffle', mininterval = 1,
#        )
#    )

    pool.close()
    pool.join()

    save = pd.DataFrame()

    cell_shuffle_res = np.power( 2, pd.concat(result,axis=1,sort=False).dropna(how='all',axis=1).reindex(cell_df.index)+1 )

    if cell_shuffle_res.shape[1] < redoN/2:
        del pool; del result; del cell_shuffle_res; gc.collect();
        continue

    mean_shuffle   = cell_shuffle_res.mean(axis=1)
    var_shuffle    = cell_shuffle_res.var(axis=1)
    r_mean_shuffle = cell_shuffle_res.round().mean(axis=1)

    save['mean_shuffle']    = mean_shuffle
    save['var_shuffle']     = var_shuffle
    save['r_mean_shuffle']  = r_mean_shuffle

    shuffle_results[cell] = save.copy()
    del pool; del result; del cell_shuffle_res; gc.collect();

data_shuffle_df = pd.concat(shuffle_results,axis=1,sort=False).reindex(ref.index).dropna(how='all').copy()
data_shuffle_df.to_pickle(sample+ '.shuffle.hmm.' + window + '.pkl')
del shuffle_results; del data_shuffle_df; gc.collect();

tmpCleanUp()


# In[ ]:


print ('', flush=True)
print ('', flush=True)
print ('Starting shuffling CBSseg ...', flush=True)
print ('', flush=True)
# tqdm.write('')
# tqdm.write('Starting shuffling CBSseg ...')

shuffle_results = {}

for cell in tqdm(data_df.xs(key='copy',axis=1,level=1).columns, desc='ForCells'):
    cell_df = data_df[cell].dropna(how='any').copy()
    cell_df.index.name = cell

    pool  = Pool(min(32,redoN))

    result = pool.imap_unordered( shuffle_CBS, repeat(cell_df, redoN) )

#    result = list(
#        tqdm(
#            pool.imap_unordered( shuffle_CBS, repeat(cell_df, redoN) ),
#            total = redoN, leave = False,
#            desc  = 'ForShuffle', mininterval = 1,
#        )
#    )

    pool.close()
    pool.join()

    save = pd.DataFrame()

    cell_shuffle_res = np.power( 2, pd.concat(result,axis=1,sort=False).dropna(how='all',axis=1).reindex(cell_df.index)+1 )

    if cell_shuffle_res.shape[1] < redoN/2:
        del pool; del result; del cell_shuffle_res; gc.collect();
        continue

    mean_shuffle   = cell_shuffle_res.mean(axis=1)
    var_shuffle    = cell_shuffle_res.var(axis=1)
    r_mean_shuffle = cell_shuffle_res.round().mean(axis=1)

    save['mean_shuffle']    = mean_shuffle
    save['var_shuffle']     = var_shuffle
    save['r_mean_shuffle']  = r_mean_shuffle

    shuffle_results[cell] = save.copy()
    del pool; del result; del cell_shuffle_res; gc.collect();

data_shuffle_df = pd.concat(shuffle_results,axis=1,sort=False).reindex(ref.index).dropna(how='all').copy()
data_shuffle_df.to_pickle(sample+ '.shuffle.cbs.' + window + '.pkl')
del shuffle_results; del data_shuffle_df; gc.collect();

tmpCleanUp()


# In[ ]:

print ('', flush=True)
print ('', flush=True)
print ('DONE!', flush=True)
# tqdm.write('DONE!')


