import sys
import pysam

ibam=sys.argv[1]
iexon=sys.argv[2]	#outer exon only
i5exon_innerSA=sys.argv[3]	#Distance from TSS to SA
i5exon_upperintron=sys.argv[4]	
i3exon_innerSA=sys.argv[5]
i3exon_upperintron=sys.argv[6]
ofpath=sys.argv[7]

##v2 is too slow
##v3: change of dictionary.keys() to dictionary. Since dictionary.keys: O(n) while dictionary: O(1)
#### v2 tooks 30 minutes for LRCAGE.BR1 data, while v3 took 0.5 minute

##Criteria for realignment
### Protruding reads from splice junction 1) aliged position itself is protruding, 2) after accounting for soft-clipped reads, it is protruding
##For 5' end realignment
### For 1), need the distane to downstream non-5' exons, neighbor exon should locates within the subject exon
### For 2), need the distance to upstream introns
##For 3' end realigment
### For 1), need the distance to upstream non-3' exons, neighbor exon should locates within the subject exon
### For 2), need the distance to downstream introns

##Output
#readname	
#5end_protrud		: length of protruding bases
#5end_protrud_withsc	
#3end_protrud	
#3end_protrud_withsc	
#Need5endRealn   
#Need3endRealn


NUM_LARGE=9999999999

def LoadBorderExon():
	dict5Exon={}	#For each 5' exon, what is the closest SA
	dict3Exon={}	#For each 3' exon, what is the closest SA

	print("Load 5exon")	
	for line in open( i5exon_innerSA ):
		arr=line.rstrip("\n").split("\t")
		strID="_".join([arr[0], arr[1], arr[2], arr[5]])
		if arr[0] in dict5Exon:		
			if strID in (dict5Exon[arr[0]]):
				#only min value is stored
				if abs(int(arr[6])) < abs( int( dict5Exon[arr[0]][ strID ] )):
					dict5Exon[ arr[0] ][ strID ]=int(arr[6]);
			else:
				dict5Exon[ arr[0] ][ strID ]=int(arr[6]);
		else:
			dict5Exon[ arr[0] ]={};
			dict5Exon[ arr[0] ][ strID ]=int(arr[6]);


	print("Load 3exon")	
	for line in open( i3exon_innerSA ):
		arr=line.rstrip("\n").split("\t")
		strID="_".join([arr[0], arr[1], arr[2], arr[5]])
	
		if arr[0] in dict3Exon:
			if strID in (dict3Exon[arr[0]]):
				#only min value is stored
				if abs(int(arr[6])) < abs( int( dict3Exon[arr[0]][ strID ] )):
					dict3Exon[arr[0]][ strID ]=int(arr[6]);
			else:
				dict3Exon[arr[0]][ strID ]=int(arr[6]);
		else:
			dict3Exon[ arr[0] ]={};
			dict3Exon[ arr[0] ][ strID ]=int(arr[6]);
		
	print("Combine two exon info");	
	dictExonPerRead={};	#For each read, how many base pair aparted from splice junction
	#dictExonPerRead[ chrom ][ read_name ]=(5end, 3end)
	for line in open( iexon, "r"):
		arr=line.rstrip("\n").split("\t")
		strID="_".join([arr[0], arr[1], arr[2], arr[5]])		

		keysDict5Exon=(dict5Exon[arr[0]]) if arr[0] in dict5Exon else {None};
		keysDict3Exon=(dict3Exon[arr[0]]) if arr[0] in dict3Exon else {None};
		if strID not in keysDict5Exon and strID not in keysDict3Exon:
			continue;

		#arr[3]: query read name
		if arr[0] not in dictExonPerRead:
			dictExonPerRead[ arr[0] ]={}
			dictExonPerRead[ arr[0] ][ arr[3] ]=(NUM_LARGE, NUM_LARGE)	
		elif arr[3] not in dictExonPerRead[ arr[0] ]:
			dictExonPerRead[ arr[0] ][ arr[3] ]=(NUM_LARGE, NUM_LARGE);

		if arr[6] in ["5end","single"]:
			if strID in keysDict5Exon:
				dictExonPerRead[ arr[0] ][ arr[3] ]=( keysDict5Exon[strID], dictExonPerRead[ arr[0] ][ arr[3] ][1] );

		if arr[6] in ["3end", "single"]:
			if strID in keysDict3Exon:
				dictExonPerRead[ arr[0] ][ arr[3] ]=( dictExonPerRead[ arr[0] ][ arr[3] ][0], keysDict3Exon[strID] );	

	return dictExonPerRead;

def LoadIntron(a_bCTSS):

	dictIntron={};
	print("Load Intron")
	for line in open(i5exon_upperintron, "r") if a_bCTSS else open(i3exon_upperintron, "r"):
		arr=line.rstrip("\n").split("\t")
		strID="_".join([arr[0], arr[1], arr[2], arr[5]])

		if arr[0] in dictIntron:
			if strID in (dictIntron[arr[0]]):
				#only min value is stored
				if abs(int(arr[6])) < abs( int( dictIntron[arr[0]][ strID ] )):
					dictIntron[arr[0]][ strID ]=int(arr[6]);
			else:
				dictIntron[arr[0]][ strID ]=int(arr[6]);	
		else:
			dictIntron[arr[0]]={};
			dictIntron[arr[0]][strID]=int(arr[6]);					
	return(dictIntron)

dictExonPerRead=LoadBorderExon();	##nested dictionary
dictIntron5=LoadIntron(True);		
dictIntron3=LoadIntron(False);

#for key in dictExonPerRead.keys():
#	print("\t".join( [ key, str(dictExonPerRead[key][0]), str(dictExonPerRead[key][1])] ));
#for key in dictIntron5.keys():
#	print("\t".join( [ key, str(dictIntron5[key]) ]) );
#for key in dictIntron3.keys():
#	print("\t".join( [ key, str(dictIntron3[key]) ]) );

def CompareSoftclipToIntronDist( a_read, a_upperintron, a_bCTSS ):
	if a_bCTSS:
		nSC5len=NUM_LARGE;
		if not a_read.is_reverse and a_read.cigartuples[0][0]==4:
			nSC5len=a_read.cigartuples[0][1];
		elif a_read.is_reverse and a_read.cigartuples[-1][0]==4:
			nSC5len=a_read.cigartuples[-1][1];
		return (abs(a_upperintron), nSC5len) if nSC5len!=NUM_LARGE else None;
	else:
		nSC3len=NUM_LARGE;
		if not a_read.is_reverse and a_read.cigartuples[-1][0]==4:
			nSC3len=a_read.cigartuples[-1][1];
		elif a_read.is_reverse and a_read.cigartuples[0][0]==4:
			nSC3len=a_read.cigartuples[0][1];
		return (abs(a_upperintron), nSC3len) if nSC3len!=NUM_LARGE else None;
	return None;
		
def GetSoftclipLength( a_read , a_bCTSS ):
	nSCLen=0
	if a_bCTSS:
		if not a_read.is_reverse and a_read.cigartuples[0][0]==4:
			nSCLen=a_read.cigartuples[0][1];
		elif a_read.is_reverse and a_read.cigartuples[-1][0]==4:
			nSCLen=a_read.cigartuples[-1][1];
	else:
		if not a_read.is_reverse and a_read.cigartuples[-1][0]==4:
			nSCLen=a_read.cigartuples[-1][1];
		elif a_read.is_reverse and a_read.cigartuples[0][0]==4:
			nSCLen=a_read.cigartuples[0][1];
	return(nSCLen);



ofile=open(ofpath, "w+");
ofile.write("\t".join(["#readname", "5end_protrud", "5end_protrudbysc", "3end_protrud", "3end_protrudbysc"])+"\n")
ibamfile=pysam.AlignmentFile( ibam, "rb");
strBuffer=""
nIndex=1;
print("Finding realign candidates")
for read in ibamfile.fetch():
	#if read.query_name not in ["m64043_200308_011316/109707318/ccs"]:
	#	continue;

	strID1="_".join( [ read.reference_name, str(read.reference_start), str(read.reference_start+1), "-" if read.is_reverse else "+" ])
	strID2="_".join( [ read.reference_name, str(read.reference_end-1), str(read.reference_end), "-" if read.is_reverse else "+" ])	

	strCTSSID=strID1 if not read.is_reverse else strID2;
	strCTESID=strID2 if not read.is_reverse else strID1;

	if read.reference_name not in dictExonPerRead or read.reference_name not in dictIntron5 or read.reference_name not in dictIntron3:
		continue;
	
	bHasInfo=read.query_name in (dictExonPerRead[read.reference_name]);
	bHasInfoIntron5=strCTSSID in (dictIntron5[read.reference_name])
	bHasInfoIntron3=strCTESID in (dictIntron3[read.reference_name])
		
	if False==bHasInfo and False==bHasInfoIntron5 and False==bHasInfoIntron3:
		continue;


	#Protruding?
	n5endProtruding=max( abs(dictExonPerRead[read.reference_name][ read.query_name ][0]),0 ) if read.query_name in dictExonPerRead[read.reference_name] else 0
	n5endProtruding=0 if n5endProtruding==NUM_LARGE else n5endProtruding;
	n3endProtruding=max( abs( dictExonPerRead[read.reference_name][ read.query_name ][1]),0 ) if read.query_name in dictExonPerRead[read.reference_name] else 0
	n3endProtruding=0 if n3endProtruding==NUM_LARGE else n3endProtruding



	#Distance to upstream intron
	##While CTES, CTSS locates within exon, intron is 1bp away from the exon border
	##For that reason even the CTSS/CTES right next to intron border, their distance is counted as 1
	##I will adjust by subtract 1bp
	n5end_distUpIntron=abs(dictIntron5[read.reference_name][ strCTSSID ])-1 if strCTSSID in dictIntron5[read.reference_name] else NUM_LARGE;
	n3end_distUpIntron=abs(dictIntron3[read.reference_name][ strCTESID ])-1 if strCTESID in dictIntron3[read.reference_name] else NUM_LARGE;
	
	#Protruding with softclip?
	n5endProtrudingSC=max( GetSoftclipLength(read, True)-n5end_distUpIntron, 0) if n5end_distUpIntron != NUM_LARGE else 0
	n3endProtrudingSC=max( GetSoftclipLength(read, False)-n3end_distUpIntron, 0) if n3end_distUpIntron != NUM_LARGE else 0


	#print(dictExonPerRead[read.reference_name][ read.query_name ][0])
	#print("\t".join([read.reference_name, str(read.reference_start), str(read.reference_end), str(n5endProtruding), str(n3endProtruding), str(n5endProtrudingSC), str(n3endProtrudingSC), str(n5end_distUpIntron), str(n3end_distUpIntron ) ]));	

	nIndex+=1

#	str5endRealn="O" if (n5endProtruding>0 and n5endProtruding!=NUM_LARGE) or n5endProtrudingSC>0 else "X" 
#	str3endRealn="O" if (n3endProtruding>0 and n3endProtruding!=NUM_LARGE) or n3endProtrudingSC>0 else "X"

	if n5endProtruding==0 and n5endProtrudingSC==0 and n3endProtruding==0 and n3endProtrudingSC==0:
		continue;

	strLine="\t".join([read.query_name, str(n5endProtruding), str(n5endProtrudingSC), str(n3endProtruding), str(n3endProtrudingSC) ])+"\n"  	
	if strBuffer=="":
		strBuffer=strLine;		
	else:
		strBuffer+=strLine;

	if nIndex%10000==0:
		ofile.write( strBuffer );
		strBuffer="";	

if nIndex%10000!=0:
	ofile.write( strBuffer );

ibamfile.close();
ofile.close();

	#readname       5end protruding length  5end distance2upstream intron   5end softcliplength     3end protruding length  3end distance2downstream intron 3end softclip length	


