import numpy as np
import stlearn as st
import scanpy as sc
from PIL import Image
def Ann_data_gen(counts_p,meta_p,sample,samples_annot=[]):

    adata=sc.AnnData(counts_p.values,obs=list(counts_p.index),var=list(counts_p.keys()))
    adata.obs_names=[str(counts_p.index[k]) for k in range(len(counts_p.index))]
    adata.var_names=list(counts_p.keys())

    if len(sample)==1:
        adata.obs['sample']=[sample]*len(adata.obs_names)
    else:
        adata.obs['sample']=sample

    adata.obs['imagecol']=np.array(meta_p['center_x'])
    adata.obs['imagerow']=np.array(meta_p['center_y'])
    adata.obsm['spatial']=np.hstack((np.array(meta_p['center_x']).reshape((-1,1)),np.array(meta_p['center_y']).reshape((-1,1))))
    adata.obs['fov']=np.array(meta_p['fov'])
    adata.obs['volume']=np.array(meta_p['volume'])
    adata.obs['cell_cov']=np.array(meta_p['cell_cov_norm'])

    if len(samples_annot)==1:
        adata.obs['gen']=[samples_annot]*len(adata.obs['sample'])
    elif len(samples_annot)==0:
        adata.obs['gen']=adata.obs['sample']
    else:
        adata.obs['gen']=samples_annot

    idx_prefilter=adata.obs.index

    #st.pp.filter_genes(adata,min_counts=3)
    #sc.pp.filter_cells(adata,min_genes=3)
    
    idx_postfilter=adata.obs.index
    idx_del=list(idx_prefilter.difference(idx_postfilter))
    
    idx_del_num=[]
    for i in range(0,len(idx_del)):
        idx_del_num.append(list(idx_prefilter).index(idx_del[i]))

    adata.raw=adata
    
    st.pp.log1p(adata)
    st.pp.scale(adata)
          
    max_size = np.nanmax([adata.obs["imagecol"].dropna().max(), adata.obs["imagerow"].dropna().max()])
    if np.isnan(max_size):
        print([adata.obs["imagecol"].dropna().max(), adata.obs["imagerow"].dropna().max()])
        print('error '+sample)
    max_size = int(max_size + 0.1 * max_size)
    image = Image.new("RGB", (max_size, max_size), (255, 255, 255, 255))
    imgarr = np.array(image)
    library_id = "MERSEQ"
    max_coor = np.max(adata.obsm["spatial"])
    scale = 2000 / max_coor
    spot_diameter_fullres=50
    quality='hires'
    adata.uns["spatial"] = {}
    adata.uns["spatial"][library_id] = {}
    adata.uns["spatial"][library_id]["images"] = {}
    adata.uns["spatial"][library_id]["images"][quality] = imgarr
    adata.uns["spatial"][library_id]["use_quality"] = quality
    adata.uns["spatial"][library_id]["scalefactors"] = {}
    adata.uns["spatial"][library_id]["scalefactors"]["tissue_" + quality + "_scalef"] = scale
    adata.uns["spatial"][library_id]["scalefactors"]["spot_diameter_fullres"] = spot_diameter_fullres

    return adata, idx_del, idx_del_num