#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse
import random
import math
import collections
import fileinput
prog = re.compile(r"(\d+)([MISDHN])")

class InsPos:
	def __init__(self,insid,refchr,pos,bandwidth):
		self.insid=insid
		self.refchr=refchr
		self.pos=pos
		self.bandwidth=bandwidth
		self.start=self.pos-self.bandwidth
		self.end=self.pos+self.bandwidth

def readInsertionPositions(clufi,bw):
	toret=[]
	"""
	X	21519138	10000
	"""
	for l in open(clufi):
		
		a=re.split("\s+",l.rstrip("\n"))
		refchr,inspos,bandwidth=a[0],int(a[1]),int(a[2])
		insid="{0}:{1}".format(refchr,inspos)
		if bw is not None: # overwrite if parameter is provide explicitely by user
			bandwidth=bw
		ip=InsPos(insid,refchr,inspos,bandwidth)
		toret.append(ip)
	return toret
		


parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""

piRNA: 23-29nt
min-mq only important for distinguishing ambiguos and unambiguous;
both are reported

Authors
-------
    Robert Kofler
""")
parser.add_argument('--sam', type=argparse.FileType('r'), default=None,dest="sam", required=True, help="A sam file")
parser.add_argument("--insert-pos", type=str, required=True, dest="insertpos", help="a file with insertion positions")
parser.add_argument("--sample-id", type=str, required=True, dest="sid", help="sample ID")

parser.add_argument("--min-mq", type=int, required=False, dest="minmq", default=1, help="min mapping quality")
parser.add_argument("--max-mm", type=int, required=False, dest="maxmm", default=10, help="max mismatches")
parser.add_argument("--pi-min", type=int, required=False, dest="pistart", default=23, help="piRNA start")
parser.add_argument("--pi-max", type=int, required=False, dest="piend", default=29, help="piRNA end")
parser.add_argument("--bw", type=int, required=False, dest="bw", default=None, help="Bandwidth")
args = parser.parse_args()
minmq=args.minmq
maxmm=args.maxmm
pistart=args.pistart
piend=args.piend
sid=args.sid # sample ID
inspos=readInsertionPositions(args.insertpos,args.bw)


picount=0


# key chr, position
ps=collections.defaultdict(lambda:collections.defaultdict(lambda:0))
pas=collections.defaultdict(lambda:collections.defaultdict(lambda:0))
ups=collections.defaultdict(lambda:collections.defaultdict(lambda:0))
upas=collections.defaultdict(lambda:collections.defaultdict(lambda:0))


reo=re.compile(r"NM:i:(\d+)")

def get_end(start,cig):
	
	result = re.findall(prog,cig)
	alignmentleng=0
	for count,cigchar in result:
		count=int(count)
		if cigchar=="M" or cigchar=="D":
			alignmentleng+=count
	
	end=start+alignmentleng-1
	return end
		


for line in args.sam:
	"""
0                                       1         2          3   4    5    6    7    8               9                        10                  11
HISEQ:273:CB6J0ANXX:7:2112:15337:99757	16	NIFY01000001.1	14	0	25M	*	0	0	TGTCAATGTCTGAAATCCATGCATA	FFFFFFFBFFBFFFFFFBFFBBBBB	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:25	ZS:Z:R	NH:i:22	HI:i:1	IH:i:1
HISEQ:273:CB6J0ANXX:7:2215:1504:84523	16	NIFY01000001.1	15	0	24M	*	0	0	GTCAATGTCTGAAATCCATGCATA	     FFFFFFBFFFFFFFFFFFFBB<BB	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:24	ZS:Z:R	NH:i:22	HI:i:1	IH:i:1
	"""
	a=line.rstrip("\n").split("\t")
	
	# discard unmapped
	flag=int(a[1])
	if flag & 0x004 > 0:
		continue
	
	
	# discard mismatch
	mm=0
	mo=re.search(reo,line)
	if mo is None:
		continue
	mm=int(mo.group(1))
	if(mm>maxmm):
		continue
	

	# discard readlen
	readlen=len(a[9])
	if(readlen<pistart or readlen > piend):
		continue
	
	mq=int(a[4])     
	picount+=1  # normalization to all mapped RNAs between 23 and 29; and a limited number of mismatches < maxmm; be it ambiguous or unambiguous
	ref=a[2]
	start=int(a[3])
	if flag& 0x10:
		# reverse complement if flag 0x10 is set
		end=get_end(start,a[5]) # for reverse complements get the end
		pas[ref][end]+=1
		if mq>=minmq:
			upas[ref][end]+=1     
	else:
		ps[ref][start]+=1
		if mq>=minmq:
			ups[ref][start]+=1
			
	
  
  
  
  
  
  

for ip in inspos:
	# cluid,ref,start,end,gapped,rc
	ref=ip.refchr
	for i in range(ip.start,ip.end+1):
		count=ps[ref][i]
		if count<1:
			continue
		normc=float(count)/float(picount)
		normc*=1000000.0
		pos=i-ip.start+1
		topr=[str(i) for i in [sid,ip.insid,ref,pos,"am","s",normc]]
		print("\t".join(topr))
	for i in range(ip.start,ip.end+1):
		count=pas[ref][i]
		if count<1:
			continue
		normc=float(count)/float(picount)
		normc*=1000000.0
		pos=i-ip.start+1
		topr=[str(i) for i in [sid,ip.insid,ref,pos,"am","as",normc]]
		print("\t".join(topr))
	for i in range(ip.start,ip.end+1):
		count=ups[ref][i]
		if count<1:
			continue
		normc=float(count)/float(picount)
		normc*=1000000.0
		pos=i-ip.start+1
		topr=[str(i) for i in [sid,ip.insid,ref,pos,"un","s",normc]]
		print("\t".join(topr))
	for i in range(ip.start,ip.end+1):
		count=upas[ref][i]
		if count<1:
			continue
		normc=float(count)/float(picount)
		normc*=1000000.0
		pos=i-ip.start+1
		topr=[str(i) for i in [sid,ip.insid,ref,pos,"un","as",normc]]
		print("\t".join(topr))

