import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import linregress

# Load the data
input_file = '/scratch/ucgd/lustre-work/marth/u0829237/btk/for_paper/VAF_comparison/vafs_for_plot.csv'
data_for_plot = pd.read_csv(input_file)

# Calculate the correlation coefficient and p-value
correlation_coefficient, p_value = data_for_plot['WES_VAFs'].corr(data_for_plot['scRNA_VAFs']), linregress(data_for_plot['WES_VAFs'], data_for_plot['scRNA_VAFs']).pvalue

# Calculate the linear regression
slope, intercept, r_value, p_value, std_err = linregress(data_for_plot['WES_VAFs'], data_for_plot['scRNA_VAFs'])

# Print the results
print(f"Correlation Coefficient (r): {correlation_coefficient:.2f}")
print(f"R-squared (R^2): {r_value**2:.2f}")
print(f"P-value: {p_value:.2e}")

# Create the bivariate plot with the regression line
plt.figure(figsize=(10, 8.5))
sns.scatterplot(x='WES_VAFs', y='scRNA_VAFs', data=data_for_plot, hue='Patient', palette='tab10', alpha=0.85, s=60)
sns.regplot(x='WES_VAFs', y='scRNA_VAFs', data=data_for_plot, scatter=False, color='red')
#plt.title(f'\nCorrelation Coefficient: {correlation_coefficient:.2f}, R-squared: {r_value**2:.2f}', size=18)
plt.xlabel(f'Whole-Exome Sequencing VAF\nCorrelation Coefficient: {correlation_coefficient:.2f}, R-squared: {r_value**2:.2f}', size=18)
plt.ylabel('MAS-Seq VAF', size=18)
plt.xticks(size=17)
plt.yticks(size=17)
plt.grid(True)
plt.savefig('/scratch/ucgd/lustre-work/marth/u0829237/btk/for_paper/VAF_comparison/VAFs_bivariate_plot.png')