setwd("~/Documents/DEGage_stuff/DEGage_testing_2/review_round_2")

### WRANGLING OUTPUTS

disps <- c(0.1, 0.5, 1, 5,  10 )
effects <- c(1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5)
ngroup1 <- c(100, 100, 100, 100, 50, 50, 50, 25, 25, 10)
ngroup2 <- c(100, 50,  25,  10,  50, 25, 10, 25, 10, 10)

read_trials <- function(f, path){
  res <- tryCatch({res<-read.csv(paste(path, f, sep = ""))},
                   error = function(e){return(NA)})
  labels <- strsplit(f, "_")[[1]]
  #check these
  disp <- as.numeric(labels[2])
  effect <- as.numeric(labels[3])
  g1 <- as.numeric(labels[4])
  g2 <- as.numeric(labels[5])
  id <- substr(f, 1, nchar(f)-4)

  if("gene" %in% colnames(res)){
    res$X <- res$gene
  }

  if(length(res) > 1){

    res<- res[!is.na(as.numeric(res[,ncol(res)])),]
    expDEs<- res[as.numeric(res[,ncol(res)]) <= 0.05,]$X

    trueDEs <- paste("DEG", 1:1000, sep = "")
    trueEEs <- paste("GENE", 1001:10000, sep = "")

    tp <- sum(trueDEs %in% expDEs)
    fp <- sum(trueEEs %in% expDEs)
    tn <- sum(!(trueEEs %in% expDEs))
    fn <- sum(!(trueDEs %in% expDEs))

    sim.sensitivity <- tp/(tp+fn)
    sim.specificity <- tn/(tn+fp)
    precision <- tp/(tp+fp)
    accuracy <- (tp +tn)/nrow(res)
    f1 <- (2*tp)/(2*tp+fp+fn)

    simdf <- data.frame("Trial" = id,
                        DEGs= length(expDEs),
                        "Sensitivity" = sim.sensitivity,
                        "Specificity" = sim.specificity,
                        "Precision" = precision,
                        "Accuracy" = accuracy,
                        "F1" = f1,
                        g1 = g1,
                        g2 = g2,
                        disp = disp,
                        effect = effect)
  }else{
    simdf <- data.frame("Trial" = id,
                        DEGs= 0,
                        "Sensitivity" = 0,
                        "Specificity" = 0,
                        "Precision" = 0,
                        "Accuracy" = 0,
                        "F1" = 0,
                        g1 = g1,
                        g2 = g2,
                        disp = disp,
                        effect = effect)
  }
  return(simdf)
}

get_averages <- function(x){
  outdf <- data.frame()
  for(disp in disps){
    for(effect in effects){
      for(pos in 1:length(ngroup1)){
          g1 = ngroup1[pos]
          g2 = ngroup2[pos]
          temp <- x[x$disp == disp &
            x$effect == effect &
            x$g1 == g1 &
            x$g2 == g2,]
          outdf <- rbind(outdf,
                         data.frame(
                           Sensitivity = mean(temp$Sensitivity),
                           Specificity = mean(temp$Specificity),
                           Precision = mean(temp$Precision),
                           Accuracy = mean(temp$Accuracy),
                           F1 = mean(temp$F1),
                           g1 = g1,
                           g2 = g2,
                           disp = disp,
                           effect = effect
                         ))
       }
    }
  }
  return(outdf)
}

get_colors <- function(x){
    colors <- c("#208eb7", "#8de4d3", "#0a4f4e",
              "#a2e67c", "#638123", "#44f270",
              "#2aa63a", "#abc8f5", "#5858e1",
              "#615887")
  x$imbalance <- paste(x$g1, x$g2, sep = " v ")
  x <- x[order(x$g1 + x$g2), ]
  x$imbalance <- factor(x$imbalance, levels = unique(x$imbalance))
  x$color <- rep(0, nrow(x))
  for(i in 1:length(levels(x$imbalance))){
    x[x$imbalance ==levels(x$imbalance)[i], ]$color <- colors[i]
  }
return(x)
}

files <- list.files("degage_imbalance_res_3/")
x <- lapply(X = files, FUN = read_trials, path = "degage_imbalance_res_3/")
x <- do.call(rbind,x )
write.csv(x, "degage_non_averaged_imbalance_stats.csv")
x <- get_averages(x)
x <- get_colors(x)
write.csv(x, "degage_imbalance_stats_3.csv")

files <- list.files("degage_imbalance_res_3_random/")
x <- lapply(X = files, FUN = read_trials, path = "degage_imbalance_res_3_random/")
x <- do.call(rbind,x )
write.csv(x, "degage_random_non_averaged_imbalance_stats.csv")
x <- get_averages(x)
x <- get_colors(x)
write.csv(x, "degage_random_imbalance_stats_3.csv")

files <- list.files("deseq2_imbalance_res_3/")
x <- lapply(X = files, FUN = read_trials, path = "deseq2_imbalance_res_3/")
x <- do.call(rbind,x )
write.csv(x, "deseq2_non_averaged_imbalance_stats.csv")
x <- get_averages(x)
x <- get_colors(x)
write.csv(x,  "deseq2_imbalance_stats_3.csv")

files <- list.files("wilcoxon_imbalance_res_3/")
x <- lapply(X = files, FUN = read_trials, path = "wilcoxon_imbalance_res_3/")
x <- do.call(rbind,x )
write.csv(x, "wilcoxon_non_averaged_imbalance_stats.csv")
x <- get_averages(x)
x <- get_colors(x)
write.csv(x, "wilcoxon_imbalance_stats_3.csv")

files <- list.files("glm_imbalance_res_3/")
x <- lapply(X = files, FUN = read_trials, path = "glm_imbalance_res_3/")
x <- do.call(rbind,x )
write.csv(x, "glm_non_averaged_imbalance_stats.csv")
x <- get_averages(x)
x <- get_colors(x)
write.csv(x, "glm_imbalance_stats_3.csv")

files <- list.files("edger_imbalance_res_3/")
x <- lapply(X = files, FUN = read_trials, path = "edger_imbalance_res_3/")
x <- do.call(rbind,x )
write.csv(x, "edger_non_averaged_imbalance_stats.csv")
x <- get_averages(x)
x <- get_colors(x)
write.csv(x,  "edger_imbalance_stats_3.csv")

## PLOTTING

library(latticeExtra)
library(gridExtra)
library(dplyr)
library(lattice)
library(tidyverse)
library(rstatix)
library(ggplot2)
library(ggpubr)
#Plots bar graphs, just manually changed F1 to the parameter of choice for each plot
plot_data <- function(x, disp, title){
  p <- cloud(F1 ~ effect+imbalance, x[x$disp == disp,],
            panel.3d.cloud=panel.3dbars,
            col.facet=x[x$disp == disp,]$color,
            xbase=0.2, ybase=0.2, scales=list(arrows=FALSE, col=1),
            par.settings = list(axis.line = list(col = "transparent")),
            xlab = "Effect Size", ylab = "Sample Sizes", zlab = "F1",
            zlim = c(0,1), pretty = T, screen = list(z =45, x = -70, y=0),
            main = paste("Dispersion = ", disp, sep = ""))
  return(p)
}


#degage - subsampling
degage_results <- read.csv("degage_imbalance_stats_3.csv")
degage_results <- degage_results[,-1]
degage_results$imbalance <- factor(degage_results$imbalance, levels = unique(degage_results$imbalance))
degage_results$effect <- factor(degage_results$effect)
degage <- grid.arrange(plot_data(degage_results, disp = 0.1),
             plot_data(degage_results, disp = 0.5),
             plot_data(degage_results, disp = 1),
             plot_data(degage_results, disp = 5),
             plot_data(degage_results, disp = 10),
             ncol = 5)

#degage - random assignment
degager_results <- read.csv("degage_random_imbalance_stats_3.csv")
degager_results <- degager_results[,-1]
degager_results$imbalance <- factor(degager_results$imbalance, levels = unique(degager_results$imbalance))
degager_results$effect <- factor(degager_results$effect)
degager <- grid.arrange(plot_data(degager_results, disp = 0.1),
                       plot_data(degager_results, disp = 0.5),
                       plot_data(degager_results, disp = 1),
                       plot_data(degager_results, disp = 5),
                       plot_data(degager_results, disp = 10),
                       ncol = 5)

#deseq2
deseq2_results <- read.csv("deseq2_imbalance_stats_3.csv")
deseq2_results <- deseq2_results[,-1]
deseq2_results$imbalance <- factor(deseq2_results$imbalance, levels = unique(deseq2_results$imbalance))
deseq2_results$effect <- factor(deseq2_results$effect)
deseq2 <- grid.arrange(plot_data(deseq2_results, disp = 0.1),
                       plot_data(deseq2_results, disp = 0.5),
                       plot_data(deseq2_results, disp = 1),
                       plot_data(deseq2_results, disp = 5),
                       plot_data(deseq2_results, disp = 10),
                       ncol = 5)
#edger
edger_results <- read.csv("edger_imbalance_stats_3.csv")
edger_results <- edger_results[,-1]
edger_results$imbalance <- factor(edger_results$imbalance, levels = unique(edger_results$imbalance))
edger_results$effect <- factor(edger_results$effect)
edger <- grid.arrange(plot_data(edger_results, disp = 0.1),
                       plot_data(edger_results, disp = 0.5),
                       plot_data(edger_results, disp = 1),
                       plot_data(edger_results, disp = 5),
                       plot_data(edger_results, disp = 10),
                       ncol = 5)
#glm
glm_results <- read.csv("glm_imbalance_stats_3.csv")
glm_results <- glm_results[,-1]
glm_results$imbalance <- factor(glm_results$imbalance, levels = unique(glm_results$imbalance))
glm_results$effect <- factor(glm_results$effect)
glm <- grid.arrange(plot_data(glm_results, disp = 0.1),
                       plot_data(glm_results, disp = 0.5),
                       plot_data(glm_results, disp = 1),
                       plot_data(glm_results, disp = 5),
                       plot_data(glm_results, disp = 10),
                       ncol = 5)
#wilcoxon
wilcoxon_results <- read.csv("wilcoxon_imbalance_stats_3.csv")
wilcoxon_results <- wilcoxon_results[,-1]
wilcoxon_results$imbalance <- factor(wilcoxon_results$imbalance, levels = unique(wilcoxon_results$imbalance))
wilcoxon_results$effect <- factor(wilcoxon_results$effect)
wilcoxon <- grid.arrange(plot_data(wilcoxon_results, disp = 0.1),
                       plot_data(wilcoxon_results, disp = 0.5),
                       plot_data(wilcoxon_results, disp = 1),
                       plot_data(wilcoxon_results, disp = 5),
                       plot_data(wilcoxon_results, disp = 10),
                       ncol = 5)

grid.arrange(degage, degager, deseq2, edger, glm, wilcoxon, ncol = 1)

#BOXPLOTS

library(ggplot2)
degage_results <- read.csv("degage_non_averaged_imbalance_stats.csv")
degager_results <- read.csv("degage_random_non_averaged_imbalance_stats.csv")
deseq2_results <- read.csv("deseq2_non_averaged_imbalance_stats.csv")
edger_results <- read.csv("edger_non_averaged_imbalance_stats.csv")
glm_results <- read.csv("glm_non_averaged_imbalance_stats.csv")
wilcoxon_results <- read.csv("wilcoxon_non_averaged_imbalance_stats.csv")

get_replicate <- function(X){
  return(strsplit(X, "_")[[1]][6])
}

average_over_replicate <- function(x, testname){
  x$rep <- as.numeric(sapply(X = x$Trial, FUN = get_replicate))
  x$imbalance <- paste(x$g1, x$g2, sep = " v ")
  outdf <- data.frame()
  for(rep in unique(x$rep)){
    for(i in unique(x$imbalance)){
    outdf <- rbind(outdf,
                   data.frame(
                     Sensitivity = mean(x[x$rep == rep & x$imbalance == i, ]$Sensitivity),
                     Specificity = mean(x[x$rep == rep & x$imbalance == i, ]$Specificity),
                     F1 = mean(x[x$rep == rep & x$imbalance == i, ]$F1),
                     imbalance = i,
                     g1 = as.numeric(strsplit(i, " v ")[[1]][1]),
                     g2 = as.numeric(strsplit(i, " v ")[[1]][2])
                   ))
    }
  }
  outdf$test <- testname
  return(outdf)
}

boxdf <- rbind(
      average_over_replicate(degage_results, "DEGage - Subsampling"),
      average_over_replicate(degager_results, "DEGage - Random Assignment"),
      average_over_replicate(deseq2_results, "DESeq2"),
      average_over_replicate(edger_results, "edgeR"),
      average_over_replicate(glm_results, "glm.nb"),
      average_over_replicate(wilcoxon_results, "Wilcoxon")
)

order <- c("100 v 100","100 v 50", "50 v 50", "100 v 25",
           "50 v 25","25 v 25", "100 v 10", "50 v 10",
           "25 v 10", "10 v 10" )
boxdf <- boxdf[order(boxdf$g1 + boxdf$g2), ]
boxdf$imbalance <- factor(boxdf$imbalance, levels = order)

make_boxplot <- function(boxdf, test){
x.f1 <- ggplot(boxdf[boxdf$test == test,], aes(x = imbalance, y = F1))+
  geom_boxplot(position = "identity") +
  theme_classic() +
  ylim(0, 1) +
  xlab("")

 x.sens <- ggplot(boxdf[boxdf$test == test,], aes(x = imbalance, y = Sensitivity))+
  geom_boxplot(position = "identity") +
  theme_classic() +
  ylim(0, 1) +
  xlab("")

x.spec <- ggplot(boxdf[boxdf$test == test,], aes(x = imbalance, y = Specificity))+
  geom_boxplot(position = "identity") +
  theme_classic() +
  ylim(0, 1) +
  xlab("")

x.plot <- ggarrange(x.f1, x.sens, x.spec, ncol= 3, nrow = 1, common.legend = T)
return(x.plot)
}

degage.all <- boxdf[grep("DEGage", boxdf$test),]


f1 <- ggplot(degage.all, aes(x= imbalance, y = F1, color = test)) +
  geom_boxplot(fill = "white") +
  theme_classic()+
  stat_compare_means(method = "t.test", label = "p.format", size = 3, label.x = 1.5, label.y = 1.05) +
  theme(axis.ticks.x = element_blank(),
        axis.text.x = element_text(angle = 45, hjust = 1),
        axis.title.x = element_blank(),
        plot.title = element_text(hjust = 0.5)) +
  ylim(0, 1.05)+
  scale_color_manual(name = "DEGage Version", values = c("darkseagreen", "thistle4"))

sens <- ggplot(degage.all, aes(x= imbalance, y = Sensitivity, color = test)) +
  geom_boxplot(fill = "white") +
  theme_classic()+
  stat_compare_means(method = "t.test", label = "p.format", size = 3, label.x = 1.5, label.y = 1.05)  +
  theme(axis.ticks.x = element_blank(),
        axis.text.x = element_text(angle = 45, hjust = 1),
        axis.title.x = element_blank(),
        plot.title = element_text(hjust = 0.5)) +
  ylim(0, 1.05)+
  scale_color_manual(name = "DEGage Version", values = c("darkseagreen", "thistle4"))

spec <- ggplot(degage.all, aes(x= imbalance, y = Specificity, color = test)) +
  geom_boxplot(fill = "white") +
  theme_classic()+
  stat_compare_means(method = "t.test", label = "p.format", size = 3, label.x = 1.5, label.y = 1.05) +
  theme(axis.ticks.x = element_blank(),
        axis.text.x = element_text(angle = 45, hjust = 1),
        axis.title.x = element_blank(),
        plot.title = element_text(hjust = 0.5)) +
  ylim(0, 1.05)+
  scale_color_manual(name = "DEGage Version", values = c("darkseagreen", "thistle4"))

degage <- ggarrange(f1, sens, spec, ncol = 3, common.legend = T)



ggarrange(
  degage,
  make_boxplot(boxdf, "DESeq2"),
  make_boxplot(boxdf, "edgeR"),
  make_boxplot(boxdf, "glm.nb"),
  make_boxplot(boxdf, "Wilcoxon"),
  ncol = 1, nrow = 5)


#Supplementary Table

result <- boxdf[,c(7, 4, 1, 2, 3)] %>%
  group_by(test, imbalance) %>%
  summarize(
    mean_Sensitivity = mean(Sensitivity, na.rm = TRUE),
    sd_Sensitivity = sd(Sensitivity, na.rm = TRUE),
    mean_Specificity = mean(Specificity, na.rm = TRUE),
    sd_Specificity = sd(Specificity, na.rm = TRUE),
    mean_F1 = mean(F1, na.rm = TRUE),
    sd_F1 = sd(F1, na.rm = TRUE)
  )


write.csv(data.frame(result), "imbalance_figure_supplementary_data.csv")












