# 22.08.30 
# modified by LWH - original script had problem dealing with blocks with clipped reads 
# If the clipped reads' 5' most position invades other block, the read is sorted in a different group after samtools sort
# When this script is run, this alone read erases the original block read information

import sys
import pysam
import re
from numpy import median
import time
from collections import OrderedDict

chrom_key = {'0':'1', '1':'2', '2':'3', '3':'4', '4':'5', '5':'6', '6':'7', '7':'8', '8':'9', '9':'10',
				 '10':'11', '11':'12', '12':'13', '13':'14', '14':'15', '15':'16', '16':'17', '17':'18', '18':'19','19':'20','20':'21','21':'22', '22':'X', '23':'Y'}

chrom_key_reverse = {v: k for k, v in chrom_key.iteritems()}

def overlaps(a,b):
	return min(a[1],b[1])-max(a[0],b[0])


def mate_dist(mc_list,order): ### 1 is end information ,  2 is start information,
	c_count=0
	current_m=0
	current_d=0
	current_s=0
	current_i=0

	if order==1:
		for cigar in mc_list:
#			print cigar
			if cigar[1]=="M":
				current_m=current_m+int(cigar[0])
			elif cigar[1]=="I" :
				current_i=current_i+int(cigar[0])
			elif cigar[1]=="D":
				current_d=current_d+int(cigar[0])
				'blank'
			elif (cigar[1]=="S") and c_count!=0:
				current_s=current_s+int(cigar[0])
			c_count+=1
	else:
		for cigar in mc_list:
			if (cigar[1]==4) and c_count==0:
				current_s=current_s+int(cigar[1])
			c_count+=1
	rel_dist=current_m+current_d+current_s
	return rel_dist

def read_dist(mc_list,order): ### 1 is start,2 is end.
	c_count=0
	current_m=0
	current_d=0
	current_s=0
	current_i=0
	if order==1:
		for cigar in mc_list:
			if (cigar[0]==4) and c_count==0:
				current_s=current_s+int(cigar[1])
			c_count+=1
	else:
		for cigar in mc_list:
			if (cigar[0]==4) and c_count!=0:
#				print cigar[0]
				current_s=current_s+int(cigar[1])
			c_count+=1

	rel_dist=current_m+current_d+current_s
	return current_s


#out_file=file(sys.argv[1].split("/")[-1]+".depth.v8.txt","w")
out_file=file(sys.argv[2],"w") # modified to receive output file path by input - LWH

#vcf_line=vcf_file.readline().rstrip()
out_file.write("CHR1\tPOS1\tPOS2\tPOS3\tPOS4\tCHR2\tF1R1(n,mq)\tF1R2(n,mq)\tF2R1(n,mq)\tF2R2(n,mq)\n")
bamfile=pysam.AlignmentFile(sys.argv[1],"rb")
start_time = time.time()
read_dic={}
k=0
"""
bed_file=file(sys.argv[2])
bed_line=bed_file.readline().rstrip()
bed_chr_list=[]
for i in range(0,21):
	bed_chr_list.append([])


while bed_line:
	bed_info=bed_line.split('\t')
#	print bed_info
	if "GL00" in bed_info[0]  or "MT" in bed_info[0] :
		bed_line=bed_file.readline().rstrip()
		continue
#	print bed_info
	bed_chr_list[int(bed_info[0].replace("X","20").replace("Y","21"))-1].append(bed_info[1]+"\t"+bed_info[2])
	bed_line=bed_file.readline().rstrip()
	continue
"""

pre_chr=-1
pre_pos_key=["","","",""]
key_append = {} # list to search for pos_keys - LWH


for read in bamfile:
	if not(read.is_proper_pair) or read.cigartuples==None or read.mate_is_unmapped:
		continue
	else:
#		print read.reference_id
#		print(read)
#		print read.reference_id ##0_base
#		print read.reference_start ##0_base
#		print read.next_reference_start ##0_base
#		print read.reference_end ##1_base
#		raw_input()

		# modified for mm10 - LWH
		try:
			cur_chr=chrom_key[str(read.reference_id)]  

		# pass the loop in case of noncanonical contig
		except KeyError:
			print str(read.reference_id) + ' is not a canonical contig!'
			continue 

		#print cur_chr
#		if cur_chr!=pre_chr:
#			read_dic=OrderedDict()
			
		# erased by LWH 
		#if cur_chr>=21:
#			out_file.close()
		#	break

		MC_list=re.findall(r'(\d+)([A-Z]{1})', read.get_tag("MC"))
		est_dist=0
		c_count=0
		h_digit=hex(read.flag)[-2]
		type_num=0


		cigar_list=read.cigartuples
		current_m=0;current_i=0;current_d=0;target_del_stat=0
#		print read
#		print hex(read.flag)
		cur_pos_key=[]
		if read.is_read1:
			if h_digit=="6":
				cur_pos_key.append(str(1+read.reference_start-read_dist(cigar_list,1)))
				cur_pos_key.append(str(read.reference_end))
				cur_pos_key.append(str(1+read.next_reference_start))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F1R2"
				type_num=1
			elif h_digit=="5":
				cur_pos_key.append(str(1+read.next_reference_start-mate_dist(MC_list,2)))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(str(1+read.reference_start))
				cur_pos_key.append(str(read.reference_end+read_dist(cigar_list,2)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F2R1"
				type_num=2
			elif h_digit=="4":
				cur_pos_key.append(str(1+read.reference_start-read_dist(cigar_list,1)))
				cur_pos_key.append(str(read.reference_end))
				cur_pos_key.append(str(1+read.next_reference_start))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F1R1"
				type_num=0
			elif h_digit=="7":
				cur_pos_key.append(str(1+read.next_reference_start-mate_dist(MC_list,2)))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(str(1+read.reference_start))
				cur_pos_key.append(str(read.reference_end+read_dist(cigar_list,2)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F2R2"
				type_num=3


		else:
			if h_digit=="a":
				cur_pos_key.append(str(1+read.reference_start-read_dist(cigar_list,1)))
				cur_pos_key.append(str(read.reference_end))
				cur_pos_key.append(str(1+read.next_reference_start))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F2R1"
				type_num=2
			elif h_digit=="9":
				cur_pos_key.append(str(1+read.next_reference_start-mate_dist(MC_list,2)))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(str(1+read.reference_start))
				cur_pos_key.append(str(read.reference_end+read_dist(cigar_list,2)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
#				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))

				read_type="F1R2"
				type_num=1
			elif h_digit=="8":
				cur_pos_key.append(str(1+read.next_reference_start-mate_dist(MC_list,2)))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(str(1+read.reference_start))
				cur_pos_key.append(str(read.reference_end+read_dist(cigar_list,2)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F1R1"
				type_num=0
			elif h_digit=="b":
				cur_pos_key.append(str(1+read.reference_start-read_dist(cigar_list,1)))
				cur_pos_key.append(str(read.reference_end+read_dist(cigar_list,2)))
				cur_pos_key.append(str(1+read.next_reference_start))
				cur_pos_key.append(str(read.next_reference_start+mate_dist(MC_list,1)))
				cur_pos_key.append(cur_chr)
				#cur_pos_key.append(str(read.next_reference_id+1).replace("20","X").replace("21","Y"))
				read_type="F2R2"
				type_num=3
#		raw_input()
#		print read_type
					
#		pos_start=int(pos_key.split('\t')[0])
#		pos_end=int(pos_key.split('\t')[1])
#		stat=0
#		index=0
#		bed_list=bed_chr_list[cur_chr]
#		print pos_key
#		print bed_start
#		print bed_end
#		print pre_pos_key
#		print cur_pos_key
#		print read_dic[pre_pos_key]
#		print pre_pos_key
#		print cur_pos_key
		###to do : 
		##1st and 4th match -> keep full inform.
		tmp_cur_pos_key=str(cur_pos_key[0])+"\t"+str(cur_pos_key[3]) # modified by LWH 
#		print "pre"
#		print pre_pos_key
#		print "cur"
#		print cur_pos_key

		if pre_chr!=cur_chr and pre_chr!=-1:
			ordered_key=sorted(read_dic.keys(),key=lambda x: (int(x.split("\t")[0]),int(x.split("\t")[0])))
			for key in ordered_key:
				out_list=[]
				key_list=[[],[],[],[],[]]
				for i in range(0,4):
					read_count=str(len(set(read_dic[key][i][0])))
					if len(read_dic[key][i][1])==0:
						mapq="NA"
					else:
						mapq=str(median(read_dic[key][i][1])) 
					out_list.append(",".join([read_count,mapq]))
#					key_list=[[],[],[],[]]
	
					for j in range(0,4):
#						print read_dic[tmp_pre_pos_key][i][2]
#						raw_input()
#						print median(read_dic[tmp_pre_pos_key][i][2][j])
						if len(read_dic[key][i][2][j])==0:
							continue
						else:
							key_list[j].extend(read_dic[key][i][2][j])
						if j>1:
							key_list[4].extend(read_dic[key][i][2][4])
				POS=[]	
				for i in range(0,5):
#					print key_list[i]
#				raw_input()
					POS.append(str(int(median(key_list[i]))))
			
				#out_file.write(str(pre_chr+1).replace("20","X").replace("21","Y")+"\t"+"\t".join(POS)+"\t"+"\t".join(out_list)+'\n')

				out_file.write(pre_chr+"\t"+"\t".join(POS)+"\t"+"\t".join(out_list)+'\n')

			read_dic={}
			key_append={} # added by LWH

		#if [pre_pos_key[0],pre_pos_key[1],pre_pos_key[2],pre_pos_key[3]]==[cur_pos_key[0],cur_pos_key[1],cur_pos_key[2],cur_pos_key[3]]:# modified by LWH 
		if tmp_cur_pos_key in key_append.keys() : # modified by LWH but finding in dict.keys() is too slow -> modified to reset key_append every 10000 reads
			read_dic[tmp_cur_pos_key][type_num][0].append(read.query_name)
			read_dic[tmp_cur_pos_key][type_num][1].append(read.mapping_quality)
			for i in range(0,5):

				# Need to only replace letters in the forth element of the list 'cur_pos_key'. Therefore we cannot use dictionary 'chrom_key_reverse' - LWH 
				#read_dic[tmp_cur_pos_key][type_num][2][i].append(int(cur_pos_key[i].replace("X","20").replace("Y","21")))
				read_dic[tmp_cur_pos_key][type_num][2][i].append(int(cur_pos_key[i].replace("X","23").replace("Y","24")))

		#elif pre_pos_key[0]=="":
		#	read_dic[tmp_cur_pos_key]=[]
		#	for i in range(0,4):
		#		read_dic[tmp_cur_pos_key].append([[],[],[[],[],[],[],[]]])
		#	read_dic[tmp_cur_pos_key][type_num][0].append(read.query_name)
		#	read_dic[tmp_cur_pos_key][type_num][1].append(read.mapping_quality)
		#	for i in range(0,5):
		#		#read_dic[tmp_cur_pos_key][type_num][2][i].append(int(cur_pos_key[i].replace("X","20").replace("Y","21")))
		#		read_dic[tmp_cur_pos_key][type_num][2][i].append(int(cur_pos_key[i].replace("X","36").replace("Y","38")))
			
			
		else:
			"""
			out_list=[]
			key_list=[[],[],[],[]]
			print pre_pos_key
			print read_dic[tmp_pre_pos_key]
			for i in range(0,4):
				read_count=str(len(set(read_dic[tmp_pre_pos_key][i][0])))
				if len(read_dic[tmp_pre_pos_key][i][1])==0:
					mapq="NA"
				else:
					mapq=str(median(read_dic[tmp_pre_pos_key][i][1])) 
				out_list.append(",".join([read_count,mapq]))
#				key_list=[[],[],[],[]]

				for j in range(0,4):
#					print read_dic[tmp_pre_pos_key][i][2]
#					raw_input()
#					print median(read_dic[tmp_pre_pos_key][i][2][j])
					if len(read_dic[tmp_pre_pos_key][i][2][j])==0:
						pass
					else:
						key_list[j].extend(read_dic[tmp_pre_pos_key][i][2][j])
			POS=[]		
			for i in range(0,4):
				print key_list[i]
#				raw_input()
				POS.append(str(int(median(key_list[i]))))
			
			out_file.write(str(pre_chr+1).replace("20","X").replace("21","Y")+"\t"+"\t".join(POS)+"\t"+"\t".join(out_list)+'\n')
			
			if cur_chr!=pre_chr:
				read_dic=OrderedDict()
			"""
			read_dic[tmp_cur_pos_key]=[]
			for i in range(0,4):
				read_dic[tmp_cur_pos_key].append([[],[],[[],[],[],[],[]]])
			read_dic[tmp_cur_pos_key][type_num][0].append(read.query_name)
			read_dic[tmp_cur_pos_key][type_num][1].append(read.mapping_quality)
			###print cur_pos_key
			###print tmp_cur_pos_key
			for i in range(0,5):
				#read_dic[tmp_cur_pos_key][type_num][2][i].append(int(cur_pos_key[i].replace("X","20").replace("Y","21")))
				read_dic[tmp_cur_pos_key][type_num][2][i].append(int(cur_pos_key[i].replace("X","23").replace("Y","24")))
#			print read_dic
#			del read_dic[pre_pos_key]
#			print read_dic
#			raw_input()

			

#		raw_input()
		"""
		if pre_chr!=cur_chr and pre_chr!=-1:
			read_dic=OrderedDict(sorted(read_dic.iteritems(),key=lambda x: (int(x[0].split("\t")[0]),int(x[0].split("\t")[0]))))
			for key in read_dic.keys():
				out_list=[]
				key_list=[[],[],[],[],[]]
				for i in range(0,4):
					read_count=str(len(set(read_dic[key][i][0])))
					if len(read_dic[key][i][1])==0:
						mapq="NA"
					else:
						mapq=str(median(read_dic[key][i][1])) 
					out_list.append(",".join([read_count,mapq]))
#					key_list=[[],[],[],[]]
	
					for j in range(0,4):
#						print read_dic[tmp_pre_pos_key][i][2]
#						raw_input()
#						print median(read_dic[tmp_pre_pos_key][i][2][j])
						if len(read_dic[key][i][2][j])==0:
							continue
						else:
							key_list[j].extend(read_dic[key][i][2][j])
						if j>1:
							key_list[4].extend(read_dic[key][i][2][4])
				POS=[]	
				for i in range(0,5):
#					print key_list[i]
#				raw_input()
					POS.append(str(int(median(key_list[i]))))
			
				out_file.write(str(pre_chr+1).replace("20","X").replace("21","Y")+"\t"+"\t".join(POS)+"\t"+"\t".join(out_list)+'\n')

			read_dic={}
		"""
		pre_chr=cur_chr
		pre_pos_key=cur_pos_key
		tmp_pre_pos_key=str(pre_pos_key[0])+"\t"+str(pre_pos_key[3]) # modified by LWH to generate a more complicated key

		key_append[tmp_pre_pos_key]=0 # modified by LWH, but using dict keys is too slow

		if k%10000==0:
			print str(pre_chr)+"\t:\t" + str(k)
			key_append={} # resetting the key
		k+=1

		continue




ordered_key=sorted(read_dic.keys(),key=lambda x: (int(x.split("\t")[0]),int(x.split("\t")[0])))
for key in ordered_key:
	out_list=[]
	key_list=[[],[],[],[],[]]
	for i in range(0,4):
		read_count=str(len(set(read_dic[key][i][0])))
		if len(read_dic[key][i][1])==0:
			mapq="NA"
		else:
			mapq=str(median(read_dic[key][i][1])) 
		out_list.append(",".join([read_count,mapq]))
#					key_list=[[],[],[],[]]
	
		for j in range(0,4):
#						print read_dic[tmp_pre_pos_key][i][2]
#						raw_input()
#						print median(read_dic[tmp_pre_pos_key][i][2][j])
			if len(read_dic[key][i][2][j])==0:
				continue
			else:
				key_list[j].extend(read_dic[key][i][2][j])
			if j>1:
				key_list[4].extend(read_dic[key][i][2][4])
	POS=[]		
	for i in range(0,5):
#		print key_list[i]
#				raw_input()
		POS.append(str(int(median(key_list[i]))))
		
	#out_file.write(str(pre_chr+1).replace("23","X").replace("24","Y")+"\t"+"\t".join(POS)+"\t"+"\t".join(out_list)+'\n')
	out_file.write(pre_chr+"\t"+"\t".join(POS)+"\t"+"\t".join(out_list)+'\n')


out_file.close()
