#finds uniquely mapped pairings from R1, RCR1_R1, R2, and RCR2_R2
# Written by Stephen Tran (@ Grace Xiao's lab, UCLA)

MAX_INSERT = 500000

number_pairs_uniquely_mapped = 0 #keep track of number uniquely paired reads found by the hyperediting protocol

import sys
import re
#import subprocess as sub

if len(sys.argv) != 3:
	sys.stderr.write( "--help requires arguments <R1> <R2>"+\
	"\nR1 : R1+RCR1 combined sam file"+\
	"\nR2 : R2+RCR2 combined sam file\n")
	sys.exit(1)

R1_file = sys.argv[1]
R2_file = sys.argv[2]

#the singletons must be on same chromosomes, in converging orientations, and within [0,500,000] distance apart

sys.stderr.write( "pairing reads from R1 and R2\n")
#now let's get the uniquely mapped reads from paired end information
#try opening files
try:
	R1_fh = open(R1_file,'r')
	R2_fh = open(R2_file,'r')
except IOError:
	sys.stderr.write( "cannot open one of the input files\n")
	sys.exit(1)


def get_all_same_reads(readname,fh,line):
#returns [lines,next_line ] where lines is list of all lines of readmappings of the readname. next_line is next read where readnameis the next readname.
#current_readname is the readname of line (not of lines)
#line is the line containing the current readname
	lines = list()
	lines.append(line)
	current_readname = readname
	while(current_readname == readname):
		line = fh.readline()
		line_info = line.split('\t')	
		current_readname = line_info[0]
		if current_readname == readname:
			lines.append(line)
		else:
			#remove duplicate reads: where two reads mapped to the same location (otherwise RCR1 and R1 might cancel each other out)	
			#also return the remaining lines
			tmp_lines = list()
			positions = list()
			for entry in lines:
				position = entry.split('\t')[3]
				if position in positions:
					pass
				else: #this is a new uinque read mapping
					positions.append(position)
					tmp_lines.append(entry)
			lines = tmp_lines
			return [lines,line,current_readname]

def record_unique_paired_end_events(R1_lines,R2_lines): 
	#sys.stderr.write( R1_lines; sys.stderr.write( R2_lines #debugging line
#R1_lines and R2_lines are lists of jsam lines of all the exact same paired end read and its mapping positions
#fh is file handle to write output to (output_fh)
#writes in output only these lines have exact one set of paired end reads (one read from R1 and one read from R2) in which the reads are on same chromosome, antisense orientation, on opposite strands, and within [0,500,000] bp apart
	global number_pairs_uniquely_mapped
	readname_pattern = re.compile(r'(.*)/[12]') #we want to remove the /1 or /2 from readnames to directly compare them
	unique_pairs1 = list() #holds from R1 reads that matches with corresopnding element from unique_pairs2
	unique_pairs2 = list() #holds from R2 reads that matches with corresponding element from unique_pairs1 
	#get all the info from each read
	R1_chromosomes = list()
	R1_strands = list()
	R1_positions = list()
	R1_readnames = list()
	for line in R1_lines:
		line_info = line.split('\t')
		R1_readname = line_info[0]
		strand = line_info[1]
		chromosome = line_info[2]
		position = int(line_info[3])
		if not readname_pattern.search(R1_readname):
			sys.stderr.write( "something wrong with regular expression for readname\n")
			sys.exit(1)
		R1_readname = readname_pattern.search(R1_readname).group(1)
		R1_chromosomes.append(chromosome)
		R1_strands.append(strand)
		R1_positions.append(position)
		R1_readnames.append(R1_readname)
	R2_chromosomes = list()
	R2_strands = list()
	R2_positions = list()
	R2_readnames = list()
	for line in R2_lines:
		line_info = line.split('\t')
		R2_readname = line_info[0]
		strand = line_info[1]
		chromosome = line_info[2]
		position = int(line_info[3])
		if not readname_pattern.search(R2_readname):
			sys.stderr.write( "something wrong with regular expression for readname\n")
			sys.exit(1)
		R2_readname = readname_pattern.search(R2_readname).group(1)
		R2_chromosomes.append(chromosome)
		R2_strands.append(strand)
		R2_positions.append(position)
		R2_readnames.append(R2_readname)
	#make the strands and positions integer types
	R1_strands = [int(i) for i in R1_strands]
	R1_positions = [int(i) for i in R1_positions]
	R2_strands = [int(i) for i in R2_strands]
	R2_positions = [int(i) for i in R2_positions]
	#just as a sanity degugging check, make sure the readnames are the same
	for i in range(0,len(R1_lines)):
		R1_chromosome = R1_chromosomes[i]
		R1_strand = R1_strands[i]
		#abstract the R1_strand from it's bit information
		R1_bit_strand = format(int(R1_strand),'#014b')[2:]	
		R1_mapped = R1_bit_strand[-3] #0 for mapped , 1 for not mapped
		R1_strand = R1_bit_strand[-5] #either 0 or 1
		R1_position = R1_positions[i]
		R1_readname = R1_readnames[i]
		for j in range(0,len(R2_lines)):
			R2_chromosome = R2_chromosomes[j]
			R2_strand = R2_strands[j]
			#abstract the R2 strand from it's bit information
			R2_bit_strand = format(int(R2_strand),'#014b')[2:]	
			R2_mapped = R2_bit_strand[-3] #0 for mapped, 1 for not mapped
			R2_strand = R2_bit_strand[-5] #either 0 or 1
			R2_position = R2_positions[j]
			R2_readname = R2_readnames[j]
			if R1_readname != R2_readname:
				sys.stderr.write( "something is very wrong with your extraction of names\n")
				sys.stderr.write( R1_readname+' does not equal '+R2_readname+'\n')
				sys.exit(1)
			#first make sure both reads mapped
			if R1_mapped == "1" or R2_mapped == "1": #if either read didn't map then we need to ignore this potential pair
				continue
			#first check chromosomes match
			if R1_chromosome != R2_chromosome: #this part also takes care of reads that don't align at all. Some aligners use *
				continue
			#next check that the strands are + or - (that the strands are either 0 or 1 in the 4th bit)
			elif (R1_strand != '1' and R1_strand != '0') or (R2_strand != '1' and R2_strand != '0'):
				continue
			#next check if strands are opposite
			elif R1_strand == R2_strand:
				continue
			#next check they are facing convergently. 
			#'FR' meaning that plus strand must be in lower position than minus strand
			#sys.stderr.write( 'positions:',R1_position,R2_position #deub
			elif R1_strand == '0' and R1_position > R2_position:
				continue
			elif R2_strand == '0' and R2_position > R1_position:
				continue
			#now check that the distance between the read positions is within 500,000
			elif abs(R1_position - R2_position) > MAX_INSERT:
				continue
			#all criteria have been satisfied so this pair is a pair of potentially uniquely aligned reads
			else:
				unique_pairs1.append(R1_lines[i])
				unique_pairs2.append(R2_lines[j])
	#now sys.stderr.write( uniquely alignment only if length(Iunique_pairs1) is 1 (meaning there was only 1 good paired end conformation)
	#also at this point fix up the sam flags and add in the paired end info
	if len(unique_pairs1) == 1 and len(unique_pairs2) ==1:
		number_pairs_uniquely_mapped += 1  #found one more uniquely paired read
		#figure out which one was plus strand and which was minus strand
		R1_read = unique_pairs1[0]
		R2_read = unique_pairs2[0]
		R1_info = R1_read.split('\t')
		R2_info = R2_read.split('\t')
		R1_strand = R1_info[1]
		R2_strand = R2_info[1]
		R1_bit_strand = format(int(R1_strand),'#014b')[2:]	
		R1_bit_strand = R1_bit_strand[-5] #either 0 or 1
		R2_bit_strand = format(int(R2_strand),'#014b')[2:]	
		R2_bit_strand = R2_bit_strand[-5] #either 0 or 1
		#fix the sam flags and paired end 
		if R1_bit_strand == "0" and R2_bit_strand == "1": #R1 is + strand
			#R1_strand = "99"; R2_strand = '147'
			R1_info[1] = "99"; R2_info[1] = '147' #fixing flags for paired endedness
			R1_info[6] = R2_info[2] #mate ref
			R1_info[7] = R2_info[3] #mate position
			R1_info[8] = str(int(R2_info[3])+len(R2_info[9])-int(R1_info[3])) #insert dist
			R2_info[6] = R1_info[2]
			R2_info[7] = R1_info[3]
			R2_info[8] = "-"+R1_info[8]
			R1_read = "\t".join(R1_info)
			R2_read = "\t".join(R2_info)
		elif R2_bit_strand == "0" and R1_bit_strand == "1":
			#R1_strand = "83"; R2_strand = "163"
			R1_info[1] = '83'; R2_info[1] = '163' #fixing flags for paired endedness
			R2_info[6] = R1_info[2] #mate ref
			R2_info[7] = R1_info[3] #mate position
			R2_info[8] = str(int(R1_info[3])+len(R1_info[9])-int(R2_info[3])) #insert dist
			R1_info[6] = R2_info[2]
			R1_info[7] = R2_info[3]
			R1_info[8] = "-"+R2_info[8]
			R1_read = "\t".join(R1_info)
			R2_read = "\t".join(R2_info)
		else:
			sys.stderr.write( "cannot determine FR strand conformation of\n")
			sys.stderr.write( R1_read+'\n')
			sys.stderr.write( R2_read+'\n')
			sys.exit(1)
		sys.stdout.write(R1_read)
		sys.stdout.write(R2_read)
#		fh.write(unique_pairs1[0])
#		fh.write(unique_pairs2[0])
	#done with function
				

still_going_through_files=True
get_next_set_R1_reads = True
get_next_set_R2_reads = True
R1_line = R1_fh.readline()
R2_line = R2_fh.readline()
counter = 0
readname_pattern = re.compile(r'(.*)/[12]') #we want to remove the /1 or /2 from readnames to directly compare them
while still_going_through_files:
	counter += 1
	if counter % 1000000 == 0:
		sys.stderr.write( str(counter)+'\n')
	#def get_all_same_reads(readname,fh):
	#R1_fh
	#R2_fh
	#output_agglomerate_fh
#	R1_line = R1_fh.readline()
#	R2_line = R2_fh.readline()
	if R1_line == '' or R2_line == '': #if either R1_line or R2_line=='' then obviously the remaining lines in the other file don't have any paired ends. So just end the program now
		still_going_through_files = False
		continue
	if get_next_set_R1_reads:	
		R1_readname = R1_line.split('\t')[0]
		if not readname_pattern.search(R1_readname):
			sys.stderr.write( "something wrong with regular expression for readname\n")
			sys.exit(1)
		R1_info = get_all_same_reads(R1_readname,R1_fh,R1_line)
		#make sure run get_all_same_reads before changing the readname
		R1_readname = readname_pattern.search(R1_readname).group(1)
		R1_lines = R1_info[0] #lines
		R1_line = R1_info[1] #line
		#R_next_readname = R1_info[3] #current_readname
		#sys.stderr.write( R1_lines #debugging line
	if get_next_set_R2_reads:
		R2_readname = R2_line.split('\t')[0]
		if not readname_pattern.search(R2_readname):
			sys.stderr.write( "something wrong with regular expression for readname\n")
			sys.exit(1)
		R2_info = get_all_same_reads(R2_readname,R2_fh,R2_line)
		#make sure run get_all_same_reads before changing the readname
		R2_readname = readname_pattern.search(R2_readname).group(1)
		R2_lines = R2_info[0] #lines
		R2_line = R2_info[1] #line
		#sys.stderr.write( R2_lines #debbing line
		#R2_next_readname = R2_info[3] #current_readname
	#let's see if the readnames match or if not see which is earlier in the alphabet
	if R1_readname == R2_readname: #then we should look for paired end stuff (use a separate function for this)
		record_unique_paired_end_events(R1_lines,R2_lines)
		get_next_set_R1_reads = True
		get_next_set_R2_reads = True
	elif R1_readname < R2_readname: #then you should discard R1_lines b/c R2 file will never this readname
		get_next_set_R1_reads = True
		get_next_set_R2_reads = False
	elif R1_readname > R2_readname: #then you should discard R2 lines b/c R1 file will never have this readname
		get_next_set_R1_reads = False
		get_next_set_R2_reads = True
	else:
		sys.stderr.write( "cannot figure out which is alphabetically greater "+R1_readname+' or '+R2_readname+'\n')
		sys.exit(1)
	
R1_fh.close()
R2_fh.close()

sys.stderr.write( 'number of read pairs uniquely mapped: '+str(number_pairs_uniquely_mapped)+'\n')
sys.stderr.write( 'job completed\n')
