#!/usr/bin/env Rscript
suppressPackageStartupMessages(library("argparse"))
# create parser object

docstring <- "The following script computes surrogate variables (SVs) from\\n\\
a tab-separated dataframe (N x P) and returns the SVs (P x K, where K = num. SVs).\\n\\
\\n\\
Input matrix is of the following format:\\n\\
\\n\\
Surrogate variables account for artifacts such as batch effects. For more information see:\\n\\
Leek 2014 (https://doi.org/10.1093/nar/gku864) and\\n\\
Leek and Story 2007 (https://doi.org/10.1371/journal.pgen.0030161)"

parser <- ArgumentParser(description=docstring, 
                         formatter_class='argparse.RawTextHelpFormatter')

# required args
parser$add_argument("-i", "--mat",
                    help="Input matrix")
parser$add_argument("-s", "--sv",
                    help="Output surrogate variables")

# optional args
parser$add_argument("--constant", type="integer", default=1,
 help="constant in svaseq, [default \"%(default)s\"]")

parser$add_argument("-l", "--lib_sizes", default=NULL, help="describe")

args <- parser$parse_args()

if (is.null(args$mat)){
    print("Please provide --mat")
    exit()
}
if (is.null(args$sv)){
    print("Please provide --sv")
    exit()
}

##################################################
suppressPackageStartupMessages(library("sva"))
suppressPackageStartupMessages(library("edgeR"))

mat <- read.table(args$mat, header = T, row.names=1, sep='\t')
# mat <- read.table("/data/reddylab/projects/tnbc/data/rna_seq//expression/all_treatments_and_CAL51.rnaseq.DMSO.featurecounts.genes.counts.selected_samples.ncrna.txt", header = T, row.names=1, sep='\t')

# set up models

# the following line relies on Alex B's advanced naming system
# for more details see XXX
group <- factor(sapply(lapply(strsplit(as.character(colnames(mat)), "\\."), 
                              function(x) x[1:length(x)-1]), paste, collapse = "."))

mat <- DGEList(counts = mat, group = group)

if (!is.null(args$lib_sizes)){
    lib_sizes <- read.table(args$lib_sizes, header = F, row.names=1, sep='\t')
    mat$samples$lib.size <- lib_sizes[colnames(mat),]
    mat <- calcNormFactors(mat, method="none")
}else{
    mat <- calcNormFactors(mat)
}

mod = model.matrix(~group)
mod0 = cbind(mod[,1])

# remove sites with 0 variance
# and with cpm \leq in more than half of samples
variance <- apply(as.matrix(mat),1,var)
mean <- apply(as.matrix(mat),1,mean)
sub_mat <- mat[(variance > 0) & (rowSums(cpm(mat)>1) >= dim(mat)[2]/2), ,
               keep.lib.sizes=TRUE]

# see if all significant SVs can be computed
res = tryCatch(svaseq(cpm(sub_mat), mod, mod0, constant=args$constant), error = function(err) return(F))
# res = tryCatch(svaseq(cpm(sub_mat), mod, mod0, constant=5), error = function(err) return(F))

# it's possible that one SV may not be able to be fit
# because the corrected matrix may become singular,
# which is more likely to happen with fewer reps per time point.
# in such a case, compute as many SVs possible.

converged=F; n.sv=0
if (is.logical(res)) {
    while(converged==F){
        n.sv=n.sv+1
        if (n.sv > 10) {
            converged<-T
        }
        res = tryCatch(svaseq(cpm(sub_mat), mod, mod0, constant=args$constant, n.sv=n.sv), error = function(err) return(F))
        if (is.logical(res)) {
            res = svaseq(cpm(sub_mat), mod, mod0, constant=args$constant, n.sv=n.sv-1);
            converged<-T
        }
    }
}

write.table(file=args$sv, res$sv, sep="\t", quote=F, row.names=F, col.names=F)