#!/usr/bin/env python 

################################################################################
### COPYRIGHT ##################################################################

# New York Genome Center

# SOFTWARE COPYRIGHT NOTICE AGREEMENT
# This software and its documentation are copyright (2014) by the New York
# Genome Center. All rights are reserved. This software is supplied without
# any warranty or guaranteed support whatsoever. The New York Genome Center
# cannot be responsible for its use, misuse, or functionality.

# Version: 0.2
# Author: Andre Corvelo

################################################################# /COPYRIGHT ###
################################################################################



################################################################################
### MODULES ####################################################################

from optparse import OptionParser
from sys import stdin, stdout, stderr, exit
from re import findall
import numpy as np

################################################################### /MODULES ###
################################################################################



################################################################################
### FUNCTIONS ##################################################################

def get_best_n(strata_s, n_strata):
	strata_l = strata_s.replace('+', ':').split(':')[:n_strata]
	for x in xrange(len(strata_l)):
		if strata_l[x] != '0':
			return (x, int(strata_l[x]))
	return (None, 0)

def get_reflen(mstring):
	return sum([int(s) for s in findall(r'\d+(?!\-)', mstring)]) + len(findall(r'[A-z]', mstring))

def p_rec(l, tax_d, rep_d, max_edit, array_d, len_d, excl_taxa, comp_cov):
	la = l.strip().split('\t')
	
	read_len = len(la[1])
	n_strata = int(max_edit * read_len) + 1
		
	dist, n_hits = get_best_n(la[3], n_strata)
	tax = la[5]
	if tax not in rep_d:
		rep_d[tax] = [0, 0, 0, 0, 0, np.zeros(1, dtype=np.uint), 0, 0]
	rep_d[tax][0] += 1
	
	if len(la)>8:
		p_tag = la[8]
		p_tax = la[9]	

		if p_tax not in rep_d:
			rep_d[p_tax] = [0, 0, 0, 0, 0, np.zeros(1, dtype=np.uint), 0, 0]
		if p_tag == 'FF' or p_tag == 'UU' or p_tag == 'MM':
			rep_d[p_tax][3] += 1
		else:
			rep_d[p_tax][4] += 1

	if n_hits:
		b_hits = map(lambda x : x.split(':'), la[4].split(',')[:n_hits])
		b_taxa = set()
		for b_hit in b_hits:
			b_hit_id, b_hit_tax = b_hit[0].split('_')
			if b_hit_tax not in rep_d:
				rep_d[b_hit_tax] = [0, 0, 0, 0, 0, np.zeros(1, dtype=np.uint), 0, 0]

			b_hit_len = get_reflen(b_hit[3])
			b_hit_start = int(b_hit[2]) - 1
			rep_d[b_hit_tax][6] += 1
			rep_d[b_hit_tax][7] += b_hit_len

			b_taxa.add(b_hit_tax)

			if b_hit_tax not in excl_taxa and comp_cov:
				if b_hit_tax not in array_d:
					array_d[b_hit_tax] = {}
				if b_hit_id not in array_d[b_hit_tax]:
					array_d[b_hit_tax][b_hit_id] = np.zeros(len_d[b_hit_tax][b_hit_id], dtype=np.uint16)
				array_d[b_hit_tax][b_hit_id][b_hit_start:b_hit_start+b_hit_len] += (array_d[b_hit_tax][b_hit_id][b_hit_start:b_hit_start+b_hit_len] < 65535)
	
		for b_tax in b_taxa:
			rep_d[b_tax][2-(len(b_taxa)==1)] += 1
			if dist >= rep_d[b_tax][5].size:
				rep_d[b_tax][5].resize(dist+1, refcheck=False)
			rep_d[b_tax][5][dist] += 1

################################################################# /FUNCTIONS ###
################################################################################



################################################################################
### ARGUMENTS,OPTIONS ##########################################################

parser = OptionParser(usage="\n%prog [options]", version="%prog v0.2")

parser.add_option(
	"-i",
	metavar = "FILE",
	type = "string",
	dest = "map_file",
	default = 'stdin',
	help = "Input GEM mapping file including LCA assignation (default = 'stdin')"
	)

parser.add_option(
	"-t",
	metavar = "FILE",
	type = "string",
	dest = "tax_file",
	default = None, help="Taxonomic table file (Mandatory)"
	)

parser.add_option(
	"-l",
	metavar = "FILE",
	type = "string",
	dest = "len_file",
	default = None,
	help = "Length files - CSV file list containing the length files for every index (Mandatory if -c)"
	)

parser.add_option(
	"-e",
	metavar = "FLOAT",
	type = "float",
	dest = "max_edit",
	default = None,
	help = "Maximum edit distance. Between 0 and 1 (Mandatory)"
	)

parser.add_option(
	"-x",
	metavar = "STR",
	type = "string",
	dest = "excluded",
	default = None,
	help = "Excluded branches - list of comma-separated taxonomic ids (default = None)"
	)

parser.add_option(
	"-c",
	metavar = "FLAG",
	dest = "comp_cov",
	action = 'store_true',
	default = False,
	help = "Compute coverage histograms (default = False)"
	)

(opt, args) = parser.parse_args()

if opt.tax_file == None or opt.max_edit == None or (opt.len_file == None and opt.comp_cov == True):
	parser.print_help()
	exit(-1)
     
######################################################### /ARGUMENTS,OPTIONS ###
################################################################################


################################################################################
### CONSTANTS ##################################################################

if opt.excluded:
	excl_nodes = set(opt.excluded.split(','))
else:
	excl_nodes = set()

################################################################# /CONSTANTS ###
################################################################################


################################################################################
### MAIN #######################################################################

if __name__ == '__main__':
	# read tax info
	excl_taxa = set()
	tax_d = {}
	tax_d['0'] = ['0', 'no rank', 'unmmaped', '0']
	tax_d['-'] = ['-', 'no rank', 'filtered out', '-']
	tax_file = open(opt.tax_file, 'r')	
	for l in tax_file:
		la = l.strip().split('\t')
		tax_d[la[0]] = la
		if set(la[3].split(':')).intersection(excl_nodes):
			excl_taxa.add(la[0])
	tax_file.close()

	# read len info
	len_d = {}
	if opt.comp_cov:
		len_file_list = opt.len_file.split(',')
		for f in len_file_list:
			len_file = open(f, 'r')	
			for l in len_file:
				la = l.strip().split('\t')
				sid, stax = la[0].split('_')
				if stax not in len_d:
					len_d[stax] = {}
				len_d[stax][sid] = int(la[1])
			len_file.close()
	
	
	# read map file
	array_d = {}
	rep_d = {}
	if opt.map_file != 'stdin':
		map_file = open(opt.map_file, 'r')
	else:
		map_file = stdin
	for l in map_file:
		p_rec(l, tax_d, rep_d, opt.max_edit, array_d, len_d, excl_taxa, opt.comp_cov)
	if opt.map_file != 'stdin':
		map_file.close()

	# output
	for tax in sorted(rep_d.keys(), key = lambda k : map(lambda x : int(x.replace('-','0')), tax_d[k][3].split(':'))):
		print_l = tax_d[tax]
		print_l += map(lambda x : str(x), rep_d[tax][:3])
		print_l.append(str(rep_d[tax][3]/2))
		print_l.append(str(rep_d[tax][4]/2))
		print_l.append(':'.join(rep_d[tax][5].astype('|S10')))
		print_l.append(str(rep_d[tax][6]))
		print_l.append(str(rep_d[tax][7]))
		
		cov_array = np.zeros(17, dtype=np.uint64)
		tot_tax_len = 0
		tot_tax_cov = 0

		if tax not in excl_taxa:
			if tax in len_d:
				tot_tax_len = sum(len_d[tax].values())
			if tax in array_d:
				for sid in array_d[tax]:
					cov_array += np.histogram(array_d[tax][sid], np.concatenate((np.array([0]), 2**np.arange(17))))[0]
			tot_tax_cov = cov_array[1:].sum()
			cov_array[0] = tot_tax_len - tot_tax_cov 

		print_l.append(str(tot_tax_cov))
		print_l.append(str(tot_tax_len))
		print_l.append(':'.join(cov_array.astype('|S10')))
		
		stdout.write('\t'.join(print_l) + '\n')
		stderr.write('\t'.join(print_l) + '\n')
		
###################################################################### /MAIN ###
################################################################################
