#Arg1:point mutation vcf
#Arg2:tbam
#Arg3:nbam
# Requirement: indel realign must be applied to the BAMs

#updates
#181102: read sequence reassembly made
#181106 size calculation error correction
#181106 cutoff change
#181129 indel minimal range 10 -> 20
#181203 AAC AA -> AC A ;  CC TT -> C T; counting by query_name
#181211 query_aligenement_sequence -> query_sequence

import sys,pysam,gzip,collections
from numpy import median
import re
lt_seq_size=20 # for indel
snv_seq_size=20 # for snv
max_repeat_unit_seq = 8

def read_search(t_file,chr1,pos1,read_dic,count_stat):
	t=0
#	count_stat==0
	for read in t_file.fetch(chr1,pos1-1,pos1):
		if not(read.is_proper_pair) or read.cigartuples==None or read.mate_is_unmapped:
			continue
		else:
			MC_list=re.findall(r'(\d+)([A-Z]{1})', read.get_tag("MC"))
			est_dist=0
			current_m=0;current_i=0;current_d=0;target_del_stat=0;current_sh=0
			c_count=0
			current_m,current_i,current_d,current_sh=real_dist_calc(MC_list,current_m,current_i,current_d,current_sh)

			rel_dist=current_m+current_d  # start with 0
			h_digit=hex(read.flag)[-2]
			tmp_read_dic=read_type(read,h_digit,rel_dist,read_dic)

			if tmp_read_dic=="pass":
				continue
			else:
				read_dic = tmp_read_dic

		if t>1000:
			count_stat=1
			break
#		       out_file2.write(vcf_line+"\n")
#		       continue
		t+=1
	return count_stat,read_dic



def real_dist_calc(MC_list,current_m,current_i,current_d,current_sh):
	for cigar in MC_list:   #loop for calculate real distance
		if cigar[1]=="M":
			current_m=current_m+int(cigar[0])
		elif cigar[1]=="I" :
			current_i=current_i+int(cigar[0])
		elif cigar[1]=="D" :
			current_d=current_d+int(cigar[0])
			'blank'
		elif (cigar[1]=="S" or cigar[1]=="H" ):
			current_sh=current_sh+int(cigar[0])
	return current_m,current_i,current_d,current_sh

def read_type(read,h_digit,rel_dist,read_dic):
	if read.is_read1:
		if h_digit=="6":
			pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F1R2"
		elif h_digit=="5":
			pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F2R1"
		elif h_digit=="4":
			pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F1R1"
		elif h_digit=="7":
			pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F2R2"
		else:
			return "pass"


	else:
		if h_digit=="a":
			pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F2R1"
		elif h_digit=="9":
			pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F1R2"
		elif h_digit=="8":
			pos_key=str(read.next_reference_start)+"\t"+str(read.reference_end)+"\t"+"F2R1"
		elif h_digit=="b":
			pos_key=str(read.reference_start)+"\t"+str(read.next_reference_start+rel_dist)+"\t"+"F2R2"
		else:
			return "pass"
	if pos_key in read_dic.keys():
		read_dic[pos_key].append(read)
	else:
		read_dic[pos_key]=[read]
	return read_dic


def real_dist_cigar_calc(cigar_list,est_dist,current_m,current_i,current_d,target_del_stat):
	for cigar in cigar_list:   #loop for calculate real distance
		if cigar[0]==0 and (current_m + current_d)  <=est_dist:
			current_m=current_m+cigar[1]
		elif cigar[0]==1 and (current_m + current_d) <=est_dist:
			current_i=current_i+cigar[1]
		elif cigar[0]==2 and (current_m + current_d) <=est_dist:
			if current_m+current_d+cigar[1] > est_dist:
				target_del_stat=1
				break
			else:
				current_d=current_d+cigar[1]
		elif current_m + current_d > est_dist:
			break
		else:
			'blank'
	return est_dist+current_i-current_d,target_del_stat


def block_search(sorted_list,new_lt_size,new_rt_size,read_dic,in_indi,n_file,alt_nt,chr1,pos1,out_file2,vcf_line):
	out_dic={}
	seq_dic={}
	for block in sorted_list:
		for i in range(new_lt_size*(-1), new_rt_size+1):
			seq_dic[i]=[]
		ref_n_list=[];var_n_list=[];ukn_n_list=[]
		for read in read_dic[block]:	
			var_read="off"
			est_dist=int(in_indi[1])-read.reference_start-1
			rlength=read.infer_query_length()
			if read.cigartuples==None:
				continue
			cigar_list=read.cigartuples
			current_m=0;current_i=0;current_d=0;target_del_stat=0
			rel_dist,target_del_stat = real_dist_cigar_calc(cigar_list,est_dist,current_m,current_i,current_d,target_del_stat)
			if target_del_stat==1:
				continue
			cigar_i=0;cigar_d=0;cigar_s=0;cigar_h=0; current_m=0
			
			if read.query_sequence[rel_dist]==alt_nt:   #var_read
				var_read='on'

			if var_read == 'on':
				var_n_list.append(read.query_name)
				for n in range(0, len(read.query_sequence)):
					if (n-rel_dist) < new_lt_size*(-1) or n-rel_dist > new_rt_size: continue
					seq_dic[n-rel_dist].append(read.query_sequence[n])
			if var_read == 'off':  #ref_read
				ref_n_list.append(read.query_name)
		prv_len=0;tmp_lt_size=''
		dic_stat=0
		for i in range(new_lt_size*-1,1):  # dictionary check
#			print seq_dic[i]
#			print prv_len
			if len(seq_dic[i]) == 0 and prv_len !=0:
				print('Sequence Dictionary Error')
#				sys.exit()
#				print vcf_line
#				print block
				out_file2.write(vcf_line+"\n")
				dic_stat=1
				break
 
			if len(seq_dic[i]) !=0 and prv_len ==0:
				tmp_lt_size=abs(i)
			prv_len=len(seq_dic[i])
		if dic_stat==1:
			continue
		prv_len=1;tmp_rt_size=''
		for i in range(0,new_rt_size+1):  # dictionary check
			if len(seq_dic[i]) != 0 and prv_len ==0:
				print('Sequence Dictionary Error')
#				print vcf_line
#				print block
				out_file2.write(vcf_line+"\n")
				dic_stat=1
				break
#				sys.exit()
			if len(seq_dic[i])==0:
				tmp_rt_size=i-1
			prv_len=len(seq_dic[i])
		if dic_stat==1:
			continue
		if tmp_lt_size != '':
			new_lt_size=tmp_lt_size
		if tmp_rt_size != '':
			new_rt_size=tmp_rt_size

#		print(seq_dic[12])
#		print(seq_dic[13])
		cons_seq=''
		for i in range(new_lt_size*-1, new_rt_size+1):
			if len(seq_dic[i])==0:continue
			lt=collections.Counter(seq_dic[i]).most_common(1)[0][0]
			cons_seq=cons_seq+lt
#		print(cons_seq)
		ncons_list=[];other_list=[]
		for read in n_file.fetch(chr1,pos1-1,pos1):
			hit='off'
			if read.is_unmapped == True or read.is_duplicate == True: continue
			if cons_seq in read.query_sequence:
				ncons_list.append(read.query_name)
			else:
				other_list.append(read.query_name)
		
		repeat_seq='.'; ru='.'
		ref_n=len(list(set(ref_n_list)))
		var_n=len(list(set(var_n_list)))
		ukn_n=len(list(set(ukn_n_list)))
		other=len(list(set(other_list)))
		ncons=len(list(set(ncons_list)))
#####3edit#
###########
		read_list=[repeat_seq, str(ru), str(new_lt_size), str(new_rt_size),str(ref_n), str(var_n),str(ukn_n), str(other),str(ncons)]
		if not ("\t".join(block.split('\t')[0:2])) in out_dic.keys():
			out_dic[("\t".join(block.split('\t')[0:2]))]=["NA","NA","NA","NA","NA","NA","NA","NA","NA"]
		if block.split('\t')[2]=="F1R1":
			out_dic["\t".join(block.split('\t')[0:2])][0]=','.join(read_list)
		elif block.split('\t')[2]=="F1R2":
			out_dic["\t".join(block.split('\t')[0:2])][1]=','.join(read_list)
		elif block.split('\t')[2]=="F2R1":
			out_dic["\t".join(block.split("\t")[0:2])][2]=','.join(read_list)
		else:
			out_dic["\t".join(block.split("\t")[0:2])][3]=','.join(read_list)
	return out_dic

def output_gen(out_dic):
	out_list=[[],[],[],[]]
	for block in sorted(out_dic.keys(),key=lambda x : (int(x.split('\t')[0]),int(x.split('\t')[1]))):
		for i in range(0,4):
			out_list[i].append(out_dic[block][i])
	cmd=""
	for i in range(0,4):
		cmd+=";".join(out_list[i])+"\t"
	return cmd

print(sys.argv[1])


def main():
	if sys.argv[1][-3:]=='.gz':
		vcf_file=gzip.open(sys.argv[1])
	else:
		vcf_file=open(sys.argv[1])
	t_file=pysam.AlignmentFile(sys.argv[2],'rb') #tumor bam file
	n_file=pysam.AlignmentFile(sys.argv[3],'rb') #normal bam file
	if sys.argv[1][-3:]=='.gz':
		out_file=open(sys.argv[1][:-3]+'.v4.rasm','w')
		out_file2=open(sys.argv[1][:-3]+'.v4.rasm.out.txt','w')
	else:
		out_file=open(sys.argv[1]+'.v4.rasm','w')
		out_file2=open(sys.argv[1]+'.v4.rasm.out.txt','w')


	vcf_line=vcf_file.readline().strip()
	
	while "##" in vcf_line:
		vcf_line=vcf_file.readline().rstrip()
		continue
#if vcf_line[0:4]=="#CHR":
	head="repeat_unit;ref_repeat_count;lt_seq_size;rt_seq_size;t_ref;t_var;t_unknown;n_other;n_consen"
	out_file.write(vcf_line+'\t'+"F1R1(%s)\tF1R2(%s)\tF2R1(%s)\tF2R2(%s)\n"%(head,head,head,head))
	out_file2.write(vcf_line+"\n")

	vcf_line=vcf_file.readline().rstrip()

	while vcf_line:
		new_lt_size=lt_seq_size
		in_indi=vcf_line.split('\t')
#	print in_indi
		chr1=in_indi[0];pos1=int(in_indi[1])
		ref_nt=in_indi[3]
		alt_nt=(in_indi[4].split(',')[0]).split('/')[0]
		seq_dic={}
		new_lt_size=snv_seq_size
		new_rt_size=snv_seq_size
###
		read_dic={}
#	print vcf_line
		t=0
		count_stat=0
		count_stat,read_dic=read_search(t_file,chr1,pos1,read_dic,count_stat)
#		print read_dic
		if count_stat==1:
			out_file2.write(vcf_line+"\n")
			vcf_line=vcf_file.readline().rstrip()
			continue

		sorted_list=sorted(read_dic.keys(),key=lambda x :(int(x.split("\t")[0]),int(x.split("\t")[1]),x.split("\t")[2]))

		out_dic=block_search(sorted_list,new_lt_size,new_rt_size,read_dic,in_indi,n_file,alt_nt,chr1,pos1,out_file2,vcf_line)
		out_file.write(vcf_line+"\t"+(output_gen(out_dic))[:-1]+"\n")
		vcf_line=vcf_file.readline().rstrip()
		continue

if __name__=="__main__":
	main()
