#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Feb  3 23:26:32 2022

@author: exx
"""

import sys

# sys.path.insert(0,'/cncm/Lab/Kevin_public/merfishTest/analysis_local_round1')
# sys.path.insert(0,'/cncm/Lab/Kevin_public/merfishTest/SingleCellRoutines')

# sys.path.insert(0, '../SingleCellRoutines')
import stlearn as st
import os
#from dca.api import dca
import scanpy as sc
import tqdm
import time
import pandas as pd
import numpy as np
from scipy import stats
from scipy.io import mmwrite
from scipy import sparse
from __init__scr import *
from __init__plots import *
import utils
import powerplot
import h5py
import json
import tempfile
from scipy.sparse import csc_matrix
from PIL import Image
from matplotlib import path


def explore_and_preproc(DATA_DIR,samples,use_dca=True,minVol=-1,maxVol=-1,minTranscript=-1,maxTranscript=-1):
    # directories

    data_dir = DATA_DIR
    output_dir = DATA_DIR+'/data' 
    figure_dir = DATA_DIR+'/figures' 

    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    if not os.path.isdir(figure_dir):
        os.mkdir(figure_dir)
    
    # palette
    colors = [
        mpl.cm.get_cmap('tab20').colors[0],
        mpl.cm.get_cmap('tab20').colors[1],
        
        mpl.cm.get_cmap('tab20').colors[4],
        mpl.cm.get_cmap('tab20').colors[5],
        
        mpl.cm.get_cmap('tab20').colors[8],
    ] 
    palette = {}
    #for i, sample in enumerate(samples):
    #    palette[sample] = colors[i] 

    ti = time.time()
    allcounts = {}
    metas = {}
    i = 0

    for sample in samples:
        print(sample)
        fdata = os.path.join(data_dir, '{}_cell_by_gene.csv'.format(sample))
        fmeta = os.path.join(data_dir, '{}_cell_metadata.csv'.format(sample))

        # metadata
        meta = pd.read_csv(fmeta, index_col=0)
        
        # gmat
        data = pd.read_csv(fdata, index_col=0)
        #data=data.drop(columns='Unnamed: 0.1')
        data = data.drop(data.filter(regex='^Blank-').columns, axis=1) # blank removal mechanism, here
        data = data.drop(data.filter(regex='^Blank_').columns, axis=1) # remove possible blanks which is becoming Blank_ during matlab processing
        cell_cov = data.sum(axis=1)
        meta['cell_cov'] = cell_cov
        
        # cell_cov
        if i == 0:
            genes = data.columns.values
        else:
            assert np.all(genes == data.columns)
            
        # record
        i += 1
        metas[sample] = meta
        allcounts[sample] = data
        
        
        
        print(time.time()-ti)
        
    #     break


    output = os.path.join(output_dir, 'processed_merfish_ad_mouse.hdf5')
    if os.path.isfile(output):
        os.remove(output)
    scanpy_output=[]
    meta_output=[]
    count_output=[]
    for sample in metas.keys():
        meta = metas[sample]
        counts = allcounts[sample]
        
        
        meta_p = meta.copy() #.head(1000)
        counts_p = counts.copy() #.head(1000)
        print(meta_p.shape, counts_p.shape)
        
        # step1 - filter cells based on volumes
        med_vol = meta_p['volume'].median()

        if maxVol==-1: # handling maxVol no input
            maxVol=3*med_vol
        if minVol==-1: # handling maxVol no input
            minVol=0.1*med_vol # ref? we currently use manual input so no worries

        #cond1 = (meta_p['volume'] < 100) # edited 072221
        cond1 = (meta_p['volume'] < minVol)
        #cond2 = (meta_p['volume'] > 3*med_vol) # edited 072221
        cond2 = (meta_p['volume'] > maxVol)
        print("remove {} cells".format(cond1.sum()))
        print("remove {} cells".format(cond2.sum()))
        meta_p = meta_p[~(cond1 | cond2)]
        counts_p = counts_p.loc[meta_p.index]
        print(meta_p.shape, counts_p.shape)
        
          
        
        
        # step2 - norm counts by volume
        med_vol = meta_p['volume'].median()
        norm_factor = meta_p['volume']/med_vol
        meta_p['cell_cov_norm'] = meta_p['cell_cov']/norm_factor
        counts_p = counts_p.divide(norm_factor, axis=0)
        
        # step3 - remove outliers
        if minTranscript==-1:
            p2 = np.percentile(meta_p['cell_cov_norm'], 2)
        else:
            p2 = minTranscript/norm_factor

        if maxTranscript==-1:
            p98 = np.percentile(meta_p['cell_cov_norm'], 98)
        else:
            p98 = maxTranscript/norm_factor

        cond3 = (meta_p['cell_cov_norm'] < p2)
        cond4 = (meta_p['cell_cov_norm'] > p98)
        print("remove {} cells".format(cond3.sum()))
        print("remove {} cells".format(cond4.sum()))
        meta_p = meta_p[~(cond3 | cond4)]
        counts_p = counts_p.loc[meta_p.index]
        print(meta_p.shape, counts_p.shape)
        
        # step4 - norm by the median
        cell_cov_norm = meta_p['cell_cov_norm']
        med_cov=cell_cov_norm[cell_cov_norm>0].median() # 050322: sometimes the median is 0, indicate more than half of the cells do not have any expression

        norm_factor = meta_p['cell_cov_norm']/med_cov
        counts_p = counts_p.divide(norm_factor, axis=0)
        print("median norm cov {}".format(med_cov))
        """
        # step5 - log
        counts_p = np.log10(counts_p+1)
        
        """
        
        # step6 - remove nan rows
        is_NaN_counts = counts_p.isnull()
        is_NaN_counts_rows = is_NaN_counts.any(axis=1)
        is_NaN_counts_rows_idx=counts_p.index[is_NaN_counts_rows]
        
        is_NaN_meta = meta_p.isnull()
        is_NaN_meta_rows = is_NaN_meta.any(axis=1)
        is_NaN_meta_rows_idx=meta_p.index[is_NaN_meta_rows]  
        
        is_NaN_rows_idx = is_NaN_counts_rows_idx.append(is_NaN_meta_rows_idx);
        is_NaN_rows_idx = list(set(is_NaN_rows_idx)); # get unique idx
        
        counts_p=counts_p.drop(is_NaN_rows_idx)
        meta_p=meta_p.drop(is_NaN_rows_idx)
        
        print(" ")
        print("Final cell counts, # of genes:")
        print(meta_p.shape[0], ",", counts_p.shape[1])
        
        meta_output.append(meta_p)
        count_output.append(counts_p)

        adata=sc.AnnData(counts_p.values,obs=list(counts_p.index),var=list(counts_p.keys()))
        adata.obs_names=[str(counts_p.index[k]) for k in range(len(counts_p.index))]
        adata.var_names=list(counts_p.keys())
        adata.obs['sample']=[sample]*len(adata.obs_names)
        adata.obs['imagecol']=np.array(meta_p['center_y'])
        adata.obs['imagerow']=np.array(meta_p['center_x'])
        adata.obsm['spatial']=np.hstack((np.array(meta_p['center_y']).reshape((-1,1)),np.array(meta_p['center_x']).reshape((-1,1))))
        adata.obs['fov']=meta_p['fov']
        adata.obs['volume']=meta_p['volume']
        adata.obs['cell_cov']=meta_p['cell_cov']
        adata.obs['cell_cov']=meta_p['cell_cov_norm']
        
        st.pp.filter_genes(adata,min_counts=3)
        sc.pp.filter_cells(adata,min_genes=3)
        
    
        #adata.X=adata.X/np.reshape(np.array(norm_factor)[kept],(-1,1))
        adata.raw=adata
        #sc.pp.normalize_total(adata)
        
        
        
        st.pp.log1p(adata)
        st.pp.scale(adata)
        
        
        
        
        max_size = np.nanmax([adata.obs["imagecol"].dropna().max(), adata.obs["imagerow"].dropna().max()])
        if np.isnan(max_size):
            print([adata.obs["imagecol"].dropna().max(), adata.obs["imagerow"].dropna().max()])
            print('error '+sample)
        max_size = int(max_size + 0.1 * max_size)
        image = Image.new("RGB", (max_size, max_size), (255, 255, 255, 255))
        imgarr = np.array(image)
        library_id = "MERSEQ"
        max_coor = np.max(adata.obsm["spatial"])
        scale = 2000 / max_coor
        spot_diameter_fullres=50
        quality='hires'
        adata.uns["spatial"] = {}
        adata.uns["spatial"][library_id] = {}
        adata.uns["spatial"][library_id]["images"] = {}
        adata.uns["spatial"][library_id]["images"][quality] = imgarr
        adata.uns["spatial"][library_id]["use_quality"] = quality
        adata.uns["spatial"][library_id]["scalefactors"] = {}
        adata.uns["spatial"][library_id]["scalefactors"]["tissue_" + quality + "_scalef"] = scale
        adata.uns["spatial"][library_id]["scalefactors"]["spot_diameter_fullres"] = spot_diameter_fullres
        
        scanpy_output.append(adata)
        
        
    #     break

    # save palette as json

    file = DATA_DIR+'/palette_admouse.json'
    try:
        with open(file, 'w') as f:
            json.dump(palette, f)
    except:
        with open(file,'w') as f:
            json.dump(palette,f)
        
    with open(file, 'r') as f:
        palette = json.load(f)
    return scanpy_output, meta_output, count_output

def assign_cluster_labels_ind_samples(sample,concat,batch_num):
    sample=scanpy[0].copy()
    ind1=list(concat.obs_names)
    suffix=batch_num
    for k in range(len(ind1)):
        temp=ind1[k]
        temp1=temp.split('-')
        ind1[k]=temp1[0]
     
    temp1=list(set(sample.obs_names)&set(ind1))
    sample=sample[temp1,:]
    temp=[temp1[k]+'-'+suffix for k in range(len(temp1))]
    
    
    concat=concat.copy()
    concat=concat[temp,:]
    sample.uns['pca']=concat.uns['pca']
    sample.uns['umap']=concat.uns['umap']
    sample.uns['louvain']=concat.uns['louvain']
    
    
    sample.obsm['X_pca']=concat.obsm['X_pca']
    sample.obsm['X_umap']=concat.obsm['X_umap']
    return sample
    
"""
DATA_DIR = '/cncm/Lab/Kevin_public/merfishTest/background_subtract_test/'

samples = [
    
    '5x12Mo300Gp_region_0'
   
    
]
explore_and_preproc(DATA_DIR=DATA_DIR,use_dca=False,samples=samples,minVol=100,maxVol=5000)
    """