#!/usr/bin/env Rscript

argv <- commandArgs(TRUE)
if (length(argv) != 2) {
  cat("Usage: rsem-plot-model sample_name output_plot_file\n")
  q(status = 1)
}

strvec <- strsplit(argv[1], split = "/")[[1]]
token <- strvec[length(strvec)]

stat.dir <- paste(argv[1], ".stat", sep = "")
if (!file.exists(stat.dir)) {
  cat("Error: directory does not exist: ", stat.dir, "\n", sep = "")
  cat(strwrap("This version of rsem-plot-model only works with the output of RSEM versions >= 1.1.8"), sep="\n")
  q(status = 1)
}
modelF <- paste(stat.dir, "/", token, ".model", sep = "")
cntF <- paste(stat.dir, "/", token, ".cnt", sep = "")

pdf(argv[2])

con <- file(modelF, open = "r")	

# model type and forward probability
model_type <- as.numeric(readLines(con, n = 4)[1])  

# fragment length distribution
strvec <- readLines(con, n = 3)
vec <- as.numeric(strsplit(strvec[1], split = " ")[[1]])
maxL <- vec[2] # maxL used for Profile
x <- (vec[1] + 1) : vec[2]
y <- as.numeric(strsplit(strvec[2], split = " ")[[1]])
mean <- weighted.mean(x, y)
std <- sqrt(weighted.mean((x - mean)^2, y))
plot(x, y, type = "h",
     main = "Fragment Length Distribution",
     sub = paste("Mean = ", round(mean, 1), ", Std = ", round(std, 1)),
     xlab = "Fragment Length",
     ylab = "Probability")

# mate length distribution
if (model_type == 0 || model_type == 1) bval <- as.numeric(readLines(con, n = 1)[1]) else bval <- 1

if (bval == 1) {
  list <- strsplit(readLines(con, n = 2), split = " ")
  vec <- as.numeric(list[[1]])
  maxL <- vec[2]
  x <- (vec[1] + 1) : vec[2]
  y <- as.numeric(list[[2]])
  mean <- weighted.mean(x, y)
  std <- sqrt(weighted.mean((x - mean)^2, y))
  plot(x, y, type = "h",
       main = "Read Length Distribution",
       sub=paste("Mean = ", round(mean, 1), ", Std = ", round(std, 1)),
       xlab = "Read Length",
       ylab = "Probability")
}
strvec <- readLines(con, n = 1)

# RSPD
bval <- as.numeric(readLines(con, n = 1)[1])
if (bval == 1) {
  bin_size <- as.numeric(readLines(con, n = 1)[1])
  y <- as.numeric(strsplit(readLines(con, n = 1), split = " ")[[1]])
  par(cex.axis = 0.7)
  barplot(y, space = 0, names.arg = 1:bin_size, main = "Read Start Position Distribution", xlab = "Bin #", ylab = "Probability")
  par(cex.axis = 1.0)
}
strvec <- readLines(con, n = 1)

# plot sequencing errors
if (model_type == 1 || model_type == 3) {
  # skip QD
  N <- as.numeric(readLines(con, n = 1)[1])
  readLines(con, n = N + 1)
  readLines(con, n = 1) # for the blank line
  
  # QProfile
  readLines(con, n = 1)

  x <- c()
  peA <- c() # probability of sequencing error given reference base is A
  peC <- c()
  peG <- c()
  peT <- c()
  
  for (i in 1 : N) {
    strvec <- readLines(con, n = 6)
    list <- strsplit(strvec[1:4], split = " ")

    vecA <- as.numeric(list[[1]])
    vecC <- as.numeric(list[[2]])
    vecG <- as.numeric(list[[3]])
    vecT <- as.numeric(list[[4]])

    if (sum(c(vecA, vecC, vecG, vecT)) < 1e-8) next
    x <- c(x, (i - 1))
    peA <- c(peA, ifelse(sum(vecA) < 1e-8, NA, -10 * log10(1.0 - vecA[1])))
    peC <- c(peC, ifelse(sum(vecC) < 1e-8, NA, -10 * log10(1.0 - vecC[2])))
    peG <- c(peG, ifelse(sum(vecG) < 1e-8, NA, -10 * log10(1.0 - vecG[3])))
    peT <- c(peT, ifelse(sum(vecT) < 1e-8, NA, -10 * log10(1.0 - vecT[4])))
  }

  matplot(x, cbind(peA, peC, peG, peT), type = "b", lty = 1:4, pch = 0:3, col = 1:4,
          main = "Observed Quality vs. Phred Quality Score",
          xlab = "Phred Quality Score",
          ylab = "Observed Quality")
  legend("topleft", c("A", "C", "G", "T"), lty = 1:4, pch = 0:3, col = 1:4)
} else {
  # Profile
  readLines(con, n = 1)

  x <- c()  
  peA <- c() # probability of sequencing error given reference base is A
  peC <- c()
  peG <- c()
  peT <- c()

  for (i in 1: maxL) {
    strvec <- readLines(con, n = 6)
    list <- strsplit(strvec[1:4], split = " ")

    vecA <- as.numeric(list[[1]])
    vecC <- as.numeric(list[[2]])
    vecG <- as.numeric(list[[3]])
    vecT <- as.numeric(list[[4]])

    if (sum(c(vecA, vecC, vecG, vecT)) < 1e-8) next
    x <- c(x, i)
    peA <- c(peA, ifelse(sum(vecA) < 1e-8, NA, (1.0 - vecA[1]) * 100))
    peC <- c(peC, ifelse(sum(vecC) < 1e-8, NA, (1.0 - vecC[2]) * 100))
    peG <- c(peG, ifelse(sum(vecG) < 1e-8, NA, (1.0 - vecG[3]) * 100))
    peT <- c(peT, ifelse(sum(vecT) < 1e-8, NA, (1.0 - vecT[4]) * 100))
  }

  matplot(x, cbind(peA, peC, peG, peT), type = "b", lty = 1:4, pch = 0:3, col = 1:4, main = "Position vs. Percentage Sequence Error", xlab = "Position", ylab = "Percentage of Sequencing Error")
  legend("topleft", c("A", "C", "G", "T"), lty = 1:4, pch = 0:3, col = 1:4)       
}

close(con)

pair <- read.table(file = cntF, skip = 3, sep = "\t")
barplot(pair[,2], names.arg = pair[,1],
        xlab = "Alignments per fragment",
        ylab = "Number of fragments",
        main = "Histogram of Alignments per Fragment")

dev.off.output <- dev.off()
