options(stringsAsFactors = FALSE)
library('caret')
library('gbm')
library('AUC')

####################
# 2018-09-12
# New caret version has different accessors
# need to correct column access

# my.tissue <- "Liver"
# my.gbm.sampled.fits <- my.liver.gbm.samp.fit
# my.testing <- my.liver.features.v2[-my.liver.training.idx,]

### Sampled models
get_gbm_metrics_and_summarize_withSAMP <- function(my.tissue, my.gbm.sampled.fits, my.testing) {
  
  my.gbm.preds <- predict_sampled(my.gbm.sampled.fits, my.testing)
  
  my.confus.mat <- confusionMatrix(my.gbm.preds,my.testing$age_change) # prediction, then ref
  
  my.bal.acc <- my.confus.mat$byClass[,"Balanced Accuracy"]
  my.sens <- my.confus.mat$byClass[,"Sensitivity"]
  my.spe <- my.confus.mat$byClass[,"Specificity"]
  my.prec <- my.confus.mat$byClass[,"Pos Pred Value"]
  
  my.data.list <- list(my.bal.acc,my.sens,my.spe)
  my.names <- c("Balanced Accuracy","Sensitivity","Specificity")
  
  my.bar.pdfname = paste(Sys.Date(),my.tissue,"Barplots_of_accuracy_measures_perClass_withSampling.pdf", sep="_")
  
  pdf(my.bar.pdfname)
  par(mfrow = c(1,3))
  par(cex = 0.6)
  par(mar = c(0, 0, 0, 0), oma = c(10, 4, 4, 1))
  par(tcl = -0.25)
  par(mgp = c(2, 0.6, 0))
  
  for (i in 1:3) {
    
    bla <- barplot(my.bal.acc,
                   ylab = "Balanced accuracy", 
                   xlim=c(0.45,1.05), width= 2,
                   col= c("grey", "firebrick","dodgerblue3"),
                   horiz = TRUE,
                   axes = F, axisnames = F)
    abline(v = 0.5, col = "red", lty = "dashed")
    box()
    
    mtext(my.names[i], side = 3, line = -1, adj = 0.1, cex = 0.6,col = "black")
    
    axis(1, col = "black", col.axis = "black", at = seq(0.5,1,0.1))
    
    if (i == 1)
      axis(2, labels = c("Unchanged with aging","Down with aging","Up with aging"), las=2, at = bla)
    box(col = "black")
  }
  mtext("Class Label", side = 2, outer = TRUE, line = 2.2,col = "black")
  mtext("Value", side = 1, outer = TRUE, line = 2.2,col = "black")
  par(mfrow = c(1,1))
  dev.off()
  
  
  my.varimps <- get_varImps(my.gbm.sampled.fits)
  
  # Decrease in Accuracies plot
  my.feats.pdfname = paste(Sys.Date(),my.tissue,"features_importance_GBM_over_samples.pdf", sep="_")
  
  #### plot all features, mark 0 and 1% change, sort
  
  pdf(my.feats.pdfname, width=10, height = 10)
  par(oma=c(0.5,5,0.5,0.5))
  
  ## 1. mean decrease in accuracy
  my.1.sort <- sort(as.numeric(lapply(my.varimps,median)), index.return=T, decreasing = F)
  boxplot(my.varimps[my.1.sort$ix], las = 2, 
          horizontal=T, outline=F, xlab = "Relative variable importance (%)", ylim = c(0, 1.3*max(as.numeric(lapply(my.varimps,median)))),
          main = paste(my.tissue,"All features"), col = "gold")
  abline(v = 0, col = "grey", lty="dashed")
  
  dev.off()
  
  # write importances to output file for future handling
  write.table(sort(unlist(lapply(my.varimps,median)), decreasing = T), 
              file= paste(Sys.Date(),my.tissue,"median_features_importance_classification_over_samples.txt", sep="_"))
  
  my.gbm.stats <- data.frame(balacc= my.bal.acc, 
                             specificity = my.spe,
                             sensitivity = my.sens,
                             precision = my.prec)
  
  my.gbm.stats[4,] <- apply(my.gbm.stats,2,mean)
  rownames(my.gbm.stats)[4] <- "ModelAverage"
  
  write.table(my.gbm.stats,
              file= paste(Sys.Date(),my.tissue,"GBM_metrics_classification_withSampling.txt", sep="_"))
  
}


########
# my.tissue <- "Liver"
# my.gbm.fit <- my.liver.gbm.noCST.fit
# my.testing <- my.no.cst.testing

get_gbm_metrics_and_summarize_noCST <- function(my.tissue, my.gbm.fit, my.testing) {
  
  my.gbm.preds <- predict(my.gbm.fit, my.testing)
  my.confus.mat <- confusionMatrix(my.gbm.preds,my.testing$age_change) # prediction, then ref
  
  my.bal.acc <- my.confus.mat$byClass["Balanced Accuracy"]
  my.sens <- my.confus.mat$byClass["Sensitivity"]
  my.spe <- my.confus.mat$byClass["Specificity"]
  my.prec <- my.confus.mat$byClass["Precision"]
  
  my.imp <- as.matrix(varImp(my.gbm.fit,scale = FALSE)$importance)
  names(my.imp) <- rownames(my.imp)
  # get ROC for GBM prediction (2-class)
  my.roc <- get_roc(my.gbm.fit, my.testing)
  my.auc <-  auc(my.roc)
  
  my.filename <- paste(Sys.Date(),my.tissue,"ROC_GBM_noCST_classification.pdf",sep="_")
  
  pdf(my.filename)
  plot(my.roc$fpr,my.roc$tpr,type='l',col="firebrick", xlab = "1-specificity (False positive rate)", ylab = "sensitivity (True positive rate)")
  abline(0,1,col="grey",lty='dashed')
  text(0.8,0.2,paste("AUC = ",round(my.auc,3),sep=""))
  dev.off()
  
  
  # Decrease in Accuracies plot
  my.feats.pdfname = paste(Sys.Date(),my.tissue,"features_importance_GBM_noCST.pdf", sep="_")
  
  #### plot all features, mark 0 and 1% change, sort
  
  pdf(my.feats.pdfname, width=6, height = 8)
  par(oma=c(0.5,7,0.5,0.5))
  
  ## 1. mean decrease in accuracy
  my.1.sort <- sort(as.numeric(my.imp), index.return=T, decreasing = F)
  
  if(length(my.imp) < 30) {
    barplot(my.imp[my.1.sort$ix], horiz=TRUE, las=2,
            xlab = "Relative variable importance (%)", xlim = c(0, 1.3*max(my.imp)),
            main = paste(my.tissue,"All features"), col = "gold")
    abline(v = 0.01, col = "red", lty="dashed")
    abline(v = 0, col = "grey", lty="dashed")
    box()
  } else {
    barplot(my.imp[my.1.sort$ix][(length(my.imp)-30):length(my.imp)], horiz=TRUE, las=2,
            xlab = "Relative variable importance (%)", xlim = c(0, 1.3*max(my.imp)),
            main = paste(my.tissue,"All features"), col = "gold")
    #abline(v = 0.01, col = "red", lty="dashed")
    abline(v = 0, col = "grey", lty="dashed")
    box()
    
    
  }
  
  dev.off()
  
  # write importances to output file for future handling
  write.table(sort(my.imp, decreasing = T), 
              file= paste(Sys.Date(),my.tissue,"GBM_features_importance_classification_noCST.txt", sep="_"))
  
  write.table(data.frame(balacc= my.bal.acc, 
                         AUC= my.auc,
                         specificity = my.spe,
                         sensitivity = my.sens,
                             precision = my.prec), 
              row.names = F,
              file= paste(Sys.Date(),my.tissue,"GBM_metrics_classification_noCST.txt", sep="_"))
  
}



#######################################################################################################################################

# my.tissue <- "Liver"
# my.liver.data <- my.liver.features.v2
# my.heart.data <- my.heart.features.v2
# my.cereb.data <- my.cereb.features.v2
# my.ob.data <- my.ob.features.v3
# my.gbm.sampled.fits <- my.liver.gbm.samp.fit

CT_acc_withSAMP <- function(my.tissue, my.gbm.sampled.fits, my.liver.data, my.heart.data, my.cereb.data, my.ob.data) {
  
  my.results <- data.frame (accuracy = rep(0,4),
                            specificity = rep(0,4),
                            sensitivity = rep(0,4)
  )
  rownames(my.results) <- c("Heart","Liver","Cerebellum","Olfactory_Bulb")
  
  my.gbm.preds.liver <- predict_sampled(my.gbm.sampled.fits, my.liver.data)
  my.gbm.preds.heart <- predict_sampled(my.gbm.sampled.fits, my.heart.data)
  my.gbm.preds.cereb <- predict_sampled(my.gbm.sampled.fits, my.cereb.data)
  my.gbm.preds.ob <- predict_sampled(my.gbm.sampled.fits, my.ob.data)
  
  my.confus.mat.liver <- confusionMatrix(my.gbm.preds.liver, my.liver.data$age_change) # prediction, then ref
  my.confus.mat.heart <- confusionMatrix(my.gbm.preds.heart, my.heart.data$age_change) # prediction, then ref
  my.confus.mat.cereb <- confusionMatrix(my.gbm.preds.cereb, my.cereb.data$age_change) # prediction, then ref
  my.confus.mat.ob   <-  confusionMatrix(my.gbm.preds.ob,    my.ob.data$age_change) # prediction, then ref
  
  
  my.results["Heart",] <- c(mean(my.confus.mat.heart$byClass[,11]), mean(my.confus.mat.heart$byClass[,2]), mean(my.confus.mat.heart$byClass[,1]))
  my.results["Liver",] <- c(mean(my.confus.mat.liver$byClass[,11]), mean(my.confus.mat.liver$byClass[,2]), mean(my.confus.mat.liver$byClass[,1]))
  my.results["Cerebellum",] <- c(mean(my.confus.mat.cereb$byClass[,11]), mean(my.confus.mat.cereb$byClass[,2]), mean(my.confus.mat.cereb$byClass[,1]))
  my.results["Olfactory_Bulb",] <- c(mean(my.confus.mat.ob$byClass[,11]), mean(my.confus.mat.ob$byClass[,2]), mean(my.confus.mat.ob$byClass[,1]))
  
  my.cols <- rep("tomato",4)
  my.cols[rownames(my.results) %in% my.tissue ] <- "darkgrey"
  
  
  my.pdfname <- paste(Sys.Date(),my.tissue,"chromatin_model_withSampling_cross_tissue_accuracy.pdf",sep="_")
  
  pdf(my.pdfname)
  par(mar = c(0, 0, 0, 0), oma = c(10, 4, 4, 1))
  par(tcl = -0.25)
  par(mgp = c(2, 0.6, 0))
  barplot(as.matrix(my.results), las=2, col = my.cols, beside=T, 
          ylim = c(0.45,1.05), axes = FALSE,
          ylab = "Measure")
  axis(2, col = "black", col.axis = "black", at = seq(0.5,1,0.1))
  abline(h = 0.5, col = "red", lty="dashed")
  box()
  legend("topright", c("Cross Tissue", "On Training tissue"), fill=c("tomato","darkgrey"), bty="n")
  dev.off()
  
  write.table(my.results,
              file= paste(Sys.Date(),my.tissue,"chromatin_model_cross_tissue_AVERAGE_GBM_metrics_classification_withSampling.txt", sep="_"))
  
}


#####
CT_acc_noCST <- function(my.tissue, my.gbm.fits, my.liver.data, my.heart.data, my.cereb.data, my.ob.data) {
  
  my.results <- data.frame (accuracy = rep(0,4),
                            specificity = rep(0,4),
                            sensitivity = rep(0,4)
  )
  rownames(my.results) <- c("Heart","Liver","Cerebellum","Olfactory_Bulb")
  
  my.gbm.preds.liver <- predict(my.gbm.fits, my.liver.data)
  my.gbm.preds.heart <- predict(my.gbm.fits, my.heart.data)
  my.gbm.preds.cereb <- predict(my.gbm.fits, my.cereb.data)
  my.gbm.preds.ob <- predict(my.gbm.fits, my.ob.data)
  
  my.confus.mat.liver <- confusionMatrix(my.gbm.preds.liver, my.liver.data$age_change) # prediction, then ref
  my.confus.mat.heart <- confusionMatrix(my.gbm.preds.heart, my.heart.data$age_change) # prediction, then ref
  my.confus.mat.cereb <- confusionMatrix(my.gbm.preds.cereb, my.cereb.data$age_change) # prediction, then ref
  my.confus.mat.ob   <-  confusionMatrix(my.gbm.preds.ob,    my.ob.data$age_change) # prediction, then ref
  
  
  my.results["Heart",] <- c(my.confus.mat.heart$byClass[11], my.confus.mat.heart$byClass[2], my.confus.mat.heart$byClass[1])
  my.results["Liver",] <- c(my.confus.mat.liver$byClass[11], my.confus.mat.liver$byClass[2], my.confus.mat.liver$byClass[1])
  my.results["Cerebellum",] <- c(my.confus.mat.cereb$byClass[11], my.confus.mat.cereb$byClass[2], my.confus.mat.cereb$byClass[1])
  my.results["Olfactory_Bulb",] <- c(my.confus.mat.ob$byClass[11], my.confus.mat.ob$byClass[2], my.confus.mat.ob$byClass[1])
  
  my.cols <- rep("tomato",4)
  my.cols[rownames(my.results) %in% my.tissue ] <- "darkgrey"
  
  
  my.pdfname <- paste(Sys.Date(),my.tissue,"chromatin_model_noCST_cross_tissue_accuracy.pdf",sep="_")
  
  pdf(my.pdfname)
  par(mar = c(0, 0, 0, 0), oma = c(10, 4, 4, 1))
  par(tcl = -0.25)
  par(mgp = c(2, 0.6, 0))
  barplot(as.matrix(my.results), las=2, col = my.cols, beside=T, 
          ylim = c(0.45,1.05), axes = FALSE,
          ylab = "Measure")
  axis(2, col = "black", col.axis = "black", at = seq(0.5,1,0.1))
  abline(h = 0.5, col = "red", lty="dashed")
  box()
  legend("topright", c("Cross Tissue", "On Training tissue"), fill=c("tomato","darkgrey"), bty="n")
  dev.off()
  
  
  write.table(my.results,
              file= paste(Sys.Date(),my.tissue,"chromatin_model_cross_tissue_GBM_metrics_classification_noCST.txt", sep="_"))
  
}
#######################################################################################################################################
### get classification accuracy measures

# get balanced accuracy
# use caret predict to avoid issues, 2016-09-20
get_bal_acc <- function(my.gbm.fit, my.testing) {
  my.gbm.preds <- predict.train(my.gbm.fit, my.testing)
  my.confus.mat <- confusionMatrix(my.gbm.preds,my.testing$age_change) # prediction, then ref
  my.balanced.acc <- mean(my.confus.mat$byClass[,8])
  
  return(my.balanced.acc)
}

get_perclass_acc <- function(my.gbm.fit,my.testing) {
  my.gbm.preds <- predict.train(my.gbm.fit, my.testing)
  my.confus.mat <- confusionMatrix(my.gbm.preds,my.testing$age_change) # prediction, then ref
  
  return(my.confus.mat$byClass[,8])
}

get_perclass_spe <- function(my.gbm.fit,my.testing) {
  my.gbm.preds <- predict.train(my.gbm.fit, my.testing)
  my.confus.mat <- confusionMatrix(my.gbm.preds,my.testing$age_change) # prediction, then ref
  
  return(my.confus.mat$byClass[,2])
}

get_perclass_sens <- function(my.gbm.fit,my.testing) {
  my.gbm.preds <- predict.train(my.gbm.fit, my.testing)
  my.confus.mat <- confusionMatrix(my.gbm.preds,my.testing$age_change) # prediction, then ref
  
  return(my.confus.mat$byClass[,1])
}

### for non sampled:
remove_cst <- function(my.testing) {
  my.training2 <- data.frame(my.testing[my.testing$age_change != 'CONSTANT',])
  my.training2$age_change <- factor(my.training2$age_change)
  return(my.training2)
}

# get ROC for GBM prediction
get_roc <- function(my.gbm.fit,my.testing) {
  my.testing.preds.probs <- predict(my.gbm.fit,my.testing,type="prob")
  my.test.labels <- rep(0,dim(my.testing.preds.probs)[1])
  my.test.labels[my.testing$age_change %in% "DOWN"] <- 1
  my.roc <- roc(as.numeric(my.testing.preds.probs[,1]), as.factor(my.test.labels))
  
  return(my.roc)
  
}

####### DEBUG ####### 
# my.gbm.sampled.fits <- my.Liver.gbm.samp.fit.RNA
# my.testing <- my.liver.features.v2[-my.liver.training.idx,]
####### DEBUG ####### 
predict_sampled <- function(my.gbm.sampled.fits,my.testing) {
  
  my.samples <- length(my.gbm.sampled.fits)
  my.predictions <- matrix('CONSTANT',dim(my.testing)[1],my.samples)
  
  for (i in 1:my.samples) {
    # predict in each model
    my.predictions[,i] <- as.character(predict.train(my.gbm.sampled.fits[[i]], my.testing))
    
  }
  
  my.up <-  apply(my.predictions,1, function(i) { sum(i %in% "UP")})
  my.dwn <-  apply(my.predictions,1, function(i) { sum(i %in% "DOWN")})
  
  my.final.predictions <- rep('CONSTANT',dim(my.testing)[1])
  my.final.predictions[my.up > 0.75*my.samples] <- "UP"
  my.final.predictions[my.dwn > 0.75*my.samples] <- "DOWN"
  my.final.predictions <- factor(my.final.predictions)
  
  return(my.final.predictions)
}

# my.confus.mat <- confusionMatrix(my.final.predictions,my.testing$age_change) # prediction, then ref
# my.balanced.acc <- mean(my.confus.mat$byClass[,8])

#####
get_varImps <- function(my.gbm.sampled.fits) {
  
  my.samples <- length(my.gbm.sampled.fits)
  
  my.n.features <- length(varImp(my.gbm.sampled.fits[[1]])$importance$Overall)
  my.importance <- vector(mode='list',length=my.n.features)
  names(my.importance) = rownames(as.matrix(varImp(my.gbm.sampled.fits[[1]])$importance))
  
  
  for (i in 1:my.samples) {
    my.i.imp <- as.matrix(varImp(my.gbm.sampled.fits[[i]],scale = FALSE)$importance)
    for (j in 1:my.n.features) {
      my.importance[[j]][i] <- my.i.imp[j,]
    }
    
  }
  
  return(my.importance)
}

####
#varImp(my.liver.gbm.noCST.fit)$importance
