#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse
import random
import math
import collections
import fileinput
prog = re.compile(r"(\d+)([MISDHN])")


def load_fai(faifile):
	toret={}
	for l in open(faifile):
		# PPI251	2907	132096	2907	2908
		a=l.rstrip("\n").split("\t")
		actseid=a[0]
		actlen=int(a[1])
		toret[actseid]=actlen
	return toret


def isSense(flag):
	"""
	In Read 1, sequencing reads map to the antisense strand.
	In Read 2, sequencing reads map to the sense strand.
	https://support.illumina.com/content/dam/illumina-support/documents/documentation/chemistry_documentation/samplepreps_truseq/truseq-stranded-mrna-workflow/truseq-stranded-mrna-workflow-reference-1000000040498-00.pdf
	"""
	assert flag & 0x1 > 0 # read must be paired end; otherwise the thing does not work
	
	if (flag & 0x40 > 0):
		# first in pair
		if(flag & 0x10 ):
			# if first in pair and reverse complement -> sense transcript
			return True
		else:
			# if first in pair and sense strand -> antisense transcript
			return False
	elif (flag & 0x80>0):
		# second in pair
		if(flag &0x0):
			# if second in pair and reverse complement -> antisense transcript
			return False 
		else:
			# if second in pair and sense strand -> sense transcript
			return True
	else:
		raise Exception("weird read; neither first nor second")




parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
miRNA: 21-23nt
piRNA: 23-28nt


Authors
-------
    Robert Kofler
""")
parser.add_argument('--sam', type=argparse.FileType('r'), default=None,dest="sam", required=True, help="A sam file")
parser.add_argument('--fai', type=str, dest="fai", required=True, help="A fai file for the reference")
parser.add_argument("--min-mq", type=int, required=False, dest="minmq", default=1, help="min mapping quality")
parser.add_argument("--sample-id", type=str, required=True, dest="sid",  help="the sample id")


args = parser.parse_args()
minmq=args.minmq
sid=args.sid
fai=load_fai(args.fai)

# counthash
coha=collections.defaultdict(lambda:[0,0]) # index0 = sense, index1=antisense
# normalization!
mappedreads=0
for line in args.sam:
	a=line.rstrip("\n").split("\t")
    # discard unmapped
	flag=int(a[1])
	if flag & 0x004 > 0:
		continue 
	
	mappedreads+=1
	# discard low mapping quality
	mq=int(a[4])
	if mq< minmq:
		continue

	# Reads mapped to a reference gene are used for normalization
	refchr=a[2]

	# get the active transcript
	issens=isSense(flag)
	seindex=0
	if not issens:
		seindex=1
	coha[refchr][seindex]+=1


for gene,covh in coha.items():
	secov,asecov=float(covh[0]),float(covh[1])
	rawse,rawase=secov,asecov

	gl=float(fai[gene])
	norm_map=float(mappedreads)/1000000.0
	norm_len=float(gl)/1000.0
	# RPKM means: "Reads per kilobase of transcript per Million mapped reads"
	# normalize
	secov=secov/(norm_map*norm_len)
	asecov=asecov/(norm_map*norm_len)
	# eg i have a gene with 6000 reads;
	# the length is 6bk (norm_len)
	# hence the expression per kp is 6000/6 = 1000
	# further 100million reads were mapped, so the expression per million mapped reads (norm_map) is
	# 1000/100 = 10
	# or expression/(norm_len * norm_map) = 6000/(6*100)

	topr=[args.sid,gene,str(rawse),str(rawase),str(gl),str(secov),str(asecov),str(secov+asecov)]
	print("\t".join(topr))