#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# correct_indels.py - Sebastian Krautwurst
'''
from any mapping: correct indels by reference and output corrected reads (with substitutions intact)
also outputs error stats
'''

# TODOS:
# stuff at cigar correction
# what to do with long deletions

import sys
import os
import re
import time
import pysam
from Bio import SeqIO
from statistics import median

args = sys.argv

input_mapping = args[1]
input_reference = args[2]

output_file = False
if len(args) > 3:
	output_file = args[3]

def log(logmsg):
	sys.stderr.write('LOG ('+time.asctime()+') :: '+logmsg+'\n')

def error(string, code=1):
	sys.stderr.write('ERROR: ' + string + '\nExiting.\n')
	exit(code)

###############

def read_reference(reference):

	chromosomes = {}
	ref = list(SeqIO.parse(reference, 'fasta'))
	for entry in ref:
		chromosomes[entry.id] = str(entry.seq)

	return chromosomes
	# return str(ref[0].seq) if (len(ref) == 1) else -1


def read_mapping(mapping):

	# TODO: handle multiple alignments

	all_reads = int(pysam.view('-c', mapping))

	if mapping.endswith('bam'):
		in_file = pysam.Samfile(mapping, "rb")
	elif mapping.endswith('sam'):
		in_file = pysam.Samfile(mapping, "r")
	elif mapping.endswith('cram'):
		in_file = pysam.Samfile(mapping, "rc")
	else:
		error("No valid mapping format detected. File has to be SAM/BAM/CRAM.")

	iteration = in_file.fetch(until_eof=True)

	for read in iteration:
		yield (read, all_reads)


def correct_read_cigar(read, reference_sequence, regex, min_quality='median', keep_softclipped=False, min_length_indels_to_keep=0):
	'''
	Remove all indels (but keep substitutions and matches with sufficient quality) from read_sequence using the cigar_string and mapping_position against the reference_sequence.
	TODO: implement min_length_indels_to_keep
	TODO: can fill splice junctions
	regex = re.compile(r'(\d+)([A-Z])') must be compiled once before
	'''

	read_sequence = str(read.query_sequence)

	# alignment length: like from blast identity - total bases in the alignment (including gaps (indels))
	# start with read.query_alignment_length (which is qend - qstart) and then add the deletions
	# this way we also ignore splice junctions
	alignment_length = read.query_alignment_length

	if read.query_qualities != None:
		quality_values = list(read.query_qualities)
	else:
		log(f'{read}')
		log(f'No quality for read {read.query_name}')
		quality_values = [0]*len(read_sequence)
		
	mapping_position = read.reference_start
	cigar_string = read.cigarstring


	pos_in_read = 0
	pos_in_ref = mapping_position

	corrected_read = ''
	new_mask_string = ''
	# '!' (phred=1) means unchecked softclipped bases or filled from reference

	# qual stats
	avg_qual = sum(quality_values) / len(quality_values)
	median_qual = median(quality_values)

	if min_quality == 'median':
		min_quality = median_qual

	# indel stats
	num_ins = 0
	num_del = 0
	#num_sub = 0
	len_ins = 0
	len_del = 0
	len_sub = 0

	for match in regex.finditer(cigar_string):
		length, cigar_operation = match.groups()
		length = int(length)

		if cigar_operation == 'H':
			# hardclipped, skip
			pass

		elif cigar_operation == 'S':
			# softclipped
			if keep_softclipped:
				# note that this is not quality checked
				corrected_read += read_sequence[pos_in_read : pos_in_read+length]
				new_mask_string += '!'*length
			pos_in_read += length

		elif cigar_operation == 'M':
			# match (or substitution), keep
			corrected_read += ''.join([ read_sequence[pos_in_read+p] if quality_values[pos_in_read+p]>min_quality else reference_sequence[pos_in_ref+p] for p in range(length) ])
			#corrected_read += read_sequence[pos_in_read : pos_in_read+length]
			new_mask_string += ''.join([ chr(q+33) if q>=min_quality else chr(33) for q in quality_values[pos_in_read : pos_in_read+length] ])
						
			subs = [0 if n==reference_sequence[pos_in_ref+i] else 1 for i, n in enumerate(read_sequence[pos_in_read:pos_in_read+length])]
			len_sub += sum(subs)

			pos_in_read += length
			pos_in_ref += length

		elif cigar_operation == 'I':
			# insertion, skip
			pos_in_read += length

			num_ins += 1
			len_ins += length

		elif cigar_operation == 'D':
			# deletion, fill from reference
			if length >= 50:
				#log(f'long deletion: {length}')
				# LONG DELETION: assume mapping error
				# treat as N
				pos_in_ref += length

			else:
				corrected_read += reference_sequence[pos_in_ref : pos_in_ref+length]
				new_mask_string += '!'*length
				pos_in_ref += length

				num_del += 1
				len_del += length

				# increase alignment length
				alignment_length += length

		elif cigar_operation == 'N':
			# splice junction, no correction
			pos_in_ref += length

		else:
			error(f'Unknown cigar operation: {cigar_operation}')

	# print(read_sequence)
	# print(corrected_read)
	# print(new_mask_string)
	# exit()

	return corrected_read, new_mask_string, [num_ins, len_ins, num_del, len_del, avg_qual, median_qual, alignment_length, len_sub]


##################

if __name__ == '__main__':

	log('Started correct_indels.py')
	log('Reading in reference.')
	chromos = read_reference(input_reference)

	regex = re.compile(r'(\d+)([A-Z])')

	out_fh = None
	if output_file:
		out_fh = open(output_file,'w')
		log('Output goes to '+output_file+'.')
	else:
		log('Output goes to stdout.')

	indelsums = [0,0,0,0] # [num_ins, len_ins, num_del, len_del]
	len_sub_total = 0
	avg_qual_list = []
	median_qual_list = []
	read_len_sum = 0

	log('Minimum quality: Read median quality')
	log('Iterating over alignments/reads.')

	ctr = 0
	readset = set()
	for read, all_reads in read_mapping(input_mapping):

		ctr += 1
		if ctr == int(all_reads) or ctr % 1000 == 0:
			log(str(ctr)+' of '+str(all_reads)+' alignments corrected.')

		read_name = read.query_name

		# skip multiple alignments
		if read_name in readset:
			continue
		readset.add(read_name)

		# get reference sequence
		ref_sequence = chromos[read.reference_name]

		corrected_read, new_mask_string, stats = correct_read_cigar(read, ref_sequence, regex)

		for j in range(4):
			indelsums[j] += stats[j]

		avg_qual_list += [stats[4]]
		median_qual_list += [stats[5]]

		read_len_sum += stats[6]

		len_sub_total += stats[7]

		#print(read[5])
		#print(ref_sequence[int(read[3]) : int(read[3])+len(read[9])])
		#print(read[9])
		#print(corrected_read)
		#exit()

		out_str = '>' + read_name + '\n' + corrected_read + '\n'

		if (output_file):
			out_fh.write(out_str)
		else:
			sys.stdout.write(out_str)

	log(f'Total nucleotides:\t{read_len_sum}')
	log(f'Mean quality:\t{sum(avg_qual_list)/all_reads:.6f}')
	log(f'Insertions:\t{indelsums[1]/read_len_sum*100:.6f}%')
	log(f'Deletions:\t{indelsums[3]/read_len_sum*100:.6f}%')
	log(f'Substitut-s:\t{len_sub_total/read_len_sum*100:.6f}%')
	log(f'Total errors:\t{(indelsums[1] + indelsums[3] + len_sub_total)/read_len_sum*100:.6f}%')

	log('All done!')

