##
## The following notebook reproduces the score-based cell type annotation benchmark using
## neuronal differentiation data and was used to generate PHATE embeddings of ANS scores.
## To run the script please download the data from  https://zenodo.org/records/4333872.
##

import scanpy as sc
import pandas as pd
import h5py
import anndata as ad
import numpy as np

outputPath = 'neuronal/'

data = sc.read_h5ad('all_timepoints_subsampled.h5') #downloaded from zenodo

#subsample further
fraction = 0.4
n = int(round(data.n_obs * fraction))
sampled = data.obs.sample(n=n, random_state=123).index
data = data[sampled].copy()

sc.pp.log1p(data)

#plot umap:
sc.tl.pca(data)
sc.pp.neighbors(data, n_neighbors=10, n_pcs=50)
sc.tl.umap(data)
sc.pl.umap(data,color=['time_point', 'celltype', 'pool_id'], ncols=3)


#to get marker genes:
sc.tl.rank_genes_groups(
    data,
    groupby="celltype",      
    method="wilcoxon",       
    n_genes=100             
)

result = data.uns['rank_genes_groups']
groups = result['names'].dtype.names

markers_long = []

for group in groups:
    df = pd.DataFrame({
        "gene": result["names"][group],
        "pval": result["pvals_adj"][group],
        "logFC": result["logfoldchanges"][group],
    })
    df["celltype"] = group
    markers_long.append(df)
    
markers_long = pd.concat(markers_long)
markers_long.reset_index(drop=True, inplace=True)

#save identified marker genes:
#markers_long.to_csv('neuronal/markers.csv', sep='\t', index=False)

markers_long = pd.read_csv('neuronal/markers.csv', sep='\t')

marker_genes_dict = {
    celltype: df.sort_values("pval")["gene"].tolist()
    for celltype, df in markers_long.groupby("celltype")
}


#score using all methods:

import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, f1_score, jaccard_score
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import signaturescoring as ssc

sys.path.append('ANS_supplementary_information/data')

from constants import (DATASETS_WITH_ANNOTATIONS, BASE_PATH_RESULTS, SCORING_METHODS,
                       METHOD_WITH_GENE_POOL, VIOLIN_PLOT_CONFIG)

def score_signatures_with_all_methods(adata, signatures, use_gene_pool: bool, verbose=False):
    if use_gene_pool:
        gene_pool = _get_gene_pool(adata, signatures)
    else:
        gene_pool = None
    added_cols = defaultdict(list)
    for sig_name, sig in signatures.items():
        for sc_method in SCORING_METHODS:
            scoring_method = sc_method['scoring_method']
            sc_params = sc_method['sc_params'].copy()
            if verbose:
                print(f"Scoring {sig_name} with {scoring_method}")
            col_name = sig_name.replace(' ', '_')
            prev_score_name = sc_params['score_name']
            sc_params['score_name'] = f"{col_name}_{sc_params['score_name']}_scores"
            if use_gene_pool and scoring_method in METHOD_WITH_GENE_POOL:
                sc_params['gene_pool'] = gene_pool
            ssc.score_signature(
                method=scoring_method,
                adata=adata,
                gene_list=sig,
                **sc_params
            )
            added_cols[prev_score_name].append(sc_params['score_name'])
    return added_cols, adata
               

#plot distribution of scores:

score_signatures_with_all_methods(data, marker_genes_dict, False)

#save scored object:
#data.write('neuronal/scored.h5ad')

data = sc.read_h5ad('neuronal/scored.h5ad')

style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy',
    'Jasmine_LH',
    'Jasmine_OR',
    'UCell'
]

import matplotlib.pyplot as plt
import seaborn as sns

score_cols = [c for c in data.obs.columns if c.endswith("_scores")]

#set the names:
replacements = {
    'P_FPP': 'P-FPP',
    'P_Sert' : 'P-Sert',
    'U_Neur1' : 'U-Neur1',
    'U_Neur2' : 'U-Neur2',
    'U_Neur3' : 'U-Neur3'
}

new_list = [
    item if not any(key in item for key in replacements) else
    next(replacements[key] + item[len(key):] for key in replacements if key in item)
    for item in score_cols
]

data.obs.columns = data.obs.columns[:8].tolist() + new_list
data.obs['celltype'] = data.obs['celltype'].replace(replacements)

df_long = data.obs.melt(
    id_vars=["celltype"],
    value_vars=new_list,
    var_name="score_type",
    value_name="score"
)

# Extract tool names (e.g., ANS, Seurat, etc.)
df_long["tool"] = df_long["score_type"].str.extract(r'_(.*?)_scores')


#plot distribution of scores for each tool:
for tool in style_order:
    subset = df_long[df_long["tool"] == tool]
    plt.figure(figsize=(8, 5))
    sns.violinplot(
        data=subset,
        x="celltype",
        y="score",
        hue="score_type",
        inner="quartile",
        cut=0,
        #palette="Set2",
        #alpha=0.6
        fill=False
    )
    plt.xticks(rotation=45, ha='right', fontsize=16)
    plt.yticks(fontsize=16)
    plt.title(f"{tool} scores per group", fontsize=16)
    plt.xlabel("")
    plt.ylabel("Score", fontsize=16)
    plt.legend().remove()
    plt.tight_layout()
    plt.savefig(path+f"{tool}_violinplot.svg", format='svg', bbox_inches='tight')
    plt.close()
#plt.show()


#for each method, assign label based on max score:
methods = ["ANS", "Seurat", "Seurat_AG", "Seurat_LVG", "Scanpy", "Jasmine_LH", "Jasmine_OR", "UCell"]
groups = data.obs['celltype'].unique().tolist()

for method in methods:
    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in data.obs.columns]
    data.obs[f"{method}_max_group"] = data.obs[cols].idxmax(axis=1)
    data.obs[f"{method}_max_group"] = data.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


from sklearn.metrics import confusion_matrix

#confusion matrix:
method = "ANS"  # example
pred = data.obs[f"{method}_max_group"]
cm = confusion_matrix(data.obs.celltype, pred, labels=np.unique(data.obs.celltype))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=np.unique(data.obs.celltype), 
yticklabels=np.unique(data.obs.celltype))
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title(f"Confusion matrix for {method}")
plt.show()

#calculate accuracy:
true_labels = data.obs["celltype"]

balanced_acc_results = {}

for method in methods:
    pred_labels = data.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

# Convert to DataFrame for easy plotting
balanced_acc_df = pd.DataFrame.from_dict(balanced_acc_results, orient='index', columns=["balanced_accuracy"])
balanced_acc_df

import matplotlib.pyplot as plt
import seaborn as sns

colors = plt.get_cmap('tab10').colors

plt.figure(figsize=(5,5))
sns.barplot(x=balanced_acc_df.index, y="balanced_accuracy", data=balanced_acc_df)
plt.xticks(rotation=45, size=16)
plt.yticks(size=16)
plt.ylabel("Balanced Accuracy", size=16)
plt.title("Balanced Accuracy per scoring method", size=16)
plt.ylim(0,1)
plt.tight_layout()
plt.savefig(path+"neuronal_balanced_accuracy_all.svg", format='svg', bbox_inches='tight')
plt.show()


#get list of groups
groups = true_labels.unique()


from sklearn.metrics import recall_score

per_group_results = {method: {} for method in methods}

for method in methods:
    pred_labels = data.obs[f"{method}_max_group"]
    for group in groups:
        #compute recall for this group (i.e., correctly predicted / total in group)
        recall = recall_score(true_labels, pred_labels, labels=[group], average='macro')
        per_group_results[method][group] = recall

#convert to DataFrame for easier plotting
per_group_df = pd.DataFrame(per_group_results).T  # rows: methods, columns: groups
per_group_df

plt.figure(figsize=(10, 6))

ax = sns.heatmap(
    per_group_df,
    annot=True,
    fmt=".2f",
    cmap="Blues",
    annot_kws={"fontsize": 14}
)

plt.ylabel("Method", fontsize=16)
plt.xlabel("Group", fontsize=16)
plt.title("Balanced Accuracy per group per scoring method", fontsize=16)

plt.xticks(rotation=45, fontsize=16, ha='right')
plt.yticks(fontsize=16)


cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)

plt.tight_layout()
plt.savefig(path + "neuronal_balanced_accuracy_pergroup.svg", format='svg', bbox_inches='tight')
plt.show()

#use ANS scores for trajectory inference

import phate

cols = [f"{g}_ANS_scores" for g in groups if f"{g}_ANS_scores" in data.obs.columns]

#exclude treated cells
data2 = data[data.obs["treatment"]=="NONE"]
data2_emb = data2.obs[cols]

phate_operator = phate.PHATE(n_jobs=-1, random_state=123)
phate_embedding = phate_operator.fit_transform(data2_emb)
data2.obsm['X_phate'] = phate_embedding
sc.pl.embedding(data2, basis='X_phate', color=['celltype', 'time_point'], frameon=False, show=False)

fig = plt.gcf()
fig.savefig(path + "ANS_phate_embedding.svg", format='svg', bbox_inches='tight')
plt.close(fig)

#FPP and P_FPP are the progenitors at time 0, use for pseudotime calculation:
root_idx = np.where(data2.obs['celltype'].isin(['FPP', 'P_FPP']))[0][0]  

from scipy.spatial.distance import cdist

pseudotime = cdist(phate_embedding, phate_embedding[[root_idx]], metric='euclidean').flatten()
data2.obs['pseudotime'] = pseudotime

sc.pl.embedding(data2, basis='X_phate', color=['celltype', 'time_point', 'pseudotime'], frameon=False)

#use input gene expression for trajectory inference:

phate_embedding_gex = phate_operator.fit_transform(data2.X)
data2.obsm['X_phate_gex'] = phate_embedding_gex
pseudotime_gex = cdist(phate_embedding_gex, phate_embedding_gex[[root_idx]], metric='euclidean').flatten()
data2.obs['pseudotime_gex'] = pseudotime_gex

sc.pl.embedding(data2, basis='X_phate_gex', color=['celltype', 'time_point', 'pseudotime_gex'], frameon=False)

#combine both plots:

import matplotlib.image as mpimg

data2.obs.rename(columns={
    "celltype": "Ground truth cell type",
    "time_point": "Time point",
    "pseudotime": "Pseudotime from ANS score embedding",
    "pseudotime_gex" : "Pseudotime from GEX embedding"
}, inplace=True)


sc.settings.set_figure_params(dpi=300, dpi_save=300, format='png')
sc.pl.embedding(
    data2,
    basis='X_phate',
    color=['Ground truth cell type', 'Time point', 'Pseudotime from ANS score embedding'],
    frameon=False,
    show=False,
    save="_phate.png"
)


sc.settings.set_figure_params(dpi=300, dpi_save=300, format='png')
sc.pl.embedding(
    data2,
    basis='X_phate_gex',
    color=['Ground truth cell type', 'Time point', 'Pseudotime from GEX embedding'],
    frameon=False,
    show=False,
    save="_phate_gex.png"
)

#combine the two panels into one figure
fig, axes = plt.subplots(2,1, figsize=(10,9), constrained_layout = True)


a_img = mpimg.imread("figures/X_phate_phate.png")
b_img = mpimg.imread("figures/X_phate_gex_phate_gex.png")

axes[0].imshow(a_img)
axes[0].axis("off")
axes[0].set_title("a PHATE ANS embedding", fontsize=14, fontweight = "bold",loc='left')

axes[1].imshow(b_img)
axes[1].axis("off")
axes[1].set_title("b PHATE GEX embedding", fontsize=14, fontweight = "bold", loc='left')

plt.tight_layout()
plt.savefig(path + "combined_PHATE_panels.svg", format='svg', bbox_inches='tight')
plt.savefig(path + "combined_PHATE_panels.png", format='png', bbox_inches='tight', dpi=300)
#plt.show()


#plot barplots with proportions of cells in the dataset
prop = (
    data.obs['celltype']
    .value_counts(normalize=True)
    .sort_values(ascending=False)
)


fig, ax = plt.subplots(figsize=(6, 6))


wedges, texts = ax.pie(
    prop,
    labels=None,
    startangle=90,
    counterclock=False,
    wedgeprops=dict(width=0.4)  # donut shape
)

#legend labels with percentages
legend_labels = [f'{ct} ({p*100:.1f}%)' for ct, p in zip(prop.index, prop.values)]

ax.legend(
    wedges,
    legend_labels,
    title='Cell type',
    loc='center left',
    bbox_to_anchor=(1, 0, 0.5, 1)
)

ax.set_title('Cell type composition')
plt.tight_layout()
plt.savefig(path + "neuronal_cell_type_composition.svg", format='svg', bbox_inches='tight')
plt.show()


#visualize similarities between cell types:

#compute mean expression per cell type

mean_profiles = pd.DataFrame(
    data.to_df().groupby(data.obs['celltype']).mean()
)
mean_profiles.shape  


from scipy.spatial.distance import pdist, squareform

#compute pairwise distances (1 - correlation)

dist_matrix = pdist(mean_profiles, metric=lambda u, v: 1 - np.corrcoef(u, v)[0, 1])
dist_square = pd.DataFrame(
    squareform(dist_matrix),
    index=mean_profiles.index,
    columns=mean_profiles.index
)


sns.clustermap(
    dist_square,
    metric='correlation',
    method='average',   
    cmap='vlag',
    linewidths=0.5,
    figsize=(8, 8)
)
plt.title("Hierarchical similarity of cell types")
plt.show()

from scipy.cluster.hierarchy import linkage, dendrogram

Z = linkage(dist_matrix, method='average')

plt.figure(figsize=(8, 4))
dendrogram(Z, labels=mean_profiles.index, leaf_rotation=90)
plt.title('Hierarchical clustering of cell types', size=16)
plt.ylabel('1 - correlation (Pearson)', size=16)
plt.xticks(size=16)
plt.yticks(size=16)
plt.savefig(path + "neuronal_cell_similarity.svg", format='svg', bbox_inches='tight')
plt.show()
