#!/usr/bin/env Rscript

suppressPackageStartupMessages(library("argparse"))
# create parser object
parser <- ArgumentParser()

# required args
parser$add_argument("-i", "--mat",
 help="Input matrix")
parser$add_argument("-s", "--sv",
 help="Input surrogate variables")
parser$add_argument("-o", "--out_mat",
 help="Output matrix")

args <- parser$parse_args()

if (is.null(args$mat)){
    print("Please provide --mat")
    exit()
}
if (is.null(args$out)){
    print("Please provide --sv")
    exit()
}
if (is.null(args$out)){
    print("Please provide --out_mat")
    exit()
}

##################################################

mat <- read.table(args$mat, header = T, row.names=1, sep='\t')
svs <- read.table(args$sv, header = F, sep='\t')

# filter out sites with 0 counts
# mat <- mat[rowMeans(mat) > 0,]
# mat = mat[apply(mat, 1, function(x) length(x[x != 0])/length(x) > 0.10),] 

# set up models
group <- factor(sapply(strsplit(as.character(colnames(mat)), "_"), function(x) x[[1]]))
mod = model.matrix(~as.factor(group))
mod0 = cbind(mod[,1])
log_mat = log(mat + 1)

cleanY = function(y, mod, svs) {
    X = cbind(mod, svs)
    Hat = solve(t(X) %*% X) %*% t(X)
    beta = (Hat %*% t(y))
    rm(Hat)
    gc()
    P = ncol(mod)
    return(y - t(as.matrix(X[,-c(1:P)]) %*% beta[-c(1:P),]))
}

# regress out surrogate variables
Y = cleanY(log_mat, mod, as.matrix(svs))

write.table(file=args$out_mat, Y, sep="\t", quote=F, row.names=T, col.names=NA)