#!/usr/bin/Rscript

# Load required packages
library(readxl)
library(utils)
library(ggplot2)
library(gprofiler2)

### Code to perform recalibration
#' @param df_de data.frame: contains differential expression per gene. Row names
#' are Ensembl gene ids and one column contains the log fold change.
#' @param vg data.frame: contains the V^G estimate per gene (rows) for one
#' or multiple tissues (columns). Default = [vg_h], the haplotype expression
#' based V^G estimates calculated from GTEx v8.
#' @param tissue char: V^G tissue that is recalibrated against. Default V^G are
#' generated for the GTEx tissues (in GTEx 6-letter code) or MEAN (weighted
#' harmonic mean across tissues). Default = "MEAN".
#' @param remove_NA bool: whether genes for which no V^G estimate exist should
#' be removed from the final data.frame. Default = FALSE.
#' @param sort_by char: sort result data.frame by one particular column.
#' Default = NA
#' @param add_vg bool: adds the V^G estimates used for recalibration to the
#' result data.frame. Default = FALSE.
#' @param variance_offset numeric: add an offset to all V^G estimates. Default
#' = 0.
#' @param FC_col_name char: column of df_de that contains the log fold change
#' values that are recalibrated. Default = "log2FoldChange".
#' @returns A modified data.frame with added recalibrated fold changes.
#' @examples
#' df <- data.frame(
#'   log2FoldChange = c(-2.95, 1.03, 4.34),
#'   padj = c(0, 1e-2, 1e-5),
#'   row.names = c("ENSG00000000003", "ENSG00000000419", "ENSG00000000457")
#' )
#'
#' recalibrateFoldChange(df)
#' recalibrateFoldChange(df, tissue = "NERVET", vg = vg_hi)
#' recalibrateFoldChange(df, sort_by = "padj", add_vg = TRUE)
#' @export
recalibrateFoldChange <- function(df_de, vg = "vg_h", tissue = "MEAN", remove_NA = FALSE,
                                  sort_by = NA, add_vg = FALSE, variance_offset = 0,
                                  FC_col_name = "log2FoldChange") {
  if (is.character(vg)) {
    if (vg == "vg_h") {
      vg <- get("vg_h") # via lazy-loading of the attached dataset
    } else if (vg == "vg_hi") {
      vg <- get("vg_hi") # via lazy-loading of the attached dataset
    } else if (vg == "vg_ae") {
      vg <- get("vg_ae") # via lazy-loading of the attached dataset
    } else {
      stop("Unknown VG. Currently only 'vg_h', 'vg_hi' & 'vg_ae' are supported")
    }
  }

  if (!is.element(tissue, colnames(vg))) {
    stop("Unknown tissue. You have to specify one GTEx tissue in 6-letter code or use 'MEAN'.")
  }

  # genes = row.names(vg)[which(!is.na(vg[,tissue]))]
  # vg_tissue = vg[which(!is.na(vg[,tissue])), tissue]

  vg_select <- vg[row.names(df_de), tissue]
  sdg_select <- sqrt(vg_select + variance_offset)

  df_de$recalibratedFC <- df_de[, FC_col_name] / sdg_select

  if (add_vg) {
    df_de$vg <- vg_select
  }

  if (remove_NA) {
    df_de <- df_de[which(!is.na(df_de$recalibratedFC)), ]
  }

  if (is.element(sort_by, colnames(df_de))) {
    df_de <- df_de[order(df_de[, sort_by]), ]
  }

  return(df_de)
}

### load V^G_H from supplementary table S3
file_tblS3 = 'Supplemental_Table_S3.xlsx'
vgh_tbl <- read_excel(file_tblS3, sheet = "Supplemental Table S3")
vgh <- as.data.frame(vgh_tbl)
rownames(vgh) <- vgh$gene_id
vgh$gene_id <- NULL
vgh[] <- lapply(vgh, function(x) {
  suppressWarnings(as.numeric(as.character(x)))
})

### Analysis example
# Download the file from the URL and save it to the temporary file
url <- "https://zenodo.org/records/839011/files/naive_vs_IFNg_DESeq2_fold_change.txt.gz"
temp_file_path <- tempfile(fileext = ".txt.gz")
download.file(url, destfile = temp_file_path, mode = "wb")
df <- read.table(temp_file_path, header = TRUE, row.names = "gene_id")
# Delete the temporary file
unlink(temp_file_path)

df <- recalibrateFoldChange(df, remove_NA = TRUE)
ggplot(data = df, aes(y = recalibratedFC, x = log2FoldChange)) +
  theme_bw() +
  geom_point(alpha = 0.5, size = 1.2) +
  ylab(bquote("recalibrated fold change (log FC /" ~ sqrt(V^G) ~ ")")) +
  xlab("nominal fold change (log FC)")

### GO term enrichment comparison
top_xgenes <- 2000

# select a background set
bg_genelist <- intersect(rownames(df), rownames(vg_h))

# only select significant genes
df_sig <- subset(df, padj < 0.05)

# select top genes by nominal and recalibrated FC
nfc_order <- order(abs(df_sig$log2FoldChange), decreasing = T)
rfc_order <- order(abs(df_sig$recalibratedFC), decreasing = T)
nfc_genes <- rownames(df_sig)[nfc_order[1:top_xgenes]]
rfc_genes <- rownames(df_sig)[rfc_order[1:top_xgenes]]

# perform GO enrichment
enrichment <- function(gene.list, bglist) {
  gostres <- gost(
    query = gene.list, organism = "hsapiens", ordered_query = FALSE,
    multi_query = FALSE, significant = TRUE, exclude_iea = FALSE,
    measure_underrepresentation = FALSE, evcodes = FALSE,
    user_threshold = 0.05, correction_method = "g_SCS",
    domain_scope = "custom", custom_bg = bglist,
    numeric_ns = "", sources = "GO", as_short_link = FALSE
  )
  gores <- as.data.frame(gostres$result[, c(3, 11)])
  rownames(gores) <- gores$term_name
  return(gores)
}
gores <- enrichment(nfc_genes, bg_genelist)
rfc_gores <- enrichment(rfc_genes, bg_genelist)

# join enrichments
gores$p_value_rfc <- 1 # impute all non-hits from rfc
for (term in row.names(rfc_gores)) {
  if (!term %in% row.names(gores)) { # add missing GO terms to nfc
    gores[term, "term_name"] <- term
    gores[term, "p_value"] <- 1
  }
  gores[term, "p_value_rfc"] <- rfc_gores[term, "p_value"]
}
rownames(gores) <- gores$term_name

# match GO terms based on strings
termAssociations <- c("regulation", "response")
gores$association <- "other"
for (term in row.names(gores)) {
  for (match in termAssociations) {
    if (grepl(match, term, fixed = T)) {
      if (gores[term, "association"] == "other") {
        gores[term, "association"] <- match
      } else {
        gores[term, "association"] <- paste(gores[term, "association"],
                                            match, sep = " & ")
      }
    }
  }
}

# plot GO term associations by method of gene selection
plot.colors <- c(
  "regulation" = "#0000FF", "regulation & response" = "#990099",
  "response" = "#FF0000", "other" = "#999999"
)
ggplot(gores, aes(-log10(p_value), -log10(p_value_rfc), color = association)) +
  geom_point() +
  geom_abline(intercept = 0, slope = 1, linetype = 2) +
  theme_minimal() +
  scale_color_manual(name = "GO Term association", values = plot.colors) +
  scale_x_continuous(name = "selected by nominal FC" ~ -log[10] ~ "p-value") +
  scale_y_continuous(name = "selected by recalibrated FC" ~ -log[10] ~ "p-value") +
  ggtitle(paste("GO enrichment of the top", top_xgenes, "genes")) +
  theme(plot.title = element_text(hjust = 0.5, face = "bold"))