#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 25 23:06:59 2022

@author: exx
"""

import h5py
import pandas as pd
from shapely.geometry import Point
from shapely.geometry import Polygon
import numpy as np
from os import listdir
import os
from os.path import isfile, join
from ismember import ismember
import multiprocessing as mp
from sklearn.mixture import GaussianMixture

pd.options.mode.chained_assignment = None 
mypath='/cncm/Lab/MERFISH_Imaging data/Kim2_202112171955_12172021TREM2-5x12Mo300GP_VMSC00101/region_0/cell_boundaries/'
files = [f for f in listdir(mypath) if isfile(join(mypath, f))]

mask=np.load('/home/exx/Downloads/mask.npy')
transcripts=pd.read_csv('/cncm/Lab/MERFISH_Imaging data/Kim2_202112171955_12172021TREM2-5x12Mo300GP_VMSC00101/region_0/detected_transcripts.csv')
cells=pd.read_csv('/cncm/Lab/MERFISH_Imaging data/Kim2_202112171955_12172021TREM2-5x12Mo300GP_VMSC00101/region_0/5x12Mo300Gp_region_0_cell_by_gene.csv')
metadata=pd.read_csv('/cncm/Lab/MERFISH_Imaging data/Kim2_202112171955_12172021TREM2-5x12Mo300GP_VMSC00101/region_0/5x12Mo300Gp_region_0_cell_metadata.csv')
cell_names=list(cells['Unnamed: 0'])
metadata_names=list(metadata['Unnamed: 0'])


transcripts['global_x_round']=np.round(transcripts['global_x']).astype('int')
transcripts['global_y_round']=np.round(transcripts['global_y']).astype('int')

transcripts1=transcripts.drop(columns=transcripts.keys()[:-2])
transcripts1['counts']=1
transcripts1=transcripts1.groupby(['global_x_round','global_y_round']).sum()
gm=GaussianMixture(2).fit(np.array(transcripts1['counts']).reshape(-1,1))
x=np.array(range(0,1000)).reshape(-1,1)
vals=gm.predict(x)
for j in range(len(x)):
    if vals[j]!=vals[0]:
        break

dense1=np.where(transcripts1['counts']>=j)
indices=list(transcripts1.index)

indices=np.array([list(indices[k]) for k in range(len(indices))])
indices=indices[dense1,:]
mask=np.unique(np.concatenate((mask,indices[0]),axis=0),axis=0)
del indices 
del transcripts1
del dense1










def background_subtract(filename,mask,transcripts,cells):
    
    file=h5py.File(filename,'r')
    keys=list(file['featuredata'].keys())


    for j in range(len(keys)):
        keys1=list(file['featuredata'][keys[j]].keys())
        good=[]
        for l in range(len(keys1)):
            try:
                coord=file['featuredata'][keys[j]][keys1[l]]['p_0']['coordinates']
                poly=Polygon(coord[0])
                mins=np.min(coord[0],0)
                maxs=np.max(coord[0],0)
                points=[[i,j] for i in range(int(np.floor(mins[0]))-1,int(np.ceil(maxs[0]))+1) for j in range(int(np.floor(mins[1]))-1,int(np.ceil(maxs[1]))+1)]
                point1=[Point(points[i]) for i in range(len(points))]
                contains=[poly.contains(point1[k]) for k in range(len(point1))]
                good=good+[points[k] for k in range(len(points)) if contains[k]]
                good+=[[int(np.floor(coord[0][k][0])),int(np.floor(coord[0][k][1]))] for k in range(coord[0].shape[0])]
                good+=[[int(np.floor(coord[0][k][0])),int(np.ceil(coord[0][k][1]))] for k in range(coord[0].shape[0])]
                good+=[[int(np.ceil(coord[0][k][0])),int(np.floor(coord[0][k][1]))] for k in range(coord[0].shape[0])]
                good+=[[int(np.ceil(coord[0][k][0])),int(np.ceil(coord[0][k][1]))] for k in range(coord[0].shape[0])]
                good=np.asarray(good)
                good=np.unique(good,axis=0)
                good=good.tolist()
            
            
            
            except:
                pass
        good=np.array(good)
        good=np.unique(good,axis=0)
        
        index=list(cells['Unnamed: 0'])   
        surround_mins=np.min(good,axis=0)-10
        surround_maxs=np.max(good,axis=0)+10
        
        avail=transcripts.loc[(transcripts['global_x']>=surround_mins[0]) & (transcripts['global_x']<=surround_maxs[0]) & (transcripts['global_y']>=surround_mins[1]) & (transcripts['global_y']<=surround_maxs[1]) ]
        avail.loc[:,'global_y']=np.floor(avail['global_y'])
        avail.loc[:,'global_x']=np.floor(avail['global_x'])
                
        coords=np.concatenate((np.reshape(np.array(avail['global_x'],dtype='int'),(-1,1)),np.reshape(np.array(avail['global_y'],dtype='int'),(-1,1))),axis=1)
        ind=np.where((mask[:,0]>=surround_mins[0])&(mask[:,0]<=surround_maxs[0])&(mask[:,1]>=surround_mins[1])&(mask[:,1]<=surround_maxs[1]))
        temp_mask=mask[ind[0],:]
        A,B=ismember(coords,temp_mask,'rows')
        B=np.setdiff1d(np.array(range(coords.shape[0])),np.where(A))
        temp_transcripts=avail.iloc[B]
        temp_transcripts['count']=1
        temp_transcripts=temp_transcripts.groupby('gene').sum()
        try:
            non_mask_cells=np.unique(coords[B,:],axis=0).shape[0]
            #print(non_mask_cells)
            if non_mask_cells>100:
                temp_transcripts['count']=np.round(good.shape[0]*temp_transcripts['count']/non_mask_cells,decimals=0)
                
                
                index=[k for k in range(len(index)) if index[k]==keys[j]][0]
                
                for gene in temp_transcripts.index:
                    cells.loc[index,gene]=max(0,cells.loc[index,gene]-temp_transcripts.loc[gene,'count'])
        except:
            pass
            
    return cells
        
        


    
                



files=[files[k] for k in range(len(files)) if os.path.isfile(mypath+files[k]) and files[k][-4:]=='hdf5']
new_cells=[]

pool = mp.Pool(40)
for j in range(int(len(files)/40)+1):
    print(str((j)/(len(files)/40)*100)+'% complete')
    if 40*(j)>=len(files):
        break
    
    cells1=[]
    transcripts1=[]
    filenames=[]
    for k in range(40*(j),min(40*(j+1),len(files))):
        
        filenames.append(mypath+files[k])
        file=h5py.File(mypath+files[k],'r')
        keys=file['featuredata'].keys()
        keys=list(keys)
        indices=[k for k in range(len(cell_names)) if cell_names[k] in keys]
        cells1.append(cells.iloc[indices])
        indices1=[k for k in range(len(metadata_names)) if metadata_names[k] in keys]
        temp_metadata=metadata.iloc[indices1]
        cells1[-1].index=list(range(0,cells1[-1].shape[0]))
        maxs=np.array([np.max(temp_metadata['max_x']),np.max(temp_metadata['max_y'])])
        mins=np.array([np.min(temp_metadata['min_x']),np.min(temp_metadata['min_y'])])
        
        maxs+=30
        mins-=30
        transcript_index=np.where((transcripts['global_x']<maxs[0])&(transcripts['global_x']>mins[0])&(transcripts['global_y']<maxs[1])&(transcripts['global_y']>mins[1]))
        
        transcripts1.append(transcripts.iloc[transcript_index[0]])
    
    
    
    input_vars=[(filenames[k],mask,transcripts1[k],cells1[k]) for k in range(len(transcripts1))]
    results = pool.starmap_async(background_subtract, input_vars)
    results=results.get()
    #results = [pool.apply(add_cell_to_mask, args=(filenames[k],)) for k in range(30)]
    #results = [r.get()[1] for r in results]
    
    new_cells.append(pd.concat(results,axis=0))
pool.close()
cells=pd.concat(new_cells,axis=0)



"""
for k in range(len(files)):
    if files[k][-4:]=='hdf5':
        if k%5==0:
            print(str(k/len(files)*100)+'% done')
        try:
            cells=background_subtract(mypath+files[k],mask,transcripts,cells)
        except:
            print('failure')
            print(files[k])
"""
cells.index=cells['Unnamed: 0']
cells=cells.drop(columns='Unnamed: 0')
cells.to_csv('updated_cells_parallel.csv')

