import sys
import pysam
import re
from numpy import median
import time


vcf_file=file(sys.argv[1])
out_file=file(sys.argv[1]+".vaf_info.v2.txt","w")
vcf_line=vcf_file.readline().rstrip()
while "##" in vcf_line:
	vcf_line=vcf_file.readline().rstrip()
	continue
#vcf_line=vcf_file.readline().rstrip()
out_file.write(vcf_line+'\t'+"F1R1(tot_n)\tF1R2(tot_n)\tF2R1(tot_n)\tF2R2(tot_n)\n")
vcf_line=vcf_file.readline().rstrip()
bamfile=pysam.AlignmentFile(sys.argv[2],"rb")
start_time = time.time()
while vcf_line:
	vcf_info=vcf_line.split('\t')
	chr1=vcf_info[0]
	pos1=int(vcf_info[1])
	ref_nt=vcf_info[3]
	alt_nt=(vcf_info[4])

	
	if len(ref_nt) == 1 and len(alt_nt)==1:
		mttype='snv'
	elif len(ref_nt) == len(alt_nt) and len(ref_nt) >1:
		mttype='snv'
		ref_nt=ref_nt[0]
		alt_nt=alt_nt[0]
	elif len(ref_nt) > len(alt_nt):
		mttype='del'
		cor_len=len(alt_nt)-1
		alt_nt=alt_nt[0]
		ref_nt=ref_nt[0:len(ref_nt)-cor_len]
		del_len=len(ref_nt)-1
	elif len(ref_nt) < len(alt_nt):
		mttype='ins'
		cor_len=len(ref_nt)-1
		ref_nt=ref_nt[0]
		alt_nt=alt_nt[0:len(alt_nt)-cor_len]
		ins_len=len(alt_nt)-1
		ins_seq=alt_nt[1:]
	else:
		print('ERROR: unknown mutation type. exiting')
		print(in_line)
		sys.exit(1)
	read_dic={}
	var_loca_lt=[]
	var_loca_rt=[]
	for read in bamfile.fetch(chr1,pos1-1,pos1):
		if not(read.is_proper_pair) or read.cigartuples==None or read.mate_is_unmapped:
			continue
		else:
			MC_list=re.findall(r'(\d+)([A-Z]{1})', read.get_tag("MC"))
			est_dist=0
			current_m=0;current_i=0;current_d=0;target_del_stat=0;current_sh=0
			c_count=0
			for cigar in MC_list:   #loop for calculate real distance
				if cigar[1]=="M":
					current_m=current_m+int(cigar[0])
				elif cigar[1]=="I" :
					current_i=current_i+int(cigar[0])
				elif cigar[1]=="D" :
					current_d=current_d+int(cigar[0])
					'blank'
				elif (cigar[1]=="S" or cigar[1]=="H" ):
					current_sh=current_sh+int(cigar[0])
						
						
			rel_dist=current_m+current_d  # start with 0
			h_digit=hex(read.flag)[-2]
			if read.is_read1:
				if h_digit=="6":
					pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F1R2"
				elif h_digit=="5":
					pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F2R1"
				elif h_digit=="4":
					pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F1R1"
				elif h_digit=="7":
					pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F2R2"
				else:
					continue
					
	
			else:
				if h_digit=="a":
					pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F2R1"
				elif h_digit=="9":
					pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F1R2"
				elif h_digit=="8":
					pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F1R1"
				elif h_digit=="b":
					pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F2R2"
				else:
					continue
#			print read.query_name
#			print read.reference_start
#			print read.is_read1
#			print h_digit		
#			print read.flag		
#			if read.query_name=="A01146:18:HMKHLDSXY:1:1126:18855:22467":
#				print h_digit
#				print pos_key
#				raw_input()
			if pos_key in read_dic.keys():
				read_dic[pos_key].append(read)
			else:
				read_dic[pos_key]=[read]
#	print sorted_list
#	raw_input()
	sorted_list=sorted(read_dic.keys(),key=lambda x :(int(x.split("\t")[0]),int(x.split("\t")[1]),x.split("\t")[2]))
#	print sorted_list
#	raw_input()
	ori_sorted_list=sorted_list
	F1R1_prev_block_info=["","",""]
	F1R2_prev_block_info=["","",""]
	F2R1_prev_block_info=["","",""]
	F2R2_prev_block_info=["","",""]
	cur_block_info=["","",""]
	edit_block_dic={}
	i=0
#	for block in sorted_list:
#		print block
#	print "done"
#	print sorted_list
	for block in sorted_list:
		cur_block_info=block.split("\t")
		if cur_block_info[2]=="F1R1":
			edit_block_dic[block]=[[block],read_dic[block]]
			if (F1R1_prev_block_info[0]==cur_block_info[0] and abs(int(F1R1_prev_block_info[1])-int(cur_block_info[1]))<=5) or (F1R1_prev_block_info[1]==cur_block_info[1] and abs(int(F1R1_prev_block_info[0])-int(cur_block_info[0]))<=5) :
				if F1R1_prev_block_info!=["","",""]:
					edit_block_dic[block][0]=edit_block_dic[F1R1_prev_block][0]
					edit_block_dic[block][0].append(block)
					edit_block_dic[block][1]=edit_block_dic[F1R1_prev_block][1]+read_dic[block]
					del edit_block_dic[F1R1_prev_block]	
			F1R1_prev_block=block
			F1R1_prev_block_info=cur_block_info
		if cur_block_info[2]=="F1R2":
			edit_block_dic[block]=[[block],read_dic[block]]	
			if (F1R2_prev_block_info[0]==cur_block_info[0] and abs(int(F1R2_prev_block_info[1])-int(cur_block_info[1]))<=5) or (F1R2_prev_block_info[1]==cur_block_info[1] and abs(int(F1R2_prev_block_info[0])-int(cur_block_info[0]))<=5) :
				if F1R2_prev_block_info!=["","",""]:
					edit_block_dic[block][0]=edit_block_dic[F1R2_prev_block][0]
					edit_block_dic[block][0].append(block)
					edit_block_dic[block][1]=edit_block_dic[F1R2_prev_block][1]+read_dic[block]
					del edit_block_dic[F1R2_prev_block]	
			F1R2_prev_block=block
			F1R2_prev_block_info=cur_block_info
		if cur_block_info[2]=="F2R1":
			edit_block_dic[block]=[[block],read_dic[block]]
			if (F2R1_prev_block_info[0]==cur_block_info[0] and abs(int(F2R1_prev_block_info[1])-int(cur_block_info[1]))<=5) or (F2R1_prev_block_info[1]==cur_block_info[1] and abs(int(F2R1_prev_block_info[0])-int(cur_block_info[0]))<=5) :
				if F2R1_prev_block_info!=["","",""]:
					edit_block_dic[block][0]=edit_block_dic[F2R1_prev_block][0]
					edit_block_dic[block][0].append(block)
					edit_block_dic[block][1]=edit_block_dic[F2R1_prev_block][1]+read_dic[block]
					del edit_block_dic[F2R1_prev_block]	
			F2R1_prev_block=block
			F2R1_prev_block_info=cur_block_info
		
		elif cur_block_info[2]=="F2R2":
			edit_block_dic[block]=[[block],read_dic[block]]
			if (F2R2_prev_block_info[0]==cur_block_info[0] and abs(int(F2R2_prev_block_info[1])-int(cur_block_info[1]))<=5) or (F2R2_prev_block_info[1]==cur_block_info[1] and abs(int(F2R2_prev_block_info[0])-int(cur_block_info[0]))<=5) :
				if F2R2_prev_block_info!=["","",""]:
					edit_block_dic[block][0]=edit_block_dic[F2R2_prev_block][0]
					edit_block_dic[block][0].append(block)
					edit_block_dic[block][1]=edit_block_dic[F2R2_prev_block][1]+read_dic[block]
					del edit_block_dic[F2R2_prev_block]	
#				raw_input()
			F2R2_prev_block=block
			F2R2_prev_block_info=cur_block_info
#		print edit_sorted_list
		i+=1
	new_block_dic={}
	out_dic={}
	for block in edit_block_dic.keys():
		for i in range(0,len(edit_block_dic[block][0])):
			new_block_dic[edit_block_dic[block][0][i]]=edit_block_dic[block][1]
#	print new_block_dic
#	raw_input()
#	print sorted_list
	for block in sorted_list:
#		print block
#		print read.query_name
		var_mapq=[];ref_mapq=[]
		ref_n=[];var_n=[]
		nalt=0
		nref=0
		var_nm=[]
		tot_n=[]
#		print len(new_block_dic[block])
		for read in new_block_dic[block]:
#			print read.query_name
                        est_dist=pos1-read.reference_start-1
                        rlength=read.infer_query_length()
                        cigar_list=read.cigartuples
                        current_m=0;current_i=0;current_d=0;target_del_stat=0
                        for cigar in cigar_list:   #loop for calculate real distance
                                if cigar[0]==0 and (current_m + current_d)  <=est_dist:
                                        current_m=current_m+cigar[1]
                                elif cigar[0]==1 and (current_m + current_d) <=est_dist:
                                        current_i=current_i+cigar[1]
                                elif cigar[0]==2 and (current_m + current_d) <=est_dist:
                                        if current_m+current_d+cigar[1] > est_dist:
                                                target_del_stat=1
                                                break
                                        else:
                                                current_d=current_d+cigar[1]
                                elif current_m + current_d > est_dist:
                                        break
                                else:
                                        'blank'
			if target_del_stat==1:
				continue
			tot_n.append(read.query_name)

		tot_n=len(set(tot_n))
		read_list=[str(tot_n)]
		if not ("\t".join(block.split('\t')[0:2])) in out_dic.keys():
			out_dic[("\t".join(block.split('\t')[0:2]))]=["NA","NA","NA","NA"]
		
		if block.split('\t')[2]=="F1R1":
			out_dic["\t".join(block.split('\t')[0:2])][0]=','.join(read_list)
		elif block.split('\t')[2]=="F1R2":
			out_dic["\t".join(block.split('\t')[0:2])][1]=','.join(read_list)
		elif block.split('\t')[2]=="F2R1":
			out_dic["\t".join(block.split("\t")[0:2])][2]=','.join(read_list)
		else:
			out_dic["\t".join(block.split("\t")[0:2])][3]=','.join(read_list)
	out_list=[[],[],[],[]]	
#	print out_dic.keys()
	for block in sorted(out_dic.keys(),key=lambda x : (int(x.split('\t')[0]),int(x.split('\t')[1]))):	
#		print block
		for i in range(0,4):
			out_list[i].append(out_dic[block][i])
	cmd=""
	
	for i in range(0,4):
		cmd+=";".join(out_list[i])+"\t"
	out_file.write(vcf_line+"\t"+cmd[:-1]+'\n')
	vcf_line=vcf_file.readline().rstrip()
	continue

print("--- %s seconds ---" % (time.time() - start_time))			
		
#	vcf_line=vcf_/
