#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse
import random
import math
import collections
import fileinput
prog = re.compile(r"(\d+)([MISDHN])")


def isSense(flag):
	"""
	In Read 1, sequencing reads map to the antisense strand.
	In Read 2, sequencing reads map to the sense strand.
	https://support.illumina.com/content/dam/illumina-support/documents/documentation/chemistry_documentation/samplepreps_truseq/truseq-stranded-mrna-workflow/truseq-stranded-mrna-workflow-reference-1000000040498-00.pdf
	"""
	assert flag & 0x1 > 0 # read must be paired end; otherwise the thing does not work
	
	if (flag & 0x40 > 0):
		# first in pair
		if(flag & 0x10 ):
			# if first in pair and reverse complement -> sense transcript
			return True
		else:
			# if first in pair and sense strand -> antisense transcript
			return False
	elif (flag & 0x80>0):
		# second in pair
		if(flag &0x0):
			# if second in pair and reverse complement -> antisense transcript
			return False 
		else:
			# if second in pair and sense strand -> sense transcript
			return True
	else:
		raise Exception("weird read; neither first nor second")


def fill_in_coverage(at,start,cig,a):
	# https://samtools.github.io/hts-specs/SAMv1.pdf
	# 12 345 6789
	#    AAA
	# start = 3, len =3
	# newstart =start+len=6              # check ok
	# coverage update range(3,6) ie 345  # check ok
	#print(len(at))
	result = re.findall(prog,cig)
	curpos=start
	for count,cigchar in result:
		count=int(count)
		if cigchar=="M" or cigchar=="=" or cigchar=="X": # match; perfect match, mismatch
			for i in range(start,start+count):
				if(i>=len(at)):
					print("problem with " +str(i) + " " +str(start)+ " "+cig+str(a))
				at[i]+=1
			start+=count
		elif cigchar=="D" or cigchar=="N":
			start+=count
		elif cigchar == "I" or cigchar=="S" or cigchar=="P":
			pass
		else:
			raise Exception("unrecognized cighcar "+cigchar)



def initialize_se_ase(seqids,fai):
	se={}
	ase={}
	for l in open(fai):
		# PPI251	2907	132096	2907	2908
		a=l.rstrip("\n").split("\t")
		actseid=a[0]
		actlen=int(a[1])
		if actseid not in seqids:
			continue
		se[actseid]=[0,]*(actlen+1) # why plus 1 ? to make indexing 1-based; is just more convenient for the coverage
		ase[actseid]=[0,]*(actlen+1)
	return se,ase



          
     

parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
miRNA: 21-23nt
piRNA: 23-28nt


Authors
-------
    Robert Kofler
""")
parser.add_argument('--sam', type=argparse.FileType('r'), default=None,dest="sam", required=True, help="A sam file")
parser.add_argument('--fai', type=str, dest="fai", required=True, help="A fai file for the reference")
parser.add_argument("--min-mq", type=int, required=False, dest="minmq", default=1, help="min mapping quality")
parser.add_argument("--sample-id", type=str, required=True, dest="sid",  help="the sample id")
parser.add_argument("--normalize", type=str, required=True, dest="norm",  help="the IDs of the sequences for normalizing the expression, coma-separated list")
parser.add_argument("--seqs", type=str, required=True, dest="seqs", help="the IDs of the sequences to analyze, coma-separated list")
parser.add_argument("-a",action="store_true",dest="addpolygon",help="add data for correct visualization of polygon")

args = parser.parse_args()
minmq=args.minmq
sid=args.sid
seqids=args.seqs
if "," in seqids:
	seqids=set(seqids.split(","))
else:
	seqids=set([seqids])

normids=args.norm
if "," in normids:
	normids=set(normids.split(","))
else:
	normids=set([normids])
addpoly=args.addpolygon

#sense and antisense coverage will be recorded here
# first key: te or gene
# second key: position in gene
secov,asecov= initialize_se_ase(seqids,args.fai)
#print(secov,asecov)
# normalization!
mappedreads=0
normreads=0
for line in args.sam:
	a=line.rstrip("\n").split("\t")
    # discard unmapped
	flag=int(a[1])
	if flag & 0x004 > 0:
		continue 
	
	mappedreads+=1
	# Reads mapped to a reference gene are used for normalization
	refchr=a[2]
	if(refchr in normids):
		normreads+=1

	# discard low mapping quality
	mq=int(a[4])
	if mq< minmq:
		continue
	# discard non-targeted genes;
	if refchr not in secov:
		continue
	
	# get the active transcript
	issens=isSense(flag)
	activeTranscript=None
	if(issens):
		activeTranscript=secov[refchr]
	else:
		activeTranscript=asecov[refchr]
	start=int(a[3])
	fill_in_coverage(activeTranscript,start,a[5],a) # coverage update with current read


normratio=float(1000)/float(normreads)
for key in secov.keys():
	curse=secov[key]
	if(addpoly):
		print("\t".join([args.sid,"se",key,"1","0.0"]))
	for i in range(1,len(curse)):
		cov=normratio*curse[i]
		toprint=[args.sid,"se",key,str(i),str(cov)]
		print("\t".join(toprint))
	if(addpoly):
		print("\t".join([args.sid,"se",key,str(i),"0.0"]))

	
	curase=asecov[key]
	if(addpoly):
		print("\t".join([args.sid,"ase",key,"1","0.0"]))
	for i in range(1,len(curase)):
		cov=normratio*curase[i]
		toprint=[args.sid,"ase",key,str(i),str(cov)]
		print("\t".join(toprint))
	if(addpoly):
		print("\t".join([args.sid,"ase",key,str(i),"0.0"]))

"""
0   1  2      		3  4     5 6   7   8            9                        10                  11
r1	16	M14653_te	172	70	23M	*	0	0	ATGTCGAGTTTCGTGCCGAATAA	FFFFFFFFFFFFFFFFFFBBBBB	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:23
r2	0	M14653_te	240	70	27M	*	0	0	AACAGCTGCGGAATCGCACCGAATGCT	BBBBBFFFFFBFFFFFFFFFFFFFFFF	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:27
"""