################################################################
## Generate heatmap of read density around the peaks.
##

library(RColorBrewer)

## Load some libraries
source('http://www.bigre.ulb.ac.be/Users/jvanheld/statistics_bioinformatics/R-files/config.R')
source(file.path(dir.util, "util.R"))
source(file.path(dir.util, "util_plots.R"))

## Redefine the main directory (this should be adapted to local configuration)
dir.main <- getwd()

## Default parameters
hm.col <- colorRampPalette(brewer.pal(9,"BuPu"))(20)
plot.formats <- c("png", "pdf")

## FOR TESTING ONLY
##dir.main <- '/Users/jvanheld/mechali'
dir.main <- '~/mechali'
#infile <- 'analysis/final_heatmap/concat_MACS_merged300bp_ES_100_5000.bins2'
#infile <- 'analysis/final_heatmap/concat_MACS_merged300bp_ES_100_25000.bins2'
#infile <- 'analysis/final_es/peaks/ori_es_swembl_100_100000.bins'
infile <- 'analysis/final_es/peaks/ori_es_swembl_100_5000.bins'


################################################################
## Read arguments from the command line.
##
## Arguments passed on the command line will over-write the default
## arguments specified above.
args = commandArgs(trailingOnly=TRUE);

if(length(args)==0){
  stop("No arguments supplied. Mandatory: infile=[input_file] ")
}else{
  print("Parsing command-line arguments")
  print(args)
  for(i in 1:length(args)){
    eval(parse(text=args[[i]]))
  }
}

## Check that input file has been specified
if (!exists("infile")) {
  stop("Missing mandatory argument: infile=[input_file] ")
}
verbose(paste("Input file", infile), 1)

## Set files and directories
setwd(dir.main)
dir.results <- dirname(infile)
dir.figures <- file.path(dir.results, "figures")
dir.create(dir.figures, showWarnings=FALSE, recurs=TRUE)

## Optional arguments
if (!exists("prefix")) {
  prefix <- basename(infile)
}

################################################################
## Load table of reads density
density.table <- read.delim(infile, header=FALSE, sep="") ## Beware: input files contain tab between the peak coordinates, and spaces between bin counts
dim(density.table)
    
## TEMP: select a small number of peaks
selected.peaks <- 1:10000
density.table <- density.table[selected.peaks,]

## Re-format density table
bins <- ncol(density.table) - 4
names(density.table) <- c("chrom", "left", "right", "dataset", paste("bin", 1:bins, sep="."))
density.columns <- 5:ncol(density.table)
densities <- density.table[,density.columns]
density.table[,"peak.center"] <- round((density.table$right - density.table$left )/2)

print(density.table[1:20, c(1:5,(ncol(density.table)-3:0))])

################################################################
## Compute stats on all density values (including negative values)
density.vect <- as.vector(as.matrix(densities))
density.stats <- c(
                   "mean"=mean(density.vect),
                   "median"=median(density.vect),
                   "min"=min(density.vect),
                   "max"=max(density.vect))
print(density.stats)
density.percentiles <- quantile(density.vect, prob=(0:100)/100)
print("density.percentiles")
print(density.percentiles)

## Compute summary statistics in separate vectors
stats.per.bin <- data.frame(
                            "mean"=apply(densities,2,mean),
                            "median"=apply(densities,2,median),
                            "max"=apply(densities,2,max),
                            "min"=apply(densities,2,min)
                            )
quartiles.per.bin <- sapply(densities, quantile, probs=(0:4)/4)
percentiles.per.bin <- sapply(densities, quantile, probs=(0:100)/100)
export.object(percentiles.per.bin, file=file.path(dir.results, paste(prefix, "percentiles_per_bin", sep="_")), export.format='table')

################################################################
## Truncate negative values and outliers (above user-selected percentile)

## Reset negative values to 0 (negative values come from the substraction of the input)
threshold.lower <- 0
threshold.upper <-  density.percentiles["99%"]
densities.truncated <- densities
densities.truncated[densities < threshold.lower] <- threshold.lower
densities.truncated[densities > threshold.upper] <- threshold.upper

## Compute summary statistics per bin
stats.per.bin.truncated <- data.frame(
                                      "mean"=apply(densities.truncated,2,mean),
                                      "median"=apply(densities.truncated,2,median),
                                      "max"=apply(densities.truncated,2,max),
                                      "min"=apply(densities.truncated,2,min)
                                      )
quartiles.per.bin.truncated <- sapply(densities.truncated, quantile, probs=(0:4)/4)
percentiles.per.bin.truncated <- sapply(densities.truncated, quantile, probs=(0:100)/100)
export.object(percentiles.per.bin.truncated, file=file.path(dir.results, paste(prefix, "percentiles_per_bin_truncated", sep="_")), export.format='table')

## Compute summary statistics per peak
stats.per.peak.truncated <- data.frame("chrom" = density.table[,"chrom"],
                                       "left" = density.table[,"left"],
                                       "right" = density.table[,"right"],
                                       "dataset" = density.table[,"dataset"],
                                       "mean" = round(apply(densities.truncated,1,mean), digits=3),
                                       "median" = apply(densities.truncated,1,median),
                                       "max" = apply(densities.truncated,1,max),
                                       "min" = apply(densities.truncated,1,min)
                                       )
export.object(stats.per.peak.truncated, file=file.path(dir.results, paste(prefix, "stats_per_peak_truncated", sep="_")), export.format='table')

## ## ref.bin <- floor (bins/2) ## Use central bin as reference position
## ref.bin <- which.max(stats.per.bin$mean) ## Use mxa density bin as reference position
## bin.pos <- (1:bins) - ref.bin

## ## Compute dot product of densities relative ot peak center
## stats.per.peak.truncated["dp"] <- as.vector(bin.pos %*% as.matrix(t(densities.truncated)))

## ## Compute moment of inertia relative to leftmost bin
## stats.per.peak.truncated["inertia"] <- as.vector((1:bins)^2 %*% as.matrix(t(densities.truncated))) 


################################################################
## Plot summary profiles

## Quick test: boxplot
x11(width=20,height=6)
boxplot(densities.truncated, main=prefix)
export.plot(file=file.path(dir.figures, paste(prefix, "boxplot", sep="_")), width=20, height=6)


## Histogram of density values
x11(width=12,height=8)
par(mfrow=c(2,1))
hist(as.vector(as.matrix(densities)), breaks=1000, xlab="nb reads", ylab="number of bins", main="Numbers of reads per bin")
hist(as.vector(as.matrix(densities.truncated)), breaks=threshold.lower:threshold.upper, xlab="nb reads", ylab="number of bins", main=paste("Numbers of reads per bin (truncated ",threshold.lower, ":", threshold.upper,")", sep=""), col="grey")
par(mfrow=c(1,1))
export.plot(file=file.path(dir.figures, paste(prefix, "histo", threshold.lower, threshold.upper, sep="_")), width=12, height=8)

## Draw mean profiles per chromosome
chromosomes <- unique(sort(as.vector(density.table$chrom)))
bin.per.chrom.mean <- by(densities, density.table$chrom, mean)

## Draw heatmap before sorting (for test)
x11(width=7, height=12)
image(t(as.matrix(densities.truncated)), col = hm.col, main=paste(prefix, "densities"), xlab="position", ylab="peaks")
export.plot(file=file.path(dir.figures, paste(prefix, "heatmap", "unsorted", sep="_")), width=12, height=8)

## ## Draw heatmap of peaks sorted by dot product
## image(t(as.matrix(densities.truncated[order(stats.per.peak.truncated["dp"]),])), col = hm.col, main=paste(prefix, "sorted by dot product"))
## export.plot(file=file.path(dir.figures, paste(prefix, "heatmap", "sorted_dp", sep="_")), width=12, height=8)

## ## Draw heatmap of peaks sorted by inertia
## x11(width=7, height=12)
## image(t(as.matrix(densities.truncated[order(stats.per.peak.truncated["inertia"]),])), col = hm.col, main=paste(prefix, "sorted by inertia"))
## export.plot(file=file.path(dir.figures, paste(prefix, "heatmap", "sorted_inertia", sep="_")), width=12, height=8)


## K-means clustering
k <- 20
clusters <- kmeans(densities.truncated, centers=k, iter.max=20)

x11(width=7, height=12)
image(t(as.matrix(densities.truncated[order(clusters$cluster),])), col = hm.col, main=paste(prefix, "sorted by inertia"))
export.plot(file=file.path(dir.figures, paste(prefix, "_heatmap", "_k", k, sep="")), width=12, height=8)

## ## hierarchical clustering is not convincing, because of tree isomorphism
## dist <- 1 - cor(t(densities.truncated))
## tree <- hclust(as.dist(dist), method="complete") 
## x11(width=12, height=12)
## heatmap(as.matrix(densities.truncated), Colv=NA)

## TO BE DONE
## - smooth profiles by grouping bins by 5
##     or more elaborate smoothing with deconvolution / kernel
## - draw profiles to see if this grouping reveals 1, 2  or meany peaks per region
## - kmeans followed by hclust of the clusters on the basis of their average profiles

convolution.filter <- c(1,2,2,2,1)
#convolution.filter <- c(2,2,2,2,2)
convolution.filter <- convolution.filter/sum(convolution.filter)
i <- 1
densities.smoothed <- matrix(nrow=nrow(densities.truncated), ncol=ncol(densities.truncated) - length(convolution.filter) +1)
for (i in 1:nrow(densities.truncated)) {
  x <- as.vector(as.matrix(densities.truncated[i,]))
  densities.smoothed[i,] <- convolve(x, convolution.filter, type="f")


  ## Illustrate the effect of smoothing with one particular profile
  if (i == 10) {
    x11(width=12,height=5)
    plot(x[3:(length(x)-2)], type="l", col="blue")
    lines(densities.smoothed[i,])
  }
}

## Draw heatmap of smoothed densities
x11(width=7, height=12)
image(t(as.matrix(densities.smoothed)), col = hm.col, main=paste(prefix, "Smoothed densities"), xlab="position", ylab="peaks")
export.plot(file=file.path(dir.figures, paste(prefix, "heatmap", "unsorted", sep="_")), width=12, height=8)



## K-means clustering on smoothed data
k <- 13
clusters.smoothed <- kmeans(densities.smoothed, centers=k, iter.max=20)

## Compare compositions of clusters obtained from raw and smoothed data, resp.
table(clusters$cluster, clusters.smoothed$cluster)

## Plot heatmap of smoothed densities sorted by clusters of smoothed densities
x11(width=7, height=12)
image(t(as.matrix(densities.smoothed[order(clusters.smoothed$cluster),])), col = hm.col, main=paste(prefix, "sorted by smoothed densities"))
export.plot(file=file.path(dir.figures, paste(prefix, "_smoothed_heatmap", "_k", k, sep="")), width=12, height=8)

## Plot heatmap of raw densities sorted by clusters of smoothed densities
x11(width=7, height=12)
image(t(as.matrix(densities.truncated[order(clusters.smoothed$cluster),])), col = hm.col, main=paste(prefix, "sorted by smoothed densities"))
export.plot(file=file.path(dir.figures, paste(prefix, "_raw_heatmap", "_k", k, "smoothed_clusters", sep="")), width=12, height=8)


################################################################
## Compute mean profile per cluster

profiles.per.cluster <- data.frame(matrix(nrow=k, ncol=ncol(densities.smoothed)))
for (cl in 1:k) {
  profiles.per.cluster[cl,] <- apply(densities.smoothed[clusters.smoothed$cluster == cl,], 2, mean)
}

profiles.prcomp <- prcomp(profiles.per.cluster)
x11()
biplot(profiles.prcomp)

cluster.order <- order(profiles.prcomp$x[,1])
peak.cluster.order <- cluster.order[clusters.smoothed$cluster]
x11(width=7, height=12)
image(t(as.matrix(densities.truncated[order(peak.cluster.order),])), col = hm.col, main=paste(prefix, "sorted by smoothed densities"))
export.plot(file=file.path(dir.figures, paste(prefix, "_raw_heatmap", "_k", k, "smoothed_clusters", sep="")), width=12, height=8)
