################################################################################
################################# random forest ################################
################################################################################
source("src/tdd_analysis_load.R")

cl <- makePSOCKcluster(16)
registerDoParallel(cl)

fit_control <- trainControl(
  method = "cv",
  number = 10,
  )

fit_grid <- expand.grid(
  mtry = seq(from = 60, to = 200, by = 10)
)

load(file = "results/tdd_index.Rdata")
rf_mtry <- tdd_index %>%
  train(
    tdd_index ~ .,
    data = .,
    method = "rf",
    trControl = fit_control,
    tuneGrid = fit_grid,
    ntree = 1000,
    scale = T,
    center = T,
    importance = T
  )
save(
  rf_mtry,
  file = "results/rf_mtry.Rdata")
load(file = "results/rf_mtry.Rdata")
rf_mtry$finalModel$mtry
plot(rf_mtry)
rf_fit <- rf_mtry$finalModel
plot(rf_fit)
# best rf_mtry 45


rf_fit <- tdd_index %>%
  drop_na() %>%
  train(
    tdd_index ~ .,
    data = .,
    method = "rf",
    trControl = fit_control,
    scale = T,
    center = T,
    ntree = 5000,
    mtry = rf_mtry$finalModel$mtry,
    importance = T
    )
print(rf_fit)
save(rf_fit,
     file = "results/rf_fit.Rdata")
load(file = "results/rf_fit.Rdata")


# compute model rmse

rmse <- function(error) {
  sqrt(mean(error^2))
}

tb_pred <- predict(
  rf_fit,
  tdd_index %>%
    dplyr::select(-tdd_index)
  )
rmse(tdd_index %>% dplyr::pull(tdd_index) - tb_pred)

################################################################################
rf_predictor = Predictor$new(
  rf_fit,
  data = tdd_index %>%
    dplyr::select(-tdd_index),
  y = tdd_index %>%
    dplyr::pull(-tdd_index)
  )
save(rf_predictor,
     file = "results/rf_predictor.Rdata")
save(tdd_index, rf_fit, rf_predictor,
     file = "results/rf_model.Rdata")


# compute Feature importance
# Permutation feature importance measures the increase in the prediction error 
# of the model after we permuted the feature’s values, which breaks the 
# relationship between the feature and the true outcome.
cl <- makePSOCKcluster(16)
registerDoParallel(cl)

load(file = "results/rf_model.Rdata")
rf_imp_mse = FeatureImp$new(rf_predictor, loss = "mse", parallel = T)
rf_imp_mae = FeatureImp$new(rf_predictor, loss = "mae", parallel = T)
save(rf_imp_mse, rf_imp_mae,
     file = "results/rf_imp.Rdata")
load(file = "results/rf_imp.Rdata")
plot(rf_imp_mae)
plot(rf_imp_mse)

# compute Feature effects
# The partial dependence function at a particular feature value represents the 
# average prediction if we force all data points to assume that feature value.
load(file = "results/rf_model.Rdata")
rf_ale <- list()
for (feature in rf_imp_mse$results %>%
     dplyr::pull(feature) %>% as_vector()) {
  rf_ale[[feature]] <- tryCatch({
    FeatureEffect$new(
      rf_predictor, feature = feature, method = "ale", parallel = t
    )
  }, warning = function(w) {
    NULL
  }, error = function(e) {
    NULL
  }, finally = {
    NULL
  })
}

save(rf_ale,
     file = "results/rf_ale.Rdata")
load(file = "results/rf_ale.Rdata")

for (feature in names(rf_ale)) {
  if (is.null(rf_ale[[feature]])) {
    print(feature)
  }
}
for (feature in names(rf_ale)) {
  print(feature)
  print(plot(rf_ale[[feature]]))
}

# ICE plot
# Individual Conditional Expectation (ICE) plots display one line per instance 
# that shows how the instance’s prediction changes when a feature changes.

load(file = "results/rf_model.Rdata")
rf_ice <- list()
for (feature in rf_imp_mse$results %>%
     dplyr::pull(feature) %>% as_vector()) {
  rf_ice[[feature]] <- tryCatch({
    FeatureEffect$new(
      rf_predictor, feature = feature, method = "pdp+ice", parallel = t
    )
  }, warning = function(w) {
    NULL
  }, error = function(e) {
    NULL
  }, finally = {
    NULL
  })
}
save(rf_ice,
     file = "results/rf_ice.Rdata")
load(file = "results/rf_ice.Rdata")

for (feature in names(rf_ice)) {
  if (is.null(rf_ice[[feature]])) {
    print(feature)
  }
}
for (feature in names(rf_ale)) {
  print(feature)
  print(plot(rf_ice[[feature]]))
}

# compute interaction of any feature with others
load(file = "results/rf_model.Rdata")

rf_interact = Interaction$new(rf_predictor, parallel = T)
save(rf_interact,
     file = "results/rf_interact.Rdata")

rf_interact_2w <- list()
for (feature in rf_imp_mse$results %>%
     dplyr::pull(feature) %>% as_vector()) {
  rf_interact_2w[[feature]] <- tryCatch({
    Interaction$new(rf_predictor, feature = feature, parallel = T)
  }, warning = function(w) {
    NULL
  }, error = function(e) {
    NULL
  }, finally = {
    NULL
  })
}
save(rf_interact_2w,
     file = "results/rf_interact_2w.Rdata")

load(file = "results/rf_interact.Rdata")
load(file = "results/rf_interact_2w.Rdata")
plot(rf_interact)

for (feature in names(rf_interact_2w)) {
  if (is.null(rf_interact_2w[[feature]])) {
    print(feature)
  }
}
for (feature in names(rf_interact_2w)) {
  print(feature)
  print(plot(rf_interact_2w[[feature]]))
}

# feature effects for all features at once:
load(file = "results/rf_model.Rdata")

rf_effs = FeatureEffects$new(rf_predictor, parallel = T)
save(rf_effs,
     file = "results/rf_effs.Rdata")
load(file = "results/rf_effs.Rdata")
plot(rf_effs)