library(tidyverse)
library(sva)
library(DESeq2)
library(WGCNA)
library(magrittr)
library(genefilter)
library(reshape2)
#devtools::install_github("kevinblighe/CorLevelPlot")
#library(CorLevelPlot)
library(plyr)
library(corrplot)
# library(tidyr)
# library(dplyr)
# library(ggplot2)
# require(ggtree)
# library(devtools)
# #install_github("jokergoo/ComplexHeatmap")
# library(ComplexHeatmap)
# library(circlize)
# library(ape)
# library(stringr)
# library(Hmisc)
# library(phylogram)
# library(RColorBrewer)


################################
# CLUSTER PROFILER FOR NETWORK FIGURES
################################

GOwithCP <- function(results_DE, Alltableok, table_CP_name) {
  
  # Extract background/Universe genes
  universe_gene <- select(org.Mm.eg.db, keys = row.names(Alltableok), columns = "ENTREZID", keytype = "SYMBOL") 
  
  # prepare DE gene list - order and get entrez gene ID
  DE_gene <- select(org.Mm.eg.db, keys = row.names(results_DE), columns = "ENTREZID", keytype = "SYMBOL") 
  # head(DE_tot_id)
  # sum(is.na(DE_tot_id$ENTREZID))
  
  # Run EnrichGO function from clusterProfiler
  DE_tot_ego <- enrichGO(gene          = DE_gene$ENTREZID,
                         universe      = universe_gene$ENTREZID,
                         OrgDb         = org.Mm.eg.db,
                         ont           = "ALL",
                         pAdjustMethod = "BH",
                         pvalueCutoff  = 0.1,
                         qvalueCutoff  = 0.05,
                         readable      = TRUE)
  head(DE_tot_ego)
  write.table(DE_tot_ego, file = table_CP_name,
              sep = "\t", quote = FALSE)

  return()
}



################################
# HEATMAP FIGURE PAPER
################################


fheatmap_wgcna=function(dset ,sptree, output_dir,hei=5, MEs0, genes, pc_df2){  
  
  sporder=sptree$tip.label
  
  # MEs0 contains Module Eigengenes per cluster (hypothetical central gene)
  
  # grey cluster is the non-clustered genes : to be removed
  MEs0 <- MEs0[,!names(MEs0)%in%c("MEgrey","total_familydeseq2","ancient_trans_familydeseq2","recent_trans_familydeseq2")]
  genes<-genes[genes$colors!="grey",] 
  n_mod=length(grep("ME",names(MEs0)))
  nSamples = nrow(pc_df2)
  
  #panova=sapply(2:(n_mod+1),function(x){summary(aov(lm(MEs0[,x]~pc_df2$"Bio17.without")))[[1]]["Pr(>F)"][[1]][1]})
  #ranova=sapply(2:(n_mod+1),function(x){summary(lm(MEs0[,x]~pc_df2$"Bio17.without"))$"adj.r.squared"})
  panovaXM=sapply(2:(n_mod+1),function(x){summary(aov(lm(MEs0[,x]~pc_df2$"cond_season")))[[1]]["Pr(>F)"][[1]][1]})
  ranovaXM=sapply(2:(n_mod+1),function(x){summary(lm(MEs0[,x]~pc_df2$"cond_season"))$"adj.r.squared"})
  sanovaXM=sapply(2:(n_mod+1),function(x){summary(lm(MEs0[,x]~pc_df2$"cond_season"))$coefficients["pc_df2$cond_seasonxeric","Estimate"]>0})
  #ranovaXM[sanovaXM==FALSE]=-ranovaXM[sanovaXM==FALSE]
  names(ranovaXM)=names(MEs0)[2:(n_mod+1)]
  names(panovaXM)=names(MEs0)[2:(n_mod+1)]
  names(sanovaXM)=names(MEs0)[2:(n_mod+1)]
  
  mat <- cbind(pc_df2[,"Bio17.without"],MEs0[,grep("ME",names(MEs0))])
  names(mat)[1] = "Bio17.without"
  cor_2 <- rcorr(as.matrix(mat))
  CorWithDeconv <- as.data.frame(cor_2$r["Bio17.without",grep("ME",names(mat))])
  CorWithDeconv_p<- as.data.frame(cor_2$P["Bio17.without",grep("ME",names(mat))])
  names(CorWithDeconv_p)="Bio17.without"
  names(CorWithDeconv)="Bio17.without"
  
  MEs0long <- MEs0 %>%
    pivot_longer(cols=-c("samples","sp_short_name","cond_season"), names_to = "module", values_to = "value") %>%
    group_by(sp_short_name,module,cond_season) %>%
    summarise(averagevalue = mean(value, na.rm = TRUE)) 
  
  
  # dendrogram to reorder the modules based on average per species
  
  MEs0wide_means <- MEs0 %>%
    pivot_longer(cols=-c("samples","sp_short_name","cond_season"), names_to = "module", values_to = "value") %>%
    as_tibble() %>% 
    dplyr::select(-samples) %>%
    pivot_wider(
      names_from = c(module),
      values_from = value,
      values_fn = mean)
  
  hc <- hclust(dist(t(MEs0wide_means[,-c(1,2)])))
  
  shortmodulename <- sub("ME", "", names(MEs0wide_means)[hc$order+2])
  names(MEs0wide_means) <- sub("ME", "", names(MEs0wide_means))
  MEs0wide_means <- data.frame(MEs0wide_means)
  rownames(MEs0wide_means) <- MEs0wide_means$sp_short_name
  
  
  # nb genes per module
  nbgenes <- table(genes$colors)
  nbgenes_order <- nbgenes[shortmodulename]
  
  MEs0wide_means <- MEs0wide_means[sporder[length(sporder):1],]
  shortmodulename <- sub("ME", "", names(MEs0wide_means)[hc$order+2])
  
  
  pvalue = panovaXM[hc$order]
  is_sig = pvalue < 0.05
  pch = rep("u*", length(hc$order))
  pch[sanovaXM==FALSE] = "o*"
  pch[!is_sig] = NA
  rvalue_col_fun = colorRamp2(c(0,1), c("white","black"))
  
  #rvalue_col_fun = colorRamp2(c(-1, 0,1), c("darkgreen","white","orange")) 
  
  ha = HeatmapAnnotation(
    rvalue = anno_simple(ranovaXM[hc$order], pch = pch),
    annotation_name_side = "left",
    col = list(rvalue=rvalue_col_fun)
  )
  
  column_ha = HeatmapAnnotation(cor= anno_simple(ranovaXM[hc$order], pch = pch,
                                                 col=rvalue_col_fun), 
                                n_genes = anno_barplot(as.numeric(nbgenes_order)))
  row_ha = rowAnnotation(status = MEs0wide_means$cond_season,col=list(status=(c("arid"="orange", "mesic"="darkgreen"))))
  
  #lgd_rvalue = Legend(title = "correlation", col_fun = rvalue_col_fun, 
  #                    labels = c("0", "0.5"))
  # and one for the significant p-values
  #lgd_sig = Legend(pch = "*", type = "points", labels = "< 0.01")
  # these two self-defined legends are added to the plot by `annotation_legend_list`
  
  pdf(file=paste0("WGCNA_heatmap_",dset,"_byphyl.pdf"),height = hei,width = 5)
  h<- Heatmap(MEs0wide_means[,hc$order+2], 
              show_row_dend = TRUE,
              show_column_dend = FALSE,
              name = "eigen genes", 
              #row_order = length(sporder):1, 
              #cluster_rows = reorder(as.dendrogram(sptree),13:1),
              cluster_rows = as.dendrogram(sptree),
              #row_dend_reorder = TRUE,
              #column_order = hc$order,
              #row_km = 7,
              top_annotation = column_ha)
  #, 
  #right_annotation = row_ha)
  #draw(h, annotation_legend_list = list(lgd_rvalue, lgd_sig))
  draw(h)
  dev.off()
  
  pdf(file=paste0("WGCNA_heatmap_",dset,"_bycond.pdf"),height = hei,width = 5)
  h<- Heatmap(MEs0wide_means[,hc$order+2], 
              show_row_dend = TRUE,
              show_column_dend = FALSE,
              name = "eigen genes", 
              row_order = length(sporder):1, 
              #cluster_rows = sptree,
              #column_order = hc$order,
              top_annotation = column_ha, 
              right_annotation = row_ha,
              row_split = MEs0wide_means$cond_season)
  #draw(h, annotation_legend_list = list(lgd_rvalue, lgd_sig))
  draw(h)
  dev.off()
  
}




################################################################################################
# FIRST FUNCTION RUNNING UP TO FINDING SOFT POWER
################################################################################################

optimize_param_FUN = function(data_expr) {
  
  print("Load data")
  
  All_table = data_expr$Alltable
  coldata = data_expr$coldata
  var1_design = data_expr$var1_design
  var2_design = data_expr$var2_design
  output_dir = data_expr$output_dir
  dataset = data_expr$dataset
  nCPU = data_expr$nCPU
  
  # Count Table preparation and filtering
  
  print("Count Table preparation")
  
  All_table = All_table[,colnames(All_table) %in% coldata$ID]
  
  mt_genes = grep(pattern = "mt-", rownames(All_table))
  if (length(mt_genes) > 0) {
    Alltableok <- All_table[-mt_genes,] #Remove mitochondrial genes (Co1, Co2, Cytb, Nd1, Nd2, Nd5, Nd6)
  } else {
    Alltableok = All_table
  }
  
  Alltableok1 <- round(Alltableok[complete.cases(Alltableok),])
  Alltableok1[is.na(Alltableok1)]=0
  n=apply(Alltableok1,1,function(x){sum(x==0)})
  #counts=counts[n<ncol(counts),]
  Alltableok1=Alltableok1[n==0,]
  
  # Prepare coldata
  
  coldata <- coldata[ which(coldata$ID %in% colnames(Alltableok1)), ]
  coldata <- coldata[ order(match(coldata$ID, colnames(Alltableok1))), ]#Reorder table such coldata and matrix has same order
  
  coldata$Batch_number <- as.factor(coldata$Batch_number)
  coldata$cond_season = factor(coldata$cond_season, level = c("arid","mesic"))
  
  if (dataset != "murinae") {
    coldata[[var1_design]] = factor(coldata[[var1_design]])
  } else {
    # do nothing
  }
  
  rownames(coldata) <- coldata$ID_final
  
  colnames(Alltableok1) <- coldata$ID_final
  

  #################################################################################
  ## Batch correction with Combat-Seq
  
  print("Run correction for all sets except Murinae")
  
  if (dataset != "murinae") {
    adjusted <- ComBat_seq(as.matrix(Alltableok1), batch=coldata[[var1_design]], group=coldata$cond_season)
  } else {
    adjusted <- Alltableok1
  }
  
  #################################################################################
  ## DE seq
  
  print("Run DEseq2")
  
  design = as.formula(paste("~ ", var2_design))
  
  ddsInput <- DESeqDataSetFromMatrix(countData = as.matrix(round(adjusted)),
                                     colData = coldata,
                                     design = design)
  
  dds <- DESeq(ddsInput)
  
  print("DE done")
  
  #################################################################################
  ## PCA after correction
  
  print("PCA figures by colors")
  
  norm <- log2(counts(dds, normalized=TRUE) + 1)
  pca <- prcomp(t(norm),scale. = T)
  
  
  intgroup.df <- as.data.frame(coldata[, c("species", "ID_final", "sp_short_name",
                                           var1_design, var2_design, "Bio17.without"), drop=FALSE])
    
  pcs=c(1,2, 3, 4, 5)
  pc_df <- data.frame(PC1=pca$x[,pcs[1]], PC2=pca$x[,pcs[2]], PC3=pca$x[,pcs[3]], PC4=pca$x[,pcs[4]], PC5=pca$x[,pcs[5]],
                      group="cond_season", intgroup.df)
  
  write.table(pc_df, file=paste0(output_dir,"Table_PC.tsv"), quote = F, row.names = T, sep = "\t")
  
  print("PCA done")
  
  #################################################################################
  ## Calculate variance to reduce number of genes
  
  print("variance step")
  
  #vsd <- varianceStabilizingTransformation(dds)
  
  wpn_vsd <- getVarianceStabilizedData(dds)
  rv_wpn <- rowVars(wpn_vsd)
  summary(rv_wpn)
  
  q50_wpn <- quantile( rowVars(wpn_vsd), .50)
  # q75_wpn <- quantile( rowVars(wpn_vsd), .75)  # <= original
  # q95_wpn <- quantile( rowVars(wpn_vsd), .95)  # <= changed to 95 quantile to reduce dataset
  expr_normalized <- wpn_vsd[ rv_wpn > q50_wpn, ]
  
  print("table expr_normalized after variance filtering:") 
  print(dim(expr_normalized))
  
  print(expr_normalized[1:5,1:10])
  
  #################################################################################
  ## WGCNA
  

  print("Start WGCNA")
  
  input_mat = t(expr_normalized)
  print(input_mat[1:5,1:10])
  
  sampleTree = hclust(dist(input_mat), method = "average")
  
  jpeg(paste0(output_dir , "sampleTree_for_outliers.jpeg"), width = 10, height = 8, units = "in", res = 300)
  plot(sampleTree, main = "Sample clustering to detect outliers", sub = "", xlab = "",
       cex.lab = 1.5, cex.axis = 1.5, cex.main = 2)
  dev.off()
  
  
  allowWGCNAThreads()          # allow multi-threading (optional)
  #> Allowing multi-threading with up to 4 threads.
  
  # Choose a set of soft-thresholding powers
  powers = c(c(1:10), seq(from = 12, to = 20, by = 2))
  
  # Call the network topology analysis function
  sft = pickSoftThreshold(
    input_mat,             # <= Input data
    #blockSize = 30,
    powerVector = powers,
    verbose = 5
  )
  
  
  jpeg(paste0(output_dir , "corrected_wgcna_r2_determination_q50.jpeg"), width = 10, height = 8, units = "in", res = 300)
  par(mfrow = c(1,2));
  cex1 = 0.9;
  
  plot(sft$fitIndices[, 1],
       -sign(sft$fitIndices[, 3]) * sft$fitIndices[, 2],
       xlab = "Soft Threshold (power)",
       ylab = "Scale Free Topology Model Fit, signed R^2",
       main = paste("Scale independence")
  )
  text(sft$fitIndices[, 1],
       -sign(sft$fitIndices[, 3]) * sft$fitIndices[, 2],
       labels = powers, cex = cex1, col = "red"
  )
  abline(h = 0.90, col = "red")
  plot(sft$fitIndices[, 1],
       sft$fitIndices[, 5],
       xlab = "Soft Threshold (power)",
       ylab = "Mean Connectivity",
       type = "n",
       main = paste("Mean connectivity")
  )
  text(sft$fitIndices[, 1],
       sft$fitIndices[, 5],
       labels = powers,
       cex = cex1, col = "red")
  
  dev.off()
  
  print("THE END")
}





################################################################################################
# FULL WGCNA FUNCTION
################################################################################################


wgcna_all_FUN = function(data_expr) {
  
  print("Load data")
  
  All_table = data_expr$Alltable
  coldata = data_expr$coldata
  var1_design = data_expr$var1_design
  var2_design = data_expr$var2_design
  output_dir = data_expr$output_dir
  pcs_to_plot = data_expr$pcs_to_plot
  dataset = data_expr$dataset
  nCPU = data_expr$nCPU
  picked_power = data_expr$picked_power
  sptree = data_expr$sptree
  
  # Count Table preparation and filtering
  
  print("Count Table preparation")
  
  All_table = All_table[,colnames(All_table) %in% coldata$ID]
  
  mt_genes = grep(pattern = "mt-", rownames(All_table))
  if (length(mt_genes) > 0) {
    Alltableok <- All_table[-mt_genes,] #Remove mitochondrial genes (Co1, Co2, Cytb, Nd1, Nd2, Nd5, Nd6)
  } else {
    Alltableok = All_table
  }
  
  Alltableok1 <- round(Alltableok[complete.cases(Alltableok),])
  Alltableok1[is.na(Alltableok1)]=0
  n=apply(Alltableok1,1,function(x){sum(x==0)})
  #counts=counts[n<ncol(counts),]
  Alltableok1=Alltableok1[n==0,]
  
  # Prepare coldata
  
  coldata <- coldata[ which(coldata$ID %in% colnames(Alltableok1)), ]
  coldata <- coldata[ order(match(coldata$ID, colnames(Alltableok1))), ]#Reorder table such coldata and matrix has same order
  
  coldata$Batch_number <- as.factor(coldata$Batch_number)
  coldata$cond_season = factor(coldata$cond_season, level = c("arid","mesic"))
  
  if (dataset != "murinae") {
    coldata[[var1_design]] = factor(coldata[[var1_design]])
  } else {
    # do nothing
  }
  
  rownames(coldata) <- coldata$ID_final
  
  colnames(Alltableok1) <- coldata$ID_final
  
  #################################################################################
  ## Batch correction with Combat-Seq
  
  print("Run correction for all sets except Murinae")
  
  if (dataset != "murinae") {
    adjusted <- ComBat_seq(as.matrix(Alltableok1), batch=coldata[[var1_design]], group=coldata$cond_season)
  } else {
    adjusted <- Alltableok1
  }
  
  #################################################################################
  ## DE seq
  
  print("Run DEseq2")
  
  design = as.formula(paste("~ ", var2_design))
  
  ddsInput <- DESeqDataSetFromMatrix(countData = as.matrix(round(adjusted)),
                                     colData = coldata,
                                     design = design)
  
  dds <- DESeq(ddsInput)
  
  print("DE done")
  
  #################################################################################
  ## PCA after correction
  
  print("PCA")
  
  norm <- log2(counts(dds, normalized=TRUE) + 1)
  pca <- prcomp(t(norm),scale. = T)
  
  
  intgroup.df <- as.data.frame(coldata[, c("species", "ID_final", "sp_short_name",
                                           var1_design, var2_design, "Bio17.without"), drop=FALSE])
  
  pcs=c(1,2, 3, 4, 5)
  pc_df <- data.frame(PC1=pca$x[,pcs[1]], PC2=pca$x[,pcs[2]], PC3=pca$x[,pcs[3]], PC4=pca$x[,pcs[4]], PC5=pca$x[,pcs[5]],
                      group="cond_season", intgroup.df)
  
  print("head(pc_df)")
  print(head(pc_df))
  print(colnames(pc_df))
  
  print("PCA done")
  
  #################################################################################
  ## Calculate variance to reduce number of genes
  
  print("variance step")
  
  #vsd <- varianceStabilizingTransformation(dds)
  
  wpn_vsd <- getVarianceStabilizedData(dds)
  rv_wpn <- rowVars(wpn_vsd)
  summary(rv_wpn)
  
  q50_wpn <- quantile( rowVars(wpn_vsd), .50)
  expr_normalized <- wpn_vsd[ rv_wpn > q50_wpn, ]
  
  print("table expr_normalized after variance filtering:") 
  print(dim(expr_normalized))
  
  print(expr_normalized[1:5,1:10])
  
  #################################################################################
  ## WGCNA
  
  print("Start WGCNA")
  
  input_mat = t(expr_normalized)
  print(input_mat[1:5,1:10])
  
  allowWGCNAThreads()          # allow multi-threading (optional)
  #> Allowing multi-threading with up to 4 threads.
  
  picked_power = picked_power
  temp_cor <- cor       
  cor <- WGCNA::cor         # Force it to use WGCNA cor function (fix a namespace conflict issue)
  
  netwk <- blockwiseModules(input_mat,                # <= input here
                            
                            # == Adjacency Function ==
                            power = picked_power,                # <= power here
                            networkType = "signed",
                            
                            # == Tree and Block Options ==
                            deepSplit = 2,
                            pamRespectsDendro = F,
                            # detectCutHeight = 0.75,
                            minModuleSize = 30,
                            maxBlockSize = 4000,
                            
                            # == Module Adjustments ==
                            reassignThreshold = 0,
                            mergeCutHeight = 0.25,
                            # Merge cutHeight to modify 0.25
                            
                            # == TOM == Archive the run results in TOM file (saves time)
                            saveTOMs = T,
                            saveTOMFileBase = "ER",
                            
                            # == Output Options
                            numericLabels = T,
                            verbose = 3)
  
  cor <- temp_cor     # Return cor function to original namespace
  
  
  ## Modules visualization
  
  print("Check modules")
  
  # Convert labels to colors for plotting
  mergedColors = labels2colors(netwk$colors)
  
  # Relate Module (cluster) Assignments to Treatment Groups
  print("module assignment to treatment groups")
  print(table(netwk$colors))
  
  module_df <- data.frame(
    gene_id = names(netwk$colors),
    colors = labels2colors(netwk$colors)
  )
  
  print("number of modules : ") 
  print(length(unique(module_df$colors)))
  
  write_delim(module_df,
              file = paste0(output_dir,"Table_gene_modules_q50.txt"),
              delim = "\t")
  
  # Plot the dendrogram and the module colors underneath
  jpeg(paste0(output_dir, "plot_dendro_colorq50.jpeg"), width = 15, height = 15, units = "in", res = 300)
  plotDendroAndColors(
    netwk$dendrograms[[1]],
    mergedColors[netwk$blockGenes[[1]]],
    "Module colors",
    dendroLabels = FALSE,
    hang = 0.03,
    addGuide = TRUE,
    guideHang = 0.05 )
  dev.off()
  
  #################################################################################
  ## Prepare MEs0 table with Eigengenes
  
  print("Prepare MEs0")
  
  # Get Module Eigengenes per cluster
  MEs0 <- moduleEigengenes(input_mat, mergedColors)$eigengenes
  
  # Reorder modules so similar modules are next to each other
  MEs0 <- orderMEs(MEs0)
  module_order = names(MEs0) %>% gsub("ME","", .)
  
  # Add samples names
  MEs0$samples = row.names(MEs0)
  MEs0 <- merge(x= MEs0, y = coldata[,c("ID_final", "sp_short_name", "species", 
                                        var1_design, var2_design)], 
                by.x = "samples", by.y = "ID_final")
  
  write.table(MEs0, file=paste0(output_dir,"Table_MEs0.tsv"), quote = F, row.names = T, sep = "\t")
  
  print("head MEs0")
  print(MEs0[1:10,1:10])
  print(colnames(MEs0))
  
  #################################################################################
  ## Plot all modules
  
  n_mod = length(unique(module_df$colors))
  mod_to_plot <- colnames(MEs0[2:n_mod]) #change column modules
  
  df_melt <- melt(MEs0[,which(colnames(MEs0) != "MEgrey")], variable.name = 'modules')
  
  print("Generate plot for all modules")
  
  if (dataset != "murinae") {
    
    box_all_mod <- ggplot(df_melt, aes(x=factor(df_melt[[var1_design]]), y = value, fill = cond_season)) +
      geom_boxplot() +
      scale_fill_manual(values= c("#FFA500", "#5EA336")) +
      theme_bw() +
      facet_wrap(modules ~ .)
    
    box_all_mod <- box_all_mod + theme(axis.text = element_text(size = 25)) # changes axis labels
    
    box_all_mod <- box_all_mod + theme(axis.title = element_text(size = 20)) # change axis titles
    
    box_all_mod <- box_all_mod + theme(text = element_text(size = 30))
    
    ggsave(box_all_mod, filename = paste0(output_dir,"WGCNA_total_all_modules_size.pdf"), 
           width = 18, height = 15, bg = "white")
    
    
    } else {
    
      box_all_mod <- ggplot(df_melt, aes(x=species, y = value, fill = cond_season)) +
        geom_boxplot() +
        scale_fill_manual(values= c("#FFA500", "#5EA336")) +
        theme_bw() +
        facet_wrap(modules ~ .)
      
      box_all_mod <- box_all_mod + theme(axis.text = element_text(size = 25)) # changes axis labels
      
      box_all_mod <- box_all_mod + theme(axis.title = element_text(size = 20)) # change axis titles
      
      box_all_mod <- box_all_mod + theme(text = element_text(size = 30))
      
      ggsave(box_all_mod, filename = paste0(output_dir,"WGCNA_all_modules_resize.pdf"), 
             width = 18, height = 15, bg = "white")
    
  }
  
  #################################################################################
  ## EXTRACT GENES FROM ALL MODULES
  
  # print("Get list of genes in all modules")
  # 
  # list_mod <- unique(module_df$colors)
  # 
  # sb_df <- lapply(list_mod, function(x){
  #   
  #   col_x = x
  #   submod = module_df %>%
  #     subset(colors %in% col_x)
  #   colnames(submod) <- c(col_x, "colors")
  #   #write.table(submod, file = paste0(path_out,"list_genes_",colnames(submod)[1],".tsv"), sep = "\t", quote = FALSE)
  #   
  #   return(submod)
  #   
  # })
  # 
  # df_test <- ldply(sb_df, data.frame)
  # write.table(df_test, file = paste0(output_dir, "list_genes_all_modules.tsv"), sep = "\t", quote = FALSE)
  # 
  
  #################################################################################
  ## PEARSON CORRELATION
  
  print("Pearson correlation")
  
  nGenes = nrow(expr_normalized);
  nSamples = ncol(expr_normalized)
  
  moduleTraitCor = cor(MEs0[2:n_mod], pc_df[c("PC1", "PC2", "PC3", "PC4", "PC5", "Bio17.without")], 
                       use= "p") # All PCs
  
  print("table moduleTraitCor")
  print(head(moduleTraitCor))
  
  moduleTraitPvalue = corPvalueStudent(moduleTraitCor, nSamples)
  print("table moduleTraitPvalue")
  print(head(moduleTraitPvalue))
  
  # Will display correlations and their p-values
  textMatrix =  paste(signif(moduleTraitCor, 2), "\n(",
                      signif(moduleTraitPvalue, 1), ")", sep = "");
  dim(textMatrix) = dim(moduleTraitCor)
  
  #par(mar = c(6, 8, 1, 1));
  # Display the correlation values within a heatmap plot
  pdf(paste0(output_dir, "WGCNA_heatmap_correlation_not_for_figure.pdf"), width = 10, height = 10)
  labeledHeatmap(Matrix = moduleTraitCor,
                 xLabels = names(pc_df[c("PC1", "PC2", "PC3", "PC4", "PC5", "Bio17.without")]),
                 yLabels = names(MEs0[2:n_mod]),
                 ySymbols = names(MEs0[2:n_mod]),
                 colorLabels = FALSE,
                 colors = blueWhiteRed(50),
                 textMatrix = textMatrix,
                 cex.text = 0.5,
                 zlim = c(-1,1),
                 main = paste("Module-trait relationships"))
  dev.off()
  
  
  # #################################################################################
  # ## Generate heatmap for paper
  # 
  # print("Do heatmap")
  # 
  # fheatmap_wgcna(dataset ,sptree, output_dir, hei=5, MEs0, genes, pc_df)
  # 
  #   
  
  #################################################################################
  ## Extract genes from significant modules to run Cluster Profiler
  
  print("extract significant cluster genes")
  
  list_mod_for_CP <- row.names(moduleTraitPvalue[which(moduleTraitPvalue$Bio17.without < abs(0.05))])
  print(list_mod_for_CP)
  
  # Combine genes from all significant modules
  WG_sig <- c()
  
  for(i in list_mod_for_CP){
    
    # Inside the loop, make one or elements to add to vector
    new_elements <- module_df[which(module_df$colors %in% i),
                               which(colnames(module_df) %in% i)]
    
    # Use 'c' to combine the existing vector with the new_elements
    WG_sig <- c(WG_sig, new_elements)
    
  }
  
  WG_sig
  print(head(WG_sig))
  WG_sig_id <- select(org.Mm.eg.db, keys = WG_sig, columns = "ENTREZID", keytype = "SYMBOL") 
  print(head(WG_sig_id))
  
  
  #################################################################################
  ## RUN ClusterProfiler
  
  # print("Analyse Cluster Profiler")
  # 
  # CP_name = paste0(output_dir,"Table_ClusterProfiler_BH_allGO.tsv")
  # 
  # GO_fam_season <- GOwithCP(results_DE=res_season_padj01_onl,
  #                           Alltableok = Alltableok,
  #                           table_CP_name = CP_name)
  # #Table background used is the count table beofre removing na lines (genes expressed in all ind)
  # 
  # 
  # print("Cluster Pro done")
  # 
  # 
  
  print("THE END")
}
 

