import sys
import os
import subprocess
import re
import pysam
import math

ibam=sys.argv[1]	#minimap2
rebam=sys.argv[2]		#STAR-realigned bam
realn_info=sys.argv[3]		#whether each read needs 5' or 3' end realignment
obam=sys.argv[4]	

g_bDebug=False; #False
ostat=obam+".stat"
nDelta=0
thread=10

##V6. Handle if one read has more than one split read

##Filtering is conducted in two steps procedure (why each candidate read is not updated as a result of realignment)
###	NOT_CANDIDATE
###1) filtering uninformative reads
####	FAIL_UNINFOR_UNMAPPED		
####	FAIL_UNINFOR_NOTSPLITREAD
###2) incompatible realigned result
####	FAIL_ANCHOR_NOTMATCH	
####	FAIL_ANCHOR_NOTMATCH	
####	ERR_COMBINE_CIGAR	
###3) ambiguous realinged reseult
####	FAIL_MULTIMAP			
###4) realignmet doesn't change the aligned position
####	FAIL_POS_UNCHANGED


#Using the realign bam,
#1) confirm the realigned position overlapping with origial aligned position
#2) update CIGAR
#3) update MD tag

##realignment only change the aligned coordinate only when I can certain 
#If one read has multiple realigned reads,
##1) choose one with higher NM (number of matched bases)
##2) if tie, don't change the annotation



dictCIGAR={0:"M",1:"I",2:"D",3:"N",4:"S",5:"H",6:"P",7:"=",8:"X",9:"B" };
NUM_LARGE=9999999999

class cRealnReadMgr():
	qname="";
	nAlignedCount5=0;	#split, nonsplit reads
	nAlignedCount3=0;
	tpStatus=None;
	arrRead_5end=[];	#only save split reads
	arrRead_3end=[];

	def __init__(self, a_qname):		
		self.qname=a_qname;
		self.arrRead_5end=[];
		self.arrRead_3end=[];
		nAlignedCount5=0;
		nAlignedCount3=0;

	def GetReadArr(self, a_b5end):
		return self.arrRead_5end if a_b5end else self.arrRead_3end;			

	def AddRead(self, a_read, a_b5end):
		if a_b5end:
			self.nAlignedCount5+=1;
		else:
			self.nAlignedCount3+=1;		

		#Only add split reads
		bIsSplitRead=False;
		for i in a_read.cigartuples:
			if i[0]==3:#N
				bIsSplitRead=True
				break;
		
		if not bIsSplitRead:	
			return -1;

		#If duplicated reads, don't add
		## The same read 1 can be multi-mapped by read 2 multi-mapping
		bExist=False;
		for i in self.arrRead_5end if a_b5end else self.arrRead_3end:
			if a_read.reference_name == i.reference_name and a_read.reference_start == i.reference_start and a_read.reference_end == i.reference_end:

				bExist=True;
				break;
		if bExist:
			return -1;

		if a_b5end:
			self.arrRead_5end.append(a_read);
		else:
			self.arrRead_3end.append(a_read);



	def hasAlignedReads(self, a_b5end):
		if a_b5end:
			return self.nAlignedCount5>0;
		else:
			return self.nAlignedCount3>0;		

	def hasSplitReads(self, a_b5end):
		if a_b5end:
			return len(self.arrRead_5end)>0;
		else:
			return len(self.arrRead_3end)>0;


print("Load candidate list")
dictRealnInfo={}	#(5end, 3end); List of candidate that has either 5' or 3' end need realignment
for line in open(realn_info, "r"):
	arr=line.rstrip("\n").split("\t");
	if arr[5]=="X" and arr[6]=="X":
		continue;
	dictRealnInfo[ arr[0] ]=(arr[5], arr[6])
print(len(dictRealnInfo))


print("Load STAR realigned bam")
dictRealnMgr={};	#chrom,strName
irebam=pysam.AlignmentFile( rebam, "rb" );
nIndex=0;
for read in irebam.fetch():
	#Make the query name format the same
	arr=read.query_name.split(";");
	#strName="/".join( arr[0].split(",") );
	strName=arr[0].replace(",","/")

	b5end=True if arr[1]=="5end" else False;

	#only read1 is considered
	if not read.is_read1:
		continue;
	
	if read.reference_name in dictRealnMgr:
		if strName in dictRealnMgr[ read.reference_name ]:
			dictRealnMgr[ read.reference_name ][ strName ].AddRead( read, b5end );		
		else:
			dictRealnMgr[ read.reference_name ][ strName ]=cRealnReadMgr( strName );
			dictRealnMgr[ read.reference_name ][ strName ].AddRead( read, b5end );

	else:
		dictRealnMgr[ read.reference_name ]={};
		dictRealnMgr[ read.reference_name ][ strName ]=cRealnReadMgr( strName );
		dictRealnMgr[ read.reference_name ][ strName ].AddRead( read, b5end ); 


	nIndex=nIndex+1;
	if nIndex%100000==1:
		print( "STAR reads loaded: "+ str(nIndex-1) )

print( len(dictRealnMgr) )


def GetRefDistanceFromEnd(a_cigar, a_strand, a_type, a_overhang ):	
	#Assuming a_overhang downstream from either ends of read, what is the genomic coordinate distance?
	#all valid outcome is + value; - means error

	nDist=0; nConsumed=0;
	arrCIGAR=a_cigar if (a_type==1 and a_strand=="+" ) or (a_type==2 and a_strand=="-") else list( reversed(a_cigar) );	
	for i in arrCIGAR:
		if i[0] in [1]: #query is consumed I (https://samtools.github.io/hts-specs/SAMv1.pdf)
			nConsumed+=i[1];

		elif i[0] in [2, 3]: #reference is consumed D, N
			nDist+=i[1];

		elif i[0] in [0, 7, 8]: #M, =, X, both query and reference is consumed
			nDist+=min(i[1], a_overhang-nConsumed)
			nConsumed+=min(i[1], a_overhang-nConsumed)
		if nConsumed>=a_overhang:
			break;
	
	return nDist;
	

def GetRefPos(a_read, a_type, a_overhang):	
	#Return the reference coordinate of aligned reads after trimming overhang
	nDist=GetRefDistanceFromEnd(a_read.cigartuples, "-" if a_read.is_reverse else "+", a_type, a_overhang);
	if a_type==1: # 5'end
		return a_read.reference_start+nDist if not a_read.is_reverse else a_read.reference_end-nDist;

	elif a_type==2:	#3' end
		return a_read.reference_end-nDist if not a_read.is_reverse else a_read.reference_start+nDist;
	return -1;	


def GetCIGARstring(a_cigartuples):
	
	result1=[ (dictCIGAR[ i[0] ], str(i[1]) ) for i in a_cigartuples ];
	result2=[ str(i[1])+str(i[0]) for i in result1];
	return "".join(result2);


def GetPatchRealign( a_cigarrealn ):
	#Note cigar strand is ordered left to right. End of interest is on the leftmost position
	#Obtain leftmost CIGAR string that contains all intron information + one downstream CIGAR
	#Rest of downstream CIGAR will be ignored

	nLastIntronIndx=-1;
	for i in range( 0, len( a_cigarrealn)):
		if a_cigarrealn[i][0]==3:	#intron
			nLastIntronIndx=i;
		
	return 	a_cigarrealn[:(nLastIntronIndx+1)] if nLastIntronIndx>=0 else a_cigarrealn;

def GetPatchRealignExtend( a_cigarrealn, a_qlen_dist ):
	#Exception handling
	##if softclipped base is very long, use a_dist to find the anchor position	

	nConsume=0;
	arrCigar=[];
	bPrevCIGAR=False        #If the last cigar element added to discard is not splitted (ex. 3M to 1M and 2M), then see whether the current cigar is N. If N, N is added to discarded.
	for i in range(0, len( a_cigarrealn) ):
		nBuffer=a_qlen_dist-nConsume;

		if nBuffer>0:
			if a_cigarrealn[i][0] in [0, 1, 4, 7, 8]: #M, I, S, =, X
				nRetain=min( nBuffer, a_cigarrealn[i][1] );
				nConsume+=nRetain;
				arrCigar.append( (a_cigarrealn[i][0], nRetain) );

				if nConsume==a_qlen_dist and nRetain==a_cigarrealn[i][1]:
					bPrevCIGAR=True
			else:
				arrCigar.append( a_cigarrealn[i] );				
		elif bPrevCIGAR:
			#If realigned STAR reads contain N or D at the end, includes those as well	
			if a_cigarrealn[i][0] in [3, 2]:#N, D
				arrCigar.append( a_cigarrealn[i] );
				bPrevCIGAR=False;
		else:
			break;		

	return arrCigar;



def GetPatchRealignWIns( a_cigarrealn ):
	#Exception handling
	#From the last intron, does it have inserted base	
	nLastIntronIndx=-1;
	for i in range( 0, len( a_cigarrealn)):
		if a_cigarrealn[i][0]==3:       #intron
			nLastIntronIndx=i;

	nFirstInsertion=-1;	
	for i in range( nLastIntronIndx+1, len( a_cigarrealn) ):
		if a_cigarrealn[i][0]==1:	#insertion
			nFirstInsertion=i;

	return a_cigarrealn[:(nFirstInsertion+1)] if nFirstInsertion>=0 else a_cigarrealn; 



def GetPatchOrig( a_cigarread, a_qlen_realn ):
	##Note cigar strand is ordered left to right. End of interest is on the leftmost position
	#Split QREAD by a_qlen_realn (5' end most a_qlen_realn and the rest)
	#What is the cigar string for the former and the latter (discarded, retained)

	arrResult=[];
	arrDiscard=[];
	nConsume=0; #while nConsume < a_qlen_realn, bases are discarded
	bPrevCIGAR=False	#If the last cigar element added to discard is not splitted (ex. 3M to 1M and 2M), then see whether the current cigar is N. If N, N is added to discarded.
	for i in a_cigarread:
		nBuffer=a_qlen_realn-nConsume
		nToDiscard=0;
		nToRetain=i[1];
		if nBuffer>0:
			if i[0] in [0, 1, 4, 7, 8]:	#M, I, S, =, X	
				nToDiscard=min( i[1], nBuffer ); #how much amount will be contribute to discarded
				nToRetain=max(i[1]-nToDiscard, 0)	#how much will contirbute to retain
				nConsume+=nToDiscard;
				if nConsume==a_qlen_realn and nToRetain==0:	
					bPrevCIGAR=True;
			else:	#N
				nToDiscard=i[1];
				nToRetain=0;	
		#Adding intron to discarded 
		elif bPrevCIGAR:
			if i[0] in [3]:	
				nToDiscard=i[1];
				nToRetain=0;
			bPrevCIGAR=False;

		if nToDiscard>0:
			arrDiscard.append( (i[0], nToDiscard) );

		if nToRetain > 0:
			arrResult.append( (i[0], nToRetain) );

	return (arrResult, arrDiscard);


def GetSoftclipsize(a_read, a_b5end):
	#how many bases of query sequence are consumed as softclip?
	nSCsize=0;
	if ( a_b5end and not a_read.is_reverse ) or ( a_b5end!=True and a_read.is_reverse ):
		nSCsize=a_read.cigartuples[0][1] if a_read.cigartuples[0][0]==4 else 0;
	else:
		nSCsize=a_read.cigartuples[-1][1] if a_read.cigartuples[-1][0]==4 else 0;
	return nSCsize;

def GetRefLen( a_cigars ):
	nSum=0;
	for i in a_cigars:
		if i[0] in [0, 2, 3, 7, 8]: #M, D, N, =, X
			nSum+=i[1]
	return nSum;	

def GetQLEN(a_cigars):
	nSum=0;
	for i in a_cigars:
		if i[0] in [0, 1, 4, 7, 8]: #M, I, S, =, X
			nSum+=i[1]
	return nSum;

def GetQLEN_MDmissing(a_cigars):
	#For MD recalculation, how many of query bases are ignored.
	nSum=0;
	for i in a_cigars:
		if i[0] in [4,5,6,3,1]:	#H, S, P, N, I
			nSum+=i[1]
	return nSum;


def GetAnchorPos( a_read, a_bstrand_reverse, a_b5end, a_delta ):
	#a_bstrand: Note this is the original read's strand info
	
	if ( a_b5end and not a_bstrand_reverse) or (a_b5end!=True and a_bstrand_reverse):
		return a_read.reference_start+a_delta;
	else:
		return a_read.reference_end-a_delta;



FAIL_CIGAR_COORDINATE=-1
ERR_COMBINE_CIGAR=-2

def GetUpdatedCIGAR(a_read, a_read_realn, a_arrCIGARPatch_orig, a_arrCIGARPatch_orig_discard, a_arrCIGARPatch_realn, a_b5end):	
	#Reorder CIGAR string so that end of interest should be on leftmost position.		
	
	##Combine CIGAR of realignment and old alignemnt
	###Intron information is used from realignment by STAR. Anchor position information of minimap2 alingment is used
	###Workflow
	###1. Retain from 5' end to last intron from STAR reads, the last intron + 1bp downstream (anchor pos) is obtained.
	###2. Calculate the read length of 1)
	###3. Trim minimap2 reads by 2) and find its aliged pos (anchor pos)
	###4. Compare whether anochor pos of minimap2 and STAR match
	
	#Intersecting CIGAR should be the same type
	##a_arrCIGARPatch_realn: 3' end truncated, contains 3' end most intron
	##a_arrCIGARPatch_orig: 5' end truncated, truncate 5' end by the size of a_arrCIGARPatch_realn
	##5'---realn----3'
	##              5'------orig-----3'
	##If 5' end of orig is soft/hard clip, fail
	if a_arrCIGARPatch_orig[0][0] not in [0, 1, 4, 7, 8]: #M, I, S, =, X
		#After 5' end trimming, the first cigar element can be only M, I, S, =, X
		print("ERR a_arrCIGARPatch_orig "+a_read.query_name);
		return ("ERR_COMBINE_CIGAR", 0, 0, 0);
	
	if a_arrCIGARPatch_realn[-1][0] not in [0, 1, 4, 7, 8, 3, 2]: #M, I, S, =, X + N, D
		#After 3' end trimming, the last cigar element can be only M, I, S, =, X + N, D
		print("ERR a_arrCIGARPatch_realn "+a_read.query_name);
		return ("ERR_COMBINE_CIGAR", 0, 0, 0);
	
	combinedCIGAR=None
	if a_arrCIGARPatch_realn[-1][0]==a_arrCIGARPatch_orig[0][0]:
		#if 5' end cigar of original and 3' end of cigar of realigned match
		linkCIGAR=[(a_arrCIGARPatch_realn[-1][0], a_arrCIGARPatch_realn[-1][1]+a_arrCIGARPatch_orig[0][1])]
		combinedCIGAR=a_arrCIGARPatch_realn[:-1]+linkCIGAR+a_arrCIGARPatch_orig[1:]
	else:
		combinedCIGAR=a_arrCIGARPatch_realn+a_arrCIGARPatch_orig;

	combinedCIGAR=list(reversed(combinedCIGAR)) if not ( a_b5end and not a_read.is_reverse ) and not ( a_b5end!=True and a_read.is_reverse ) else combinedCIGAR;
		
	return (combinedCIGAR, GetQLEN( a_arrCIGARPatch_realn ), GetQLEN_MDmissing( a_arrCIGARPatch_realn ), GetQLEN_MDmissing( a_arrCIGARPatch_orig_discard) )


def GetPatchMD(a_arrMD, a_length, a_include):
	#For given a_arrMD, return leftmost a_length if a_include=T
	#		    return everything other than a_length of leftmost if a_includ=F

	arrLeft=[];
	arrRight=[];
	nConsume=0;
	for i in a_arrMD:
		if nConsume<a_length:
			nBuffer=a_length-nConsume;
			
			if i.isdigit() or i[0]!="^": #match/mismatch
				nConsume+=min( int(i) if i.isdigit() else len(i), nBuffer);
			else:	#deletion
				continue;

			if a_length==nConsume:
				arrRight.append( str( int(i)-nBuffer ) if i.isdigit() else i[-(len(i)-nBuffer):] );
				arrLeft.append( str( nBuffer ) if i.isdigit() else i[:nBuffer] );
			else:
				arrLeft.append( i )

		else:
			arrRight.append( i );		
	return arrLeft if a_include else arrRight;


def ConvertCIGAR2MD( a_trimmedcigar, a_md, a_bleftmost):
	##With trimmedcigar, retain corresponding md tags

	nConsume=0;
	arrMD=re.split( '(\d+)', a_md); arrMD=[i for i in arrMD if i!=''];
	arrMD_trimmed=[];
	nqlen=GetQLEN_MDmissing( a_trimmedcigar )

	if not a_bleftmost:
		arrMD.reverse();	

	for i in arrMD:
		nBuffer=nqlen-nConsume;
		if nBuffer<=0:
			break;

		if i.isdigit():	#Match
			nMatch=min( int(i), nBuffer)
			arrMD_trimmed.append( str(nMatch) );
			nConsume+=nMatch;
		elif i[0]=="^":	#Deletion
			arrMD_trimmed.append( i );
		elif i=="0":	#between two mismatches
			arrMD_trimmed.append( i );
		else:	#mismatch
			arrMD_trimmed.append( i );
			nConsume+1;

			if len(i)>1:
				print("Unexpected mismatch length ")
				print( a_trimmedcigar )
				print( a_md )				
	if not a_bleftmost:
		arrMD_trimmed.reverse();
	return arrMD_trimmed
	
def GetUpdatedMD( a_read, a_read_realn, a_arrCIGARPatch_orig, a_arrCIGARPatch_orig_discard, a_arrCIGARPatch_realn, a_b5end):
	#Find MD tag corresponds to a_arrCIGARPatch_orig and a_arrCIGARPatch_realn, then combine them into one
	
	#return MD tag
	if not a_read.has_tag("MD") or not a_read_realn.has_tag("MD"):
		return None;

	bRealnLeft=True if (a_b5end and not a_read.is_reverse) or (not a_b5end and a_read.is_reverse) else False;
	arrMDPatch_orig=ConvertCIGAR2MD( a_arrCIGARPatch_orig, a_read.get_tag("MD"), not bRealnLeft )
	arrMDPatch_realn=ConvertCIGAR2MD( a_arrCIGARPatch_realn, a_read_realn.get_tag("MD"), bRealnLeft)

	if len(arrMDPatch_orig)<1 or len(arrMDPatch_realn)<1:
		print("ERR GetUpdatedMD"+str(len(arrMDPatch_orig))+"\t"+str(len(arrMDPatch_realn)));
		return None;
	
	#combine MD tags
	combinedMD=[];
	if arrMDPatch_realn[-1].isdigit() and arrMDPatch_orig[0].isdigit():	#match
		linkMD=[str( int(arrMDPatch_realn[-1])+int(arrMDPatch_orig[0]) )]
		combinedMD=arrMDPatch_realn[:-1]+linkMD+arrMDPatch_orig[1:]
	elif not arrMDPatch_realn[-1].isdigit() and not arrMDPatch_orig[0].isdigit():	#mismatch
		combinedMD=arrMDPatch_realn+['0']+arrMDPatch_orig;
	else:
		combinedMD=arrMDPatch_realn+arrMDPatch_orig;
	
	return combinedMD if bRealnLeft else reversed( combinedMD )

def IsAlignPosChanged( a_read, a_read_realn, a_b5end ):
	#Chromosome is already compared in the upstream
	if a_b5end: #5 end
		nPosOrig=a_read.reference_end if a_read.is_reverse else a_read.reference_start;
		nPosRealn=a_read_realn.reference_end if a_read.is_reverse else a_read_realn.reference_start
		return nPosOrig != nPosRealn;

	elif a_b5end!=True: #3 end
		nPosOrig=a_read.reference_end if not a_read.is_reverse else a_read.reference_start;
		nPosRealn=a_read_realn.reference_end if not a_read.is_reverse else a_read_realn.reference_start
		return nPosOrig != nPosRealn;
	else:
		return False;


#Read realigned sam file and see which of alignment result is close to original aligned position
samfile=pysam.AlignmentFile( ibam, "rb")
osamfile=pysam.AlignmentFile( obam, "wb", template=samfile);

ostatfile=open(ostat, "w+");
ostatfile.write("\t".join(["#qname", "realnresult_5end", "errmsg_5end",  "realnresult_3end", "errmsg_3end", "chrom", "start", "end", "strand" ])+"\n")
##main


def GetStatLineSuffix(a_origread, a_realn):
	if a_realn is not None:
		return "\t".join([ a_origread.reference_name, "-" if a_origread.is_reverse else "+", str(a_origread.reference_start), str(a_origread.reference_end), a_realn.reference_name, "-" if a_realn.is_reverse else "+", str(a_realn.reference_start), str(a_realn.reference_end), a_origread.cigarstring, a_origread.get_tag("MD") if a_origread.has_tag("MD") else "NA", a_realn.cigarstring, a_realn.get_tag("MD") if a_realn.has_tag("MD") else "NA" ]);
	else:
		return "\t".join([ a_origread.reference_name, "-" if a_origread.is_reverse else "+", str(a_origread.reference_start), str(a_origread.reference_end), "NA", "NA", "NA", "NA", a_origread.cigarstring, a_origread.get_tag("MD") if a_origread.has_tag("MD") else "NA", "NA", "NA" ]);


def CheckAnchorPos(a_read, a_read_realn, a_b5end):
	#1) Find the first intron junction from the realigned reads. Then calculte the aligned position of 3' end
	#2) Trim LRCAGE reads by 1)
	#3) compare whether the position from 1 and 2 match
	

	#Reorder CIGAR string so that end of interest should be on leftmost position.
	arrCIGAR=a_read.cigartuples if ( a_b5end==True and not a_read.is_reverse ) or ( a_b5end!=True and a_read.is_reverse ) else list( reversed(a_read.cigartuples) );	

	##Note only the strand of minimap2 alignment matters!!!
	arrRealnCIGAR=a_read_realn.cigartuples if ( a_b5end==True and not a_read.is_reverse ) or ( a_b5end!=True and a_read.is_reverse ) else list( reversed( a_read_realn.cigartuples ));


	##Combine CIGAR of realignment and old alignemnt
	arrCIGARPatch_realn=GetPatchRealign( arrRealnCIGAR );
	if a_b5end and arrCIGARPatch_realn[-1][0]==4:	#If the last CIGAR is soft-clipped, exclude that to calculate QLEN
		arrCIGARPatch_realn=arrCIGARPatch_realn[:-1];
	nQLEN_CIGARPatch_realn=GetQLEN( arrCIGARPatch_realn );	

	##Split original cigar string into two groups, 1) one being discarded, 2) one being retained and used for combining with STAR realinged reads
	arr2CIGARPatch_orig=GetPatchOrig( arrCIGAR, nQLEN_CIGARPatch_realn );
	arrCIGARPatch_orig=arr2CIGARPatch_orig[0]
	arrCIGARPatch_orig_discard=arr2CIGARPatch_orig[1]
	
	##What is the position of reads after trimming by nQLEN_CIGARPatch_realn
	nDeltaOld=GetRefLen( arrCIGARPatch_orig_discard )
	nDeltaNew=GetRefLen( arrCIGARPatch_realn )
	nOldAnchorPos=GetAnchorPos( a_read, a_read.is_reverse, a_b5end, nDeltaOld )
	nNewAnchorPos=GetAnchorPos( a_read_realn, a_read.is_reverse, a_b5end, nDeltaNew )

	if g_bDebug:
		print(a_read.query_name)
		print("\t".join([str(nDeltaOld), str(nDeltaNew),str(nOldAnchorPos),str(nNewAnchorPos)]))
		print(arrCIGARPatch_orig)
		print(arrCIGARPatch_orig_discard)
		print(nQLEN_CIGARPatch_realn)

	##Exception handling
	if arrCIGARPatch_orig[0][0]==4: #softclip
		#If you enter here, softclipped bases are very long 
		#You need to extend anchor search
		#In this case, you softclip base length to find the anchor position		

		#read with 9bp overlap with 5' end exon and insertion error at 5'end of 2'nd exon
		#inserted base is the same as the base of 5' end of 2' nd exon
		#minimap2 align it to 2nd exon with 9bp softclip and inserted bases are soft-clipped
		#STAR aligns 8bp to 1st exon and add 1 insertion to 2nd exon alignment
		
		nQLEN_CIGARPatch_realn=arrCIGAR[0][1]
		arrCIGARPatch_realn=GetPatchRealignExtend( arrRealnCIGAR, nQLEN_CIGARPatch_realn );

		arr2CIGARPatch_orig=GetPatchOrig( arrCIGAR, nQLEN_CIGARPatch_realn );
		arrCIGARPatch_orig=arr2CIGARPatch_orig[0]
		arrCIGARPatch_orig_discard=arr2CIGARPatch_orig[1]
	
		nDeltaOld=GetRefLen( arrCIGARPatch_orig_discard )
		nDeltaNew=GetRefLen( arrCIGARPatch_realn )
		nOldAnchorPos_updated=GetAnchorPos( a_read, a_read.is_reverse, a_b5end, nDeltaOld )
		nNewAnchorPos=GetAnchorPos( a_read_realn, a_read.is_reverse, a_b5end, nDeltaNew )

		nOldAnchorPos=nOldAnchorPos_updated


	nFarPos=a_read_realn.reference_start if ( a_b5end==True and not a_read.is_reverse ) or ( a_b5end!=True and a_read.is_reverse ) else a_read_realn.reference_end;

	##(delta betwen anchor position, retained CIGAR, discarded CIGAR, retained CIGAR from mm2, qlen of retained STAR CIGAR, distance from the farthest realigned position to new anchor position)
	return (nNewAnchorPos-nOldAnchorPos, arrCIGARPatch_orig, arrCIGARPatch_orig_discard, arrCIGARPatch_realn, nQLEN_CIGARPatch_realn, abs(nFarPos-nNewAnchorPos));



def GetRealnStatus(a_read, a_b5end):
	if a_read.reference_name not in dictRealnMgr:
		return ("ERR", a_read);

	if a_read.query_name not in dictRealnMgr[ a_read.reference_name ]:
		return ("FAIL_UNINFOR_UNMAPPED", a_read)
	elif False==dictRealnMgr[ a_read.reference_name ][ a_read.query_name ].hasAlignedReads( a_b5end ):
		#If both 5' end and 3' end needed to be realigned
		#but one end has no read in STAR, but the other has some aligned reads	
		return ("FAIL_UNINFOR_UNMAPPED", a_read)
	elif False==dictRealnMgr[ a_read.reference_name ][ a_read.query_name ].hasSplitReads( a_b5end ):	
		return ("FAIL_UNINFOR_NOTSPLITREAD", a_read)
	return ("PROCEED", a_read);

def GetArrMatchAnchor(a_arrIntron, a_read, a_b5end):
	#Out of split read array, output read array that having matching anchor position
	arrMatchAnchor=[];
	for i in a_arrIntron:
		tDeltaAnchor=CheckAnchorPos( a_read, i, a_b5end );
		if 0==tDeltaAnchor[0]:
			arrMatchAnchor.append( (i, tDeltaAnchor)  );
	return arrMatchAnchor;
	
def GetRealnRead(a_TParrMatchAnchor, a_read, a_b5end):
	#Out of matching anchor read array, output read that are not maching 5 end position but have higher MAPQ
	#a_TParrMatchAnchor : array of (read, tDetalAnchor)

	MAX_AS=0;
	MAX_DIST=0;
	nStart_orig=a_read.reference_start if (a_b5end and not a_read.is_reverse ) or (False==a_b5end and a_read.is_reverse) else a_read.reference_end;

	tpRealn=None;
	for i in a_TParrMatchAnchor:	
		nStart_new=i[0].reference_start if (a_b5end and not a_read.is_reverse ) or (False==a_b5end and a_read.is_reverse) else i[0].reference_end;
		if nStart_new==nStart_orig:
			continue;

		if tpRealn is None:
			tpRealn=i;
			MAX_AS=i[0].get_tag("AS")
			MAX_DIST=i[0].template_length
			continue;

		if i[0].get_tag("AS")>MAX_AS:
			tpRealn=i;
			MAX_AS=i[0].get_tag("AS")
		elif i[0].template_length>MAX_DIST:
			tpRealn=i;
			MAX_DIST=i[0].template_length
	return tpRealn;

def GetSTARrealnRead(a_read, a_b5end):
	if a_read.reference_name not in dictRealnMgr:
		return ("ERR", a_read);

	if a_read.query_name not in dictRealnMgr[ a_read.reference_name ]:
		return ("FAIL_EMPTYREAD", a_read);	

	arrRead=dictRealnMgr[a_read.reference_name][ a_read.query_name ].GetReadArr( a_b5end );
	if len( arrRead )==0:
		return ("FAIL_EMPTYREAD", a_read);

	#Interested in reads that are on the same chromosome and have intron
	arrRead_samechr=[];	
	arrRead_intron=[];
	for i in arrRead:
		if i.reference_name==a_read.reference_name:
			arrRead_samechr.append( i );
			for j in i.cigartuples:
				if j[0]==3: #N
					arrRead_intron.append( i );
					break; 
	if len(arrRead_samechr)==0:
		return ("FAIL_ANCHOR_NOTMATCH", a_read);
	if len(arrRead_intron)==0:
		return ("FAIL_EMPTYREAD_INTRON", a_read);
	
	#Find one STAR read that has the matching anchor position
	##If multi-position exist, the closest one is chosen
	arrTPAnchorMatch=GetArrMatchAnchor( arrRead_intron, a_read, a_b5end );
	if len( arrTPAnchorMatch )==0:
		return ("FAIL_ANCHOR_NOTMATCH", a_read);

	tp_read_realn=GetRealnRead( arrTPAnchorMatch, a_read, a_b5end );
	if tp_read_realn is None:
		return ("FAIL_POS_UNCHANGED", a_read);
	else:
		return ("PROCEED", tp_read_realn[0], tp_read_realn[1][1], tp_read_realn[1][2], tp_read_realn[1][3], tp_read_realn[1][4]);	

def GetUpdatedMM2read(a_read, a_starread, a_arrCIGARPatch_orig, a_arrCIGARPatch_orig_discard, a_arrCIGARPatch_realn, a_b5end):
	tpUpdatedCIGAR=GetUpdatedCIGAR( a_read, a_starread, a_arrCIGARPatch_orig, a_arrCIGARPatch_orig_discard, a_arrCIGARPatch_realn, a_b5end );
	if tpUpdatedCIGAR[0]=="ERR_COMBINE_CIGAR":
		return ("ERR_COMBINE_CIGAR", a_read);

	arrNewCIGAR=tpUpdatedCIGAR[ 0 ];
	
	#MD tag and NM tag are updated using samtools

	#TLEN	 #always 0 for single end read
	nNewRefPos_start=a_read.reference_start
	if (a_b5end and not a_read.is_reverse ) or (False==a_b5end and a_read.is_reverse):
		nNewRefPos_start=a_starread.reference_start
	
	newRead=a_read;
	newRead.set_tag("OC", a_read.cigarstring);
	newRead.cigarstring=GetCIGARstring( arrNewCIGAR );
	newRead.reference_start=nNewRefPos_start;
	return ("PROCEED", newRead );

def GetRealignResult(a_read, a_b5end):
	tpResult=GetRealnStatus( a_read, a_b5end );
	if tpResult[0]!="PROCEED":
		return tpResult;
	tpResult=GetSTARrealnRead( a_read, a_b5end );
	if tpResult[0]!="PROCEED":
		return tpResult;

	#a_read, a_starread, a_arrCIGARPatch_orig, a_arrCIGARPatch_orig_discard, a_arrCIGARPatch_realn, a_b5end );
	tpResult=GetUpdatedMM2read( a_read, tpResult[1], tpResult[2], tpResult[3], tpResult[4], a_b5end);
	return ( "PASS" if tpResult[0]=="PROCEED" else tpResult[0], tpResult[1]);

print("Combine minimap2 and STAR realigned")	
for read in samfile.fetch():
	##No correponding reads in STAR	
	if read.query_name not in dictRealnInfo:	#reads does not need realignment
		#ZR tag: 1 character for 5 and 3 end each. N: not candidate, O: candidate but keep original, M: candidate but multi-mapped; R: candidate and realigned 
		read.set_tag("ZR", "NN");	
		osamfile.write(read);
		strStat="\t".join([read.query_name, "N", "NOT_REALN_CANDIDATE", "N", "NOT_REALN_CANDIDATE", read.reference_name, str(read.reference_start), str(read.reference_end), "-" if read.is_reverse else "+"]);
		continue;

	##Has corresponding reads in STAR
	tpNewRead5=("NOT_REALN_CANDIDATE", read);	
	if dictRealnInfo[ read.query_name ][0]=="O": #5' end realignment is needed
		tpNewRead5=GetRealignResult( tpNewRead5[1], True )
	
	tpNewRead3=("NOT_REALN_CANDIDATE", read if tpNewRead5[0]!="PROCEED" else tpNewRead5[1]);
	if dictRealnInfo[ read.query_name ][1]=="O": #3' end realignment is needed
		tpNewRead3=GetRealignResult( tpNewRead3[1] , False );							
	newRead=tpNewRead3[1];

	##MAPQ update when multi-mapping
	if tpNewRead5[0]=="FAIL_MULTIMAP" or tpNewRead3[0]=="FAIL_MULTIMAP":
		newRead.mapping_quality=0;
	
	strTag5="O";
	if tpNewRead5[0]=="NOT_REALN_CANDIDATE":
		strTag5="N";
	elif tpNewRead5[0]=="PASS":
		strTag5="R"
	elif tpNewRead5[0]=="FAIL_MULTIMAP":
		strTag5="M"
	strTag3="O";
	if tpNewRead3[0]=="NOT_REALN_CANDIDATE":
		strTag3="N";
	elif tpNewRead3[0]=="PASS":
		strTag3="R"
	elif tpNewRead3[0]=="FAIL_MULTIMAP":		
		strTag3="M"
	
	strZRtag=strTag5+strTag3;
	newRead.set_tag("ZR", strZRtag);	
	osamfile.write( newRead );
	
	strStat5="\t".join([ strTag5, tpNewRead5[0] ]);
	strStat3="\t".join([ strTag3, tpNewRead3[0] ] );
	strStat="\t".join([read.query_name, strStat5, strStat3, read.reference_name, str(read.reference_start), str(read.reference_end), "-" if read.is_reverse else "+"])+"\n";
	ostatfile.write(strStat);


irebam.close();
samfile.close();
osamfile.close();
ostatfile.close();

