#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse
import random
import math
import collections
import fileinput
prog = re.compile(r"(\d+)([MISDHN])")


class SpliceSite:
	def __init__(self,key):
		self.key=key
		self.se=0
		self.ase=0

	def total(self):
		return self.se+self.ase



def isSense(flag):
	"""
	In Read 1, sequencing reads map to the antisense strand.
	In Read 2, sequencing reads map to the sense strand.
	https://support.illumina.com/content/dam/illumina-support/documents/documentation/chemistry_documentation/samplepreps_truseq/truseq-stranded-mrna-workflow/truseq-stranded-mrna-workflow-reference-1000000040498-00.pdf
	"""
	assert flag & 0x1 > 0 # read must be paired end; otherwise the thing does not work
	
	if (flag & 0x40 > 0):
		# first in pair
		if(flag & 0x10 ):
			# if first in pair and reverse complement -> sense transcript
			return True
		else:
			# if first in pair and sense strand -> antisense transcript
			return False
	elif (flag & 0x80>0):
		# second in pair
		if(flag &0x0):
			# if second in pair and reverse complement -> antisense transcript
			return False 
		else:
			# if second in pair and sense strand -> sense transcript
			return True
	else:
		raise Exception("weird read; neither first nor second")


def getKey(hash,key,tolerance):
	start,end=key.split("-")
	start,end=int(start),int(end)
	for i in range(start-tolerance,start+tolerance):
		for k in range(end-tolerance,end+tolerance):
			newkey="{0}-{1}".format(i,k)
			if newkey in hash:
				return newkey
	return key
	

def checkTolerance(sh,tolerance):
	toret=collections.defaultdict(lambda:{})
	for seqid,keyhash in sh.items():
		sortk=sorted(keyhash.values(),key=lambda i:-i.total())
		newhash={}
		for k in sortk:
			newkey=getKey(newhash,k.key,tolerance)
			if newkey not in newhash:
				newhash[newkey]=SpliceSite(newkey)
			newhash[newkey].se+=k.se
			newhash[newkey].ase=k.ase
		toret[seqid]=newhash
	return toret
	


def getSpliceKeys(start,cigar,minlen):
	toret=[]
	#print(len(at))
	result = re.findall(prog,cigar)
	curpos=start
	for count,cigchar in result:
		count=int(count)
		if cigchar=="M" or cigchar=="=" or cigchar=="X": # match; perfect match, mismatch
			curpos+=count
		elif cigchar=="D" or cigchar=="N":
			if(count>=minlen):
				# ulalal we found a suitable splice site
				# 0123456789
				# AA**TTTTTT count=2
				# start = 1
				# end   = 4 = start+count+1
				key="{0}-{1}".format(curpos,curpos+count+1) # the splice site is the start and the end site
				toret.append(key)
			curpos+=count
		elif cigchar == "I" or cigchar=="S" or cigchar=="P":
			pass
		else:
			raise Exception("unrecognized cighcar "+cigchar)
	return toret


parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
miRNA: 21-23nt
piRNA: 23-28nt


Authors
-------
    Robert Kofler
""")
parser.add_argument('--sam', type=argparse.FileType('r'), default=None,dest="sam", required=True, help="A sam file")
parser.add_argument("--min-mq", type=int, required=False, dest="minmq", default=10, help="min mapping quality")
parser.add_argument("--tolerance", type=int, required=False, dest="tolerance", default=2, help="tolerance of the splice site")
parser.add_argument("--min-len", type=int, required=False, dest="minlen", default=10, help="min length of splicing")
parser.add_argument("--sample-id", type=str, required=True, dest="sid", default=10, help="the sample id")
parser.add_argument("--seqs", type=str, required=True, dest="seqs", help="the IDs of the sequences to analyze, coma-separated list")

args = parser.parse_args()
minmq=args.minmq
minlen=args.minlen
sid=args.sid
seqids=args.seqs
if "," in seqids:
	seqids=set(seqids.split(","))
else:
	seqids=set([seqids])

# splice hash
# seqid -> splicesite (se,asecount) # SpliceSite(key,se,ase)
sh = collections.defaultdict(lambda:{}) 

mappedreads=0
for line in args.sam:
	a=line.rstrip("\n").split("\t")
    # discard unmapped
	flag=int(a[1])
	if flag & 0x004 > 0:
		continue 
	# Note mappedreads used for normalization
	mappedreads+=1
	# discard low mapping quality
	mq=int(a[4])
	if mq< minmq:
		continue
	# discard non-targeted genes;
	refchr=a[2]
	if refchr not in seqids:
		continue
	# 0											1	2	3		4	5		6	7	8
	# A00902:303:H3VMCDRXY:2:2243:31901:17973	153	412	5787	37	22S78M	*	0	0
	cigar=a[5]
	start=int(a[3])
	hassplice='D' in cigar or 'N' in cigar
	if not hassplice:
		continue # if no splicing - > fuck it
	keys=getSpliceKeys(start,cigar,minlen)
	
	# get the active transcript
	issens=isSense(flag)
	for k in keys:
		if k not in sh[refchr]:
			sh[refchr][k]=SpliceSite(k)
		if(issens):
			sh[refchr][k].se+=1
		else:
			sh[refchr][k].ase+=1

# do the tolerance thing - unless there is no tolerance
tolerancesh=sh
if (args.tolerance>0):
	tolerancesh=checkTolerance(sh,args.tolerance)


# normalization to a million mapped reads
normratio=float(1000000)/float(mappedreads)
for seqid,hash in tolerancesh.items():
	for k,ss in hash.items():
		start,end=k.split("-")
		if(ss.se>0):
			senorm=ss.se*normratio
			toprse= [args.sid, "se",seqid,k,start,end,str(ss.se), str(senorm)]
			print("\t".join(toprse))
		if(ss.ase>0):
			asenorm=ss.ase*normratio
			toprase=[args.sid,"ase",seqid,k,start,end,str(ss.ase),str(asenorm)]
			print("\t".join(toprase))

