setwd("~/brainmap/seurat")
library(Seurat)
library(squash)
library(dplyr)
library(clusterProfiler)
library(pheatmap)
library(viridis)
library(vioplot)

load("integrated.nomito.RData")
rlist.all <- rlist

regions <- c("acc","cn","cer")
species <- c("H","C","B","M")

load("anova.v2.Rdata")

rlist <- list()
for (s in species){
  rlist[[s]] <- list()
}
for (r in regions){
  for (s in species){
    sel = names(rlist.all[[r]]@active.ident)[rlist.all[[r]]$orig.ident==s]
    print(length(sel))
    rlist[[s]][[r]] <- SubsetData(rlist.all[[r]], cells = sel)
    # remove genes that do not survive ANOVA
    rlist[[s]][[r]]@assays$RNA@data <- rlist[[s]][[r]]@assays$RNA@data[pvals[[r]]<0.05,]
  }
}
rm(rlist.all)

load("ct.ids.nomito.RData")
load("cells.nomito.RData")

ncells <- lapply(cells, function (x) lapply(x, function (y) lapply(y,length)))
ncells <- min(unlist(ncells))
ncells

##### correlational heatmaps (in inferno colors) #####

ncells
balance <- ncells

genes <- list()
for (r in regions){
  print(r)
  genes[[r]] <- list()
  # bootstrapping cells
  for (i in names(cells[["H"]][[r]])){
    print(i)
    genes[[r]][[i]] <- list(H=0,HCBM=0)
    for (b in 1:1000){
      bulk <- list()
      for (s in species){
        cells.i <- cells[[s]][[r]][[i]]
        boot <- sample(cells.i,balance,replace=F) # if using balance>25, set replace=T
        x <- Matrix::rowMeans(rlist[[s]][[r]]@assays$RNA@data[,c(boot,boot)])
        bulk[[s]] <- x
      }
      bulk.HM <- abs(bulk[["H"]]-bulk[["M"]]) 
      bulk.CM <- abs(bulk[["C"]]-bulk[["M"]]) 
      bulk.BM <- abs(bulk[["B"]]-bulk[["M"]])

      genes[[r]][[i]][["H"]] <- genes[[r]][[i]][["H"]] + bulk[["H"]]
      genes[[r]][[i]][["HCBM"]] <- genes[[r]][[i]][["HCBM"]] + ( bulk.HM > 2*bulk.BM & bulk.HM > 2*bulk.CM & (bulk.HM>0|bulk.CM>0) & (bulk.HM>0|bulk.BM>0) )
    }
  }
}

save(genes,file="nomito.genes.RData")
load("nomito.genes.RData")

hcol <- rev(inferno(90))
ct.names <- unlist(ct.names)

heatmaps <- list()
overlaps <- list()
for (type in c("H","HCBM")){
  names <- names(unlist(lapply(genes, function (x) names(x))))
  overlap <- matrix(NA,length(names),length(names))
  dimnames(overlap)[[1]] <- names
  dimnames(overlap)[[2]] <- names
  for (r1 in regions){
    for (cl1 in names(genes[[r1]])){
      for (r2 in regions){
        for(cl2 in names(genes[[r2]])){
          if(r1!=r2 | cl1!=cl2){
            t1 <- genes[[r1]][[cl1]][[type]]
            t2 <- genes[[r2]][[cl2]][[type]]
            
            t1 <- t1[names(t1) %in% names(t2)]
            t2 <- t2[names(t1)]
            
            x1 <- as.vector(t1)
            x2 <- as.vector(t2)
            
            corr <- cor(x1[is.finite(x1)],x2[is.finite(x1)],use="pairwise.complete.obs")
            overlap[paste0(r1,as.numeric(cl1)+1),paste0(r2,as.numeric(cl2)+1)] <- corr
          }
        }
      }
    }
  }
  dimnames(overlap)[[1]] <- ct.names
  dimnames(overlap)[[2]] <- ct.names
  ord <- c("AC In 1","CN Pur 3","CB In 1","AC Ex 0","CN Pur 0","CB Gr 0",
           "AC Ast 2","CN Ast 2","CB Ast 2","AC OD 3","CN OD 1","CB OD 3",
           "AC OPC 4","CN OPC 4","AC MG 5","CN MG 5")
  overlap <- overlap[ord,rev(ord)]
  overlaps[[type]] <- overlap
  
  heatmaps[[type]] <- pheatmap(overlap,
                               color=hcol,
                               breaks=log10(seq(10^min(overlap,na.rm=T),10^max(overlap,na.rm=T),length.out=91)),
                               #breaks=10^seq(log10(min(overlap,na.rm=T)),log10(max(overlap,na.rm=T)),length.out=91),
                               #breaks=seq(min(overlap,na.rm=T),max(overlap,na.rm=T),length.out=91),
                               border_color=NA,
                               na_col=hcol[length(hcol)],
                               gaps_row=c(6,9,12,14),
                               gaps_col=c(2,4,7,10),
                               cluster_rows=F,
                               cluster_cols=F,
                               show_rownames=T,
                               silent=T,
                               main=type)
}

pdf("overlap.nomito.pdf",width=17,height=4,pointsize=24)
gridExtra::grid.arrange(grobs=lapply(heatmaps, function (x) x[[4]]),ncol=4)
dev.off()

pdf("overlap.nomito.bxp.pdf",width=14,height=9,pointsize=24)
par(mfrow=c(1,2),las=1)
for (type in c("H","HCBM")){
  overlap <- overlaps[[type]]
  within.glia <- c()
  cts <- c("Ast","OD","OPC","MG")
  for (ct in cts){
    x <- grep(ct,rownames(overlap))
    y <- grep(ct,colnames(overlap))
    within.glia <- c(within.glia, as.vector(overlap[x,y]))
    overlap[x,y] <- NA
  }
  cts <- c("In","Pur","Ex","Gr")
  x <- unlist(sapply(cts,function (x) grep(x,rownames(overlap))))
  y <- unlist(sapply(cts,function (x) grep(x,colnames(overlap))))
  within.neurons <- as.vector(overlap[x,y])
  overlap[x,y] <- NA
  between <- as.vector(overlap)
  
  l <- list(neurons=within.neurons, glia=within.glia, between=between)
  p.glia <- format(wilcox.test(within.glia,between)$p.value, digits = 2)
  fc.glia <- format(mean(within.glia,na.rm=T)/mean(between,na.rm=T), digits=3)
  p.neu <- format(wilcox.test(within.neurons,between)$p.value, digits = 2)
  fc.neu <- format(mean(within.neurons,na.rm=T)/mean(between,na.rm=T), digits=3)
  vioplot(l,names=names(l),ylab="Pearson's R",main=c(type,c(paste("P =",p.neu,"FC =",fc.neu), paste("P =",p.glia,"FC =",fc.glia))),ylim=c(0.2,1),col="lightblue")
}
dev.off()


##### Human-specificity in each cell cluster for human-specific genes from 33 regions #####
##### Subsampling cells

load("../bulk/fc.th.dynamic.RData")
bulk.genes <- sapply(names(hs.HB), function (x) names(hs.HB[[x]])[hs.HB[[x]]>log2(2) & hs.HC[[x]]>log2(2)] )
min(unlist(lapply(bulk.genes,length)))

o <- as.matrix(read.delim("order_EBnames.txt",header=F,row.names=1))[,1]
o

load("cells.nomito.RData")

balance <- 25

HS <- list()
for (br in o){
  HS[[br]] <- list()
  for (r in regions){
    HS[[br]][[r]] <- list()
  }
}
for (r in regions){
  print(r)
  # bootstrapping cells
  for (i in names(cells[["H"]][[r]])){
    print(i)
    HS[[br]][[r]][[i]] <- c()
    for (b in 1:1000){
      bulk <- list()
      for (s in species){
        cells.i <- cells[[s]][[r]][[i]]
        boot <- sample(cells.i,balance,replace=F) # if using balance>25, set replace=T
        x <- Matrix::rowMeans(rlist[[s]][[r]]@assays$RNA@data[,c(boot,boot)])
        bulk[[s]] <- x
      }
      bulk.HM <- abs(bulk[["H"]]-bulk[["M"]]) 
      bulk.CM <- abs(bulk[["C"]]-bulk[["M"]]) 
      bulk.BM <- abs(bulk[["B"]]-bulk[["M"]])
      
      nucl <- names(bulk.HM)[bulk.HM > 2*bulk.BM & bulk.HM > 2*bulk.CM & (bulk.HM>0|bulk.CM>0) & (bulk.HM>0|bulk.BM>0)]
      #nucl <- names(bulk.HM)[bulk.HM > bulk.BM & bulk.HM > bulk.CM & (bulk.HM>0|bulk.CM>0) & (bulk.HM>0|bulk.BM>0)]
      
      for (br in o){
        bulk <- bulk.genes[[br]]
        HS[[br]][[r]][[i]] <- c(HS[[br]][[r]][[i]], sum(nucl %in% bulk) / length(bulk))
      }
    }
  }
}

means <- c()
medians <- c()
ratios.allreg <- list()
for (br in o){
  ratios <- list()
  for (r in regions){
    ratios[[r]] <- HS[[br]][[r]]
  }
  ratios.allreg[[br]] <- list()
  ratios.allreg[[br]][["Neu"]] <- c(ratios[["acc"]][["0"]], ratios[["acc"]][["1"]], ratios[["cn"]][["0"]], ratios[["cn"]][["3"]], ratios[["cer"]][["0"]], ratios[["cer"]][["1"]])
  ratios.allreg[[br]][["Ast"]] <- c(ratios[["acc"]][["2"]], ratios[["cn"]][["2"]], ratios[["cer"]][["2"]])
  ratios.allreg[[br]][["OD"]] <- c(ratios[["acc"]][["3"]], ratios[["cn"]][["1"]], ratios[["cer"]][["3"]])
  ratios.allreg[[br]][["OPC"]] <- c(ratios[["acc"]][["4"]], ratios[["cn"]][["4"]])
  ratios.allreg[[br]][["MG"]] <- c(ratios[["acc"]][["5"]], ratios[["cn"]][["5"]])
  
  means <- rbind(means, unlist(lapply(ratios.allreg[[br]], mean)))
  medians <- rbind(medians, unlist(lapply(ratios.allreg[[br]], median)))
}
dimnames(means)[[1]] <- o
dimnames(means)[[2]] <- names(ratios.allreg[[br]])
dimnames(medians)[[1]] <- o
dimnames(medians)[[2]] <- names(ratios.allreg[[br]])

save(means,medians,file="hspec.HS.33regions.integrated.mergedCT.subsampling.nomito.RData")
#save(means,medians,file="hspec.HS.33regions.integrated.mergedCT.subsampling.th1.nomito.RData")

ratios.allreg <- lapply(ratios.allreg, function (x) x[["Neu"]])
#ratios.allreg <- sapply(rev(o), simplify=F, function (x) ratios.allreg[[x]]-mean(means[,"Neu"]))

hcol <- colorRampPalette(c("#1565C0","white","#C62828"))(94)
hcol <- hcol[c(1:45,50:94)]

#means <- apply(means,2, function (x) x-mean(x))

pdf(paste0("hspec.HS.33regions.integrated.mergedCT.subsampling.nomito.pdf"),width=5,height=6)
#pdf(paste0("hspec.HS.33regions.integrated.mergedCT.subsampling.th1.nomito.pdf"),width=5,height=6)
par(mar=c(2,0,0,13.3)+0.5, las=1)
plot(means[rev(o),"Neu"],1:33,pch=16,xlim=c(min(c(means[,"Neu"],medians[,"Neu"])), max(c(means[,"Neu"],medians[,"Neu"]))),yaxt="n")
points(medians[rev(o),"Neu"],1:33,pch=16,col="blue")
v <- mean(means[c("30 Cerebellar Grey Matter","32 Caudate","11 Cingulate Anterior (BA24)"),"Neu"])
v
abline(v=v,lty=2)
v <- mean(medians[c("30 Cerebellar Grey Matter","32 Caudate","11 Cingulate Anterior (BA24)"),"Neu"])
v
abline(v=v,lty=2,col="blue")
axis(4,at=1:33,labels = rev(o))
legend("bottomleft",c("mean","median"),col=c("black","blue"),pch=16,bty="n")
par(mar=c(2,0,0,13.3)+0.5, las=1)
boxplot(ratios.allreg,outline=F,range=0.5,at=33:1,horizontal=T,yaxt="n")
abline(v=v,lty=2,col="blue")
points(means[rev(o),"Neu"],1:33,pch=16,col="blue")
axis(4,at=1:33,labels = rev(o))
br <- seq(78,92,length.out=91)
pheatmap(means,
         color=hcol,
         # breaks=br,
         na_col="lightgray",
         cluster_rows=F,
         cluster_cols=F,
         show_rownames=T)
pheatmap(means[,c("Neu","Neu")],
         color = colorRampPalette(c("#1565C0","white","#C62828"))(90),
         # breaks=br,
         na_col="lightgray",
         cluster_rows=F,
         cluster_cols=F,
         show_rownames=T)
dev.off()

##########

##### Human-specificity in each cell cluster for human-specific genes from 33 regions #####

load("../bulk/fc.th.dynamic.RData")
bulk.genes <- sapply(names(hs.HB), function (x) names(hs.HB[[x]])[hs.HB[[x]]>log2(2) & hs.HC[[x]]>log2(2)] )
min(unlist(lapply(bulk.genes,length)))

o <- as.matrix(read.delim("order_EBnames.txt",header=F,row.names=1))[,1]
o

load("cells.nomito.RData")
load("nomito.genes.RData")

HS <- list()
for (br in o){
  HS[[br]] <- list()
}
for (r in regions){
  print(r)
  for (br in o){
    HS[[br]][[r]] <- list()
    for (i in names(cells[["H"]][[r]])){
      nucl <- names(genes[[r]][[i]][["HCBM"]])[genes[[r]][[i]][["HCBM"]]>500]
      print(length(nucl))
      # if(length(grep("33",br))==1){
      #   hist(genes[[r]][[i]][["HCBM"]],ylim=c(0,1000),main=c(r,i))
      # }
      bulk <- bulk.genes[[br]]
      HS[[br]][[r]][[i]] <- sum(nucl %in% bulk) / length(bulk)
    }
  }
}

ratios.allreg <- list()
means <- c()
for (br in o){
  ratios <- list()
  for (r in regions){
    ratios[[r]] <- HS[[br]][[r]]
  }
  ratios.allreg[[br]] <- list()
  ratios.allreg[[br]][["Neu"]] <- c(ratios[["acc"]][["0"]], ratios[["acc"]][["1"]], ratios[["cn"]][["0"]], ratios[["cn"]][["3"]], ratios[["cer"]][["0"]], ratios[["cer"]][["1"]])
  ratios.allreg[[br]][["Ast"]] <- c(ratios[["acc"]][["2"]], ratios[["cn"]][["2"]], ratios[["cer"]][["2"]])
  ratios.allreg[[br]][["OD"]] <- c(ratios[["acc"]][["3"]], ratios[["cn"]][["1"]], ratios[["cer"]][["3"]])
  ratios.allreg[[br]][["OPC"]] <- c(ratios[["acc"]][["4"]], ratios[["cn"]][["4"]])
  ratios.allreg[[br]][["MG"]] <- c(ratios[["acc"]][["5"]], ratios[["cn"]][["5"]])
  
  means <- rbind(means, unlist(lapply(ratios.allreg[[br]], mean)))
}
dimnames(means)[[1]] <- o
dimnames(means)[[2]] <- names(ratios.allreg[[br]])

ratios.allreg <- lapply(ratios.allreg, function (x) x[["Neu"]])

hcol <- colorRampPalette(c("#1565C0","white","#C62828"))(94)
hcol <- hcol[c(1:45,50:94)]

pdf(paste0("hspec.HS.33regions.integrated.mergedCT.500.nomito.pdf"),width=5,height=6)
par(mar=c(2,0,0,13.3)+0.5, las=1)
plot(means[rev(o),"Neu"],1:33,pch=16,yaxt="n")
v <- mean(means[c("30 Cerebellar Grey Matter","32 Caudate","11 Cingulate Anterior (BA24)"),"Neu"])
v
abline(v=v,lty=2)
axis(4,at=1:33,labels = rev(o))
pheatmap(means,
         color=hcol,
         # breaks=br,
         na_col="lightgray",
         cluster_rows=F,
         cluster_cols=F,
         show_rownames=T)
pheatmap(means[,c("Neu","Neu")],
         color = colorRampPalette(c("#1565C0","white","#C62828"))(90),
         # breaks=br,
         na_col="lightgray",
         cluster_rows=F,
         cluster_cols=F,
         show_rownames=T)
dev.off()

##########




##### Correlation with deconvolution #####

#deconv <- as.matrix(readRDS("../bulk/bulk.deconv.human.tree.auto.rds")$prop.est)
deconv <- as.matrix(readRDS("../bulk/bulk.deconv.human.new.rds")$prop.est)


annot <- as.matrix(read.delim("../bulk/batches_34groups_4ctx_EBnames.txt",header=F))
dimnames(annot)[[1]] <- paste0("X",sub("-",".",annot[,2]))
dim(annot)

deconv <- deconv[annot[rownames(deconv),8]!="",]
head(deconv)

annot <- annot[annot[,8]!="",]
annot <- cbind(annot, unlist(lapply(strsplit(annot[,8]," "), function (x) x[[1]])))

pdf("deconvolution_Artemiy_v2.pdf",width=5,height=20)
par(mfrow=c(7,1))
for (i in 1:6){
  boxplot(deconv[,i] ~ as.numeric(annot[rownames(deconv),10]), las=2, col="red", main=colnames(deconv)[i])
  stripchart(deconv[,i] ~ as.numeric(annot[rownames(deconv),10]), pch=21, cex=0.8, col="dimgray", bg=rgb(0.5,0.5,0.5,0.3), add=T, vertical=T)
}
deconv.means <- apply(deconv, 2, function (x) tapply(x, as.numeric(annot[rownames(deconv),10]), mean))
barplot(t(deconv.means), beside=F, las=2, col=c("#CE6928","#B02325","#208C43", "#9F5DA5","#F16EAA","#4A87C6"))
dev.off()

head(deconv.means)

col <- as.matrix(read.delim("order_EBnames.txt",header=F,row.names=1))[,2]

load("hspec.HS.33regions.integrated.mergedCT.subsampling.nomito.RData")

pdf("deconvolution_Artemiy_v2.corr.pdf",width=4.5,height=5)
plot(means[,"Neu"]*100, (deconv.means[,"Ex"]+deconv.means[,"In"])/2*100, pch=21, bg="gray", cex=2,las=1)
cor.test(means[,"Neu"], (deconv.means[,"Ex"]+deconv.means[,"In"])/2)
dev.off()

##########
