##
## The following notebook reproduces the experiment of the impact of signature gene filtering on score-based cell annotation. 
## To run the script please use preprocessed dataset available via "download_preprocessed_datasets.py".
##

# Run experiment on PBMC data composed of CD4 T cells: 
#CD4 CTL
#CD4 Naive 
#CD4 Proliferating 
#CD4 TCM
#CD4 TEM 
#Treg

import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

types = sc.read_h5ad('ANS_supplementary_information/reproduce_project/data/preprocessed/pp_pbmc_cd4_subtypes.h5ad') 


#markers from DEGs:
m = pd.read_csv("ANS_supplementary_information/reproduce_project/data/annotations/citeseq_pbmc/DE_by_celltype.csv",
sep=',')

markers_dict = {
'CD4 CTL' : m['Gene'][m['Cell Type'].isin(['CD4 CTL'])].unique().tolist(),
'CD4 Naive' : m['Gene'][m['Cell Type'].isin(['CD4 Naive'])].unique().tolist(),
'CD4 Proliferating' : m['Gene'][m['Cell Type'].isin(['CD4 Proliferating'])].unique().tolist(),
'CD4 TCM' : m['Gene'][m['Cell Type'].isin(['CD4 TCM_1','CD4 TCM_2'])].unique().tolist(),
'CD4 TEM' : m['Gene'][m['Cell Type'].isin(['CD4 TEM_4', 'CD4 TEM_3', 'CD4 TEM_2', 'CD4 TEM_1'])].unique().tolist(),
'Treg' : m['Gene'][m['Cell Type'].isin(['Treg Memory', 'Treg Naive'])].unique().tolist()
}


gene_means = np.array(types.X.mean(axis=0)).flatten()
gene_df = pd.DataFrame({
    'gene': types.var_names,
    'mean_expr': gene_means
})

gene_df['rank'] = gene_df['mean_expr'].rank(ascending=False, method='dense')
gene_df['percentile'] = 100 * (1 - gene_df['rank'] / gene_df['rank'].max())
gene_df['is_marker_CD4 CTL'] = gene_df['gene'].isin(markers_dict['CD4 CTL'])
gene_df['is_marker_CD4 Naive'] = gene_df['gene'].isin(markers_dict['CD4 Naive'])
gene_df['is_marker_CD4 Proliferating'] = gene_df['gene'].isin(markers_dict['CD4 Proliferating'])
gene_df['is_marker_CD4 TCM'] = gene_df['gene'].isin(markers_dict['CD4 TCM'])
gene_df['is_marker_CD4 TEM'] = gene_df['gene'].isin(markers_dict['CD4 TEM'])
gene_df['is_marker_Treg'] = gene_df['gene'].isin(markers_dict['Treg'])


gene_df['expr_bin'] = pd.qcut(gene_df['mean_expr'], q=25, labels=False, duplicates='drop')

gene_df['expr_bin50'] = pd.qcut(gene_df['mean_expr'], q=50, labels=False, duplicates='drop')


#add annotation on top x genes:
n_genes = len(gene_df)

cutoffs = {
    'Top bin (25 bins)': sum(gene_df['expr_bin']==24),
    'Top bin (50 bins)': sum(gene_df['expr_bin50']==49),
    'Top 50 genes': 50,
    'Top 100 genes': 100}

colors = plt.get_cmap('tab10').colors

celltypes = [
    ('CD4 CTL cell marker gene positions among ranked genes', 'is_marker_CD4 CTL', colors[0]),
    ('CD4 Naive cell marker gene positions among ranked genes', 'is_marker_CD4 Naive', colors[1]),
    ('Treg cell marker gene positions among ranked genes', 'is_marker_Treg', colors[2]),
    ('CD4 Proliferating cell marker gene positions among ranked genes', 'is_marker_CD4 Proliferating', colors[3]),
    ('CD4 TCM cell marker gene positions among ranked genes', 'is_marker_CD4 TCM', colors[4]),
    ('CD4 TEM cell marker gene positions among ranked genes', 'is_marker_CD4 TEM', colors[5])
]

# Create a single figure with 3 vertical subplots
fig, axes = plt.subplots(nrows=6, ncols=1, figsize=(7, 5), sharex=True)

for ax, (title, marker_col, color) in zip(axes, celltypes):
    for _, row in gene_df.iterrows():
        line_color = color if row[marker_col] else 'white'
        ax.vlines(row['rank'], ymin=0, ymax=1, color=line_color, alpha=0.7, linewidth=0.8)
    
    # Add percentile cutoffs
    for label, x in cutoffs.items():
        ax.axvline(x=x, color='black', linestyle='--', alpha=0.6)
        ax.text(
            x, 1.02, label, rotation=90, va='bottom', ha='center',
            fontsize=8, fontweight='bold'
        )
    
    ax.set_xlim(0, n_genes)
    ax.set_ylim(0, 1.1)
    ax.set_title(title, fontsize=11)
    ax.set_yticks([])

axes[-1].set_xlabel('Gene rank (by mean expression, highest to lowest)')
fig.tight_layout()
plt.savefig("bcd4_marker_gene_rank_positions.svg", format='svg', bbox_inches='tight') 
plt.show()


top_n = 2000
gene_df_top = gene_df[gene_df['rank'] <= top_n]

fig, axes = plt.subplots(nrows=6, ncols=1, figsize=(7, 5), sharex=True)

for ax, (title, marker_col, color) in zip(
    axes, 
    [
    ('CD4 CTL cell marker gene positions among ranked genes', 'is_marker_CD4 CTL', colors[0]),
    ('CD4 Naive cell marker gene positions among ranked genes', 'is_marker_CD4 Naive', colors[1]),
    ('Treg cell marker gene positions among ranked genes', 'is_marker_Treg', colors[2]),
    ('CD4 Proliferating cell marker gene positions among ranked genes', 'is_marker_CD4 Proliferating', colors[3]),
    ('CD4 TCM cell marker gene positions among ranked genes', 'is_marker_CD4 TCM', colors[4]),
    ('CD4 TEM cell marker gene positions among ranked genes', 'is_marker_CD4 TEM', colors[5])
    ]
):
    for _, row in gene_df_top.iterrows():
        line_color = color if row[marker_col] else 'white'
        ax.vlines(row['rank'], ymin=0, ymax=1, color=line_color, alpha=0.7, linewidth=0.8)
    
    # Add percentile cutoffs
    for label, x in cutoffs.items():
        if x <= top_n:  # only draw cutoff if it's within the top N range
            ax.axvline(x=x, color='black', linestyle='--', alpha=0.6)
            ax.text(
                x, 1.02, label, rotation=90, va='bottom', ha='center',
                fontsize=8, fontweight='bold'
            )
    
    ax.set_xlim(0, top_n)
    ax.set_ylim(0, 1.1)
    ax.set_title(title, fontsize=11)
    ax.set_yticks([])

axes[-1].set_xlabel('Gene rank (by mean expression, highest to lowest)')
plt.tight_layout()
plt.savefig("cd4_marker_gene_rank_positions_top2000.svg", format='svg', bbox_inches='tight') 
plt.show()


#calcualte number of removed markers:
top1 = gene_df[gene_df['expr_bin'] == 24]
top5 = gene_df[gene_df['expr_bin50'] ==49]
top50genes = gene_df[gene_df['rank'] <= 50]
top100genes = gene_df[gene_df['rank'] <= 100]

top1['is_marker_CD4 CTL'].sum()
top1['is_marker_CD4 Naive'].sum()
top1['is_marker_Treg'].sum()

top1['is_marker_CD4 Proliferating'].sum()
top1['is_marker_CD4 TCM'].sum()
top1['is_marker_CD4 TEM'].sum()


round(top1['is_marker_CD4 CTL'].sum()/gene_df['is_marker_CD4 CTL'].sum()*100,1)
round(top1['is_marker_CD4 Naive'].sum()/gene_df['is_marker_CD4 Naive'].sum()*100,1)
round(top1['is_marker_Treg'].sum()/gene_df['is_marker_Treg'].sum()*100,1)

round(top1['is_marker_CD4 Proliferating'].sum()/gene_df['is_marker_CD4 Proliferating'].sum()*100,1)
round(top1['is_marker_CD4 TCM'].sum()/gene_df['is_marker_CD4 TCM'].sum()*100,1)
round(top1['is_marker_CD4 TEM'].sum()/gene_df['is_marker_CD4 TEM'].sum()*100,1)

top5['is_marker_CD4 CTL'].sum()
top5['is_marker_CD4 Naive'].sum()
top5['is_marker_Treg'].sum()

top5['is_marker_CD4 Proliferating'].sum()
top5['is_marker_CD4 TCM'].sum()
top5['is_marker_CD4 TEM'].sum()


round(top5['is_marker_CD4 CTL'].sum()/gene_df['is_marker_CD4 CTL'].sum()*100,1)
round(top5['is_marker_CD4 Naive'].sum()/gene_df['is_marker_CD4 Naive'].sum()*100,1)
round(top5['is_marker_Treg'].sum()/gene_df['is_marker_Treg'].sum()*100,1)

round(top5['is_marker_CD4 Proliferating'].sum()/gene_df['is_marker_CD4 Proliferating'].sum()*100,1)
round(top5['is_marker_CD4 TCM'].sum()/gene_df['is_marker_CD4 TCM'].sum()*100,1)
round(top5['is_marker_CD4 TEM'].sum()/gene_df['is_marker_CD4 TEM'].sum()*100,1)

top50genes['is_marker_CD4 CTL'].sum()
top50genes['is_marker_CD4 Naive'].sum()
top50genes['is_marker_Treg'].sum()

top50genes['is_marker_CD4 Proliferating'].sum()
top50genes['is_marker_CD4 TCM'].sum()
top50genes['is_marker_CD4 TEM'].sum()


round(top50genes['is_marker_CD4 CTL'].sum()/gene_df['is_marker_CD4 CTL'].sum()*100,1)
round(top50genes['is_marker_CD4 Naive'].sum()/gene_df['is_marker_CD4 Naive'].sum()*100,1)
round(top50genes['is_marker_Treg'].sum()/gene_df['is_marker_Treg'].sum()*100,1)

round(top50genes['is_marker_CD4 Proliferating'].sum()/gene_df['is_marker_CD4 Proliferating'].sum()*100,1)
round(top50genes['is_marker_CD4 TCM'].sum()/gene_df['is_marker_CD4 TCM'].sum()*100,1)
round(top50genes['is_marker_CD4 TEM'].sum()/gene_df['is_marker_CD4 TEM'].sum()*100,1)

top100genes['is_marker_CD4 CTL'].sum()
top100genes['is_marker_CD4 Naive'].sum()
top100genes['is_marker_Treg'].sum()

top100genes['is_marker_CD4 Proliferating'].sum()
top100genes['is_marker_CD4 TCM'].sum()
top100genes['is_marker_CD4 TEM'].sum()


round(top100genes['is_marker_CD4 CTL'].sum()/gene_df['is_marker_CD4 CTL'].sum()*100,1)
round(top100genes['is_marker_CD4 Naive'].sum()/gene_df['is_marker_CD4 Naive'].sum()*100,1)
round(top100genes['is_marker_Treg'].sum()/gene_df['is_marker_Treg'].sum()*100,1)

round(top100genes['is_marker_CD4 Proliferating'].sum()/gene_df['is_marker_CD4 Proliferating'].sum()*100,1)
round(top100genes['is_marker_CD4 TCM'].sum()/gene_df['is_marker_CD4 TCM'].sum()*100,1)
round(top100genes['is_marker_CD4 TEM'].sum()/gene_df['is_marker_CD4 TEM'].sum()*100,1)


#prepare subsets of signatures:

sig1 = {
'CD4 CTL_1':[x for x in markers_dict['CD4 CTL'] if x not in top1[top1['is_marker_CD4 CTL']]['gene'].tolist()],
'CD4 Naive_1':[x for x in markers_dict['CD4 Naive'] if x not in top1[top1['is_marker_CD4 Naive']]['gene'].tolist()],
'Treg_1':[x for x in markers_dict['Treg'] if x not in top1[top1['is_marker_Treg']]['gene'].tolist()],
'CD4 Proliferating_1':[x for x in markers_dict['CD4 Proliferating'] if x not in top1[top1['is_marker_CD4 Proliferating']]['gene'].tolist()],
'CD4 TCM_1':[x for x in markers_dict['CD4 TCM'] if x not in top1[top1['is_marker_CD4 TCM']]['gene'].tolist()],
'CD4 TEM_1':[x for x in markers_dict['CD4 TEM'] if x not in top1[top1['is_marker_CD4 TEM']]['gene'].tolist()]
}

sig5 = {
'CD4 CTL_5':[x for x in markers_dict['CD4 CTL'] if x not in top5[top5['is_marker_CD4 CTL']]['gene'].tolist()],
'CD4 Naive_5':[x for x in markers_dict['CD4 Naive'] if x not in top5[top5['is_marker_CD4 Naive']]['gene'].tolist()],
'Treg_5':[x for x in markers_dict['Treg'] if x not in top5[top5['is_marker_Treg']]['gene'].tolist()],
'CD4 Proliferating_5':[x for x in markers_dict['CD4 Proliferating'] if x not in top5[top5['is_marker_CD4 Proliferating']]['gene'].tolist()],
'CD4 TCM_5':[x for x in markers_dict['CD4 TCM'] if x not in top5[top5['is_marker_CD4 TCM']]['gene'].tolist()],
'CD4 TEM_5':[x for x in markers_dict['CD4 TEM'] if x not in top5[top5['is_marker_CD4 TEM']]['gene'].tolist()]
}

sig50g = {
'CD4 CTL_50g':[x for x in markers_dict['CD4 CTL'] if x not in top50genes[top50genes['is_marker_CD4 CTL']]['gene'].tolist()],
'CD4 Naive_50g':[x for x in markers_dict['CD4 Naive'] if x not in top50genes[top50genes['is_marker_CD4 Naive']]['gene'].tolist()],
'Treg_50g':[x for x in markers_dict['Treg'] if x not in top50genes[top50genes['is_marker_Treg']]['gene'].tolist()],
'CD4 Proliferating_50g':[x for x in markers_dict['CD4 Proliferating'] if x not in top50genes[top50genes['is_marker_CD4 Proliferating']]['gene'].tolist()],
'CD4 TCM_50g':[x for x in markers_dict['CD4 TCM'] if x not in top50genes[top50genes['is_marker_CD4 TCM']]['gene'].tolist()],
'CD4 TEM_50g':[x for x in markers_dict['CD4 TEM'] if x not in top50genes[top50genes['is_marker_CD4 TEM']]['gene'].tolist()]
}

sig100g = {
'CD4 CTL_100g':[x for x in markers_dict['CD4 CTL'] if x not in top100genes[top100genes['is_marker_CD4 CTL']]['gene'].tolist()],
'CD4 Naive_100g':[x for x in markers_dict['CD4 Naive'] if x not in top100genes[top100genes['is_marker_CD4 Naive']]['gene'].tolist()],
'Treg_100g':[x for x in markers_dict['Treg'] if x not in top100genes[top100genes['is_marker_Treg']]['gene'].tolist()],
'CD4 Proliferating_100g':[x for x in markers_dict['CD4 Proliferating'] if x not in top100genes[top100genes['is_marker_CD4 Proliferating']]['gene'].tolist()],
'CD4 TCM_100g':[x for x in markers_dict['CD4 TCM'] if x not in top100genes[top100genes['is_marker_CD4 TCM']]['gene'].tolist()],
'CD4 TEM_100g':[x for x in markers_dict['CD4 TEM'] if x not in top100genes[top100genes['is_marker_CD4 TEM']]['gene'].tolist()]
}


#score cells with respective signatures:

import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, f1_score, jaccard_score
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import signaturescoring as ssc

sys.path.append('ANS_supplementary_information/data')

from constants import (DATASETS_WITH_ANNOTATIONS, BASE_PATH_RESULTS, SCORING_METHODS,
                       METHOD_WITH_GENE_POOL, VIOLIN_PLOT_CONFIG)

SCORING_METHODS[0] = {'scoring_method': 'adjusted_neighborhood_scoring', 
'sc_params': {'ctrl_size': 100, 'score_name': 'ANS', 'remove_genes_with_invalid_control_set':False}}
SCORING_METHODS = SCORING_METHODS[0:5]

def score_signatures_with_all_methods(adata, signatures, use_gene_pool: bool, verbose=False):
    if use_gene_pool:
        gene_pool = _get_gene_pool(adata, signatures)
    else:
        gene_pool = None
    added_cols = defaultdict(list)
    for sig_name, sig in signatures.items():
        for sc_method in SCORING_METHODS:
            scoring_method = sc_method['scoring_method']
            sc_params = sc_method['sc_params'].copy()
            if verbose:
                print(f"Scoring {sig_name} with {scoring_method}")
            col_name = sig_name.replace(' ', '-')
            prev_score_name = sc_params['score_name']
            sc_params['score_name'] = f"{col_name}_{sc_params['score_name']}_scores"
            if use_gene_pool and scoring_method in METHOD_WITH_GENE_POOL:
                sc_params['gene_pool'] = gene_pool
            ssc.score_signature(
                method=scoring_method,
                adata=adata,
                gene_list=sig,
                **sc_params
            )
            added_cols[prev_score_name].append(sc_params['score_name'])
    return added_cols, adata

               

score_signatures_with_all_methods(types, markers_dict, False)
score_signatures_with_all_methods(types, sig1, False)
score_signatures_with_all_methods(types, sig5, False)
score_signatures_with_all_methods(types, sig50g, False)
score_signatures_with_all_methods(types, sig100g, False)

types.obs.to_csv('cd4_scores.csv')


style_order = [
    'ANS',
    'Seurat',
    'Seurat_AG',
    'Seurat_LVG',
    'Scanpy'
    #'Jasmine_LH',
    #'Jasmine_OR',
    #'UCell'
]

import matplotlib.pyplot as plt
import seaborn as sns
# Extract relevant columns
score_cols = [c for c in types.obs.columns if c.endswith("_scores")]


df_long = types.obs.melt(
    id_vars=["celltype.l2"],
    value_vars=score_cols,
    var_name="score_type",
    value_name="score"
)


df_long["tool"] = df_long["score_type"].str.extract(r'_(.*?)_scores')


import seaborn as sns
n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)

for ax, tool in zip(axes, style_order):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)

for ax, tool in zip(axes, ["1_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)
for ax, tool in zip(axes, ["5_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


n_tools = len(style_order)
fig, axes = plt.subplots(n_tools, 1, figsize=(10, 5 * n_tools), sharex=True)
for ax, tool in zip(axes, ["50g_" + x for x in style_order]):
    subset = df_long[df_long["tool"] == tool]
    sns.violinplot(
        data=subset,
        x="celltype.l2",
        y="score",
        hue="score_type",
        inner="box",
        cut=0,
        ax=ax,
    )
    ax.set_title(f"{tool} scores per group", fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


#for each method, assign label based on max score:
methods = ["ANS", "Seurat", "Seurat_AG", "Seurat_LVG", "Scanpy"]
groups = types.obs['celltype.l2'].unique().tolist()
groups = [x.replace(" ", "-") for x in groups]
#groups = ['Int', 'Memory', 'Naive']

for method in methods:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]



for method in ["1_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]

for method in ["5_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


for method in ["50g_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]


for method in ["100g_" + x for x in methods]:

    cols = [f"{g}_{method}_scores" for g in groups if f"{g}_{method}_scores" in types.obs.columns]

    types.obs[f"{method}_max_group"] = types.obs[cols].idxmax(axis=1)

    types.obs[f"{method}_max_group"] = types.obs[f"{method}_max_group"].str.extract(fr"^(.*?)_{method}_scores")[0]

#calculate accuracy:
true_labels = types.obs["celltype.l2"]  # or whichever is your true label column
true_labels = [x.replace(" ", "-") for x in true_labels]

balanced_acc_results = {}

for method in methods:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc


for method in ["1_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

for method in ["5_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc
    
   
for method in ["50g_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc

for method in ["100g_" + x for x in methods]:
    pred_labels = types.obs[f"{method}_max_group"]
    acc = balanced_accuracy_score(true_labels, pred_labels)
    balanced_acc_results[method] = acc


# Convert to DataFrame for easy plotting
balanced_acc_df = pd.DataFrame.from_dict(balanced_acc_results, orient='index', columns=["balanced_accuracy"])
balanced_acc_df

balanced_acc_df.to_csv('cd4_balanced_accuracy.csv')
