# Performs enrichment analysis on the DESEQ output using the library FGSEA.
# 
# AUTHOR:	C.Barker and I. Kamzolas
# INPUT: 
#          gene set file.gmt ,          .GMT file containing terms from https://www.gsea-msigdb.org/gsea/msigdb/genesets.jsp?collection=H
#          DEGs.txt,                    DESEQ output with differentially enriched genes. 
#
# OUTPUT:  list of enriched terms   


library(fgsea)
library(gprofiler2)
path<-"~/phenotype_networks/data/differential_exp"
source("~/phenotype_networks/scripts/get_wgcna.R") #get function to extract wgcna modules 
get_enrich <- function(deseq_res_path, pathways_to_enrich) {
  res <- read_csv(deseq_res_path)
  res$X1<-res$...1
  #res<-res[res$padj <0.05 & !is.na(res$padj),]#take out significant
  library(org.Hs.eg.db)
  ens2symbol <- AnnotationDbi::select(org.Hs.eg.db,
                                      key=res$X1, 
                                      columns="SYMBOL",
                                      keytype="ENSEMBL")
  #get symbol 
  ens2symbol <- as_tibble(ens2symbol)
  res <- inner_join(res, ens2symbol, by=c("X1"="ENSEMBL"))
  #get test statistic 
  res2 <- res %>% 
    dplyr::select(SYMBOL, stat) %>% 
    na.omit() %>% 
    distinct() %>% 
    group_by(SYMBOL)
  res3<-aggregate(res2$stat,by=list(SYMBOL=res2$SYMBOL),data=res2,FUN=mean)
  library(fgsea)
  ranks <- deframe(res3)
  head(ranks, 20)
  # Load the pathways into a named list
  fgseaRes <- fgsea(pathways=pathways_to_enrich, stats=ranks, nperm=10000)
}
setwd(path)

library(tidyverse)
pathways.hallmark <- gmtPathways("~/phenotype_networks/data/gene_set_files/Human_GO_AllPathways_with_GO_iea_April_01_2020_symbol.gmt") #Human_GO_AllPathways_with_GO_iea_April_01_2020_symbol.gmt
fgseaRes<-get_enrich("~/phenotype_networks/data/differential_exp/basal_DEGs.txt", pathways.hallmark)
#plot
sig.fgsea<-fgseaRes[fgseaRes$padj<0.1,]
sig.fgsea<-suppressWarnings(separate(data = sig.fgsea, col = pathway, into = "pathway",sep = "%"))

sig.fsea<-sig.fgsea[order(sig.fgsea$NES,decreasing = TRUE),]
  sig.fsea<-sig.fgsea[1:20] #take top 10
label<-strsplit(as.character(sig.fsea$pathway), "%")
sig.fsea$pathway<-do.call( rbind, label)[,1]
sig.fsea$NES<-as.numeric(sig.fsea$NES)
sig.fsea<-sig.fsea[isUnique(sig.fsea$pathway),]
library(ggplot2)
ggplot(sig.fsea, aes(reorder(pathway, NES), NES), NES) +
  geom_col(aes(fill=padj<0.1)) +
  coord_flip() +
  labs(x="Regulon", y="Normalized Enrichment Score",
       title="Regulons") + 
  theme_minimal()


wgcna.split<-get_wgcna(path_wgcna = "~/cell_shapes/genesprmodule/ALLgenesprmodule.tab",
                      path_correlations = "~/cell_shapes/genesprmodule/correlations.txt",is.full = T)
cluster_files<-c("./hetero_DEGs.txt", "./luminal_DEGs.txt", "./basal_DEGs.txt")
rm(results_all)
for (file in cluster_files) {
  fgseaRes<-suppressMessages(get_enrich(file, wgcna.split))
  fgseaRes$type <- strsplit(file, "_")[[1]][1]
  fgseaRes$minuslog10pvalue<- -log2(fgseaRes$padj)
  if (exists("results_all")) {
    results_all<-rbind(results_all, fgseaRes)
  }
  else {results_all<-fgseaRes}
}
results_all$cluster <- "all"
#wgcna.split.correlation and wgcna.split.shape are not local variables, you get them from running crispr_validation.R this needs 
correlations_table<-read.delim(file = "~/cell_shapes/genesprmodule/correlations.txt",header = FALSE) #../../genesprmodule/correlations.txt
correlations_table<-correlations_table[correlations_table$V1 != 1,]
correlations_table$V2<-as.numeric(correlations_table$V2)
correlations<-correlations_table$V1[correlations_table$V2 < -0.6 | correlations_table$V2 > 0.6]
wgcna.split.correlation<-wgcna.split[correlations] #subset modules that are only correlated with a shape 

results_all$cluster[results_all$pathway %in% names(wgcna.split.correlation)] <- "corelated"
#results_all$cluster[results_all$pathway %in% names(wgcna.split.shape)] <- "network"
results_all$sig <- results_all$padj > 0.01 #significant if under 0.01
conversions<-read.csv("~/cell_shapes/data/module_names.csv", header = T)
new_results<-merge(x = conversions, y = results_all, by.x = "ME_names", by.y = "pathway")
#plot
dots<-ggplot(data = results_all, aes(x=type, y = reorder(pathway, -minuslog10pvalue), color = NES, size = minuslog10pvalue, shape = sig)) + 
  geom_point() + scale_shape_manual(values = c(19, 0)) + 
  scale_x_discrete(limits=c("./hetero", "./luminal", "./basal")) 
dots + facet_grid(cluster ~ ., scales='free',space = "free") + scale_color_continuous(type = "viridis") + theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)) + theme(text = element_text(size = 20))  
