##
## The following notebook reproduces the experiment of the impact of signature gene filtering on score-based cell annotation
## and plot final balanced accuracies across all PBMC subsets used for gene filtering benchmark.
## To run the script please use preprocessed dataset available via "download_preprocessed_datasets.py".
## 

# Run experiment on PBMC data composed of CD8 T cells:
#CD8 Naive
#CD8 Proliferating 
#CD8 TCM 
#CD8 TEM

import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


types = sc.read_h5ad('ANS_supplementary_information/reproduce_project/data/preprocessed/pp_pbmc_cd8_subtypes.h5ad') 


#markers from DEGs:
m = pd.read_csv("ANS_supplementary_information/reproduce_project/data/annotations/citeseq_pbmc/DE_by_celltype.csv",
sep=',')

markers_dict = {
'CD8 Naive' : m['Gene'][m['Cell Type'].isin(['CD8 Naive','CD8 Naive_2'])].unique().tolist(),
'CD8 Proliferating' : m['Gene'][m['Cell Type'].isin(['CD8 Proliferating'])].unique().tolist(),
'CD8 TCM' : m['Gene'][m['Cell Type'].isin(['CD8 TCM_1', 'CD8 TCM_2', 'CD8 TCM_3'])].unique().tolist(),
'CD8 TEM' : m['Gene'][m['Cell Type'].isin(['CD8 TEM_1', 'CD8 TEM_2', 'CD8 TEM_3', 'CD8 TEM_4', 'CD8 TEM_5', 'CD8 TEM_6'])].unique().tolist()
}


gene_means = np.array(types.X.mean(axis=0)).flatten()
gene_df = pd.DataFrame({
    'gene': types.var_names,
    'mean_expr': gene_means
})

gene_df['rank'] = gene_df['mean_expr'].rank(ascending=False, method='dense')
gene_df['percentile'] = 100 * (1 - gene_df['rank'] / gene_df['rank'].max())
gene_df['is_marker_CD8 Naive'] = gene_df['gene'].isin(markers_dict['CD8 Naive'])
gene_df['is_marker_CD8 Proliferating'] = gene_df['gene'].isin(markers_dict['CD8 Proliferating'])
gene_df['is_marker_CD8 TCM'] = gene_df['gene'].isin(markers_dict['CD8 TCM'])
gene_df['is_marker_CD8 TEM'] = gene_df['gene'].isin(markers_dict['CD8 TEM'])


gene_df['expr_bin'] = pd.qcut(gene_df['mean_expr'], q=25, labels=False, duplicates='drop')

gene_df['expr_bin50'] = pd.qcut(gene_df['mean_expr'], q=50, labels=False, duplicates='drop')


#add annotation on top x genes:
n_genes = len(gene_df)

cutoffs = {
    'Top bin (25 bins)': sum(gene_df['expr_bin']==24),
    'Top bin (50 bins)': sum(gene_df['expr_bin50']==49),
    'Top 50 genes': 50,
    'Top 100 genes': 100}

colors = plt.get_cmap('tab10').colors

celltypes = [
    ('CD8 Naive cell marker gene positions among ranked genes', 'is_marker_CD8 Naive', colors[0]),
    ('CD8 Proliferating cell marker gene positions among ranked genes', 'is_marker_CD8 Proliferating', colors[1]),
    ('CD8 TCM cell marker gene positions among ranked genes', 'is_marker_CD8 TCM', colors[2]),
    ('CD8 TEM cell marker gene positions among ranked genes', 'is_marker_CD8 TEM', colors[3]),
]

# Create a single figure with 3 vertical subplots
fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(7, 5), sharex=True)

for ax, (title, marker_col, color) in zip(axes, celltypes):
    for _, row in gene_df.iterrows():
        line_color = color if row[marker_col] else 'white'
        ax.vlines(row['rank'], ymin=0, ymax=1, color=line_color, alpha=0.7, linewidth=0.8)
    
    # Add percentile cutoffs
    for label, x in cutoffs.items():
        ax.axvline(x=x, color='black', linestyle='--', alpha=0.6)
        ax.text(
            x, 1.02, label, rotation=90, va='bottom', ha='center',
            fontsize=8, fontweight='bold'
        )
    
    ax.set_xlim(0, n_genes)
    ax.set_ylim(0, 1.1)
    ax.set_title(title, fontsize=11)
    ax.set_yticks([])

axes[-1].set_xlabel('Gene rank (by mean expression, highest to lowest)')
fig.tight_layout()
plt.savefig("cd8_marker_gene_rank_positions.svg", format='svg', bbox_inches='tight') 
plt.show()

top_n = 2000
gene_df_top = gene_df[gene_df['rank'] <= top_n]

fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(7, 5), sharex=True)

for ax, (title, marker_col, color) in zip(
    axes, 
    [
    ('CD8 Naive cell marker gene positions among ranked genes', 'is_marker_CD8 Naive', colors[0]),
    ('CD8 Proliferating cell marker gene positions among ranked genes', 'is_marker_CD8 Proliferating', colors[1]),
    ('CD8 TCM cell marker gene positions among ranked genes', 'is_marker_CD8 TCM', colors[2]),
    ('CD8 TEM cell marker gene positions among ranked genes', 'is_marker_CD8 TEM', colors[3]),
    ]
):
    for _, row in gene_df_top.iterrows():
        line_color = color if row[marker_col] else 'white'
        ax.vlines(row['rank'], ymin=0, ymax=1, color=line_color, alpha=0.7, linewidth=0.8)
    
    # Add percentile cutoffs
    for label, x in cutoffs.items():
        if x <= top_n:  # only draw cutoff if it's within the top N range
            ax.axvline(x=x, color='black', linestyle='--', alpha=0.6)
            ax.text(
                x, 1.02, label, rotation=90, va='bottom', ha='center',
                fontsize=8, fontweight='bold'
            )
    
    ax.set_xlim(0, top_n)
    ax.set_ylim(0, 1.1)
    ax.set_title(title, fontsize=11)
    ax.set_yticks([])

axes[-1].set_xlabel('Gene rank (by mean expression, highest to lowest)')
plt.tight_layout()
plt.savefig("cd8_marker_gene_rank_positions_top2000.svg", format='svg', bbox_inches='tight') 
plt.show()


#calcualte number of removed markers:
top1 = gene_df[gene_df['expr_bin'] == 24]
top5 = gene_df[gene_df['expr_bin50'] ==49]
top50genes = gene_df[gene_df['rank'] <= 50]
top100genes = gene_df[gene_df['rank'] <= 100]


top1['is_marker_CD8 Naive'].sum()
top1['is_marker_CD8 Proliferating'].sum()
top1['is_marker_CD8 TCM'].sum()
top1['is_marker_CD8 TEM'].sum()

round(top1['is_marker_CD8 Naive'].sum()/gene_df['is_marker_CD8 Naive'].sum()*100,1)
round(top1['is_marker_CD8 Proliferating'].sum()/gene_df['is_marker_CD8 Proliferating'].sum()*100,1)
round(top1['is_marker_CD8 TCM'].sum()/gene_df['is_marker_CD8 TCM'].sum()*100,1)

round(top1['is_marker_CD8 TEM'].sum()/gene_df['is_marker_CD8 TEM'].sum()*100,1)

top5['is_marker_CD8 Naive'].sum()
top5['is_marker_CD8 Proliferating'].sum()
top5['is_marker_CD8 TCM'].sum()
top5['is_marker_CD8 TEM'].sum()

round(top5['is_marker_CD8 Naive'].sum()/gene_df['is_marker_CD8 Naive'].sum()*100,1)
round(top5['is_marker_CD8 Proliferating'].sum()/gene_df['is_marker_CD8 Proliferating'].sum()*100,1)
round(top5['is_marker_CD8 TCM'].sum()/gene_df['is_marker_CD8 TCM'].sum()*100,1)

round(top5['is_marker_CD8 TEM'].sum()/gene_df['is_marker_CD8 TEM'].sum()*100,1)

top50genes['is_marker_CD8 Naive'].sum()
top50genes['is_marker_CD8 Proliferating'].sum()
top50genes['is_marker_CD8 TCM'].sum()
top50genes['is_marker_CD8 TEM'].sum()

round(top50genes['is_marker_CD8 Naive'].sum()/gene_df['is_marker_CD8 Naive'].sum()*100,1)
round(top50genes['is_marker_CD8 Proliferating'].sum()/gene_df['is_marker_CD8 Proliferating'].sum()*100,1)
round(top50genes['is_marker_CD8 TCM'].sum()/gene_df['is_marker_CD8 TCM'].sum()*100,1)

round(top50genes['is_marker_CD8 TEM'].sum()/gene_df['is_marker_CD8 TEM'].sum()*100,1)

top100genes['is_marker_CD8 Naive'].sum()
top100genes['is_marker_CD8 Proliferating'].sum()
top100genes['is_marker_CD8 TCM'].sum()
top100genes['is_marker_CD8 TEM'].sum()

round(top100genes['is_marker_CD8 Naive'].sum()/gene_df['is_marker_CD8 Naive'].sum()*100,1)
round(top100genes['is_marker_CD8 Proliferating'].sum()/gene_df['is_marker_CD8 Proliferating'].sum()*100,1)
round(top100genes['is_marker_CD8 TCM'].sum()/gene_df['is_marker_CD8 TCM'].sum()*100,1)

round(top100genes['is_marker_CD8 TEM'].sum()/gene_df['is_marker_CD8 TEM'].sum()*100,1)

#prepare subsets of signatures:

sig1 = {
'CD8 Naive_1':[x for x in markers_dict['CD8 Naive'] if x not in top1[top1['is_marker_CD8 Naive']]['gene'].tolist()],
'CD8 Proliferating_1':[x for x in markers_dict['CD8 Proliferating'] if x not in top1[top1['is_marker_CD8 Proliferating']]['gene'].tolist()],
'CD8 TCM_1':[x for x in markers_dict['CD8 TCM'] if x not in top1[top1['is_marker_CD8 TCM']]['gene'].tolist()],
'CD8 TEM_1':[x for x in markers_dict['CD8 TEM'] if x not in top1[top1['is_marker_CD8 TEM']]['gene'].tolist()]
}

sig5 = {
'CD8 Naive_5':[x for x in markers_dict['CD8 Naive'] if x not in top5[top5['is_marker_CD8 Naive']]['gene'].tolist()],
'CD8 Proliferating_5':[x for x in markers_dict['CD8 Proliferating'] if x not in top5[top5['is_marker_CD8 Proliferating']]['gene'].tolist()],
'CD8 TCM_5':[x for x in markers_dict['CD8 TCM'] if x not in top5[top5['is_marker_CD8 TCM']]['gene'].tolist()],
'CD8 TEM_5':[x for x in markers_dict['CD8 TEM'] if x not in top5[top5['is_marker_CD8 TEM']]['gene'].tolist()]
}

sig50g = {
'CD8 Naive_50g':[x for x in markers_dict['CD8 Naive'] if x not in top50genes[top50genes['is_marker_CD8 Naive']]['gene'].tolist()],
'CD8 Proliferating_50g':[x for x in markers_dict['CD8 Proliferating'] if x not in top50genes[top50genes['is_marker_CD8 Proliferating']]['gene'].tolist()],
'CD8 TCM_50g':[x for x in markers_dict['CD8 TCM'] if x not in top50genes[top50genes['is_marker_CD8 TCM']]['gene'].tolist()],
'CD8 TEM_50g':[x for x in markers_dict['CD8 TEM'] if x not in top50genes[top50genes['is_marker_CD8 TEM']]['gene'].tolist()]
}

sig100g = {
'CD8 Naive_100g':[x for x in markers_dict['CD8 Naive'] if x not in top100genes[top100genes['is_marker_CD8 Naive']]['gene'].tolist()],
'CD8 Proliferating_100g':[x for x in markers_dict['CD8 Proliferating'] if x not in top100genes[top100genes['is_marker_CD8 Proliferating']]['gene'].tolist()],
'CD8 TCM_100g':[x for x in markers_dict['CD8 TCM'] if x not in top100genes[top100genes['is_marker_CD8 TCM']]['gene'].tolist()],
'CD8 TEM_100g':[x for x in markers_dict['CD8 TEM'] if x not in top100genes[top100genes['is_marker_CD8 TEM']]['gene'].tolist()]
}


#score cells with respective signatures:

import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, f1_score, jaccard_score
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import signaturescoring as ssc

sys.path.append('ANS_supplementary_information/data')

from constants import (DATASETS_WITH_ANNOTATIONS, BASE_PATH_RESULTS, SCORING_METHODS,
                       METHOD_WITH_GENE_POOL, VIOLIN_PLOT_CONFIG)

SCORING_METHODS[0] = {'scoring_method': 'adjusted_neighborhood_scoring', 
'sc_params': {'ctrl_size': 100, 'score_name': 'ANS', 'remove_genes_with_invalid_control_set':False}}
SCORING_METHODS = SCORING_METHODS[0:5]

def score_signatures_with_all_methods(adata, signatures, use_gene_pool: bool, verbose=False):
    if use_gene_pool:
        gene_pool = _get_gene_pool(adata, signatures)
    else:
        gene_pool = None
    added_cols = defaultdict(list)
    for sig_name, sig in signatures.items():
        for sc_method in SCORING_METHODS:
            scoring_method = sc_method['scoring_method']
            sc_params = sc_method['sc_params'].copy()
            if verbose:
                print(f"Scoring {sig_name} with {scoring_method}")
            col_name = sig_name.replace(' ', '-')
            prev_score_name = sc_params['score_name']
            sc_params['score_name'] = f"{col_name}_{sc_params['score_name']}_scores"
            if use_gene_pool and scoring_method in METHOD_WITH_GENE_POOL:
                sc_params['gene_pool'] = gene_pool
            ssc.score_signature(
                method=scoring_method,
                adata=adata,
                gene_list=sig,
                **sc_params
            )
            added_cols[prev_score_name].append(sc_params['score_name'])
    return added_cols, adata


score_signatures_with_all_methods(types, markers_dict, False)
score_signatures_with_all_methods(types, sig1, False)
score_signatures_with_all_methods(types, sig5, False)
score_signatures_with_all_methods(types, sig50g, False)
score_signatures_with_all_methods(types, sig100g, False)

types.obs.to_csv('cd8_scores.csv')


style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy'
    #'Jasmine_LH',
    #'Jasmine_OR',
    #'UCell'
]

import matplotlib.pyplot as plt
import seaborn as sns
# Extract relevant columns
score_cols = [c for c in types.obs.columns if c.endswith("_scores")]


df_long = types.obs.melt(
    id_vars=["celltype.l2"],
    value_vars=score_cols,
    var_name="score_type",
    value_name="score"
)


df_long["tool"] = df_long["score_type"].str.extract(r'_(.*?)_scores')


import seaborn as sns
n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)

for ax, tool in zip(axes, style_order):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)

for ax, tool in zip(axes, ["1_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)
for ax, tool in zip(axes, ["5_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()



n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)
for ax, tool in zip(axes, ["50g_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


#for each method, assign label based on max score:
methods = ["ANS", "Seurat", "Seurat_AG", "Seurat_LVG", "Scanpy"]
groups = types.obs['celltype.l2'].unique().tolist()
groups = [x.replace(" ", "-") for x in groups]


for method in methods:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]



for method in ["1_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]

for method in ["5_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


for method in ["50g_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


for method in ["100g_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]

#calculate accuracy:
true_labels = types.obs["celltype.l2"] 
true_labels = [x.replace(" ", "-") for x in true_labels]

balanced_acc_results = {}

for method in methods:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc


for method in ["1_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

for method in ["5_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc
    
   
for method in ["50g_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

for method in ["100g_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

# Convert to DataFrame for easy plotting
balanced_acc_df = pd.DataFrame.from_dict(balanced_acc_results, orient='index', columns=["balanced_accuracy"])
balanced_acc_df

balanced_acc_df.to_csv('cd8_balanced_accuracy.csv')


#Plot all results across cell types:

files = ['b_nk_mono_balanced_accuracy.csv',
'b_balanced_accuracy.csv',
'b6_balanced_accuracy.csv',
'cd4_balanced_accuracy.csv',
'cd8_balanced_accuracy.csv']

all = pd.concat([pd.read_csv(files[0], index_col=0),pd.read_csv(files[1], index_col=0),
pd.read_csv(files[2], index_col=0),pd.read_csv(files[3], index_col=0),pd.read_csv(files[4], index_col=0)],
axis=1)

all.columns = ['B_NK_Mono', 'B_3', 'B_6', 'CD4', 'CD8']

group =['\nNo filtering']*5 + ['\nTop bin (25 bins)']*5+ ['\nTop bin (50 bins)']*5+ ['\nTop 50 genes']*5+ ['\nTop 100 genes']*5
all['group'] = group

tools = ['ANS', 'Seurat', 'Seurat_AG', 'Seurat_LVG', 'Scanpy']*5
all['tools'] = tools

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

df = all.reset_index(drop=True)


plot_df = df.melt(
    id_vars=['group', 'tools'],
    value_vars=['B_NK_Mono', 'B_3', 'B_6', 'CD4', 'CD8'],
    var_name='Cell_type',
    value_name='Balanced accuracy'
)

plot_df.loc[plot_df['Cell_type'] == "B_NK_Mono",'Cell_type'] = "B NK Monocytes"
plot_df.loc[plot_df['Cell_type'] == "B_3",'Cell_type'] = "B cells: 3 subtypes"
plot_df.loc[plot_df['Cell_type'] == "B_6",'Cell_type'] = "B cells: 6 subtypes"
plot_df.loc[plot_df['Cell_type'] == "CD4",'Cell_type'] = "CD4 T cells: 6 subtypes"
plot_df.loc[plot_df['Cell_type'] == "CD8",'Cell_type'] = "CD8 T cells: 4 subtypes"

group_order = [
    '\nNo filtering',
    '\nTop 50 genes',
    '\nTop 100 genes',
    '\nTop bin (50 bins)',
    '\nTop bin (25 bins)'
]
sns.set_context("talk", font_scale=0.9)

g = sns.catplot(
    data=plot_df,
    x='tools',
    y='Balanced accuracy',
    hue='tools',
    col='group',
    row='Cell_type',
    kind='bar',
    height=3,
    aspect=1.2,
    palette='Set2',
    col_order=group_order,
    width=0.6
)


g.set_titles(row_template="{row_name}", col_template="{col_name}")
g.set_axis_labels("", "Balanced\naccuracy")
for ax in g.axes.flat:
    ax.tick_params(axis='x', rotation=45)

g.tight_layout()
g.fig.set_size_inches(10, 10)
plt.savefig("gene_filtering_balanced_accuracy.svg", format='svg', bbox_inches='tight')
plt.show()
