#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse

prog = re.compile(r"(\d+)([MISDHN])")

def get_end(start,result):
     
     alignmentleng=0
     for count,cigchar in result:
          count=int(count)
          # characters that consume reference https://samtools.github.io/hts-specs/SAMv1.pdf
          if cigchar=="M" or cigchar=="D" or cigchar=="N" or cigchar=="=" or cigchar=="X":
               alignmentleng+=count
     
     end=start+alignmentleng-1
     return end

class PAF:
	def __init__(self,qname,qlen,qstart,qend,strand,tname,tlen,tstart,tend,matches,allen,mapq,cigar):
		self.qname=qname
		self.qlen=qlen
		self.qstart=qstart
		self.qend=qend
		self.strand=strand
		self.tname=tname
		self.tlen=tlen
		self.tstart=tstart
		self.tend=tend
		self.matches=matches
		self.allen=allen
		self.mapq=mapq
		self.cigar=cigar
		self.shortrep="{0}..{1}   {2}{3}:{4}..{5}".format(qstart,qend,strand,tname,tstart,tend)
		
	def __str__(self):
		return(self.shortrep)
	
	def antistrand(self):
		if(self.strand=="+"):
			return "-"
		elif(self.strand=="-"):
			return "+"
		else:
			raise Exception("unknown strand "+self.strand)
	

class ReadPAF:
	def __init__(self,paflist):
		assert(len(paflist)>0)
		self.paflist=sorted(paflist,key=lambda i: i.qstart)
		topr=[paflist[0].qname]
		for i in self.paflist:
			topr.append(i.shortrep)
		self.shortrep="\n".join(topr)
	
	def __str__(self):
		return(self.shortrep)
	
	def subalignments(self):
		return len(self.paflist)
	
	def containsTargetName(self,tname):
		for p in self.paflist:
			if p.tname==tname:
				return True
		return False
	
	def readlen(self):
		return self.paflist[0].qlen
	
	def qname(self):
		return self.paflist[0].qname

	def updateShortCigar(self,minsize):
		# update cigar just for the short representation of representation
		# this should enhance readability.
		# of course, this cigar is useless and should not be used for parsing
		for p in self.paflist:
			p.cigtuples=re.findall(prog,p.cigar)
			cigtup=[]
			for count,cigchar in p.cigtuples:
				count=int(count)
				if count>minsize:
					if(len(cigtup)>0 and cigtup[-1][0]==cigchar):
						cigtup[-1][1]+=count
					else:
						cigtup.append([cigchar,count])
			cigtup=[cha+str(co) for cha,co in cigtup]
			shortrep=" ".join(cigtup)
			p.shortrep="{0}..{1}   {2}{3}:{4}..{5} {6}".format(p.qstart,p.qend,p.strand,p.tname,p.tstart,p.tend,shortrep)
		topr=[self.paflist[0].qname]
		for i in self.paflist:
			topr.append(i.shortrep)
		self.shortrep="\n".join(topr)
		
	def deconvolute(self):
		# deal with the weird reads that seem to fold back (similar to 2D in ONT protocols)
		# get rid of everything that is present in other strand
		# conservatie algorithm, terminate if the current paf-alignment is basically any of the previous ones 
		previous=set([])
		toret=[]
		for p in self.paflist:
			key="{0}{1}".format(p.strand,p.tname)
			antikey="{0}{1}".format(p.antistrand(),p.tname)
			if antikey in previous:
				break
			previous.add(key)
			toret.append(p)
		return ReadPAF(toret)


class PAFReader:
	"""
	A light-weight PAF reader;
	returns a ReadPAF
	
	"""
	def __init__(self,fh):
		self.__filehandle=open(fh)
		self.__prevqname=None
		self.__prevpaf=None

	def __iter__(self):
		return self
	
	def close(self):
		self.__filehandle.close()
	
	def __next__(self):
		return self.next()
	
	def next(self):
		line=""
		prevqname=self.__prevqname
		paflist=[]
		if(self.__prevpaf is not None):
			paflist.append(self.__prevpaf)
			

		while(1):
			line=self.__filehandle.readline()
			if line=="":					# end of file
				if(prevqname is not None):
					self.__prevqname=None		#
					self.__prevpaf=None
					return(ReadPAF(paflist))
				else:
					raise StopIteration
					# file empty and no last sequence -> STOP
			line=line.rstrip("\n")				# somethin is in the file
			a=line.split("\t")
			# 0											1	  2			3     4			5				6				7				8				9		10	  11       12       13
			# adb3e330-d3cf-4bad-b04a-298f6cd124bd    11989   25      11971   +       contig_26       23990342        14874730        14886730        10727   12058   60      tp:A:P  cm:i:1808 s1:i:10720      s2:i:74 dv:f:0.0142     rl:i:15
			# 3ba79a29-dfd8-491f-8e07-3d9c3239ac10	  420	  31	  420	  -	 	  contig_422	19827742		  2995553		   2995952		  374	  401	  60	  NM:i:27	ms:i:632	AS:i:632	nn:i:0	tp:A:P	cm:i:36	s1:i:246	s2:i:0	de:f:0.0556	rl:i:0	cg:Z:5M1D143M2I114M2D1M1D18M1D6M1D20M1D51M2D4M3D25M
			#						0	1	2		3		4	5	 6		7	  8		9		10	11
			# def __init__(self,qname,qleng,qstart,qend,strand,tname,tlen,tstart,tend,matches,allen,mapq):
			cigscan=a[12:]
			cigar=""
			for c in cigscan:
				if c.startswith("cg:Z:"):
					cigar=c[5:]
			curpaf=PAF(a[0],int(a[1]),int(a[2]), int(a[3]), a[4], a[5], int(a[6]), int(a[7]), int(a[8]), int(a[9]), int(a[10]), int(a[11]),cigar)
			if(prevqname is None):
				prevqname=curpaf.qname
				paflist.append(curpaf)
			elif(curpaf.qname==prevqname):
				paflist.append(curpaf)
			else:
				self.__prevqname=curpaf.qname
				self.__prevpaf=curpaf
				return(ReadPAF(paflist))
class InsPos:
	def __init__(self,chrm,pos,support,pelekey):
		self.chrm=chrm
		self.pos=pos
		self.support=support
		self.pelekey=pelekey

def getOrientation(genome,pele):
	if(genome.strand==pele.strand):
		return "fwd"
	else:
		return "rev"

			
def getInsertionposition(rpaf):
	subs=rpaf.paflist
	subs=sorted(subs,key=lambda i: i.qstart)

	# aggregate several tandem P-insertions (e.g. KP-element)
	psub=[]
	for i in range(0,len(subs)):
		cur=subs[i]
		# treat non-Pele reads
		if(cur.tname!="PPI251"):
			psub.append(cur)
			continue
		
		# aggregate P-ele reads; ignore additional ones but add the key
		pkey="{0}{1}..{2}".format(cur.strand,cur.tstart,cur.tend)
		if(len(psub)>0 and psub[-1].tname=="PPI251"):
			prev=psub[-1]
			prev.pkey+=" {0}".format(pkey)
		else:
			cur.pkey=pkey
			psub.append(cur)
			
	
	
	toret=[]# [chrm,pos,pkey,support]
	for i in range(0,len(psub)):
		cur=psub[i]
		if(cur.tname!="PPI251"):
			continue
		if(i>0 and i<(len(psub)-1)):
			prev,nex=psub[i-1],psub[i+1]
			if(prev.tname!=nex.tname):
				toret.append([prev.tname+" "+nex.tname,str(prev.tstart)+" "+str(nex.tstart),cur.pkey,"NA","NA"])
			else:
				assert(prev.strand==nex.strand) # after deconvolution, should be identical
				strand=prev.strand
				pos=0
				if(strand=="+"):
					pos=int(float(prev.tend+nex.tstart)/2.0)
				elif(strand=="-"):
					pos=int(float(prev.tstart+nex.tend)/2.0)
				else:
					raise Exception("unknown strand")

				toret.append([prev.tname, pos, cur.pkey, "FR",getOrientation(prev,cur)])
		elif(i>0):
			prev=psub[i-1]
			strand=prev.strand
			if(strand=="+"):
				toret.append([prev.tname, prev.tend, cur.pkey,"F",getOrientation(prev,cur)])
			elif(strand=="-"):
				toret.append([prev.tname, prev.tstart, cur.pkey,"F",getOrientation(prev,cur)])
			else:
				raise Exception("Unknown strand "+strand)
		elif(i<(len(psub)-1)):
			nex=psub[i+1]
			strand=nex.strand
			if(strand=="+"):
				toret.append([nex.tname, nex.tstart, cur.pkey,"R",getOrientation(nex,cur)])
			elif(strand=="-"):
				toret.append([nex.tname, nex.tend, cur.pkey,"R",getOrientation(nex,cur)])
			else:
				raise Exception("Unknown strand "+strand)
	return toret
	



parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Authors
-------
    Robert Kofler
""")
parser.add_argument("--paf", type=str, required=True, dest="paf", default=None, help="paf")
# parser.add_argument("--minlen", type=int, required=False, dest="minlen", default=100, help="minlen")
args,unknown = parser.parse_known_args()



for rpaf in PAFReader(args.paf):
	#deal with weird ONT foldback
	rpaf=rpaf.deconvolute()
	
	if(rpaf.subalignments()<2):
		continue
	if(not rpaf.containsTargetName("PPI251")):
		continue
	rpaf.updateShortCigar(25)
	#print("{0}\t{1}\t{2}".format(rpaf.qname(), rpaf.readlen(),rpaf.subalignments()))

	inspos=getInsertionposition(rpaf)
	
	for k in inspos:
		print("{0}\t{1}\t{2}\t{3}\t{4}".format(k[0], k[1], k[2], k[3],k[4]))
		  

			
		

