#!/usr/bin/Rscript

# AUTHOR: Charlie Barker 
#
# Uses DESEQ library to do differential expression analysis on expression data on cell shape data. Please note : 
# line 149 (cluster.of.interest <- "TWO") allows you to select which cell shape group you are calculating differential 
# expression from. The variable "col.data" gives an idea of how this relates with named cell lines.
#
# INPUT: 
#          MTAB_COUNTS.csv,         expression data. If not present in the directory specified by the "path" variable. 
#                                   it will download it for you.
#          pk_subtypes.csv          csv describing the prior knowledge for cell line subtypes 
#          shape_clusters.csv       csv describing our groups identified from cell morphology along 
#
# OUTPUT:   
#	 
#   hetero_DEGs.txt             file describing significantly differntially expressed genes . 


library(gplots)
library(org.Hs.eg.db)
require(plyr)
library(plotly)
library(DESeq2)

suppressMessages( library( ExpressionAtlas ) )

path<-"/home/charlie/phenotype_networks/data/expression"
setwd(path)

##GET DATA 

#turns out you can t use FPKM with edge r, so im getting the actual interger count rather than relying on Eirini's preprocessing and filtering. 
count.file<-"MTAB_COUNTS.csv"
if (file.exists(count.file)) {
  all.mtab<-read.csv(count.file)
} else {
  datasets<-c("E-MTAB-2706", "E-MTAB-2770")
  allExps <- getAtlasData(datasets)
  E.MTAB.2706<-allExps[1]$`E-MTAB-2706`$rnaseq
  E.MTAB.2770<-allExps[2]$`E-MTAB-2770`$rnaseq
  counts.2706<-data.frame(assays(E.MTAB.2706)$counts)
  counts.2770<-data.frame(assays(E.MTAB.2770)$counts)
  
  counts.2770$ID <- rownames(counts.2770)
  counts.2706$ID <- rownames(counts.2706)
  
  colnames(counts.2706)<-E.MTAB.2706$cell_line
  colnames(counts.2770)<-E.MTAB.2770$cell_line
  #merge 
  all.mtab<-data.frame(cbind.data.frame(counts.2706,counts.2770))
  #write.csv(all.mtab, file = "./MTAB_COUNTS.csv")
}

## SELECTED BREAST CANCER CELL LINES:

BT474 <- all.mtab$BT.474
CAMA1 <- all.mtab$CAMA.1.1 
T47D <- all.mtab$T.47D.1
ZR75.1 <- all.mtab$ZR.75.1.1
# MCDMB453 missing
SKBR3 <- all.mtab$SK.BR.3.1
MCF7 <- all.mtab$MCF.7 #and here 
HCC1143 <- all.mtab$HCC1143.1 #two here  
HCC1954 <- all.mtab$HCC1954.1 #also two of these? 
HCC70 <- all.mtab$HCC70.1 #and two of these?
hs578T <- all.mtab$Hs.578T.1
JIMT1 <- all.mtab$JIMT.1.1
MCF10A <- all.mtab$MCF.10A
# MCF12A missing
MDAMB157 <- all.mtab$MDA.MB.157.1
MDAMB231 <- all.mtab$MDA.MB.231
# SUM149 missing
# SUM159 missing

count.data<-data.frame(BT474,CAMA1,T47D,
                       ZR75.1,SKBR3,MCF7,
                       HCC1143,HCC1954,HCC70,
                       hs578T,JIMT1,MCF10A,
                       MDAMB157,MDAMB231)
row.names(count.data)<-all.mtab$X
count.data$GeneID<-all.mtab$X
# Data contains many genes that are not protein encoding. I extract the gene ids and input them into
# ensembl biomart: Dataset: Ensembl Genes 89, Human genes (GRCh38.p10)
#                  Filters: None Selected
#                  Attributes: Gene stable ID, Gene type
# output saved in mart_export.csv

#### GeneXData <- read_csv("~/GeneXData.csv",col_types = cols(X1 = col_skip()))
library(readr)
mart_export <- read_csv("../geneID-type.csv")
colnames(mart_export) <- c("GeneID","GeneType")
row.names(mart_export) <- mart_export$GeneID
GeneXDType <- merge(count.data,mart_export, by = "GeneID")

count.data <- subset(GeneXDType, GeneType == "protein_coding")



##PREP RNASEQ DATA

#my analysis 
#seq.data <- read.csv("../data/GeneXDataNorm.csv", stringsAsFactors = FALSE)
#row.names(seq.data)<-seq.data$GeneID
#seq.data$X<-NULL
#get counts per million 
row.names(count.data)<-count.data$GeneID
ID.DF<-data.frame(count.data$GeneID,count.data$GeneType)
count.data$GeneID<-NULL
count.data$GeneType<-NULL
library(edgeR)
count.data.cpm<-cpm(count.data)
#id genes with at least 0.5 cpm in at least 4 samples and keep them. 
thresh <- count.data.cpm > 0.5 #5  
keep <- rowSums(thresh) >= 8 #2 #if the number of counts hhas to be over 0.5 for at least x cell lines for a gene to be kept.
dim(count.data)
dim(count.data[keep,])
count.data <- count.data[keep,] 

#removes a total of 45970 genes 
#convert to edger object 

#count data is now good enough to use in DESEQ2

##DESIGN MATRIX 

#obtain sample information from PMID: 29158785 and Eirini's stuff 
various.cluster <- read.csv("../phenotype_features/pk_subtypes.csv",header = TRUE,stringsAsFactors = FALSE)
shape.cluster <- read.csv("../phenotype_features/shape_clusters.csv",header = TRUE,stringsAsFactors = FALSE)
shape.cluster[shape.cluster$Cell.Line == "ZR751",1] <- "ZR75.1"
row.names(shape.cluster)<-shape.cluster$Cell.Line
shape.cluster<-shape.cluster[match(various.cluster$Cell.Line, shape.cluster$X),]
col.data<-data.frame(various.cluster$Authors.shape.cluster, various.cluster$my.shape.cluster, shape.cluster$k2.cluster)
row.names(col.data)<-various.cluster$Cell.Line
col.data<-col.data[colnames(count.data),]
colnames(col.data)<-c("a", "b", "c")
col.data[col.data$b == 2,2]<-"TWO"
col.data[col.data$b == 1,2]<-"ONE"
col.data[col.data$c == 3,3]<-"THREE"
col.data[col.data$c == 2,3]<-"TWO"
col.data[col.data$c == 1,3]<-"ONE"



#make sure everythings in the right order 
all(colnames(count.data) == rownames(col.data))
#choose which cluster you want to study, and turn the rest into control to compare to
#one basal
#three luminal
#two hetero
cluster.of.interest <- "TWO"
#set the other clusters to "control"
col.data[col.data$c != cluster.of.interest,3]<-"control"

#make DESEQ2 
#for design, put a b c d etc depending on which groups you want to test 
deseq2Data <- DESeqDataSetFromMatrix(countData=count.data, colData=col.data, design= ~c + a)

## SET UP MULTI CORES  

# Register the number of cores to use
library(BiocParallel)
register(MulticoreParam(4))

## RUN 

# Run pipeline for differential expression steps (if you set up parallel processing, set parallel = TRUE here)

deseq2Data <- DESeq(deseq2Data,parallel = TRUE)
# Extract differential expression results
# For "tissueType" perform primary vs normal comparison
deseq2Results <- results(deseq2Data, contrast=c("c", cluster.of.interest, "control"))
summary(deseq2Results)
# Using DEseq2 built in method
plotMA(deseq2Results)

# Load libraries
# install.packages(c("ggplot2", "scales", "viridis"))
library(ggplot2)
library(scales) # needed for oob parameter
library(viridis)

# Coerce to a data frame
deseq2ResDF <- as.data.frame(deseq2Results)
# Examine this data frame
#write results 
#write.csv(deseq2ResDF, "../differential_exp/hetero_DEGs.txt")
# Set a boolean column for significance
deseq2ResDF$significant <- ifelse(deseq2ResDF$padj < .1, "Significant", NA)

# Plot the results similar to DEseq2
ggplot(deseq2ResDF, aes(baseMean, log2FoldChange, colour=significant)) + geom_point(size=1) + scale_y_continuous(limits=c(-10, 10), oob=squish) + scale_x_log10() + geom_hline(yintercept = 0, colour="tomato1", size=2) + labs(x="mean of normalized counts", y="log fold change") + scale_colour_manual(name="q-value", values=("Significant"="red"), na.value="grey50") + theme_bw()

# Let's add some more detail
plot.1<-ggplot(deseq2ResDF, aes(baseMean, log2FoldChange, colour=padj)) + geom_point(size=1.5) + scale_y_continuous(limits=c(-10, 10), oob=squish) + theme(panel.background = element_blank()) + scale_x_log10() + geom_hline(yintercept = 0, colour="darkorchid4", size=1, linetype="longdash") + labs(x="mean of normalized counts", y="log fold change") + scale_colour_viridis(direction=-1, trans='sqrt') + theme_bw() + geom_density_2d(colour="black", size=2)
plot.1 + theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
               panel.background = element_blank(), axis.line = element_line(colour = "black"))

#get significant genes
sig.gene.IDs<-deseq2ResDF[deseq2ResDF$significant == "Significant" & !is.na(deseq2ResDF$significant),]
head(deseq2ResDF[order(deseq2ResDF$padj),])



#my_rplotly is a function to plot the interactive Volcano Plots
my_rplotly <- function(deg_filename)
{
  degs <- read_csv(deg_filename)

  log2FoldChange=data.frame(degs$log2FoldChange)
  AdjPvalue=data.frame(degs$padj)
  
  
  
  colors_ = c(rep("adj_pval<0.05",dim(degs)[1]))
  colors_[which(abs(degs$log2FoldChange) >2)] = "|LogFC|>2"
  colors_[which(degs$padj >.05)] = "adj_pval>0.05"
  
  
  p <- plot_ly(x=~log2FoldChange[,], y=~-log10(AdjPvalue)[,], type = 'scatter', mode='markers', color = ~colors_,
               text=paste("Gene:",degs$X1,"\nLogFC:",formatC(degs$log2FoldChange,digits=2),"\nPvalue:",formatC(degs$pvalue,digits=2),
                          "\nAdj.Pvalue:",formatC(degs$padj,digits=2), "\nBasemean:",round(degs$baseMean) ))
  htmlwidgets::saveWidget(as_widget(p), paste(deg_filename,"_Volcano_padj.html")) #Save as html
}

#file <- "./deseqResults/luminal_results.txt"
#library(org.Hs.eg.db)
#ens2symbol <- AnnotationDbi::select(org.Hs.eg.db,
#                                    key=rownames(deseq2Results), 
#                                    columns="SYMBOL",
#                                    keytype="ENSEMBL")
#get symbol 
#ens2symbol <- as_tibble(ens2symbol)
#ens2symbol <- ens2symbol[ens2symbol$ENSEMBL %in% rownames(deseq2Results),]
#ens2symbol<-ens2symbol[!duplicated(ens2symbol$ENSEMBL),]
#library(EnhancedVolcano)
#EnhancedVolcano(deseq2Results,
#                lab = ens2symbol$SYMBOL,
#                x = 'log2FoldChange',
#                y = 'pvalue')
