source("src/tdd_analysis_load.R")

tuneplot <- function(x) {
  ggplot(x) +
    scale_y_log10() +
    annotation_logticks(sides = "l") +
    theme_bw()
}

rmse <- function(error) {
  sqrt(mean(error ^ 2))
}

tuneplot_model_construction <- function(name, outdir) {
  load(file = str_c(outdir, "bt_train_", name, ".Rdata"))
  print(fit$bestTune)
  print(tuneplot(fit))
}

fit_bt_model <- function(data, factor_rm, transcript_selection, grid, outdir, name, cpus = 10) {
  fit <- tdd_index %>%
      dplyr::select(-all_of(factor_rm)) %>%
      dplyr::filter(transcript_selection) %>%
      train(
        tdd_index ~ .,
        data = .,
        method = "xgbTree",
        trControl = trainControl(method = "cv",
          number = 3,
          allowParallel = F),
        tuneGrid = grid,
        verbose = T,
        nthread = cpus
      )
  save(fit, grid,
        file = str_c(outdir, "bt_train_", name, ".Rdata"))
  print(fit$bestTune)
  return(fit)
}

model_construction <- function(tdd_index, factor_rm, transcript_selection, outdir, cpus = 10){
  if (!dir.exists(outdir)) {
    dir.create(outdir, recursive = T)
  }
  # fit model with base parameter
  bt_train_base <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = 100,
      max_depth = 6,
      eta = 0.3,
      gamma = 0,
      colsample_bytree = 1,
      min_child_weight = 1,
      subsample = 1
    ),
    outdir = outdir,
    name = "base",
    cpus = cpus
  )
  # grid search for eta and depth
  bt_train_r1 <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = seq(from = 200, to = 1000, by = 50),
      eta = c(0.025, 0.05, 0.1, 0.3, 0.4, 0.5, 0.6),
      max_depth = 1:20,
      gamma = 0,
      colsample_bytree = 1,
      min_child_weight = 1,
      subsample = 1
    ),
    outdir = outdir,
    name = "r1",
    cpus = cpus
  )
  # grid search for Maximum Depth and Minimum Child Weight
  bt_train_r2 <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = seq(from = 100, to = 3000, by = 100),
      eta = bt_train_r1$bestTune$eta,
      max_depth = bt_train_r1$bestTune$max_depth,
      gamma = 0,
      colsample_bytree = 1,
      min_child_weight = 1:10,
      subsample = 1
    ),
    outdir = outdir,
    name = "r2",
    cpus = cpus
  )
  # grid search for Column and Row Sampling
  bt_train_r3 <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = seq(from = 50, to = 3000, by = 50),
      eta = bt_train_r1$bestTune$eta,
      max_depth = bt_train_r2$bestTune$max_depth,
      gamma = 0,
      colsample_bytree = c(0.4, 0.6, 0.8, 1.0),
      min_child_weight = bt_train_r2$bestTune$min_child_weight,
      subsample = c(0.5, 0.75, 1.0)
    ),
    outdir = outdir,
    name = "r3",
    cpus = cpus
  )
  # grid search Gamma
  bt_train_r4 <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = seq(from = 50, to = 3000, by = 50),
      eta = bt_train_r1$bestTune$eta,
      max_depth = bt_train_r2$bestTune$max_depth,
      gamma = c(0, 0.05, 0.1, 0.5, 0.7, 0.9, 1.0),
      colsample_bytree = bt_train_r3$bestTune$colsample_bytree,
      min_child_weight = bt_train_r2$bestTune$min_child_weight,
      subsample = bt_train_r3$bestTune$subsample
    ),
    outdir = outdir,
    name = "r4",
    cpus = cpus
  )
  # Reducing the Learning Rate
  bt_train_r5 <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = seq(from = 100, to = 5000, by = 100),
      eta = seq(from = 0.01, to = 0.2, by = 0.01),
      max_depth = bt_train_r2$bestTune$max_depth,
      gamma = bt_train_r4$bestTune$gamma,
      colsample_bytree = bt_train_r3$bestTune$colsample_bytree,
      min_child_weight = bt_train_r2$bestTune$min_child_weight,
      subsample = bt_train_r3$bestTune$subsample
    ),
    outdir = outdir,
    name = "r5",
    cpus = cpus
  )
  # model fitting
  bt_train_final <- fit_bt_model(
    data = tdd_index,
    factor_rm = factor_rm,
    transcript_selection = transcript_selection,
    grid = expand.grid(
      nrounds = bt_train_r5$bestTune$nrounds,
      eta = bt_train_r5$bestTune$eta,
      max_depth = bt_train_r5$bestTune$max_depth,
      gamma = bt_train_r5$bestTune$gamma,
      colsample_bytree = bt_train_r5$bestTune$colsample_bytree,
      min_child_weight = bt_train_r5$bestTune$min_child_weight,
      subsample = bt_train_r5$bestTune$subsample
    ),
    outdir = outdir,
    name = "final"
  )
  
  # compute model rmse
  tb_pred <- predict(
    bt_train_base,
    tdd_index %>%
      dplyr::select(-all_of(factor_rm)) %>%
      dplyr::filter(transcript_selection) %>%
      dplyr::select(-tdd_index)
  )
  rmse(tdd_index %>%
      dplyr::filter(transcript_selection) %>%
      dplyr::pull(tdd_index) - tb_pred)
  tb_pred <- predict(
    bt_fit,
    tdd_index %>%
      dplyr::select(-all_of(factor_rm)) %>%
      dplyr::filter(transcript_selection) %>%
      dplyr::select(-tdd_index)
  )
  rmse(tdd_index %>%
      dplyr::filter(transcript_selection) %>%
      dplyr::pull(tdd_index) - tb_pred)
}

tdd_analysis <- function(tdd_index, factor_rm, transcript_selection, outdir, cpus = 10) {
  cl <- makePSOCKcluster(cpus)
  registerDoParallel(cl)
  # compute Feature importance
  # Permutation feature importance measures the increase in the prediction error
  # of the model after we permuted the feature’s values, which breaks the
  # relationship between the feature and the true outcome.

  load(file = str_c(outdir, "bt_train_final.Rdata"))
  bt_predictor = Predictor$new(
  fit,
  data = tdd_index %>%
    dplyr::select(-transcript_id) %>% 
    dplyr::select(-tdd_index),
  y = tdd_index %>%
    dplyr::pull(-tdd_index)
  )
  if (!file.exists(str_c(outdir, "bt_imp.Rdata"))) {
    bt_imp_mse = FeatureImp$new(bt_predictor, loss = "mse")
    bt_imp_mae = FeatureImp$new(bt_predictor, loss = "mae")
    save(bt_imp_mse, bt_imp_mae,
         file = str_c(outdir, "bt_imp.Rdata"))
  }
  load(file = str_c(outdir, "bt_imp.Rdata"))
  plot(bt_imp_mae)
  plot(bt_imp_mse)

  # compute Feature effects
  # The partial dependence function at a particular feature value represents the
  # average prediction if we force all data points to assume that feature value.
  if (!file.exists(str_c(outdir, "bt_ale.Rdata"))) {
    bt_ale <- list()
  for (feature in tdd_index %>%
       dplyr::select(-all_of(factor_rm)) %>%
       dplyr::select(-tdd_index) %>% names()) {
      if (is.null(bt_ale[[feature]])) {
        bt_ale[[feature]] <- tryCatch({
          FeatureEffect$new(bt_predictor, feature = feature, method = "ale")
        }, error = {
          FeatureEffect$new(bt_predictor, feature = feature, method = "pdp")
        }, finally = {
          NULL
        })
      }
    }
    save(bt_ale,
         file = str_c(outdir, "bt_ale.Rdata"))
  }
  load(file = str_c(outdir, "bt_ale.Rdata"))

  for (feature in names(bt_ale)) {
    if (is.null(bt_ale[[feature]])) {
      print(feature)
    }
  }
  for (feature in names(bt_ale)) {
    print(feature)
    print(plot(bt_ale[[feature]]))
  }

  bt_ale_numeric <-
    tibble(
      ale = NA,
      type = NA,
      value = NA,
      feature = NA
    )
  for (feature in tdd_index %>%
       dplyr::select(-all_of(factor_rm)) %>%
       keep(is.numeric) %>% dplyr::select(-tdd_index) %>% names()) {
    tmp <- bt_ale[[feature]]$results
    names(tmp) <- c("ale", "type", "value")
    tmp$feature <- feature
    bt_ale_numeric <- rbind(bt_ale_numeric, tmp)
  }
  bt_ale_numeric %>%
    drop_na() %>%
    ggplot(aes(x = value, y = ale)) +
    geom_line() +
    geom_rug() +
    facet_wrap( ~ feature, scales = "free", ncol = 4)

  bt_ale_fct <- tibble(
    ale = NA,
    type = NA,
    value = NA,
    feature = NA
  )
  for (feature in tdd_index %>%
       keep(is.factor) %>% dplyr::select(-all_of(factor_rm)) %>% names()) {
    tmp <- bt_ale[[feature]]$results
    names(tmp) <- c("ale", "type", "value")
    if (tmp$value[1] == "pdp") {
      names(tmp) <- c("type", "ale", "value")
    }
    tmp$feature <- feature
    bt_ale_fct <-
      bt_ale_fct %>% bind_rows(tmp %>% mutate(value = as.character(value)))
  }
  bt_ale_fct %>%
    drop_na() %>%
    mutate(type = factor(type)) %>%
    ggplot(aes(x = type, y = ale)) +
    geom_boxplot() +
    facet_wrap( ~ feature, scales = "free", ncol = 3)


  # ICE plot
  # Individual Conditional Expectation (ICE) plots display one line per instance
  # that shows how the instance’s prediction changes when a feature changes.

  if (!file.exists(str_c(outdir, "bt_ice.Rdata"))) {
    bt_ice <- list()
    for (feature in bt_imp_mse$results %>%
         dplyr::pull(feature) %>% as_vector()) {
      bt_ice[[feature]] <- tryCatch({
        FeatureEffect$new(bt_predictor, feature = feature, method = "pdp+ice")
      }, warning = function(w) {
        NULL
      }, error = function(e) {
        NULL
      }, finally = {
        NULL
      })
    }
    save(bt_ice,
         file = str_c(outdir, "bt_ice.Rdata"))
  }
  load(file = str_c(outdir, "bt_ice.Rdata"))

  for (feature in names(bt_ice)) {
    if (is.null(bt_ice[[feature]])) {
      print(feature)
    }
  }
  for (feature in names(bt_ale)) {
    print(feature)
    print(plot(bt_ice[[feature]]))
  }

  # compute interaction of any feature with others
  if (!file.exists(str_c(outdir, "bt_interact.Rdata"))) {
    cl <- makePSOCKcluster(24)
    registerDoParallel(cl)
    bt_interact = Interaction$new(bt_predictor, parallel = T)
    stopCluster(cl)
    save(bt_interact,
         file = str_c(outdir, "bt_interact.Rdata"))
  }

  if (!file.exists(str_c(outdir, "bt_interact_2w.Rdata"))) {
    bt_interact_2w <- list()
    for (feature in tdd_index %>%
         dplyr::select(-all_of(factor_rm)) %>%
         dplyr::select(-tdd_index) %>%
         names()) {
      if (is.null(bt_interact_2w[[feature]])) {
        bt_interact_2w[[feature]] <- tryCatch({
          Interaction$new(bt_predictor, feature = feature)
        }, error = function(e) {
          NULL
        }, finally = {
          NULL
        })
      }
      save(bt_interact_2w,
           file = str_c(outdir, "bt_interact_2w.Rdata"))
    }
    save(bt_interact_2w,
         file = str_c(outdir, "bt_interact_2w.Rdata"))
  }

  load(file = str_c(outdir, "bt_interact.Rdata"))
  load(file = str_c(outdir, "bt_interact_2w.Rdata"))
  plot(bt_interact)

  for (feature in names(bt_interact_2w)) {
    if (is.null(bt_interact_2w[[feature]])) {
      print(feature)
    }
  }
  for (feature in names(bt_interact_2w)) {
    print(feature)
    print(plot(bt_interact_2w[[feature]]))
  }
}

# transcript_selection <- tdd_index$tdd_cell == "lympho"
# factor_rm <- c("transcript_id", "tdd_cell")
# outdir <- "results/bt_lympho/"

# model_construction(tdd_index, factor_rm, transcript_selection, outdir)
# tdd_analysis(tdd_index, factor_rm, transcript_selection, outdir)


# tuneplot_model_construction("r1", outdir)
# tuneplot_model_construction("r2", outdir)
# tuneplot_model_construction("r3", outdir)
# tuneplot_model_construction("r4", outdir)
# tuneplot_model_construction("r5", outdir)
# 
# transcript_selection <- tdd_index$tdd_cell == "macro"
# factor_rm <- c("transcript_id", "tdd_cell")
# outdir <- "results/bt_macro/"
# 
# model_construction(tdd_index, factor_rm, transcript_selection, outdir)
# 
# tuneplot_model_construction("r1", outdir)
# tuneplot_model_construction("r2", outdir)
# tuneplot_model_construction("r3", outdir)
# tuneplot_model_construction("r4", outdir)
# tuneplot_model_construction("r5", outdir)
# 
# transcript_selection <- tdd_index$tdd_cell == "lympho" & tdd_index$tdd_state == "activated"
# factor_rm <- c("transcript_id", "tdd_cell", "tdd_state")
# outdir <- "results/bt_lympho/activated/"
# 
# model_construction(tdd_index, factor_rm, transcript_selection, outdir)
# 
# tuneplot_model_construction("r1", outdir)
# tuneplot_model_construction("r2", outdir)
# tuneplot_model_construction("r3", outdir)
# tuneplot_model_construction("r4", outdir)
# tuneplot_model_construction("r5", outdir)
# 
# transcript_selection <- tdd_index$tdd_cell == "lympho" & tdd_index$tdd_state == "resting"
# factor_rm <- c("transcript_id", "tdd_cell", "tdd_state")
# outdir <- "results/bt_lympho/resting/"
# 
# model_construction(tdd_index, factor_rm, transcript_selection, outdir)
# 
# tuneplot_model_construction("r1", outdir)
# tuneplot_model_construction("r2", outdir)
# tuneplot_model_construction("r3", outdir)
# tuneplot_model_construction("r4", outdir)
# tuneplot_model_construction("r5", outdir)
# 
# transcript_selection <- tdd_index$tdd_cell == "macro" & tdd_index$tdd_state == "activated"
# factor_rm <- c("transcript_id", "tdd_cell", "tdd_state")
# outdir <- "results/bt_macro/activated/"
# 
# model_construction(tdd_index, factor_rm, transcript_selection, outdir)
# 
# tuneplot_model_construction("r1", outdir)
# tuneplot_model_construction("r2", outdir)
# tuneplot_model_construction("r3", outdir)
# tuneplot_model_construction("r4", outdir)
# tuneplot_model_construction("r5", outdir)
# 
# transcript_selection <- tdd_index$tdd_cell == "macro" & tdd_index$tdd_state == "resting"
# factor_rm <- c("transcript_id", "tdd_cell", "tdd_state")
# outdir <- "results/bt_macro/resting/"
# 
# model_construction(tdd_index, factor_rm, transcript_selection, outdir)
# 
# tuneplot_model_construction("r1", outdir)
# tuneplot_model_construction("r2", outdir)
# tuneplot_model_construction("r3", outdir)
# tuneplot_model_construction("r4", outdir)
# tuneplot_model_construction("r5", outdir)
