#this script changes back G's to A's in mapped sam files
#this script assumes the index file is sorted lexicographically by readname (and it was at an earlier step)
# Written by Stephen Tran (@ Grace Xiao's lab, UCLA)


import sys
import re

if (len(sys.argv) != 2):
	sys.stderr.write( "--help : requires arguments  <input_index_file>"+ \
	"\n\tinput_index_file : name of index file (containing key of what As were converted to Gs)\n")
	sys.exit(1)

input_index_file = sys.argv[1]

def ReverseComplement1(seq):
	seq_dict = {'A':'T','T':'A','G':'C','C':'G','N':'N','.':'.'}
	return "".join([seq_dict[base] for base in reversed(seq)])

#try to read in files
try:
	input_index_fh = open(input_index_file,'r')
except IOError:
	sys.stderr.write( "cannot open file "+input_index_file+'\n')
	sys.exit(1)


#instead of reading in the index file, we'll be requesting in the main program, a readname from the index file, and this program will return the positions in which As were converted to Gs in that readname. This works because both the index file and the sam file are at this point sorted by readnames.
#first read in the index file 
#index_info[readname]->[list of positions changed]
def get_positions_changed(index_fh,readname): #index_fh is the filehandle for the index file, readname is the readname for which we want to find the As that were converted to Gs
	#if we reach the index of file without finding the readname, we should exit the program with an err. No readname should exist in the sam file that doesn't exist in the index file
	read_name_found = False
	while not read_name_found:
		current_file_position = index_fh.tell()
		line = index_fh.readline()
		#old position now is just index_fh.seek(current_file_position-next_file_position,1)
		if line == '': #then we've reached end of file without finding the readname, which is bad
			sys.stderr.write('index file does not contain readname\n')
			sys.stderr.write(readname+'\n')
			sys.exit(1)
		line_info = line.split('\t')
		index_readname = line_info[0]
		positions = line_info[1]
		positions = positions.lstrip().rstrip()
		#remove the leading @ from the index_readnames
		index_readname = re.sub(r'^@','',index_readname) #good because if index_readnames don't have @ in front, this line doesn't cause any errors
		#now check if index_readname matches the query sam readname
		if index_readname == readname:
			index_fh.seek(current_file_position) #recycle to old position in the file. Because the sam has many multi mappings.
			read_name_found = True
			if positions == '': #the readname never contained As in it
				positions = list()
			else:
				positions = positions.split(',')
				positions = [int(i) for i in positions]
			return positions
		
#next change the G's to A's
sys.stderr.write( "changing back G's to A's in sam file\n")
#counter = 0
for line in sys.stdin:
	line = line.lstrip().rstrip()
	line_info = line.split('\t')
	readname = line_info[0]
	strand = line_info[1]
	sequence = line_info[9]
	converted_positions = get_positions_changed(input_index_fh,readname)
	bit_strand = format(int(strand),'#014b')[2:]
	if bit_strand[-5] == "1": #same as the 16 flag being set but generalizes to allowing other flags to simultaneously being set like secondary alignment 256 flag
		sequence = ReverseComplement1(sequence)
	sequence = list(sequence)
	for i in converted_positions:
		if sequence[i] != 'G':
			sys.stderr.write( "".join(sequence)+' at position '+str(i)+' is not G (0-base). readname: '+readname+'\n')
			sys.exit(1)
		sequence[i] = 'A'
	sequence = "".join(sequence)
	if bit_strand[-5] == "1":
		sequence = ReverseComplement1(sequence)
	line_info[9] = sequence
	sys.stdout.write("\t".join(line_info)+'\n')

input_index_fh.close()

sys.stderr.write( "job completed converting G's back to A's\n")
