#!/usr/bin/env Rscript

library(data.table)
library(dplyr)
library(argparse)
library(car)
library(mediation)

# Parse Arguments
parser <- ArgumentParser(description = 'Run Causal Mediation Analysis with different covariates in mediator vs outcome models')

parser$add_argument("--dependent", required = TRUE, 
                    help = "Dependent variable name (Y)")
parser$add_argument("--independent", required = TRUE, 
                    help = "Independent variable name (X)")
parser$add_argument("--mediator", required = TRUE, 
                    help = "Mediator variable name (M)")
parser$add_argument("--data_file", required = TRUE, 
                    help = "Path to the input data file (TSV)")
parser$add_argument("--cov_file", required = TRUE, 
                    help = "Path to the covariate file (TSV)")
parser$add_argument("--output_file", required = TRUE, 
                    help = "Path to save the output results (prefix)")
parser$add_argument("--normalize", action="store_true", default=FALSE, 
                    help = "Use inverse normal transform on X, M, Y")
parser$add_argument("--center", action="store_true", default=FALSE, 
                    help = "Center X, M, Y around the mean (after optional normalization)")
parser$add_argument("--boot_sims", type="integer", default=1000,
                    help = "Number of bootstrap simulations for mediate() (default = 1000)")

args <- parser$parse_args()

# Extract arguments
dependent_var   <- args$dependent      # Y
independent_var <- args$independent    # X
mediator_name   <- args$mediator      # M
data_file       <- args$data_file
cov_file        <- args$cov_file
output_file     <- args$output_file
normalize_flag  <- args$normalize
center_flag     <- args$center
boot_sims       <- args$boot_sims

# Load and Merge Data
data_raw <- fread(data_file, header = TRUE, sep = "\t")
required_cols <- unique(c("Sample", dependent_var, independent_var, mediator_name))

# Keep only required columns
data_raw <- data_raw[, required_cols, with = FALSE]
# Convert numeric columns
data_raw <- data_raw %>% mutate(across(-Sample, ~as.numeric(as.character(.))))

# Read covariate file
covariate_data <- fread(cov_file, header = TRUE, sep = "\t")

# Merge and remove missing values
combined_data <- merge(data_raw, covariate_data, by = "Sample") %>% na.omit()

# Variable Transformation Function
transform_variable <- function(var, normalize=FALSE, center=FALSE) {
  if(normalize) {
    # rank-based inverse normal
    var <- qnorm((rank(var, ties.method="random") - 0.5) / length(var))
  }
  if(center) {
    var <- scale(var, center=TRUE, scale=FALSE)
  }
  return(as.numeric(var))
}

# Optionally transform X, M, Y
combined_data[[dependent_var]]   <- transform_variable(combined_data[[dependent_var]], 
                                                       normalize_flag, center_flag)
combined_data[[mediator_name]]   <- transform_variable(combined_data[[mediator_name]], 
                                                       normalize_flag, center_flag)
combined_data[[independent_var]] <- transform_variable(combined_data[[independent_var]], 
                                                       normalize_flag, center_flag)

# Covariates
exclude_vars <- c("Sample", dependent_var, independent_var, mediator_name)
all_covariates <- setdiff(names(combined_data), exclude_vars)
pc_covariates  <- grep("^PC", all_covariates, value=TRUE)

# Remove Highly Collinear Covariates (Optional)
remove_collinear_vars <- function(data, vars, thresh=5) {
  repeat {
    if(length(vars) < 2) break
    fml <- as.formula(paste("1:nrow(data) ~", paste(vars, collapse = " + ")))
    vif_model <- lm(fml, data = data)
    vif_values <- vif(vif_model)
    if (max(vif_values) < thresh) break
    var_to_remove <- names(vif_values)[which.max(vif_values)]
    cat("Removing due to high VIF:", var_to_remove, "(VIF=", max(vif_values), ")\n")
    vars <- setdiff(vars, var_to_remove)
  }
  return(vars)
}

# Perform VIF check separately for each model
mediator_covs <- remove_collinear_vars(combined_data, all_covariates, thresh=5)
outcome_covs  <- remove_collinear_vars(combined_data, pc_covariates, thresh=5)
quote_var <- function(v) { paste0("`", v, "`") }

X_quoted <- quote_var(independent_var) 
Y_quoted <- quote_var(dependent_var)
M_quoted <- quote_var(mediator_name)

mediator_covs_quoted <- sapply(mediator_covs, quote_var)
outcome_covs_quoted  <- sapply(outcome_covs,  quote_var)

# Build Formulas for Mediator and Outcome Models
# Mediator model
if(length(mediator_covs_quoted) > 0){
  mediatorFormula <- paste0(M_quoted, " ~ ", X_quoted, " + ", 
                            paste(mediator_covs_quoted, collapse=" + "))
} else {
  mediatorFormula <- paste0(M_quoted, " ~ ", X_quoted)
}

# Outcome model
if(length(outcome_covs_quoted) > 0){
  outcomeFormula <- paste0(Y_quoted, " ~ ", X_quoted, " + ", M_quoted, " + ",
                           paste(outcome_covs_quoted, collapse=" + "))
} else {
  outcomeFormula <- paste0(Y_quoted, " ~ ", X_quoted, " + ", M_quoted)
}

cat("Mediator model:\n", mediatorFormula, "\n")
cat("Outcome model:\n", outcomeFormula, "\n\n")

# Fit Mediator and Outcome Models
model.m <- lm(as.formula(mediatorFormula), data = combined_data)
model.y <- lm(as.formula(outcomeFormula),  data = combined_data)

# Run Causal Mediation Analysis
fit.med <- mediate(model.m, model.y, 
                   treat    = independent_var, 
                   mediator = mediator_name,
                   sims     = boot_sims)

# Output Results
res_file <- paste0(output_file, "_CausalMediation.txt")
sink(res_file)
cat("===== Causal Mediation Analysis =====\n")
cat("Mediator formula:\n", mediatorFormula, "\n")
cat("Outcome formula:\n", outcomeFormula, "\n\n")
summary(fit.med)
sink()

cat("\nSummary of mediate() result:\n")
summary(fit.med)

cat("\nCausal mediation analysis result saved to:", res_file, "\n")
