#Arg1:point mutation vcf
#Arg2:tumor bam file

#for BWA
#updates
#180503: cigar loop update: error correction for deletion spanning the alt
#180518: cigar loop multiple error correction;add var_loca_reverse;change output format; get rlength from read
#180518: add informations: presence of clipping(soft of hard), INS, DEL in ref of var reads 
#180531: add informations: min, max location, refNM, varNM
#180717: indel is not available now ( write NA at all columns)
#180830: remove indel nt count from nm count
#180903: support gzip input
#181010: print NA when variant read count was zero.
#181022: indel can be annotated
#181101: unmapped read, duplicate read -> 'blank'
#181114: correction for multiple nucleotides in mutect2 such as CTT > CT, CT>AG
#181128: deletion length calculation error correction
#181129: deletion and insertion variant read detection error correction

import sys,pysam,gzip
from numpy import median
from numpy import std
print(sys.argv[1])
if sys.argv[1][-3:]=='.gz':
	in_file=gzip.open(sys.argv[1])
else:
	in_file=open(sys.argv[1])
bam_file=pysam.AlignmentFile(sys.argv[2],'rb') #bam file
if sys.argv[1][-3:]=='.gz':
	out_file=open(sys.argv[1][:-3]+'.std','w')
else:
	out_file=open(sys.argv[1]+'.std','w')



in_line=in_file.readline().strip()
i=0


sudo_head_list=["var_LocaLeftSTD","var_LocaRightSTD"]
head_list=[]
for head in sudo_head_list:
	head_list.append("F1R2_"+head)
	head_list.append("F2R1_"+head)

while in_line:
#	print in_line
	if in_line[0:3]=='CHR':
		out_file.write(in_line+"\t"+"\t".join(head_list)+"\n")
	elif in_line[0]=='#':
#		out_file.write(in_line+'\n')
		in_line=in_file.readline().rstrip()
		continue
	else:
		in_indi=in_line.split('\t')
		chr1=in_indi[0]
		pos1=int(in_indi[1])
		ref_nt=in_indi[3]
		alt_nt=(in_indi[4].split(',')[0]).split('/')[0]
		mttype='snv'

#		F1R2_dic=[{},{},{},[{},{}],{},{},[],[],[],[],[],[],[],[],[],[],[],[],[],{}]
#		F2R1_dic=[{},{},{},[{},{}],{},{},[],[],[],[],[],[],[],[],[],[],[],[],[],{}]
		F1R2_dic=[{},{},{},{},{},{},[],[],[],[],[],[],[],[],[],[],[],[],[],{}]
		F2R1_dic=[{},{},{},{},{},{},[],[],[],[],[],[],[],[],[],[],[],[],[],{}]

		"""
		var_mapq={}
		ref_mapq={}
		var_loca_lt=[{},{}]
		var_loca_rt=[{},{}]
		var_nm={}
		ref_nm={}
		ref_n={}
		var_n={}
		ref_i={}
		ref_d={}
		ref_c={}
		var_i={}
		var_d={}
		var_c={}

		var_read_list=[]
		tot_read_list={}
		ref_read_list={}
	
		var_read1_list={}
		var_read2_list={}

		var_baseq={}
		"""
		for read in bam_file.fetch(chr1,pos1-1,pos1):
#			read_list.append(read.queryname)
			if read.is_unmapped == True or read.is_duplicate == True: continue
			var_read='off'
			est_dist=pos1-read.reference_start-1
			rlength=read.infer_query_length()
			if read.cigartuples==None:
				continue
			cigar_list=read.cigartuples
			current_m=0;current_i=0;current_d=0;target_del_stat=0;current_n=0;target_spl_stat=0
			for cigar in cigar_list:   #loop for calculate real distance 
				if cigar[0]==0 and (current_m + current_d+current_n)  <=est_dist:
					current_m=current_m+cigar[1]
				elif cigar[0]==1 and (current_m + current_d+current_n) <=est_dist:
					current_i=current_i+cigar[1]
				elif cigar[0]==2 and (current_m + current_d+current_n) <=est_dist:
					if current_m+current_d+current_n+cigar[1] > est_dist:
						target_del_stat=1
						break
					else:
						current_d=current_d+cigar[1]
				elif cigar[0]==3 and (current_m+current_d+current_n) <=est_dist:
					if current_m+current_d+current_n+cigar[1] > est_dist:
						target_spl_stat=1
						break
					else:
						current_n=current_n+cigar[1]
				elif current_m + current_d + current_n > est_dist:
					break
#				elif cigar[0]=3 
				else:
					'blank'
			if target_del_stat==1:
				continue
			if target_spl_stat==1:
				continue
			rel_dist=est_dist+current_i-current_d -current_n # start with 0
			cigar_i=0;cigar_d=0;cigar_s=0;cigar_h=0; current_m=0;cigar_md=0  # cigar_md: any deletion after matching
			for cigar in cigar_list:   #loop for check presence of clipping, insertion, deletion
				if cigar[0]==0:
					current_m += cigar[1]
				elif cigar[0]==1:
					cigar_i=cigar_i+1
				elif cigar[0]==2:
					cigar_d=cigar_d+1
					if current_m > 0:
						cigar_md += cigar[1]
				elif cigar[0]==3:
					if current_m >0:
						cigar_md += cigar[1]
				elif cigar[0]==4:
					cigar_s=cigar_s+1
				elif cigar[0]==5:
					cigar_h=cigar_h+1
			h_digit=hex(read.flag)[-2]
			if h_digit=="6" or h_digit=="9":
				info_dic=F1R2_dic
			elif h_digit=="5" or h_digit=="a":
				info_dic=F2R1_dic
			else:
				print("warning: low quality read")
#				print in_line
#				print(read.query_name)
				continue
			var_mapq=info_dic[0]
			ref_mapq=info_dic[1]
			var_loca_lt=info_dic[2]
			var_loca_rt=info_dic[3]
			var_nm=info_dic[4]
			ref_nm=info_dic[5]
			ref_n=info_dic[6]
			var_n=info_dic[7]
			ref_i=info_dic[8]
			ref_d=info_dic[9]
			ref_c=info_dic[10]
			var_i=info_dic[11]
			var_d=info_dic[12]
			var_c=info_dic[13]

			var_read_list=info_dic[14]
			tot_read_list=info_dic[15]
			ref_read_list=info_dic[16]
	
			var_read1_list=info_dic[17]
			var_read2_list=info_dic[18]

			var_baseq=info_dic[19]
			
			if read.query_alignment_sequence[rel_dist]==alt_nt:   #var_read
				i+=1
				var_read='on'

			if var_read == 'on':
				var_n.append(read.query_name)
#				var_n = var_n +1
				var_mapq[read.query_name]=read.mapping_quality
#				raw_input()
			
				if read.has_tag('NM')==True:
#					raw_input()
					current_nm=read.get_cigar_stats()[0][10] - (read.get_cigar_stats()[0][1]+read.get_cigar_stats()[0][2])
					var_nm[read.query_name]=current_nm
#					var_nm.append(current_nm)

				if cigar_list[0][0]==0:
					dist=rel_dist
#				       var_loca_lt.append(dist)
#					print read.query_name
#					print dist
					var_loca_lt[read.query_name]=dist
					var_loca_rt[read.query_name]=rlength-1-dist-1

#				       var_loca_rt.append(rlength-1-dist)
				elif cigar_list[0][0]==4 or cigar_list[0][0]==5:
					dist=rel_dist+cigar_list[0][1]
#					print read.query_name
#					print dist
					var_loca_lt[read.query_name]=dist
					var_loca_rt[read.query_name]=rlength-1-dist-1
#				       var_loca_lt.append(dist)
#				       var_loca_rt.append(rlength-1-dist)
				else:
					'blank'

				"""
				if cigar_list[0][0]==0:
					dist=rel_dist
#					var_loca_lt.append(dist)
#					print var_loca_lt
					if read.is_read1:
						var_loca_lt[0][read.query_name]=dist
						var_loca_rt[0][read.query_name]=rlength-1-dist-1
					else:
						var_loca_lt[1][read.query_name]=dist
						var_loca_rt[1][read.query_name]=rlength-1-dist-1
						
#					var_loca_rt.append(rlength-1-dist)
				elif cigar_list[0][0]==4 or cigar_list[0][0]==5:
					dist=rel_dist+cigar_list[0][1]
					if read.is_read1:
						var_loca_lt[0][read.query_name]=dist
						var_loca_rt[0][read.query_name]=rlength-1-dist-1
					else:
						var_loca_lt[1][read.query_name]=dist
						var_loca_rt[1][read.query_name]=rlength-1-dist-1
#					var_loca_lt.append(dist)
#					var_loca_rt.append(rlength-1-dist)
				else:
					'blank'
				"""
				if cigar_i > 0:
					var_i.append(read.query_name)
#					var_i = var_i+1
				if cigar_d > 0:
					var_i.append(read.query_name)
#					var_d = var_d+1
				if cigar_s > 0 or cigar_h > 0:
					var_c.append(read.query_name)
#					var_c = var_c+1

				var_read_list.append(read.query_name)
				tot_read_list.append(read.query_name)
#				if float(read.query_qualities[rel_dist])>=20:
				if not(read.is_reverse):
					var_read1_list.append(read.query_name)
				else:
					var_read2_list.append(read.query_name)
#				print read.query_qualities[rel_dist]
				var_baseq[read.query_name]=read.query_qualities[rel_dist]
				
			if var_read == 'off':  #ref_read
#				ref_n = ref_n +1i
				ref_n.append(read.query_name)
				ref_mapq[read.query_name]=read.mapping_quality
				if read.has_tag('NM')==True: 
					current_nm=read.get_cigar_stats()[0][10] - (read.get_cigar_stats()[0][1]+read.get_cigar_stats()[0][2])
					ref_nm[read.query_name]=current_nm
				if cigar_i > 0:
					ref_i.append(read.query_name)
#					ref_i = ref_i+1
				if cigar_d > 0:
					ref_d.append(read.query_name)
#					ref_d = ref_d+1
				if cigar_s > 0 or cigar_h > 0:
					ref_c.append(read.query_name)
#					ref_c = ref_c+1

				ref_read_list.append(read.query_name)
				tot_read_list.append(read.query_name)

			else:
				'blank'
		out_info=[]

		for info_dic in [F1R2_dic,F2R1_dic]:

			ref_n=info_dic[6]
			var_n=info_dic[7]
			ref_i=info_dic[8]
			ref_d=info_dic[9]
			ref_c=info_dic[10]
			var_i=info_dic[11]
			var_d=info_dic[12]
			var_c=info_dic[13]

                        if len(info_dic[2].keys())==0:
                                vlocal='NA'; vlocar='NA'
                        else:
                                var_loca_lt=info_dic[2]
                                var_loca_rt=info_dic[3]

                                var_loca_lt_value=[]
                                var_loca_rt_value=[]
                                var_loca_lt_value=[]
                                var_loca_rt_value=[]
                                for rn in var_loca_lt.keys():
                                        var_loca_lt_value.append(var_loca_lt[rn])
                                for rn in var_loca_rt.keys():
                                        var_loca_rt_value.append(var_loca_rt[rn])
				vlocal=str(std(var_loca_lt_value))
				vlocar=str(std(var_loca_rt_value))
				"""
                                vlocalmin=str(min(var_loca_lt_value))
                                vlocalmed=str(median(var_loca_lt_value))
                                vlocalmax=str(max(var_loca_lt_value))
                                vlocarmin=str(min(var_loca_rt_value))
                                vlocarmed=str(median(var_loca_rt_value))
                                vlocarmax=str(max(var_loca_rt_value))
                                vlocal=';'.join([vlocalmin, vlocalmed, vlocalmax])
                                vlocar=';'.join([vlocarmin, vlocarmed, vlocarmax])
				"""

			"""
			if len(info_dic[1].keys())==0:
				mr_mq="NA"
			else:
				mr_mapq=[]
				for rn in info_dic[1].keys():
					mr_mapq.append(info_dic[1][rn])
				mr_mq=str(median(mr_mapq))

	
			if len(info_dic[0].keys())==0:
				mv_mq="NA"
			else:
				mv_mapq=[]
				for rn in info_dic[0].keys():
					mv_mapq.append(info_dic[0][rn])
				mv_mq=str(median(mv_mapq))
#			print info_dic[2]
#			print info_dic[3]
			vlocal_r1=""
			vlocal_r2=""
			vlocar_r1=""
			vlocar_r2=""
			if len(info_dic[2][0].keys())==0:
#				print "hit1"
				vlocal_r1='NA'
			if len(info_dic[2][1].keys())==0:
#				print "hit2"
				vlocal_r2='NA'
			if len(info_dic[3][0].keys())==0:
#				print "hit3"
				vlocar_r1='NA'
			if len(info_dic[3][1].keys())==0:
#				print "hit4"
				vlocar_r2='NA'
			
			var_loca_lt=info_dic[2]
			var_loca_rt=info_dic[3]
#			print var_loca_lt
#			print var_loca_rt
			var_loca_lt_r1_value=[]
			var_loca_rt_r1_value=[]
			var_loca_lt_r2_value=[]
			var_loca_rt_r2_value=[]
			for rn in var_loca_lt[0].keys():
				var_loca_lt_r1_value.append(var_loca_lt[0][rn])
			for rn in var_loca_lt[1].keys():
				var_loca_lt_r2_value.append(var_loca_lt[1][rn])
			for rn in var_loca_rt[0].keys():
				var_loca_rt_r1_value.append(var_loca_rt[0][rn])
			for rn in var_loca_rt[1].keys():
				var_loca_rt_r2_value.append(var_loca_rt[1][rn])	
			if vlocal_r1!="NA":
				vlocalmin_r1=str(min(var_loca_lt_r1_value))
				vlocalmed_r1=str(median(var_loca_lt_r1_value))
				vlocalmax_r1=str(max(var_loca_lt_r1_value))
				vlocal_r1=';'.join([vlocalmin_r1, vlocalmed_r1, vlocalmax_r1])
			if vlocal_r2!="NA":
				vlocalmin_r2=str(min(var_loca_lt_r2_value))
				vlocalmed_r2=str(median(var_loca_lt_r2_value))
				vlocalmax_r2=str(max(var_loca_lt_r2_value))
				vlocal_r2=';'.join([vlocalmin_r2, vlocalmed_r2, vlocalmax_r2])
			if vlocar_r1!="NA":
				vlocarmin_r1=str(min(var_loca_rt_r1_value))
				vlocarmed_r1=str(median(var_loca_rt_r1_value))
				vlocarmax_r1=str(max(var_loca_rt_r1_value))
				vlocar_r1=';'.join([vlocarmin_r1, vlocarmed_r1, vlocarmax_r1])
			if vlocar_r2!="NA":
				vlocarmin_r2=str(min(var_loca_rt_r2_value))
				vlocarmed_r2=str(median(var_loca_rt_r2_value))
				vlocarmax_r2=str(max(var_loca_rt_r2_value))
				vlocar_r2=';'.join([vlocarmin_r2, vlocarmed_r2, vlocarmax_r2])
			if len(set(info_dic[6]))==0:
				ref_clip='NA;NA'
				ref_ins='NA;NA'
				ref_del='NA;NA'
			else:
				ref_n=len(set(info_dic[6]))
				ref_c=len(set(info_dic[10]))
				ref_i=len(set(info_dic[8]))
				ref_d=len(set(info_dic[9]))
				ref_clip=str(ref_c)+';'+str(round(ref_c*100/float(ref_n),2))
				ref_ins=str(ref_i)+';'+str(round(ref_i*100/float(ref_n),2))
				ref_del=str(ref_d)+';'+str(round(ref_d*100/float(ref_n),2))

			if len(set(info_dic[7]))==0:
				var_clip='NA;NA'
				var_ins='NA;NA'
				var_del='NA;NA'
			else:
				var_n=len(set(info_dic[7]))
				var_c=len(set(info_dic[13]))
				var_i=len(set(info_dic[11]))
				var_d=len(set(info_dic[12]))
#				print var_n
				var_clip=str(var_c)+';'+str(round(var_c*100/float(var_n),2))
				var_ins=str(var_i)+';'+str(round(var_i*100/float(var_n),2))
				var_del=str(var_d)+';'+str(round(var_d*100/float(var_n),2))
			clip_info=ref_clip+';'+var_clip
			ins_info=ref_ins+';'+var_ins
			del_info=ref_del+';'+var_del
		
			ref_nm=info_dic[5]
			ref_nm_value=[]
			for rn in ref_nm.keys():
				ref_nm_value.append(ref_nm[rn])

			if len(ref_nm_value)==0:
				mr_nm='NA'
			else:
				mr_nm=str(median(ref_nm_value))
		
			var_nm=info_dic[4]
			var_nm_value=[]
			for rn in var_nm.keys():
				var_nm_value.append(var_nm[rn])

			if len(var_nm_value)==0:
				mv_nm='NA'
			else:
				mv_nm=str(median(var_nm_value))
#			print info_dic[19]
			if len(info_dic[19].keys())==0:
				var_basequal='NA'
			else:
				var_bq=[]
				for rn in info_dic[19].keys():
					var_bq.append(info_dic[19][rn])
				var_basequal=str(median(var_bq))
			"""
			info_list=[vlocal,vlocar]
#			print info_list
			for inform in info_list:
				out_info.append(inform)
		idx=len(out_info)/2
		fin_info=[]
		for i in range(0,idx):
			fin_info.append(out_info[i])
			fin_info.append(out_info[i+idx])
		out_file.write(in_line+'\t'+'\t'.join(fin_info)+'\n')
	in_line=in_file.readline().strip()

