#!/usr/bin/env python
# coding: utf-8

###############################################################################
#
# Author: HeChen
#
###############################################################################

# In[ ]:

###############################################################################
# Import
###############################################################################
import os
import sys
import warnings

from io              import StringIO
from tqdm            import tqdm
from multiprocessing import Pool

import pandas as pd
import numpy  as np

import seaborn as sns
import matplotlib.pyplot as plt

from rpy2                   import __version__ as rpy2_version
from rpy2.robjects          import pandas2ri   as pd2ri
from rpy2.robjects          import r           as robj
from rpy2.robjects.packages import importr     as libraryr
from rpy2.rinterface        import RRuntimeWarning

from pybedtools import BedTool as BedTools
from pybedtools import cleanup as tmpCleanUp
###############################################################################
#assert pd.__version__ == '0.23.4', 'Pandas==0.23.4, we got ' + pd.__version__ 
#assert rpy2_version   == '2.9.4',  'rpy2==2.9.4, we got ' + rpy2_version
###############################################################################
print ('Pandas=='+pd.__version__)
print ('rpy2=='+rpy2_version)

# In[ ]:


if len(sys.argv) < 1:
    print ('USAGE: python ' + sys.argv[0] + ' <Sample_Name>')
    sys.exit(0)

sample = sys.argv[1]
window = '200k'

pd2ri.activate()

HMMcopy = libraryr('HMMcopy')
DNAcopy = libraryr('DNAcopy')

r2df = robj['data.frame']
r2n  = robj['as.numeric']

gc_file  = '/PATH/TO/REFERENCE/hg38.gc200k.wig'
map_file = '/PATH/TO/REFERENCE/hg38.umap200k.wig'
refFile  = 'tilling200k.bed'

if not ( os.path.isfile(gc_file) and os.path.isfile(map_file) and os.path.isfile(refFile) ):
    print ('Not exists <gc_file> or <map_file> or refFile: ' + gc_file + ' ' + map_file + ' ' + refFile)
    sys.exit(0)

window_size = {
    '50k' :5e4,
    '100k':1e5, 
    '200k':2e5, 
    '500k':5e5, 
    '1m'  :1e6, 
    '2m'  :2e6, 
}

EPS = np.finfo(np.float32).eps


# In[ ]:


# Load fix bin pos
ref = pd.read_csv(refFile, sep='\t', header=None, names=['space','start','end'])
# Set index
ref = ref.set_index(ref.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1))
# Add abs pos, rough / precise
ref['abs_pos'] = (np.arange(0,ref.shape[0])*window_size[window]).astype(np.int64)
# ref['abs_pos'] = ref['end'].cumsum().shift().fillna(0).astype(np.int64)
# ref.iloc[2489:2492]
chr_uidx = ref['space'].unique()

state = '''HOMD Homozygous deletion, ≤ 0 copies
HETD Heterozygous deletion, 1 copy 
NEUT Neutral change, 2 copies
GAIN Gain of chromosome, 3 copies 
AMPL Amplification event, 4 copies
HLAMP High level amplification, ≥ 5 copies'''

state = (
    pd.read_fwf( StringIO(state), header=None )[0]
    .str.split(' ', expand=True)
    .rename( columns={0:'State'} )['State']
#     .set_axis( np.arange(1,7), inplace=False )
    .reindex( np.arange(1,7) )
)


# ### HMMcopy pipeline in R
# 
# ```R
# library(HMMcopy)
# rfile <- system.file("extdata", "normal.wig", package = "HMMcopy")
# gfile <- system.file("extdata", "gc.wig",     package = "HMMcopy")
# mfile <- system.file("extdata", "map.wig",    package = "HMMcopy")
# 
# normal_reads    <- wigsToRangedData(rfile, gfile, mfile)
# normal_copy     <- correctReadcount(normal_reads)
# normal_segments <- HMMsegment(normal_copy)
# ```

# In[ ]:


#def call_HMM(sample_file, e=0.995, sth=1e7):
def call_HMM(sample_file, e=0.9999, sth=1e7):
    warnings.simplefilter(action='ignore', category=(FutureWarning,RRuntimeWarning,ResourceWarning))

    # Normalize reads
    normal_copy = HMMcopy.correctReadcount(HMMcopy.wigsToRangedData(sample_file, gc_file, map_file))

    # Tunning the parameters
    if True: 
        param = pd2ri.ri2py_dataframe(HMMcopy.HMMsegment(normal_copy, getparam=True))
        # e
        # Probability of extending a segment, increase to lengthen
        # segments, decrase to shorten segments. Range: (0, 1)
        # Default: 1
        param['e'] = e
        # strength
        # Strength of initial e suggestion, reducing allows e
        # to change, increasing makes e undefiable.  Range: [0, Inf)
        # Default: 1e7
        param['strength'] = sth
        param = pd2ri.py2ri(param)

    # Call segment
    normal_segments = HMMcopy.HMMsegment(normal_copy, param, verbose=False)

    # Transform R obj to pandas Dataframe
    df_copy = pd2ri.ri2py_dataframe(r2df(normal_copy)).dropna()
    df_segs = pd2ri.ri2py_dataframe(
        normal_segments[normal_segments.names.index('segs')]
        ).astype( {'start':np.int32, 'end':np.int32} )

    # Shift seg.mean to ensure the seg.mean is unique value
    tmp_i = 0
    while True:
        if df_segs['median'].unique().shape[0] == df_segs.shape[0]:
            break
        else:
            if tmp_i >= 10:
                ### TODO ###
                print ('This sample require more than 10 shift at seg values: ' + sample_file)
                break
            df_segs['median'] += ( np.random.random(size=df_segs.shape[0])/1e5 ).clip(EPS,1e-5)
            tmp_i += 1

    # Transform pandas Dataframe to bed
    obj_bin = BedTools.from_dataframe(df_copy)
    obj_seg = BedTools.from_dataframe(df_segs)

    # Run intersectBed to assign segment to bin
    data = (
        BedTools.intersect(self=obj_bin, b=obj_seg, wao=True)
        .to_dataframe(
            names=(df_copy.columns.tolist() + df_segs.columns.tolist()[-2:]), 
            usecols=np.arange(0,12).tolist()+[15,16]
        )
        .drop(['width','valid'],axis=1)
    )

    # Add abs pos to result
    data = pd.concat(
        [
            data.set_index(
                data.apply(lambda r: ':'.join(r[0:3].values.astype(str)), axis=1)
            ).reindex(ref.index), 
            ref['abs_pos'],
        ], axis=1
    ).drop(
        ['space','start','end','gc','map'],
        axis=1
    ).dropna().astype( {'reads':np.int32, 'ideal':np.int8, 'state':np.int8} )

    del(obj_bin); del(obj_seg); tmpCleanUp()

    warnings.simplefilter('default')

    # Save result
#     data.to_pickle(sample_file.replace('wig','pkl'))

    return data[['copy','median','ideal']]
#     return data


# In[ ]:


def bin2seg_DNAcopy(binSeries, ALPHA=1e-4, NPERM=10000, UNDOSD=1, MINWD=5):

    warnings.simplefilter(action='ignore', category=(FutureWarning,RRuntimeWarning,ResourceWarning))


    # Get chr and pos information from index
    dChr = pd.Series(binSeries.index).str.split(':',expand=True)[0].str.replace('chr','').replace('X','23').replace('Y','24').astype(np.int8)
    pos1 = pd.Series(binSeries.index).str.split(':',expand=True)[1].astype(np.int32)

    # Call CBS
    obj_CNA = DNAcopy.smooth_CNA( DNAcopy.CNA(binSeries, dChr, pos1, sampleid='ID') )
    obj_seg = DNAcopy.segment(obj_CNA, alpha=ALPHA, nperm=NPERM, undo_splits='sdundo', undo_SD=UNDOSD, min_width=MINWD, verbose=0)

    # Save Segment as Pandas Dataframe
    pd_segs = pd2ri.ri2py(pd2ri.ri2py(obj_seg)[1])
    pd_segs['chr'] = 'chr'+pd_segs['chrom'].map(str).replace('23','X').replace('24','Y')

    tmp_i = 0
    while True:
        if pd_segs['seg.mean'].unique().shape[0] == pd_segs.shape[0]:
            break
        else:
            if tmp_i >= 10:
                ### TODO ###
                break
            pd_segs['seg.mean'] += ( np.random.random(size=pd_segs.shape[0])/1e5 ).clip(EPS,1e-5)
            tmp_i += 1

    # Get Segment value for each bin by BedTools, and re-index
    obj_seg = BedTools.from_dataframe(pd_segs[['chr','loc.start','loc.end','seg.mean']])
    obj_bin = BedTools.from_dataframe(pd.Series(binSeries.index).str.split(':',expand=True).rename(columns={0:'Chr',1:'start',2:'end'}).assign(bin_log2=binSeries.values))

    data = BedTools.intersect(self=obj_bin, b=obj_seg, wao=True).to_dataframe(names=['bin_log2','seg'],usecols=[3,7]).set_index(binSeries.index)
    data['seg'] = pd.to_numeric(data['seg'], errors='coerce')

    # Fill missing bins
    autosome = data.loc[data.index.str.contains(r'chr[\d+]'),'seg']
    data.loc[data.index.str.contains(r'chr[\d+]'),'seg'] = autosome.fillna(autosome.mean())

    sexX = data.loc[data.index.str.contains(r'chrX'),'seg']
    data.loc[data.index.str.contains(r'chrX'),'seg'] = sexX.fillna(sexX.mean())

    sexY = data.loc[data.index.str.contains(r'chrY'),'seg']
    data.loc[data.index.str.contains(r'chrY'),'seg'] = sexY.fillna(sexY.mean())

    tmpCleanUp()
    
    warnings.simplefilter('default')

    return data['seg']


# In[ ]:


def worker(sample_file):
    try:
        result_df = call_HMM(sample_file)
        cbs_seg   = bin2seg_DNAcopy(result_df['copy'])
        result_df = pd.concat( [ result_df, cbs_seg ], axis=1 )
    except:
        print ('Error in Sample: '+ sample_file.replace('.' + window + '.wig','').replace('/','_'))
        result_df = pd.DataFrame()
    return result_df


# In[ ]:


# ### Detect samples
samples = pd.Series([folder for folder in os.listdir('.') if os.path.isdir(folder) & ('_m' in folder) & ('link' not in folder) ]).str.split('_',expand=True)[0].unique()
samples.sort()

# ### Select sample & Detect cells
fd = [folder for folder in os.listdir('.') if os.path.isdir(folder) & (sample in folder) & ('_m' in folder) & ('link' not in folder) ]
fd.sort()

files = []
names = []

for f in fd:
    files += [os.path.join(f,i)                                     for i in os.listdir(f) if '.'+window+'.wig' in i]
    names += [os.path.join(f,i)[:-(len(window)+5)].replace('/','_') for i in os.listdir(f) if '.'+window+'.wig' in i]


# In[ ]:



###############################################################################

# ### Start RUN !!!

pool = Pool(32)
results = list( tqdm( pool.imap(worker, files), total=len(files) ) )

pool.close()
pool.join()
data_df = pd.concat(results,axis=1,keys=names,sort=False).reindex(ref.index)

###############################################################################

# In[ ]:

out_df = (
    np.power( 2, data_df+1 )            # Transform log2 to linear
    .dropna(how='all').T                # Drop bad bins in all samples (cell), transpose
    .reset_index().rename( columns=     # Rename MultiIndex
        {
            'level_0':'Cell_ID',
            'level_1':'Type'
        })
    .replace(                           # Rename label
        {
            'Type': {'median':'HMMseg', 'seg':'CBSseg'}
        })
    .set_index( ['Cell_ID', 'Type'] ).T # Reset MultiIndex, transpose
    .rename_axis('Pos_bin')             # Rename index
    .sort_index(axis=1,level=0)         # Sort index
)

# Reset ideal flag
ideal_idx = pd.IndexSlice[:, 'ideal']
out_df.loc[:, ideal_idx] = out_df.loc[:, ideal_idx] / 2 -1

# Save
out_df.to_pickle(sample+'.data_seg.' + window + '.pkl')

print ('Save to file: ' + sample+'.data_seg.' + window + '.pkl')


# ---

