setwd("~/brainmap/seurat")
library(Seurat)
library(squash)
library(dplyr)
library(clusterProfiler)
library(pheatmap)
library(viridis)

load("integrated.srong.RData")
rlist.all <- rlist

regions <- c("acc","cn","cer")
species <- c("H","C","B","M")

rlist <- list()
for (s in species){
  rlist[[s]] <- list()
}
for (r in regions){
  for (s in species){
    sel = names(rlist.all[[r]]@active.ident)[rlist.all[[r]]$orig.ident==s]
    print(length(sel))
    rlist[[s]][[r]] <- SubsetData(rlist.all[[r]], cells = sel)
  }
}
rm(rlist.all)


cells <- list()
for (s in species){
  cells[[s]] <- list()
}
for (r in regions){
  for (s in species){
    print(s)
    cells[[s]][[r]] <- list()
    tab <- table(rlist[[s]][[r]]@active.ident)
    print(tab)
    for (i in names(tab)){
      cells.i <- names(rlist[[s]][[r]]@active.ident)[rlist[[s]][[r]]@active.ident==as.numeric(i)]
      cells[[s]][[r]][[as.character(i)]] <- cells.i
    }
  }
}

load("ct.ids.srong.RData")

ct.names <- list()
for (r in regions){
  ct.names[[r]] <- ct.ids[[r]][order(as.numeric(names(ct.ids[[r]])))]
}
ct.names[["acc"]] <- paste("AC", ct.names[["acc"]])
ct.names[["cn"]] <- paste("CN", ct.names[["cn"]])
ct.names[["cer"]] <- paste("CB", ct.names[["cer"]])

save(cells,ct.names,file="cells.srong.RData")

ncells <- lapply(cells, function (x) lapply(x, function (y) lapply(y,length)))
ncells
ncells <- min(unlist( ncells ))
ncells


pdf("non-zero.genes.srong.pdf",width=9,height=4)
mean.expr <- list()
for(r in regions){
  for (s in species){
    for (i in names(cells[["H"]][[r]])){
      cells.i <- cells[[s]][[r]][[i]]
      x <- Matrix::rowMeans(rlist[[s]][[r]]@assays$RNA@data[,cells.i])
      mean.expr[[paste(r,s,i)]] <- x
      print(c(min(x[x>0]), paste(r,s,i), sum(x>0)))
    }
  }
}
vioplot(mean.expr,las=2,names=names(mean.expr),col=c( rep(c("#CE6928","#B02325","#208C43", "#9F5DA5","#F16EAA","#4A87C6"),times=4),
                                                      rep(c("#CE6928","#9F5DA5","#208C43", "#B02325","#F16EAA","#4A87C6"),times=4),
                                                      rep(c("#CE6928","#B02325","#208C43", "#9F5DA5"),times=4) ),
        ylab="Mean expression of a gene among cells in cluster")
nonzero <- unlist(lapply(mean.expr, function (x) sum(x>0)/length(x)))
barplot(nonzero,las=2,names.arg=names(mean.expr),col=c( rep(c("#CE6928","#B02325","#208C43", "#9F5DA5","#F16EAA","#4A87C6"),times=4),
                                                        rep(c("#CE6928","#9F5DA5","#208C43", "#B02325","#F16EAA","#4A87C6"),times=4),
                                                        rep(c("#CE6928","#B02325","#208C43", "#9F5DA5"),times=4) ),
        ylab="Proportion of genes with average expression > 0")

mean.expr.perCluster <- list()
for(r in regions){
  for (i in names(cells[["H"]][[r]])){
    x <- c()
    for (s in species){
      x <- c(x, mean.expr[[paste(r,s,i)]])
    }
    mean.expr.perCluster[[paste(r,i)]] <- x
  }
}
vioplot(mean.expr.perCluster,las=2,names=names(mean.expr.perCluster),col=c( rep(c("#CE6928","#B02325","#208C43", "#9F5DA5","#F16EAA","#4A87C6"),times=1),
                                                                            rep(c("#CE6928","#9F5DA5","#208C43", "#B02325","#F16EAA","#4A87C6"),times=1),
                                                                            rep(c("#CE6928","#B02325","#208C43", "#9F5DA5"),times=1) ))
nonzero.perCluster <- unlist(lapply(mean.expr.perCluster, function (x) sum(x>0)/length(x)))
barplot(nonzero.perCluster,las=2,names.arg=names(mean.expr.perCluster),col=c( rep(c("#CE6928","#B02325","#208C43", "#9F5DA5","#F16EAA","#4A87C6"),times=1),
                                                                              rep(c("#CE6928","#9F5DA5","#208C43", "#B02325","#F16EAA","#4A87C6"),times=1),
                                                                              rep(c("#CE6928","#B02325","#208C43", "#9F5DA5"),times=1) ),
        ylab="Proportion of genes with average expression > 0")
dev.off()


#for (balance in c(25,15,5)){ # number of cells per cluster (minimal real number=25)
  balance <- 3
  
  HS <- list()
  PS <- list()
  for(r in regions){
    HS[[r]] <- list()
    PS[[r]] <- list()
    for (i in names(cells[["H"]][[r]])){
      print(i)
      HS[[r]][[i]] <- list(C=c(),B=c(),CB=c())
      PS[[r]][[i]] <- list(C=c(),B=c(),CB=c())
      for(b in 1:1000){
        bulk <- list()
        for (s in species){
          cells.i <- cells[[s]][[r]][[i]]
          boot <- sample(cells.i,balance,replace=F) # if using balance>ncells, set replace=T
          x <- Matrix::rowMeans(rlist[[s]][[r]]@assays$RNA@data[,c(boot,boot)])
          bulk[[s]] <- x
        }
        bulk.HM <- abs(bulk[["H"]]-bulk[["M"]]) 
        bulk.CM <- abs(bulk[["C"]]-bulk[["M"]]) 
        bulk.BM <- abs(bulk[["B"]]-bulk[["M"]])
        #      print(unlist(lapply(bulk,function (x) sum(x>0))))
        #      print(sum(bulk.HM>0|bulk.CM>0) + sum(bulk.HM>0|bulk.BM>0))
        HS[[r]][[i]][["C"]] <- c(HS[[r]][[i]][["C"]], sum(bulk.HM > 2*bulk.CM) / sum(bulk.HM>0|bulk.CM>0))
        PS[[r]][[i]][["C"]] <- c(PS[[r]][[i]][["C"]], sum(bulk.HM*2 < bulk.CM) / sum(bulk.HM>0|bulk.CM>0))
        HS[[r]][[i]][["B"]] <- c(HS[[r]][[i]][["B"]], sum(bulk.HM > 2*bulk.BM) / sum(bulk.HM>0|bulk.BM>0))
        PS[[r]][[i]][["B"]] <- c(PS[[r]][[i]][["B"]], sum(bulk.HM*2 < bulk.BM) / sum(bulk.HM>0|bulk.BM>0))
        HS[[r]][[i]][["CB"]] <- c(HS[[r]][[i]][["CB"]], sum(bulk.HM > 2*bulk.BM & bulk.HM > 2*bulk.CM) / sum((bulk.HM>0|bulk.CM>0) & (bulk.HM>0|bulk.BM>0)) )
        PS[[r]][[i]][["CB"]] <- c(PS[[r]][[i]][["CB"]], sum(bulk.HM*2 < bulk.BM & bulk.HM*2 < bulk.CM) / sum((bulk.HM>0|bulk.CM>0) & (bulk.HM>0|bulk.BM>0)) )
        # HS[[r]][[i]][["C"]] <- c(HS[[r]][[i]][["C"]], sum(bulk.HM > 2*bulk.CM))
        # PS[[r]][[i]][["C"]] <- c(PS[[r]][[i]][["C"]], sum(bulk.HM*2 < bulk.CM))
        # HS[[r]][[i]][["B"]] <- c(HS[[r]][[i]][["B"]], sum(bulk.HM > 2*bulk.BM))
        # PS[[r]][[i]][["B"]] <- c(PS[[r]][[i]][["B"]], sum(bulk.HM*2 < bulk.BM))
        # HS[[r]][[i]][["CB"]] <- c(HS[[r]][[i]][["CB"]], sum(bulk.HM > 2*bulk.BM & bulk.HM > 2*bulk.CM))
        # PS[[r]][[i]][["CB"]] <- c(PS[[r]][[i]][["CB"]], sum(bulk.HM*2 < bulk.BM & bulk.HM*2 < bulk.CM))
      }
    }
  }

  save(HS,PS,file=paste0("srong.HS.PS.",balance,".RData"))
  
  pdf(paste0("srong.relativeBarplot.integrated.",balance,".pdf"),width=9,height=3.5)
  par(mfrow=c(1,3),las=2)
  full.rate <- c()
  avg.rate <- c()
  for (r in regions){
    Bh <- lapply(HS[[r]],function (x) x[["B"]])
    Ch <- lapply(HS[[r]],function (x) x[["C"]])
    Bp <- lapply(PS[[r]],function (x) x[["B"]])
    Cp <- lapply(PS[[r]],function (x) x[["C"]])
    CBh <- lapply(HS[[r]],function (x) x[["CB"]])
    CBp <- lapply(PS[[r]],function (x) x[["CB"]])
    #all <- sapply(names(Bh),simplify=F,function (x) c(Bh[[x]],Ch[[x]],Bp[[x]],Cp[[x]]))
    all <- sapply(names(Bh),simplify=F,function (x) c(CBh[[x]],CBp[[x]]))
    
    #  all <- sapply(names(all), simplify=F, function (x) all[[x]]/nonzero.perCluster[paste(r,x)]*100)
    #  norm <- mean(unlist(all))
    #  all <- lapply(all,function (x) log2(x/norm))
    full.rate <- c(full.rate, unlist(lapply(all,mean)))
    
    means <- unlist(lapply(all,mean))
    names(means) <- unlist(lapply(strsplit(ct.names[[r]], " "), function (x) x[[2]]))
    avg.rate <- rbind(avg.rate, cbind(ct.names[[r]], means))
    sds <- unlist(lapply(all,sd))
    bars <- barplot(means,border="#4000C0",col="#4000C048",main=r,ylab="Evolutionary rate",ylim=c(0,0.4))
    arrows(bars,means+sds,bars,means-sds,length=0,col="#4000C0")
  }
  
  barplot(full.rate,border="#4000C0",col="#4000C048",ylab="Evolutionary rate rev.")
  
  plot(nonzero.perCluster, full.rate, main=cor(nonzero.perCluster, full.rate))
  dev.off()
  
  
  pdf(paste0("srong.ratio.integrated.",balance,".pdf"),width=9,height=3.5)
  par(mfcol=c(1,3))
  for (r in regions){
    Bh <- lapply(HS[[r]],function (x) x[["B"]])
    Ch <- lapply(HS[[r]],function (x) x[["C"]])
    Bp <- lapply(PS[[r]],function (x) x[["B"]])
    Cp <- lapply(PS[[r]],function (x) x[["C"]])
    CBh <- lapply(HS[[r]],function (x) x[["CB"]])
    CBp <- lapply(PS[[r]],function (x) x[["CB"]])
    
    Cratio <- sapply(names(Ch), simplify=F, function (x) log2(Ch[[x]]) - log2(Cp[[x]]))
    Bratio <- sapply(names(Bh), simplify=F, function (x) log2(Bh[[x]]) - log2(Bp[[x]]))
    ratio <- sapply(names(CBh), simplify=F, function (x) log2(CBh[[x]]) - log2(CBp[[x]]))
    boxplot(ratio,names=ct.names[[r]],las=2,notch=T,outpch=".",ylim=c(-1,3.5),main=r,ylab="log2( #hum-spec / #pan-spec )")
    abline(h=0,lty=2)
  }
  dev.off()
  
  
#}
