########################################################################
# ExoPlotter
# 
# ExoPlotter generates a set of PDF plots using output files generated by 5PrimeCounter. This R script needs as input the prefix provided to 5PrimeCounter. Two other flags can be added to get additional plots, also written in PDF files.
# 
# A typical command line would be :
# 
# > Rscript exoPlotter.R \<OUTPUT_PREFIX> genome_seq perm
# 
# Where :
# 
# * \<OUTPUT_PREFIX> is the output prefix used in the 5PrimeCounter command line.
# * 'genome_seq' is a fixed label. If present, 5PrimeCounter was run with option \--genome_seq, and corresponding output files can be found using \<OUTPUT_PREFIX>.
# * 'perm' is a fixed label. If present, 5PrimeCounter was run with option \--perm, and corresponding output files can be found using \<OUTPUT_PREFIX>.
#
#
# 15.10.13 Tool developed in R by Jonas Ibn-Salem, with code from Mike Love, Morgane Thomas-Chollier, Samuel Collombet and Celine Hernandez.
# 
# Developed with:
# 
# * R (2.14.1).
# 
# Package dependencies (available on CRAN):
# 
# * gdata (NB: gtools will also be installed).
# * ape 
# 
# 
# Last update: 26.09.2014
# 
########################################################################


########################################################################
# load libraries
########################################################################

suppressMessages(library(gdata))

########################################################################
# supporting functions
########################################################################

# converts a sequence into a vector of colors
seq2color <- function(s) {
    #return(sapply(strsplit(s, "")[[1]], switch , 
    return(sapply(s, switch , 
        "A"="green", "a"="green", 
        "C"="blue", "c"="blue",
        "G"="gold", "g"="gold",
        "T"="red", "t"="red", 
        "N"="gray", "n"="gray", 
        "darkgray"))
}

equal_dist = function(a,b,m) { return (1 -  sum(a==b)/m ) }

custom.dist <- function(input.matrix, my.function) {
    n <- nrow(input.matrix)
    m <- ncol(input.matrix)
    mat <- matrix(0, ncol = n, nrow = n)
    colnames(mat) <- rownames(mat) <- rownames(input.matrix)
    for(i in 1:nrow(mat)) {
        for(j in 1:ncol(mat)) {
            mat[i,j] <- my.function(input.matrix[i,],input.matrix[j,], m)
    }}
    return(as.dist(mat))
}

########################################################################
# Parameters
########################################################################

# maximal 5'coverage count value for visualization in heatmaps
MAX_COUNT = 100

# pseudocount to 5'coverage values, needed for proper log transformation in the heatmaps
PSEUDOCOUNT = 1

plot_formats=c("pdf")


########################################################################
# main code
########################################################################

# read argument from command line:
args <- (commandArgs(TRUE))
prefix <- args[1] 

do_plot_QC <- any(args == 'genome_seq') 
do_plot_perm <- any(args == 'perm') 

print(args)
print(do_plot_QC)
print(do_plot_perm)

# read 5'coverage count data
up_counts <- as.matrix(read.table(paste(prefix, ".up_counts.tab", sep=""), header=FALSE, row.names=1))
down_counts <- as.matrix(read.table(paste(prefix, ".down_counts.tab", sep=""), header=FALSE, row.names=1))

has_enough_counts <- rowSums(up_counts+down_counts) > 5
up_counts <- up_counts[has_enough_counts, ]
down_counts <- down_counts[has_enough_counts, ]

# read consensus sequence and convert string in vector of chars
consensus_str <- readLines(paste(prefix, ".consensus.txt", sep=""))
consensus <- substring(consensus_str, seq(1,nchar(consensus_str),1), seq(1,nchar(consensus_str),1))

# get dimensions
n <- dim(up_counts)[1] # number of regions/sites
m <- dim(up_counts)[2] # number of positions

# x-axis values for plotting:
xs <- (-m/2):(m/2-1)

########################################################################
# Heatmap like plot with max of forward and reverse counts per cell:
########################################################################

# MAX_COUNT <- ifelse(max(up_counts,down_counts)<100, 10, 100)

# scale data logarithmically for better visualisation of low count sites.
heat.val.up <- log10(up_counts + PSEUDOCOUNT)/log10(MAX_COUNT + PSEUDOCOUNT)
heat.val.down <- log10(down_counts+ PSEUDOCOUNT)/log10(MAX_COUNT + PSEUDOCOUNT)

# take maximum of up and down counts
heat.val = ifelse(heat.val.up >= heat.val.down, heat.val.up, -heat.val.down )

# force heat map values between -1 and 1
# In this way, higher values than MAX_COUNT will have the color intensity coresponding to MAX_COUNT
heat.val[heat.val > 1] <- 1
heat.val[heat.val < -1] <- -1


for ( f in plot_formats ) {
    get(f)(file=paste(prefix,".maxstrand_heatmap.", f, sep=""))

        # make tow plots in one figure above each other
        #par(mar= 'c(bottom, left, top, right)', default is c(5,4,4,2)
        def.par <- par(no.readonly = TRUE,mar=c(5, 2, 4, 1))
        nf <- layout(matrix(c(1,2), 1,2),  width = c(.9, .1), T)
        #layout.show(nf)

        # plot heatmap
        image(x=xs, z=t(heat.val), xlim=c(-m/2, m/2), zlim=c(-1,1), 
            col=colorRampPalette(c("red","white","blue"))(MAX_COUNT+1), 
            xaxt="n", yaxt="n",xlab="Distance from motif center (bp)")
        
        # add axis, axis labels and consensus seq
        axis(1, line=1, xpd=TRUE, xaxs="i")
        mtext(paste("Binding sites ( n =", n, ")"),2,line=1)
        mtext(strsplit(consensus, ""),side = 1,line = 0,at = xs, col=seq2color(consensus), cex=0.8)
        
        #===============================================================
        # Legend with color scales from white to blue/red
        #===============================================================
        r <- cbind(seq(0, 1, 0.01), -seq(0, 1, 0.01) )
        #r = cbind(-seq(0, 100), seq(0, 100) )
        # par(mar= 'c(bottom, left, top, right)', default is c(5,4,4,2)
        op <- par(mar=c(12, 0, 4, 2) )
        image(x=c(1,2), y=seq(0, 1, 0.01), z=t(r), col=colorRampPalette(c("red","white","blue"))(MAX_COUNT+1), ylab="", xlab="", xaxt="n", yaxt='n')
        par(op)
        
        mtext(c("Forward", "Reverse", "5' coverage"), side=4, line=c(-3, -2, 0), at=-0.075, adj=1)
        
        #pow=c(0,1)
        pow <- 0:(log10(MAX_COUNT)-1)
        ticksat <- as.vector(sapply(pow, function(p) (1:10)*10^p))
        myticks <- log10(ticksat)/log10(max(ticksat))

        #xlabs = signif(10^( log10(MAX_COUNT+1) * seq(0, 1, 0.25) )- 1 , 2)
        #xlabs = expression(0, 10^0.5,10^1,10^1.5, 10^2)
        #xlabs = expression(1, 10^0.5,10^1,10^1.5, 10^2)
        xlabs = 10^(0:log10(MAX_COUNT)) # 1, 10, 100

        #xlabs = signif(10^( log10(max_count+1) * c(0, .5, 1) )- 1 , 2)
        axis(side=4, at=0:log10(MAX_COUNT)/log10(MAX_COUNT), labels=xlabs, line=-1)
        # add ticks:
        axis(side=4, at=myticks, labels=NA, tcl=-0.25, lwd=0, lwd.ticks=1, line=-1)

        #axis(1, at=10^c(0,2,4,6), labels=expression(1, 10^2,10^4, 10^6))
        par(def.par)
    dev.off()
}

# MAX_COUNT <- 100

########################################################################
# 5'coverage profile plot
########################################################################

for ( f in plot_formats ) {
    up_sum = colSums(up_counts) # /n
    down_sum = colSums(down_counts) # / n
    ymax = max(c(up_sum, down_sum))

    get(f)(file=paste(prefix,".profile.", f, sep=""))
    par(mar = c(5, 4, 3, 2) + 0.1)
    plot(0,0, type="n", xlim=c(-m/2, m/2),ylim=c(0,ymax),
        xlab="Distance from motif center (bp)", xaxt="n", 
        ylab=paste("5' coverage (", n, "sites )"))
    axis(1, line=1)
    #mtext(paste("sum 5' coverage (", n, "sites )"),2,line=2)
    mtext(strsplit(consensus, ""),side = 1,line = ,at = xs, col=seq2color(consensus), cex=0.8)

    # add grid
    abline(v=xs, col="gray", lty="dotted")
    nlines <- ceiling(m / 5 / 2)
    abline(v=-nlines:nlines * 5, col="black", lty=2)

    points(xs, up_sum, col="blue",type="o",lwd=3, cex=.6, pch=19)
    points(xs, down_sum, col="red",type="o",lwd=3, cex=.6, pch=19)
        
    opar <- par(fig=c(0, 1, 0, 1), oma=c(0, 0, 0, 0), 
                mar=c(0, 0, 0, 0), new=TRUE)
    plot(0, 0, type='n', bty='n', xaxt='n', yaxt='n')
    legend(x="top", c("Forward", "Reverse"), bg="white", horiz = TRUE, bty='n',
        lty=c(1,1), pt.cex=.6, pch=19, col=c("blue", "red"), lwd=3)
    par(opar)
    
    graphics.off()
}

########################################################################


########################################################################
# hierarchical clustering of sites
########################################################################


pseudo.count <- 0.01
mat <- cbind(log10(up_counts+1)/log10(max(up_counts)+1), log10(down_counts+1)/log10(max(down_counts)+1))

if (n>1) {
    pdf(file=paste(prefix,".cluster.pdf", sep=""))
    
    def.par <- par(no.readonly = TRUE,mar=c(5,2,4,1))
    nf <- layout(matrix(c(1,2), 1,2),  width = c(.3, .7), TRUE)

    dpar <- par(mar=c(4, 4, 3, 0) )
    
    mat <- t(apply(mat,1, smooth))
    rownames(mat) <- rep("", n)
    
    hc <- hclust(dist(mat))
    
    if (n<500) {
        # If we don't have too many leafs, we transform to a dendrogram and plot it
        den <- as.dendrogram(hc)
        plot(den, horiz=TRUE, ylab=paste0("Hierarchical clustering (",length(hc$labels), " sites)"), leaflab = 'none')
    }
    else {
        # If we have more than 500 leaves, we use the ape library in order to avoid 'node stack overflow'
        library(ape)
        plot.phylo(as.phylo(hc), direction='rightwards')
        axisPhylo(side=1)
        mtext(side=2, text = c(paste0("Hierarchical clustering (", length(hc$labels), " sites)")))
    }
    mtext("Distance" ,side = 1,line = 2)
    par(dpar)
    
    # par(mar= 'c(bottom, left, top, right)', default is c(5,4,4,2)
    op = par(mar=c(5, 1, 4, 2) )
    image(x=xs, z=t(heat.val[hc$order,]), xlim=c(-m/2, m/2), col=colorRampPalette(c("red","white","blue"))(100), xaxt="n", yaxt="n",xlab="Distance from motif center (bp)")
    axis(1, line=1, xpd=TRUE, xaxs="i")
    mtext(strsplit(consensus, ""),side = 1,line = 0,at = xs, col=seq2color(consensus), cex=0.8)
    par(op)
    
    par(def.par)
        
    dev.off()
}

########################################################################
# color map of centered sequences and clustering of sites by sequence
########################################################################

if(do_plot_QC) {
    seq_matrix = read.table(paste(prefix, ".seq_matrix.tab", sep=""), header=FALSE, row.names=1)[has_enough_counts, ]
    
    # plot color map
    for ( f in plot_formats ) {
        get(f)(file=paste(prefix,"_seq.", f, sep=""))
        image(x=xs,y=1:n, z=t(seq_matrix), xlim=c(-m/2, m/2), 
              col=c("green", "blue", "yellow", "red"),
              xaxt="n", yaxt="n",xlab="Distance from motif center (bp)", ylab="")
        axis(1, line=1)
        mtext(strsplit(consensus, ""),side = 1,line = 0,at = xs, col=seq2color(consensus), cex=0.8)
        mtext(paste("Binding sites ( n =", n, ")"),2,line=1)
        #,xaxt="n",yaxt="n"
        graphics.off()
    }
    
    # remove labels
    seq_matrix = as.matrix(seq_matrix)
    rownames(seq_matrix) = rep("", nrow(seq_matrix))
    # take only a subset of maximal 100 sites
    top.sites = min(n, 200) #500
    
    # coniser only the inner +- BASES nucleotides for the sequence clustering
    BASES=10
    consider = ((m/2)-BASES):((m/2)+BASES-1)
    
    # calcualte paire-wise distances
    d = custom.dist(seq_matrix[1:top.sites,consider], equal_dist)
    # cluster:
    hc = hclust(d)
    den = as.dendrogram(hc)
    
    pdf(file=paste(prefix,".seq_cluster.pdf", sep=""), width=14, height=7)
    
    def.par <- par(no.readonly = TRUE,mar=c(5,2,4,1)) 
    nf <- layout(matrix(1:4, 1,4),  width = c(.15, .425, 0.3825, 0.0425), T)
    dpar = par(mar=c(4, 4, 3, 0) )
    plot(den, horiz=TRUE, cex.axis=1)
    mtext("Distance" ,side = 1,line = 2)
    mtext(paste("Clustering by sequence (",length(hc$labels), "sites )"), side=2,line=2)
    
    # par(mar= 'c(bottom, left, top, right)', default is c(5,4,4,2)
    op = par(mar=c(5.5, 1, 4.5, 1) )
    image(x=xs,y=1:top.sites, z=t(seq_matrix[hc$order,]), xlim=c(-m/2, m/2), 
          col=c("green", "blue", "yellow", "red"),
          xaxt="n", yaxt="n", ylab="", xlab="")
    axis(1, line=1.5, xpd=TRUE, xaxs="i", cex.axis=1)
    mtext(strsplit(consensus, ""),side = 1,line = .5,at = xs, col=seq2color(consensus), cex=0.8)
    mtext("Distance from motif center (bp)", side=1, line=4)
    par(op)
    
    # par(mar= 'c(bottom, left, top, right)', default is c(5,4,4,2)
    op = par(mar=c(5.5, 1, 4.5, 1) )
    image(x=xs, z=t(heat.val[hc$order,]), xlim=c(-m/2, m/2), 
          col=colorRampPalette(c("red","white","blue"))(100), xaxt="n", yaxt="n", xlab="")
    axis(1, line=1.5, xpd=TRUE, xaxs="i", cex.axis=1)
    mtext(strsplit(consensus, ""),side = 1,line = .5,at = xs, col=seq2color(consensus), cex=0.8)
    mtext("Distance from motif center (bp)", side=1, line=4)
    par(op)
    
    r = cbind(seq(0, 1, 0.01), -seq(0, 1, 0.01) )
    # par(mar= 'c(bottom, left, top, right)', default is c(5,4,4,2)
    op = par(mar=c(12, 0, 4.5, 3) )
    image(x=c(1,2), y=seq(0, 1, 0.01), z=t(r), col=colorRampPalette(c("red","white","blue"))(100), ylab="", xlab="", xaxt="n", yaxt='n')
    par(op)
    mtext(c("Forward", "Reverse", "5' coverage"), side=4, line=c(-5, -4, -2), at=-0.075, adj=1)        
    pow=c(0,1)
    ticksat <- as.vector(sapply(pow, function(p) (1:10)*10^p))
    myticks = log10(ticksat)/log10(max(ticksat))
    xlabs = expression(1, 10, 100)
    axis(side=4, at=c(0, .5, 1), labels=xlabs, line=-3, cex.axis=1.5)
    axis(side=4, at=myticks, labels=NA, tcl=-0.25, lwd=0, lwd.ticks=1, line=-3, cex.axis=1.5)        
    
    par(def.par)
    
    dev.off()
    
}

########################################################################

if(do_plot_perm) {
    
    ######################################################################
    ## find all perm files
    
    prefix_perm <- paste0(prefix, '_perm')
    basename_perm <- basename(prefix_perm)
    dirname_perm <- dirname(prefix_perm)
    all_perm_files_up <- list.files(path = dirname_perm, pattern = paste0(basename_perm, '.+', '.up_counts.tab'), full.names = TRUE)
    all_perm_files_down <- list.files(path = dirname_perm, pattern = paste0(basename_perm, '.+', '.down_counts.tab'), full.names = TRUE)

    rep_up <- length(all_perm_files_up)
    rep_down <- length(all_perm_files_down)    
    if(rep_up<10 || rep_down<10) { 
        warnings("Didn't find enough permutation outputs. Found ",
                 rep_up," files with pattern ", prefix_perm, '*.up_counts.tab', ' and ',
                 rep_down, " files with pattern ", prefix_perm, '*.down_counts.tab')
    }
    if(rep_up!=rep_down) { 
        warnings("Didn't find same number of permutation outputs. Found ",
                 rep_up," files with pattern ", prefix_perm, '*.up_counts.tab', ' and ',
                 rep_down, " files with pattern ", prefix_perm, '*.down_counts.tab')
    }
    
    ######################################################################
    ## Perm profiles
    
    pdf(file=paste0(prefix,".profile-perm.pdf"))
    
    ## read non-perm files and obtain the total counts per position
    
    # compute sum of each column and divide by total number of regions/sites
    up_A_norm <- colSums(up_counts) / nrow(up_counts)
    down_A_norm <- colSums(down_counts) / nrow(down_counts)

    ## read all perm files and obtain the total counts per position

    normalized_updown_counts2 <- apply(
        X = data.frame(all_perm_files_up, all_perm_files_down, as.is=TRUE), MARGIN=1,
        FUN=function(files) {
            counts_up <- read.table(files[1], header=FALSE, sep = "\t", as.is=TRUE, na.strings = "N/A", row.names=1)
            counts_down <- read.table(files[2], header=FALSE, sep = "\t", as.is=TRUE, na.strings = "N/A", row.names=1)

            has_enough_counts <- rowSums(counts_up+counts_down) > 5
            counts_up <- counts_up[has_enough_counts, ]
            counts_down <- counts_down[has_enough_counts, ]
            
            sum_counts_up <- colSums(counts_up)
            norm_counts_up <- sum_counts_up / nrow(counts_up)
            
            sum_counts_down <- colSums(counts_down)
            norm_counts_down <- sum_counts_down / nrow(counts_down)
            
            return(data.frame(norm_counts_up, norm_counts_down))
        })
    normalized_up_counts <- t(sapply(normalized_updown_counts2, '[[', 'norm_counts_up'))
    normalized_down_counts <- t(sapply(normalized_updown_counts2, '[[', 'norm_counts_down'))
    
    
    # Plot
    
    # get 10, 25, 50, 75 and 90th quantiles
    quant_up <- apply(normalized_up_counts, 2, quantile, probs=(c(0.1,0.25,0.5,0.75,0.9)))
    quant_down <- apply(normalized_down_counts, 2, quantile, probs=(c(0.1,0.25,0.5,0.75,0.9)))
    
    # find overall max value
    ymax <- max(quant_up, quant_down, up_A_norm, down_A_norm)
    
    # plot the results
    plot(0, 0, type="n", las=1,
         xlim=c(-m/2, m/2),ylim=c(0,ymax),
         ylab="Normalized 5' coverage", xlab="Distance from motif center (bp)", xaxt="n")
    axis(1, line = 1)
    mtext(strsplit(consensus, ""),side = 1,line = ,at = xs, col=seq2color(consensus), cex=0.8)
    
    # plot the controls areas
    # up IQR
    polygon(c(xs,rev(xs)), c(quant_up["75%",],rev(quant_up["25%",])), col=rgb(0,0,1,alpha=0.1), lwd=1, lty=1, border=NA)
    # up median
    points(xs, quant_up["50%",], col=rgb(0,0,1,alpha=0.5),type="l",lty=3,lwd=2)
    
    # down IQR
    polygon(c(xs,rev(xs)), c(quant_down["75%",],rev(quant_down["25%",])), col=rgb(1,0,0,alpha=0.1), lwd=1, lty=1, border=NA)
    # down median
    points(xs, quant_down["50%",], col=rgb(1,0,0,alpha=0.5),type="l",lty=3,lwd=2)
    
    # add grid
    abline(v=xs, col="gray", lty="dotted")
    nlines <- ceiling(m / 5 / 2)
    abline(v=-nlines:nlines * 5, col="darkgray", lty=2)
    
    # plot the actual signal
    points(xs, t(up_A_norm), col="blue", type="l",lwd=3, cex=.6, pch=19)
    points(xs, t(down_A_norm), col="red", type="l",lwd=3, cex=.6, pch=19)

    opar <- par(fig=c(0, 1, 0, 1), oma=c(0, 0, 0, 0), 
                mar=c(0, 0, 0, 0), new=TRUE)
    plot(0, 0, type='n', bty='n', xaxt='n', yaxt='n')
    legend(x="top", c("Forward", "Reverse", "Permuted Fwd", "Permuted Rev"), bg="white", ncol=2, bty='n',
           pt.cex=.6, #pch=19, 
           col=c("blue", "red", 0, 0), lty=c(1,1,0,0), lwd=c(3,3,0,0),
           fill = c(0, 0, rgb(0,0,1,alpha=0.1), rgb(1,0,0,alpha=0.1)), border= c(0, 0, rgb(0,0,1,alpha=0.5), rgb(1,0,0,alpha=0.5))
           )
    par(opar)

    dev.off()
    
    ######################################################################
    ## Wilcoxon Rank sum test
    

    pseudo_count <- 1
    read_sum <- rowSums(up_counts) + rowSums(down_counts) + pseudo_count
    
    random_read_sum <- apply(
        X = data.frame(all_perm_files_up, all_perm_files_down, as.is=TRUE), MARGIN=1,
        FUN=function(files) {
            random_up_counts <- read.table(files[1], header=FALSE, sep = "\t", as.is=TRUE, na.strings = "N/A", row.names=1)
            random_down_counts <- read.table(files[2], header=FALSE, sep = "\t", as.is=TRUE, na.strings = "N/A", row.names=1)
 
            has_enough_counts <- rowSums(random_up_counts+random_down_counts) > 5
            random_up_counts <- random_up_counts[has_enough_counts, ]
            random_down_counts <- random_down_counts[has_enough_counts, ]
            
            return(rowSums(random_up_counts) + rowSums(random_down_counts) + pseudo_count)
        })
    random_read_sum <- unlist(random_read_sum)

    # Wilcox test 
    wilcox = wilcox.test(read_sum, random_read_sum, alternative="greater")
    
    # Print result in a file
    header <- paste("Nb observed", "Mean observed", "SD observed", "Nb random", "Mean random", paste0(wilcox$method, " (p-value)"), sep="\t")
    out_line <- paste(length(read_sum), mean(read_sum), sd(read_sum), length(random_read_sum), mean(random_read_sum), sd(random_read_sum), wilcox$p.value, sep="\t")
    cat(paste(header, out_line, sep="\n"), file=paste(prefix,".permut_matrix_", length(all_perm_files_up), ".values.tab", sep=""))
    



}
