#!/usr/bin/env python 

################################################################################
### COPYRIGHT ##################################################################

# New York Genome Center

# SOFTWARE COPYRIGHT NOTICE AGREEMENT
# This software and its documentation are copyright (2014) by the New York
# Genome Center. All rights are reserved. This software is supplied without
# any warranty or guaranteed support whatsoever. The New York Genome Center
# cannot be responsible for its use, misuse, or functionality.

# Version: 0.2
# Author: Andre Corvelo

################################################################# /COPYRIGHT ###
################################################################################



################################################################################
### MODULES ####################################################################

from optparse import OptionParser, OptionGroup
import sys, os, errno, commands, uuid

################################################################### /MODULES ###
################################################################################



################################################################################
### FUNCTIONS ##################################################################

def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else: raise

def rp(path,start):
	return os.path.relpath(path, start)

def bn(path):
	return os.path.basename(path)

def missing_files(csv_file_list):
	return sum(not os.path.isfile(f) for f in csv_file_list.split(','))

################################################################# /FUNCTIONS ###
################################################################################



################################################################################
### ARGUMENTS,OPTIONS ##########################################################

parser = OptionParser(usage = "\n%prog options", version = "%prog v0.2")

# input
parser_input = OptionGroup(parser, "Input (MANDATORY one of the following: -i, -f, -b or (-1 and -2)")

parser_input.add_option(
	"-i",
	metavar = "STR",
	type = "string",
	dest = "i_comm",
	default = None,
	help = "Input command (use absolute paths!). Quoted, with fq on stdout. Interleaved for paired modes (MANDATORY)"
	)

parser_input.add_option(
	"-f",
	metavar = "STR",
	type = "string",
	dest = "i_fq",
	default = None,
	help = "Input .fq, .fastq, .fq.gz or .fastq.gz files. Interleaved for paired modes (MANDATORY)"
	)

parser_input.add_option(
	"-b",
	metavar = "STR",
	type = "string",
	dest = "i_bam",
	default = None,
	help = "Input .bam file. Sorted for paired modes. Requires .bam.bai in the same folder (MANDATORY)"
	)

parser_input.add_option(
	"-1",
	metavar = "STR",
	type = "string",
	dest = "i_fq_1",
	default = None,
	help = "Input .fq, .fastq, .fq.gz, .fastq.gz read 1 file. In sync with, and of the same sort as, -2 input file for paired modes (MANDATORY)"
	)

parser_input.add_option(
	"-2",
	metavar = "STR",
	type = "string",
	dest = "i_fq_2",
	default = None,
	help = "Input .fq, .fastq, .fq.gz, .fastq.gz read 2 file. In sync with, and of the same sort as, -1 input file for paired modes (MANDATORY)"
	)

parser_input.add_option(
	"-D",
	metavar = "FLOAT",
	type = "string",
	dest = "i_sprob",
	default = "1",
	help = "Downsampling probability (default = 1 = No sampling)"
	)

parser.add_option_group(parser_input)

# preprocessing
parser_preproc = OptionGroup(parser, "Preprocessing")

parser_preproc.add_option(
	"-q",
	metavar = "INT",
	type = "string",
	dest = "p_qual",
	default = None,
	help = "Cutadapt quality cutoff value (default = None)"
	)

parser_preproc.add_option(
	"-a",
	metavar="STR",
	type = "string",
	dest = "p_adapt",
	default = None,
	help = "Adapters string. Quoted.  e.g. \"-a AGATCGGAAGAGC -g AGATCGGAAGAGC -n2 \"  (default = None)"
	)

parser_preproc.add_option(
	"-l",
	metavar = "INT",
	type = "string",
	dest = "p_minlen",
	default = '50',
	help = "Minimum read length for mapping (default = 50)"
	)

parser_preproc.add_option(
	"-L",
	metavar = "INT",
	type = "string",
	dest = "p_hardt",
	default = None,
	help = "Maximum read length (hard trimmming; default = None)"
	)

parser_preproc.add_option(
	"-w",
	metavar = "INT",
	type = "string",
	dest = "p_leftt",
	default = None,
	help = "5p trim length (hard trimmming; default = None)"
	)

parser_preproc.add_option(
	"-C",
	metavar = "INT",
	type = "int",
	dest = "p_lowc",
	default = 70,
	help = "Entropy cutoff for low complexity filtering. Use 0 for no filtering (default = 70)"
	)

parser_preproc.add_option(
	"-N",
	metavar = "INT",
	type = "int",
	dest = "p_nrem",
	default = 4,
	help = "Filter reads with more than N% of 'N' characters. Use 100 for no filtering (default = 4)"
	)

parser_preproc.add_option(
	"--phred64",
	metavar = "FLAG",
	dest = "p_qoff",
	action = 'store_true',
	default = False,
	help = "Quality scores in Phred+64 format (default = False)"
	)

parser.add_option_group(parser_preproc)

# mapping
parser_mapping = OptionGroup(parser, "Mapping")

parser_mapping.add_option(
	"-d",
	metavar = "FILE_LIST",
	type = "string",
	dest = "m_indices",
	default = None,
	help = "Index files. CSV (MANDATORY)"
	)

parser_mapping.add_option(
	"-e",
	metavar = "INT/FLOAT_LIST",
	type = "string",
	dest = "m_edits",
	default = '0.2',
	help = "Edit distances. CSV (INT or FLOAT<1; default = 0.2)"
	)

parser_mapping.add_option(
	"-c",
	metavar = "INT_LIST",
	type = "string",
	dest = "m_cpus",
	default = '1',
	help = "Number of CPUs. CSV (default = 1)"
	)

parser.add_option_group(parser_mapping)

# taxonomy
parser_tax = OptionGroup(parser, "Taxonomy")

parser_tax.add_option(
	"-t",
	metavar = "FILE",
	type = "string",
	dest = "t_table",
	default = None,
	help = "Taxonomic table file (MANDATORY)"
	)

parser_tax.add_option(
	"-m",
	metavar = "STR",
	type = "string",
	dest = "t_mode",
	default = 's',
	help = "Taxa determination mode ('s' single-end; 'p' paired-end; 'P' paired-end strict; default = 's')"
	)

parser.add_option_group(parser_tax)
# reporting
parser_rep = OptionGroup(parser, "Reporting")

parser_rep.add_option(
	"--cov",
	metavar = "FLAG",
	dest = "r_cov",
	action = 'store_true',
	default = False,
	help = "Compute coverage histograms (default = False)"
	)

parser_rep.add_option(
	"-x",
	metavar = "STR_LIST",
	type = "string",
	dest = "r_exc",
	default = None,
	help = "Excluded taxids. CSV (only if --cov flag ON)"
	)

parser_rep.add_option(
	"-z",
	metavar = "FLOAT",
	type = "string",
	dest = "r_cutoff",
	default = '1',
	help = "Reporting cutoff (INT or FLOAT<1; default = 1)"
	)

parser.add_option_group(parser_rep)

# output and misc
parser_misc = OptionGroup(parser, "Output and miscellaneous")

parser_misc.add_option(
	"-p",
	metavar = "STR",
	type = "string",
	dest = "m_prefix",
	default = str(uuid.uuid4())[:8],
	help = "Sample prefix (default = random 8-char UUID)"
	)

parser_misc.add_option(
	"-o",
	metavar = "DIR",
	type = "string",
	dest = "m_outdir",
	default = ".",
	help = "Output directory (default = '.')"
	)

parser_misc.add_option(
	"-Q",
	metavar = "DIR",
	type = "string",
	dest = "m_queue",
	default = None,
	help = "Cluster queue/partition (default = None)"
	)

parser_misc.add_option(
	"-S",
	metavar = "DIR",
	type = "string",
	dest = "m_slots",
	default = "20",
	help = "Cluster slots (default = 20)"
	)

parser_misc.add_option(
	"--dry",
	metavar = "FLAG",
	dest = "m_dry",
	action = 'store_true',
	default = False,
	help = "Dry run (default = False)"
	)

parser.add_option_group(parser_misc)

(opt, args) = parser.parse_args()


### check options
error_list = []
warning_list = []

# input
input_list = []
if not opt.i_comm and not opt.i_fq and not opt.i_bam and (not opt.i_fq_1 or not opt.i_fq_2):
	error_list.append('ERROR: Missing valid input [-i | -b | -f | (-1 & -2)]')
if opt.i_comm:
	input_list.append('-i')
if opt.i_fq:
	if missing_files(opt.i_fq):
		error_list.append('ERROR: File(s) not found in [-f ' + opt.i_fq + ']')
	else:
		input_list.append('-f')
if opt.i_bam:
	if missing_files(opt.i_bam):
		error_list.append('ERROR: File(s) not found in [-b ' + opt.i_bam + ']')
	else:
		input_list.append('-b')
if opt.i_fq_1 and opt.i_fq_2:
	if len(opt.i_fq_1.split(',')) != len(opt.i_fq_2.split(',')):
		error_list.append('ERROR: Different number of R1 an R2 files [-1 ' + opt.i_fq_1 + ' -2 ' + opt.i_fq_2 + ']')
	else:
		if missing_files(opt.i_fq_1) or missing_files(opt.i_fq_2):
			error_list.append('ERROR: File(s) not found in [-1 ' + opt.i_fq_1 + ' -2 ' + opt.i_fq_2 + ']')
		else:
			input_list.append('(-1 & -2)')
if len(input_list) > 1:
	error_list.append('ERROR: Too many inputs [' + ' & '.join(input_list) + ']')
		
	
# mapping
n_indexes = 0
if opt.m_indices:
	if missing_files(opt.m_indices):
		error_list.append('ERROR: File(s) not found in [-d ' + opt.m_indices + ']')
	else:
		n_indexes = len(opt.m_indices.split(','))
if not n_indexes:
	error_list.append('ERROR: Missing valid index(es) [-d]')

n_dist = len(opt.m_edits.split(','))
if n_dist != n_indexes and n_indexes:
	if n_dist == 1 :
		warning_list.append('WARNING: Setting the maximum edit distance for all indexes [-e ' + ','.join([opt.m_edits]*n_indexes) + ']')
	else:
		error_list.append('ERROR: ' + str(n_dist) + ' distances specified [-e ' + opt.m_edits + '], incompatible with ' + str(n_indexes) + ' indexes [-d ' + opt.m_indices + ']')

n_cpus = len(opt.m_cpus.split(','))
if n_cpus != n_indexes and n_indexes:
	if n_cpus == 1 :
		warning_list.append('WARNING: Setting the number of CPUs for all indexes [-c ' + ','.join([opt.m_cpus]*n_indexes) + ']')
	else:
		error_list.append('ERROR: ' + str(n_cpus) + ' CPU specifications [-c ' + opt.m_cpus + '], incompatible with ' + str(n_indexes) + ' indexes [-d ' + opt.m_indices + ']')  


# taxonomy
if opt.t_table:
	if missing_files(opt.t_table):
		error_list.append('ERROR: File not found [-t ' + opt.t_table + ']')
else:
	error_list.append('ERROR: Missing taxonomic table [-t]')


if opt.t_mode not in ['s', 'p', 'P']:
	error_list.append('ERROR: Invalid mode [-m ' + opt.t_mode + ']')

	
# misc 
if opt.m_queue:
	if opt.m_queue not in commands.getoutput('qstat -g c | egrep "CLUSTER|\--" -v | awk \'{print $1}\'').split('\n'): 
		error_list.append('ERROR: Specified queue does not exist [-q ' + opt.m_queue + ']')
		ecounter += 1 


sys.stderr.write('\n'.join(warning_list) + '\n')
if error_list:
	sys.stderr.write('\n'.join(error_list) + '\n\n')
	parser.print_help()
	exit(-1)
  
######################################################### /ARGUMENTS,OPTIONS ###
################################################################################



################################################################################
### CONSTANTS ##################################################################

################################################################# /CONSTANTS ###
################################################################################



################################################################################
### MAIN #######################################################################

if __name__ == '__main__':
	# base dirs
	module_dir = os.path.abspath(os.path.dirname(__file__))
	start_dir = os.getcwd()
	
	prefix = opt.m_prefix

	run_dir = os.path.abspath(opt.m_outdir)	
	mkdir_p(run_dir)

	res_dir = run_dir + '/txM.' + prefix + '.base'
	mkdir_p(res_dir)

	log_dir = run_dir + '/txM.' + prefix + '.log'
	mkdir_p(log_dir)

	map_dir = run_dir + '/txM.' + prefix + '.map'
	mkdir_p(map_dir)

	out_dir = run_dir + '/txM.' + prefix + '.out'
	mkdir_p(out_dir)


	# input
	if opt.i_fq:
		fq_files = [os.path.abspath(fq_file) for fq_file in opt.i_fq.split(',')]
		fq_lnks = [res_dir + '/' + bn(fq_file) for fq_file in fq_files]
		for x in xrange(len(fq_files)):
			os.system(' '.join(['ln', '-s', fq_files[x], fq_lnks[x]]))
	elif opt.i_bam:
		bam_file = os.path.abspath(opt.i_bam)
		bam_lnk = res_dir + '/' + bn(bam_file)
		os.system(' '.join(['ln', '-s', bam_file, bam_lnk]))
	elif opt.i_fq_1 and opt.i_fq_2:
		fq1_files = [os.path.abspath(fq1_file) for fq1_file in opt.i_fq_1.split(',')]
		fq1_lnks = [res_dir + '/' + bn(fq1_file) for fq1_file in fq1_files]
		for x in xrange(len(fq1_files)):
			os.system(' '.join(['ln', '-s', fq1_files[x], fq1_lnks[x]]))
		fq2_files = [os.path.abspath(fq2_file) for fq2_file in opt.i_fq_2.split(',')]
		fq2_lnks = [res_dir + '/' + bn(fq2_file) for fq2_file in fq2_files]
		for x in xrange(len(fq2_files)):
			os.system(' '.join(['ln', '-s', fq2_files[x], fq2_lnks[x]]))
	elif opt.i_comm:
		comm_str = opt.i_comm
		comm_file = open(res_dir + '/input_stream.sh', 'w')
		comm_file.write('#!/bin/sh\n' + comm_str + ' > input_stream.fq\n')
		comm_file.close()

	# preprocessing
	cutadapt_qual = opt.p_qual
	cutadapt_adapts = opt.p_adapt
	min_rlen = opt.p_minlen
	max_rlen = opt.p_hardt
	leftt_len = opt.p_leftt
	q_offset = opt.p_qoff

	# mapping
   	index_files = opt.m_indices.split(',')
   	len_files = ['.'.join(f.split('.')[:-1]) + '.len' for f in index_files]

	index_names = ['.'.join(os.path.basename(f).split('.')[:-1]) for f in index_files]
	
	index_edits = opt.m_edits.split(',')
	if len(index_edits) == 1:
		index_edits = [index_edits[0]]*len(index_names)
	max_edit = str(max([float(x) for x in index_edits]))

	index_cpus = opt.m_cpus.split(',')
	if len(index_cpus) == 1:
		index_cpus = [index_cpus[0]]*len(index_names)
		
	index_lnks = []
	len_lnks = []
	for i in xrange(len(index_names)):
		index_lnks.append(res_dir + '/' + bn(index_files[i]))
		len_lnks.append(res_dir + '/' + bn(len_files[i]))
		os.system(' '.join(['ln', '-s', os.path.abspath(index_files[i]), index_lnks[i]]))
		os.system(' '.join(['ln', '-s', os.path.abspath(len_files[i]), len_lnks[i]]))
	
	# taxonomy
	tax_file = os.path.abspath(opt.t_table)
	tax_lnk = res_dir + '/' + bn(tax_file)
	os.system(' '.join(['ln', '-s', tax_file, tax_lnk]))
	
	det_mode = opt.t_mode
	
	comp_coverage =  opt.r_cov
   	exc_taxa = opt.r_exc
	
	# report
	rep_cutoff = opt.r_cutoff

	# miscellaneous
	sge_queue = opt.m_queue
	sge_slots = opt.m_slots
	dry_run = opt.m_dry
	
	
	#sh script and pipe
	sh_filename = run_dir + '/txM.' + prefix + '.sh'
	pipe_filename = run_dir + '/txM.' + prefix + '.sge'
	

	###
	sh_file = open(sh_filename, 'w')
	sh_file.write('#!/bin/sh\n\n')
	
	sh_file.write('export PATH=$PATH:' + module_dir + '/bin\n\n')

	out_files = []
	# input command
	if opt.i_comm:
		input_command = comm_str
	elif opt.i_fq:
		if opt.i_fq[-3:] == '.gz':
			input_command = 'zcat ' + ' '.join([rp(fq_lnk, run_dir) for fq_lnk in fq_lnks]) 
		else:
			input_command = 'cat ' + ' '.join([rp(fq_lnk, run_dir) for fq_lnk in fq_lnks])
	elif opt.i_bam:
		stview_log = rp(log_dir + '/' + prefix + '.st_view.err', run_dir)
		stbam2fq_log = rp(log_dir + '/' + prefix + '.st_bam2fq.err', run_dir)
		bam = rp(bam_lnk, run_dir)
		if opt.t_mode == 's':
			input_command = 'samtools view -f4 -F256 ' + bam + ' -b 2> ' + stview_log + ' \\\n\t| samtools bam2fq - 2> ' + stbam2fq_log
		else:
			input_command = 'samtools view -f12 -F256 ' + bam + ' -b 2> ' + stview_log + ' \\\n\t| samtools bam2fq - 2> ' + stbam2fq_log
			input_command += ' \\\n\t| txM_fqmate'
	elif opt.i_fq_1 and opt.i_fq_2:
		input_command = 'txM_fqintlv -1 ' + ','.join([rp(fq1_lnk, run_dir) for fq1_lnk in fq1_lnks]) + ' -2 ' + ','.join([rp(fq2_lnk, run_dir) for fq2_lnk in fq2_lnks])
	input_command += ' \\\n\t| '
	
	if opt.i_sprob != '1':
		if opt.t_mode == 's':
			input_command += 'txM_fqsample -s ' + opt.i_sprob + ' \\\n\t| '
		else:
			input_command += 'txM_fqsample -p -s ' + opt.i_sprob + ' \\\n\t| '
			
	# preprocessing
	filt_out = rp(map_dir + '/' + prefix + '.filtout.map', run_dir)
	out_files.append(filt_out)
	prep_command = ''
	if cutadapt_qual or cutadapt_adapts:
		cutadapt_log = rp(log_dir + '/' + prefix + '.cutadapt.err', run_dir)
		prep_command += 'cutadapt - '
		if cutadapt_adapts:
			prep_command += cutadapt_adapts + ' '
		if cutadapt_qual:
			prep_command += '-q ' + cutadapt_qual + ' '
		if q_offset:
			prep_command += '--quality-base=64 '
		prep_command += '2> ' + cutadapt_log + ' \\\n\t| '
	if leftt_len:
		prep_command += 'txM_fqltrim -l ' + leftt_len + ' \\\n\t| '
	if max_rlen:
		prep_command += 'txM_fqhtrim -l ' + max_rlen + ' \\\n\t| '
	if min_rlen:
		prep_command += 'txM_fqminlen -m ' + min_rlen + ' 2> ' + filt_out + ' \\\n\t| '
	#low complexity and Ns
	if opt.p_lowc or opt.p_nrem != 100:
		prinseq_log = rp(log_dir + '/' + prefix + '.prins.err', run_dir)
		prinseq_out = rp(map_dir + '/' + prefix + '.prinsout', run_dir)
		prinseq_out2 = rp(map_dir + '/' + prefix + '.prinsout.map', run_dir)
		out_files.append(prinseq_out2)
		prep_command += 'prinseq-lite.pl -no_qual_header -fastq stdin -out_good stdout '
		prep_command += '-out_bad ' + prinseq_out + ' '
		if opt.p_lowc:
			prep_command += '-lc_method entropy -lc_threshold ' + str(opt.p_lowc)+ ' '
		if opt.p_nrem != 100:
			prep_command += '-ns_max_p ' + str(opt.p_nrem) + ' '
		if q_offset:
			prep_command += '-phred64 '
		prep_command += '2> ' + prinseq_log + ' \\\n\t| '

	# mapping
	gem_command = 'gem-mapper -q ignore --fast-mapping'
	map_commands = []
	for i in xrange(len(index_lnks)):
		map_log = rp(log_dir + '/' + prefix + '.' + index_names[i] + '.gem.err', run_dir)
		map_index = rp(index_lnks[i], run_dir)
		map_dist = index_edits[i]
		map_cpus = index_cpus[i]
		map_out = rp(map_dir + '/' + prefix + '.' + index_names[i] + '.map', run_dir)
		out_files.append(map_out)
		map_str = gem_command + ' '.join([' -I', map_index, '-m', map_dist, '-e', map_dist, '-T', map_cpus, '2>', map_log, ' \\\n\t| txM_rescore \\\n\t| '])
		map_str += 'txM_mapout 2> ' + map_out 
		map_commands.append(map_str)
	map_lnker = ' \\\n\t| ' + 'txM_gem2fq \\\n\t| ' 
	map_command = map_lnker.join(map_commands)
	map_out = rp(map_dir + '/' + prefix + '.unmapped.map', run_dir)
	out_files.append(map_out)
	map_command += ' > ' + map_out
	
	#parse low complexity and Ns
	if opt.p_lowc or opt.p_nrem != 100:
		lc_command = 'cat ' + prinseq_out + '.fastq \\\n\t| txM_fq2gem > ' + prinseq_out2 + ' \n\n'
		lc_command += 'rm ' + prinseq_out + '.fastq'

	# taxonomy
	if det_mode == 's':
		tax_command = 'cat '
	else:
		tax_command = 'txM_mergintlv '
	tax_command += ' '.join(out_files)
	tax_command += ' 2> /dev/null \\\n\t| ' 
	tax_command +=  'txM_lca '
	tax_command += '-t ' + rp(tax_lnk, run_dir) + ' '
	tax_command += '-m ' + det_mode + ' '
	tax_command += '2> ' + rp(map_dir + '/' + prefix + '.merged.map.lca', run_dir)
	tax_command += ' \\\n\t| '
	tax_command += 'txM_summary '
	tax_command += '-t ' + rp(tax_lnk, run_dir) + ' '
	tax_command += '-e ' + max_edit + ' '
	if comp_coverage:
		tax_command += '-c '
		tax_command += '-l ' + ','.join([rp(x, run_dir) for x in len_lnks]) + ' ' 
		if exc_taxa:
			tax_command += '-x ' + exc_taxa + ' '
	tax_command += '2> /dev/null > ' + rp(out_dir + '/' + prefix + '.merged.map.lca.summary', run_dir)
	
	# reporting
	rep_command = 'cat ' + rp(out_dir + '/' + prefix + '.merged.map.lca.summary', out_dir) + ' \\\n\t| '
	rep_command += 'txM_report '
	rep_command += '-t ' + rp(tax_lnk, out_dir) + ' '
	if det_mode == 'p' or det_mode == 'P':
		rep_command += '-p '
	rep_command += '-c ' + rep_cutoff + ' '
	rep_command += '-o ' + prefix + ' '

	sh_file.write(input_command + prep_command + map_command + '\n\n')
	if opt.p_lowc:
		sh_file.write(lc_command + '\n\n')
	sh_file.write(tax_command + '\n\n')
	sh_file.write('cd ' + rp(out_dir, run_dir) +'\n')
	sh_file.write(rep_command + '\n')				
	sh_file.close()   
	
	if sge_queue:
		pipe_file = open(pipe_filename, 'w')
		pipe_file.write('\t'.join(['txM' + prefix, sge_queue, '0', '-', sge_slots, '.', 'sh ' + rp(sh_filename, run_dir)]) + '\n')
		pipe_file.close()
		if not dry_run:
			os.chdir(run_dir)
			os.system(' '.join([module_dir + '/bin/txM_sge', '-s', rp(pipe_filename, run_dir)]))
			os.chdir(start_dir)
	else:
		if not dry_run:
			os.chdir(run_dir)
			os.system(' '.join(['sh', rp(sh_filename, run_dir)]))
			os.chdir(start_dir)
		
###################################################################### /MAIN ###
################################################################################
