# plot weighted codon usage
library(tidyverse)
library(ggplot2)
library(parallel)

options(mc.cores = 8)
# functions ---------------------------------------------------------------

select_genes <- function(x, database = df_codon_count) {
  sub_data <- database %>% filter(Gene.stable.ID %in% x)
  return(sub_data)
}

calc_freq_codon_in_transcriptome <- function(data, universe = "transcriptome") {
  if (all(universe == "transcriptome")) {
    codon <- c("AAA", "AAC", "AAG", "AAT", "ACA", "ACC", "ACG", "ACT", "AGA", "AGC", "AGG", "AGT", "ATA", "ATC", "ATG", "ATT", 
               "CAA", "CAC", "CAG", "CAT", "CCA", "CCC", "CCG", "CCT", "CGA", "CGC", "CGG", "CGT", "CTA", "CTC", "CTG", "CTT", 
               "GAA", "GAC", "GAG", "GAT", "GCA", "GCC", "GCG", "GCT", "GGA", "GGC", "GGG", "GGT", "GTA", "GTC", "GTG", "GTT", 
               "TAC", "TAT", "TCA", "TCC", "TCG", "TCT", "TGC", "TGG", "TGT", "TTA", "TTC", "TTG", "TTT")
    
    data <- data[,colnames(data) %in% codon]
  } else {
    data <- data[data$Gene.stable.ID %in% universe, colnames(data) %in% codon]
  }
  
  return(colSums(data) / sum(data))
  
}

calc_codon_demand_in_apply <- function(x) {
  return(x[1:length(x)-1] * x[length(x)])
}

calc_pond_freq_in_apply <- function(x) {
  return(x[1:length(x)] / sum(x))
}

merge_data <- function(codon_table, expr_table, x = "Gene.stable.ID", y = "gene_id"){
  merged_data <- merge(codon_table, expr_table, by.x = x, by.y = y)
  merged_data <- merged_data[,colnames(merged_data) %in% c(x, "control", "treated", codon)]
  return(merged_data)
}

#' calculate weighted codon frequencies
#'
#' @param data dataframe create with merge_data()
#'
#' @return named vector
#' @export
#'
#' @examples
pond_freq <- function(data, colname_RPKM) {
  #data <- data %>% mutate(treated_baseMean = baseMean * 2^log2FoldChange)
  codon <- c("AAA", "AAC", "AAG", "AAT", "ACA", "ACC", "ACG", "ACT", "AGA", "AGC", "AGG", "AGT", "ATA", "ATC", "ATG", "ATT", 
             "CAA", "CAC", "CAG", "CAT", "CCA", "CCC", "CCG", "CCT", "CGA", "CGC", "CGG", "CGT", "CTA", "CTC", "CTG", "CTT", 
             "GAA", "GAC", "GAG", "GAT", "GCA", "GCC", "GCG", "GCT", "GGA", "GGC", "GGG", "GGT", "GTA", "GTC", "GTG", "GTT", 
             "TAC", "TAT", "TCA", "TCC", "TCG", "TCT", "TGC", "TGG", "TGT", "TTA", "TTC", "TTG", "TTT")
  
  if (colname_RPKM == "none") {
    codons <- data[, colnames(data) %in% c(codon)]
    weighted_freq <- as.data.frame(t(apply(codons, 1, FUN =  calc_pond_freq_in_apply )))
    weighted_freq <- colMeans(weighted_freq)
  } else {
    codons <- data[, colnames(data) %in% c(codon, colname_RPKM)]
    weighted_freq <- as.data.frame(t(apply(codons, 1, FUN =  calc_codon_demand_in_apply )))
    weighted_freq <- colSums(weighted_freq)
  }
  return(weighted_freq)
}

#' Calculated weighted codon frequences for control and treated condition and calculate the foldchange between them
#' 
#' @param data dataframe of differential expression from DESeq package
#' @return named vector of weighted frequence fold change 
#' @example treated_vs_control_pond_freq(data)
treated_vs_control_pond_freq <- function(data) {
  data <- drop_na(data)
  control <- pond_freq(data, "control")
  treated <- pond_freq(data, "treated")
  a <- treated / control
  return(a)
}

#' Create a list of n_times elements of n random genes for bootstrapping analysis
#'
#' @param n 
#' @param universe 
#' @param n_times 
#'
#' @return list of vectors
#' @export
#'
#' @examples
create_random_genes_list <- function (n, universe = expr_diff$gene_id, n_times = 10000) {
  random_list <- list()
  for (i in 1:n_times) {
    random_list[[i]] <- sample(universe, n, replace = FALSE)
  }
  return(random_list)
}

subset_ref_table <- function(x, ref_table = merged_data) {
  return(ref_table[ref_table$Gene.stable.ID %in% x, ])
}

fold_change_in_bootstrap <- function(treated, control) {
  return (treated / control)
}

pond_freq_on_random_list <- function(selected_genes, n_bootstrap, ref_table, condition = "treated") {
  treated <- pond_freq(subset_ref_table(selected_genes, ref_table), condition)
  a <- create_random_genes_list(n = length(selected_genes), universe = ref_table$Gene.stable.ID, n_times = n_bootstrap)
  
  random_freq_list <- mclapply(a, function(x) treated / pond_freq(subset_ref_table(x, ref_table), condition))
  
  return(random_freq_list)
}

mean_FC_from_list <- function(c){
  x <- as.data.frame(names(c[[1]]))
  for(i in 1:length(c)) {
    x[,i] <- c[[i]]
  }
  x <- rowMeans(x)
  names(x) <- names(c[[1]])
  return(x)
}

# fold_change_weight_freq <- function(weighted_freq, random_freq_list){
#   FC_weighted_freq <- lapply(random_freq_list, FUN = function(x) weighted_freq / x)
#   return(FC_weighted_freq)
# }

load_codon_counts <- function(fasta = "~/RMI2/gitlab/tdd/data/cds_seq_mm10.txt"){
  suppressMessages(source(file = "src/CSC_functions.R"))
  message("Loading fasta")
  fastaFile <- readDNAStringSet(fasta)
  message("counting codon in each cds")
  df_codon_count <- as.data.frame(list_to_df(list_codon = count_codon(fastaFile)))
  df_codon_count$transcript_id <- rownames(df_codon_count)
  
  message("collapsing transcript counts into gene counts")
  app <- as_tibble(read.csv(file = "data/ApprisAnnot_for_transcripts.txt",
                            sep = "\t"))
  df_codon_count <- merge(df_codon_count, app, by.x = "transcript_id", by.y = "Transcript.stable.ID")
  df_codon_count <- df_codon_count[grep("principal", df_codon_count$APPRIS.annotation),]
  df_codon_count <- df_codon_count %>% select(-transcript_id) %>% unique
  dup <- df_codon_count[duplicated(df_codon_count$Gene.stable.ID), "Gene.stable.ID"]
  return(df_codon_count)
}

prepare_diff_expr_table <- function(x = "results/DiffExpr_lympho_activation3h.csv") {
  expr_diff <- read.csv(x)
  colnames(expr_diff)[1] <- "gene_id"
  expr_diff$gene_id <- gsub("\\.[[:digit:]]*", "", expr_diff$gene_id)
  return(expr_diff)
}

extract_RPKM_from_table <- function(cell = "Lympho", db_path = "data/databases/2020-03-26_07-39-50_Subset_Data_processed.csv", late = FALSE) {
  table <- read.csv(db_path)
  if (late == FALSE) {
    RPKM_values <- table[,grep(paste0("(^gene_id$)|(",cell, "_(Resting|Activated)_RPKM)"),colnames(table))]
  } else if (late == TRUE) {
    RPKM_values <- table[,grep(paste0("(^gene_id$)|(",cell, "_(Resting|Activated_6h)_RPKM)"),colnames(table))]
  }
  RPKM_values <- unique(RPKM_values)
  colnames(RPKM_values) <- c("gene_id", "control", "treated")
  return(RPKM_values)
}

make_plot <- function(x){
  y <- as.data.frame(x)
  ranked_codon <- names(x)[order(x)]
  y$codon <- factor(rownames(y), levels = ranked_codon)
  y$variable <- "FoldChange treated vs control"

  ggplot(data = y, aes(x = variable, y = codon, fill = x)) + 
    geom_tile() + 
    scale_fill_gradient2(low = "green", high = "red", mid = "black", midpoint = mean(y$x))
}

make_plot_dual <- function(x,y){
  ranked_codon <- names(x)[order(x)]
  x <- as.data.frame(x)
  colnames(x) <- "value"
  x$codon <- rownames(x)
  x$variable <- "3h post activation" 
  
  y <- as.data.frame(y)
  colnames(y) <- "value"
  y$codon <- rownames(y)
  y$variable <- "6h post activation" 
  
  w <- rbind(x,y)
  w$codon <- factor(w$codon, levels = ranked_codon)
  
  ggplot(data = w, aes(x = variable, y = codon, fill = value)) + 
    geom_tile() + 
    scale_fill_gradient2(low = "green", high = "red", mid = "black", midpoint = mean(w$value))
}

# Main -------------------------------------------------------------
main <- function(){
  # for the population
  df_codon_count <- load_codon_counts()
  RPKM_values <- extract_RPKM_from_table(cell = "Lympho", late = FALSE)
  merged_data <- merge_data(codon_table = df_codon_count, expr_table = RPKM_values)
  x <- treated_vs_control_pond_freq(merged_data)
  # early <- make_plot(x)
  
  RPKM_values <- extract_RPKM_from_table(cell = "Lympho", late = TRUE)
  merged_data <- merge_data(codon_table = df_codon_count, expr_table = RPKM_values)
  y <- treated_vs_control_pond_freq(merged_data)
  # make_plot(y)
  
  make_plot_dual(x,y)

  ### upgenes at 3h 
  # with only uopgenes
  expr_diff <- prepare_diff_expr_table()
  up_genes <- expr_diff %>% filter(padj < 0.05 & log2FoldChange > 1) %>%
    select(gene_id) %>%
    unlist
  RPKM_values <- extract_RPKM_from_table(cell = "Lympho", late = FALSE)
  merged_data <- merge_data(codon_table = df_codon_count, expr_table = RPKM_values)
  c <- pond_freq_on_random_list(selected_genes = up_genes, n_bootstrap = 10000, condition = "treated", ref_table = merged_data)
  make_plot(mean_FC_from_list(c))
  
  # freq codon in upregulated genes
  transcriptome <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = "transcriptome")
  up <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = up_genes)
  up <- up[names(up)%in%names(transcriptome)]
  make_plot(up/transcriptome)
  
  ### upgenes at 6h 
  expr_diff <- prepare_diff_expr_table(x = "results/DiffExpr_lympho_activation6h.csv")
  up_genes <- expr_diff %>% filter(padj < 0.05 & log2FoldChange > 1) %>%
    select(gene_id) %>%
    unlist
  RPKM_values <- extract_RPKM_from_table(cell = "Lympho", late = TRUE)
  merged_data <- merge_data(codon_table = df_codon_count, expr_table = RPKM_values)
  c <- pond_freq_on_random_list(selected_genes = up_genes, n_bootstrap = 10000, condition = "treated", ref_table = merged_data)
  make_plot(mean_FC_from_list(c))
  
  # freq codon in upregulated genes
  transcriptome <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = "transcriptome")
  up <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = up_genes)
  up <- up[names(up)%in%names(transcriptome)]
  make_plot(up/transcriptome)
  
  ### down at 3h 
  # with only downgenes
  expr_diff <- prepare_diff_expr_table()
  down_genes <- expr_diff %>% filter(padj < 0.05 & log2FoldChange < -1) %>%
    select(gene_id) %>%
    unlist
  RPKM_values <- extract_RPKM_from_table(cell = "Lympho", late = FALSE)
  merged_data <- merge_data(codon_table = df_codon_count, expr_table = RPKM_values)
  c <- pond_freq_on_random_list(selected_genes = down_genes, n_bootstrap = 10000, condition = "treated", ref_table = merged_data)
  make_plot(mean_FC_from_list(c))
  
  # freq codon in upregulated genes
  transcriptome <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = "transcriptome")
  down <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = down_genes)
  down <- down[names(down)%in%names(transcriptome)]
  make_plot(down/transcriptome)
  
  ### downgenes at 6h 
  expr_diff <- prepare_diff_expr_table(x = "results/DiffExpr_lympho_activation6h.csv")
  down_genes <- expr_diff %>% filter(padj < 0.05 & log2FoldChange < -1) %>%
    select(gene_id) %>%
    unlist
  RPKM_values <- extract_RPKM_from_table(cell = "Lympho", late = TRUE)
  merged_data <- merge_data(codon_table = df_codon_count, expr_table = RPKM_values)
  c <- pond_freq_on_random_list(selected_genes = down_genes, n_bootstrap = 10000, condition = "treated", ref_table = merged_data)
  make_plot(mean_FC_from_list(c))
  
  # freq codon in upregulated genes
  transcriptome <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = "transcriptome")
  down <- calc_freq_codon_in_transcriptome(data = df_codon_count, universe = down_genes)
  down <- down[names(down)%in%names(transcriptome)]
  make_plot(down/transcriptome)
}
