##
## The following notebook reproduces the experiment of the impact of signature gene filtering on score-based cell annotation. 
## To run the script please use preprocessed dataset available via "download_preprocessed_datasets.py".
##

# Run experiment on PBMC data composed of B cells:
#B intermediate lambda 
#B intermediate kappa 
#B memory lambda 
#B memory kappa 
#B naive lambda 
#B naive kappa

import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


types = sc.read_h5ad('ANS_supplementary_information/reproduce_project/data/preprocessed/pp_pbmc_b_subtypes.h5ad') 


#Markers from DEGs:
m = pd.read_csv("ANS_supplementary_information/reproduce_project/data/annotations/citeseq_pbmc/DE_by_celltype.csv",
sep=',')

markers_dict = {
'B intermediate lambda' : m['Gene'][m['Cell Type'].isin(['B intermediate lambda'])].unique().tolist(),
'B intermediate kappa' : m['Gene'][m['Cell Type'].isin(['B intermediate kappa'])].unique().tolist(),
'B memory lambda' : m['Gene'][m['Cell Type'].isin(['B memory lambda'])].unique().tolist(),
'B memory kappa' : m['Gene'][m['Cell Type'].isin(['B memory kappa'])].unique().tolist(),
'B naive lambda' : m['Gene'][m['Cell Type'].isin(['B naive lambda'])].unique().tolist(),
'B naive kappa' : m['Gene'][m['Cell Type'].isin(['B naive kappa'])].unique().tolist()
}


gene_means = np.array(types.X.mean(axis=0)).flatten()
gene_df = pd.DataFrame({
    'gene': types.var_names,
    'mean_expr': gene_means
})

gene_df['rank'] = gene_df['mean_expr'].rank(ascending=False, method='dense')
gene_df['percentile'] = 100 * (1 - gene_df['rank'] / gene_df['rank'].max())
gene_df['is_marker_int_kappa'] = gene_df['gene'].isin(markers_dict['B intermediate kappa'])
gene_df['is_marker_memory_kappa'] = gene_df['gene'].isin(markers_dict['B memory kappa'])
gene_df['is_marker_naive_kappa'] = gene_df['gene'].isin(markers_dict['B naive kappa'])
gene_df['is_marker_int_lambda'] = gene_df['gene'].isin(markers_dict['B intermediate lambda'])
gene_df['is_marker_memory_lambda'] = gene_df['gene'].isin(markers_dict['B memory lambda'])
gene_df['is_marker_naive_lambda'] = gene_df['gene'].isin(markers_dict['B naive lambda'])


gene_df['expr_bin'] = pd.qcut(gene_df['mean_expr'], q=25, labels=False, duplicates='drop')

gene_df['expr_bin50'] = pd.qcut(gene_df['mean_expr'], q=50, labels=False, duplicates='drop')


#add annotation on top x genes:
n_genes = len(gene_df)

cutoffs = {
    'Top bin (25 bins)': sum(gene_df['expr_bin']==24),
    'Top bin (50 bins)': sum(gene_df['expr_bin50']==49),
    'Top 50 genes': 50,
    'Top 100 genes': 100}

colors = plt.get_cmap('tab10').colors

celltypes = [
    ('B intermediate kappa cell marker gene positions among ranked genes', 'is_marker_int_kappa', colors[0]),
    ('B memory kappa cell marker gene positions among ranked genes', 'is_marker_memory_kappa', colors[1]),
    ('B naive kappa cell marker gene positions among ranked genes', 'is_marker_naive_kappa', colors[2]),
    ('B intermediate lambda cell marker gene positions among ranked genes', 'is_marker_int_lambda', colors[3]),
    ('B memory lambda cell marker gene positions among ranked genes', 'is_marker_memory_lambda', colors[4]),
    ('B naive lambda cell marker gene positions among ranked genes', 'is_marker_naive_lambda', colors[5])
]

# Create a single figure with 3 vertical subplots
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(7, 5), sharex=True)

for ax, (title, marker_col, color) in zip(axes, celltypes):
    for _, row in gene_df.iterrows():
        line_color = color if row[marker_col] else 'white'
        ax.vlines(row['rank'], ymin=0, ymax=1, color=line_color, alpha=0.7, linewidth=0.8)
    
    # Add percentile cutoffs
    for label, x in cutoffs.items():
        ax.axvline(x=x, color='black', linestyle='--', alpha=0.6)
        ax.text(
            x, 1.02, label, rotation=90, va='bottom', ha='center',
            fontsize=8, fontweight='bold'
        )
    
    ax.set_xlim(0, n_genes)
    ax.set_ylim(0, 1.1)
    ax.set_title(title, fontsize=11)
    ax.set_yticks([])

axes[-1].set_xlabel('Gene rank (by mean expression, highest to lowest)')
fig.tight_layout()
plt.savefig("b6_marker_gene_rank_positions.svg", format='svg', bbox_inches='tight') 
plt.show()


top_n = 2000
gene_df_top = gene_df[gene_df['rank'] <= top_n]

fig, axes = plt.subplots(nrows=6, ncols=1, figsize=(7, 5), sharex=True)

for ax, (title, marker_col, color) in zip(
    axes, 
    [
      ('B intermediate kappa cell marker gene positions among ranked genes', 'is_marker_int_kappa', colors[0]),
    ('B memory kappa cell marker gene positions among ranked genes', 'is_marker_memory_kappa', colors[1]),
    ('B naive kappa cell marker gene positions among ranked genes', 'is_marker_naive_kappa', colors[2]),
    ('B intermediate lambda cell marker gene positions among ranked genes', 'is_marker_int_lambda', colors[3]),
    ('B memory lambda cell marker gene positions among ranked genes', 'is_marker_memory_lambda', colors[4]),
    ('B naive lambda cell marker gene positions among ranked genes', 'is_marker_naive_lambda', colors[5])
    ]
):
    for _, row in gene_df_top.iterrows():
        line_color = color if row[marker_col] else 'white'
        ax.vlines(row['rank'], ymin=0, ymax=1, color=line_color, alpha=0.7, linewidth=0.8)
    
    # Add percentile cutoffs
    for label, x in cutoffs.items():
        if x <= top_n:  # only draw cutoff if it's within the top N range
            ax.axvline(x=x, color='black', linestyle='--', alpha=0.6)
            ax.text(
                x, 1.02, label, rotation=90, va='bottom', ha='center',
                fontsize=8, fontweight='bold'
            )
    
    ax.set_xlim(0, top_n)
    ax.set_ylim(0, 1.1)
    ax.set_title(title, fontsize=11)
    ax.set_yticks([])

axes[-1].set_xlabel('Gene rank (by mean expression, highest to lowest)')
plt.tight_layout()
plt.savefig("b6_marker_gene_rank_positions_top2000.svg", format='svg', bbox_inches='tight') 
plt.show()

#calcualte number of removed markers:
top1 = gene_df[gene_df['expr_bin'] == 24]
top5 = gene_df[gene_df['expr_bin50'] ==49]
top50genes = gene_df[gene_df['rank'] <= 50]
top100genes = gene_df[gene_df['rank'] <= 100]


top1.is_marker_int_kappa.sum()
top1.is_marker_memory_kappa.sum()
top1.is_marker_naive_kappa.sum()

top1.is_marker_int_lambda.sum()
top1.is_marker_memory_lambda.sum()
top1.is_marker_naive_lambda.sum()


round(top1.is_marker_int_kappa.sum()/gene_df.is_marker_int_kappa.sum()*100,1)
round(top1.is_marker_memory_kappa.sum()/gene_df.is_marker_memory_kappa.sum()*100,1)
round(top1.is_marker_naive_kappa.sum()/gene_df.is_marker_naive_kappa.sum()*100,1)

round(top1.is_marker_int_lambda.sum()/gene_df.is_marker_int_lambda.sum()*100,1)
round(top1.is_marker_memory_lambda.sum()/gene_df.is_marker_memory_lambda.sum()*100,1)
round(top1.is_marker_naive_lambda.sum()/gene_df.is_marker_naive_lambda.sum()*100,1)


top5.is_marker_int_kappa.sum()
top5.is_marker_memory_kappa.sum()
top5.is_marker_naive_kappa.sum()

top5.is_marker_int_lambda.sum()
top5.is_marker_memory_lambda.sum()
top5.is_marker_naive_lambda.sum()


round(top5.is_marker_int_kappa.sum()/gene_df.is_marker_int_kappa.sum()*100,1)
round(top5.is_marker_memory_kappa.sum()/gene_df.is_marker_memory_kappa.sum()*100,1)
round(top5.is_marker_naive_kappa.sum()/gene_df.is_marker_naive_kappa.sum()*100,1)

round(top5.is_marker_int_lambda.sum()/gene_df.is_marker_int_lambda.sum()*100,1)
round(top5.is_marker_memory_lambda.sum()/gene_df.is_marker_memory_lambda.sum()*100,1)
round(top5.is_marker_naive_lambda.sum()/gene_df.is_marker_naive_lambda.sum()*100,1)

top50genes.is_marker_int_kappa.sum()
top50genes.is_marker_memory_kappa.sum()
top50genes.is_marker_naive_kappa.sum()

top50genes.is_marker_int_lambda.sum()
top50genes.is_marker_memory_lambda.sum()
top50genes.is_marker_naive_lambda.sum()


round(top50genes.is_marker_int_kappa.sum()/gene_df.is_marker_int_kappa.sum()*100,1)
round(top50genes.is_marker_memory_kappa.sum()/gene_df.is_marker_memory_kappa.sum()*100,1)
round(top50genes.is_marker_naive_kappa.sum()/gene_df.is_marker_naive_kappa.sum()*100,1)

round(top50genes.is_marker_int_lambda.sum()/gene_df.is_marker_int_lambda.sum()*100,1)
round(top50genes.is_marker_memory_lambda.sum()/gene_df.is_marker_memory_lambda.sum()*100,1)
round(top50genes.is_marker_naive_lambda.sum()/gene_df.is_marker_naive_lambda.sum()*100,1)

top100genes.is_marker_int_kappa.sum()
top100genes.is_marker_memory_kappa.sum()
top100genes.is_marker_naive_kappa.sum()

top100genes.is_marker_int_lambda.sum()
top100genes.is_marker_memory_lambda.sum()
top100genes.is_marker_naive_lambda.sum()


round(top100genes.is_marker_int_kappa.sum()/gene_df.is_marker_int_kappa.sum()*100,1)
round(top100genes.is_marker_memory_kappa.sum()/gene_df.is_marker_memory_kappa.sum()*100,1)
round(top100genes.is_marker_naive_kappa.sum()/gene_df.is_marker_naive_kappa.sum()*100,1)

round(top100genes.is_marker_int_lambda.sum()/gene_df.is_marker_int_lambda.sum()*100,1)
round(top100genes.is_marker_memory_lambda.sum()/gene_df.is_marker_memory_lambda.sum()*100,1)
round(top100genes.is_marker_naive_lambda.sum()/gene_df.is_marker_naive_lambda.sum()*100,1)



#prepare subsets of signatures:

sig1 = {
'B intermediate kappa_1':[x for x in markers_dict['B intermediate kappa'] if x not in top1[top1.is_marker_int_kappa]['gene'].tolist()],
'B memory kappa_1':[x for x in markers_dict['B memory kappa'] if x not in top1[top1.is_marker_memory_kappa]['gene'].tolist()],
'B naive kappa_1':[x for x in markers_dict['B naive kappa'] if x not in top1[top1.is_marker_naive_kappa]['gene'].tolist()],
'B intermediate lambda_1':[x for x in markers_dict['B intermediate lambda'] if x not in top1[top1.is_marker_int_lambda]['gene'].tolist()],
'B memory lambda_1':[x for x in markers_dict['B memory lambda'] if x not in top1[top1.is_marker_memory_lambda]['gene'].tolist()],
'B naive lambda_1':[x for x in markers_dict['B naive lambda'] if x not in top1[top1.is_marker_naive_lambda]['gene'].tolist()]
}

sig5 = {
'B intermediate kappa_5':[x for x in markers_dict['B intermediate kappa'] if x not in top5[top5.is_marker_int_kappa]['gene'].tolist()],
'B memory kappa_5':[x for x in markers_dict['B memory kappa'] if x not in top5[top5.is_marker_memory_kappa]['gene'].tolist()],
'B naive kappa_5':[x for x in markers_dict['B naive kappa'] if x not in top5[top5.is_marker_naive_kappa]['gene'].tolist()],
'B intermediate lambda_5':[x for x in markers_dict['B intermediate lambda'] if x not in top5[top5.is_marker_int_lambda]['gene'].tolist()],
'B memory lambda_5':[x for x in markers_dict['B memory lambda'] if x not in top5[top5.is_marker_memory_lambda]['gene'].tolist()],
'B naive lambda_5':[x for x in markers_dict['B naive lambda'] if x not in top5[top5.is_marker_naive_lambda]['gene'].tolist()]
}

sig50g = {
'B intermediate kappa_50g':[x for x in markers_dict['B intermediate kappa'] if x not in top50genes[top50genes.is_marker_int_kappa]['gene'].tolist()],
'B memory kappa_50g':[x for x in markers_dict['B memory kappa'] if x not in top50genes[top50genes.is_marker_memory_kappa]['gene'].tolist()],
'B naive kappa_50g':[x for x in markers_dict['B naive kappa'] if x not in top50genes[top50genes.is_marker_naive_kappa]['gene'].tolist()],
'B intermediate lambda_50g':[x for x in markers_dict['B intermediate lambda'] if x not in top50genes[top50genes.is_marker_int_lambda]['gene'].tolist()],
'B memory lambda_50g':[x for x in markers_dict['B memory lambda'] if x not in top50genes[top50genes.is_marker_memory_lambda]['gene'].tolist()],
'B naive lambda_50g':[x for x in markers_dict['B naive lambda'] if x not in top50genes[top50genes.is_marker_naive_lambda]['gene'].tolist()]
}

sig100g = {
'B intermediate kappa_100g':[x for x in markers_dict['B intermediate kappa'] if x not in top100genes[top100genes.is_marker_int_kappa]['gene'].tolist()],
'B memory kappa_100g':[x for x in markers_dict['B memory kappa'] if x not in top100genes[top100genes.is_marker_memory_kappa]['gene'].tolist()],
'B naive kappa_100g':[x for x in markers_dict['B naive kappa'] if x not in top100genes[top100genes.is_marker_naive_kappa]['gene'].tolist()],
'B intermediate lambda_100g':[x for x in markers_dict['B intermediate lambda'] if x not in top100genes[top100genes.is_marker_int_lambda]['gene'].tolist()],
'B memory lambda_100g':[x for x in markers_dict['B memory lambda'] if x not in top100genes[top100genes.is_marker_memory_lambda]['gene'].tolist()],
'B naive lambda_100g':[x for x in markers_dict['B naive lambda'] if x not in top100genes[top100genes.is_marker_naive_lambda]['gene'].tolist()]
}


#score cells with respective signatures:


import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, f1_score, jaccard_score
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import signaturescoring as ssc

sys.path.append('ANS_supplementary_information/data')

from constants import (DATASETS_WITH_ANNOTATIONS, BASE_PATH_RESULTS, SCORING_METHODS,
                       METHOD_WITH_GENE_POOL, VIOLIN_PLOT_CONFIG)

SCORING_METHODS[0] = {'scoring_method': 'adjusted_neighborhood_scoring', 
'sc_params': {'ctrl_size': 100, 'score_name': 'ANS', 'remove_genes_with_invalid_control_set':False}}
SCORING_METHODS = SCORING_METHODS[0:5]

def score_signatures_with_all_methods(adata, signatures, use_gene_pool: bool, verbose=False):
    if use_gene_pool:
        gene_pool = _get_gene_pool(adata, signatures)
    else:
        gene_pool = None
    added_cols = defaultdict(list)
    for sig_name, sig in signatures.items():
        for sc_method in SCORING_METHODS:
            scoring_method = sc_method['scoring_method']
            sc_params = sc_method['sc_params'].copy()
            if verbose:
                print(f"Scoring {sig_name} with {scoring_method}")
            col_name = sig_name.replace(' ', '-')
            prev_score_name = sc_params['score_name']
            sc_params['score_name'] = f"{col_name}_{sc_params['score_name']}_scores"
            if use_gene_pool and scoring_method in METHOD_WITH_GENE_POOL:
                sc_params['gene_pool'] = gene_pool
            ssc.score_signature(
                method=scoring_method,
                adata=adata,
                gene_list=sig,
                **sc_params
            )
            added_cols[prev_score_name].append(sc_params['score_name'])
    return added_cols, adata


score_signatures_with_all_methods(types, markers_dict, False)
score_signatures_with_all_methods(types, sig1, False)
score_signatures_with_all_methods(types, sig5, False)
score_signatures_with_all_methods(types, sig50g, False)
score_signatures_with_all_methods(types, sig100g, False)

types.obs.to_csv('b6_scores.csv')


style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy'
    #'Jasmine_LH',
    #'Jasmine_OR',
    #'UCell'
]

import matplotlib.pyplot as plt
import seaborn as sns
# Extract relevant columns
score_cols = [c for c in types.obs.columns if c.endswith("_scores")]


df_long = types.obs.melt(
    id_vars=["celltype.l2"],
    value_vars=score_cols,
    var_name="score_type",
    value_name="score"
)

df_long["tool"] = df_long["score_type"].str.extract(r'_(.*?)_scores')


import seaborn as sns
n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)

for ax, tool in zip(axes, style_order):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)

for ax, tool in zip(axes, ["1_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)
for ax, tool in zip(axes, ["5_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()



n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)
for ax, tool in zip(axes, ["50g_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


#for each method, assign label based on max score:
methods = ["ANS", "Seurat", "Seurat_AG", "Seurat_LVG", "Scanpy"]
groups = types.obs['celltype.l3'].unique().tolist()
groups = [x.replace(" ", "-") for x in groups]
#groups = ['Int', 'Memory', 'Naive']

for method in methods:
    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]
    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)
    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]



for method in ["1_" + x for x in methods]:
    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]
    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)
    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]

for method in ["5_" + x for x in methods]:
    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]
    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)
    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


for method in ["50g_" + x for x in methods]:
    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]
    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)
    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


for method in ["100g_" + x for x in methods]:
    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]
    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)
    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]

#calculate accuracy:
true_labels = types.obs["celltype.l3"]  
true_labels = [x.replace(" ", "-") for x in true_labels]

balanced_acc_results = {}

for method in methods:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc


for method in ["1_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

for method in ["5_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc
    
   
for method in ["50g_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

for method in ["100g_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc


# Convert to DataFrame for easy plotting
balanced_acc_df = pd.DataFrame.from_dict(balanced_acc_results, orient='index', columns=["balanced_accuracy"])
balanced_acc_df

balanced_acc_df.to_csv('b6_balanced_accuracy.csv')
