#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse
import collections


def compute_maternaltransmission(embryo,ovary):
	if(ovary==0):
		return "na"
	mt=float(embryo)/float(ovary)
	return(mt)

def head(tonorm,number):
	pcount=0
	for tsa,tmp4 in tonorm.items():
		for contig,tmp3 in tmp4.items():
			for tilenumber,count in tmp3.items():
				topr=[tsa,contig,str(tilenumber),str(count)]
				print("\t".join(topr))
				pcount+=1
				if(pcount > number):
					return

def compute_unfolded_strandbias(se,ase):
	"""
	nice idea!
	"""
	if(se+ase)==0:
		return "na"
	ufsb=float(se-ase)/float(se+ase)
	return ufsb



def compute_strandbias(se,ase):
	"""
	[112,112] => 0.0 (equal distribution; dual strand cluster)
	[0,500] => 1.0 (unequal distribution, uni strand cluster)
	"""
	toc=[se,ase]
	toc=sorted(toc)
	if(toc[1]==0):
		return "na"
	tmp=float(toc[0])/float(toc[1])
	bias=1.0-tmp
	return(bias)

def normalize_library(tonorm,normcount):
	toret=collections.defaultdict(lambda:collections.defaultdict(lambda:collections.defaultdict(lambda:0.0)))
	for tsa,tmp4 in tonorm.items():
		for contig,tmp3 in tmp4.items():
			for tilenumber,count in tmp3.items():
				nc=(float(count)*1000000.0)/float(normcount)
				#print(tsa,contig,tilenumber,count,nc)
				toret[tsa][contig][tilenumber]=nc
	return toret

def load_pirna_tiles(file,tilesize,minmq):
	toret=collections.defaultdict(lambda:collections.defaultdict(lambda:collections.defaultdict(lambda:0)))
	#pas=collections.defaultdict(lambda:collections.defaultdict(lambda:0))
	minlen,maxlen=23,29
	pirnacount=0
	for line in file:
		"""
          0         1         2              3    4         5    6         7      8            9                        10                  11
          r1	16	M14653_te	172	70	23M	*	0	0	ATGTCGAGTTTCGTGCCGAATAA	FFFFFFFFFFFFFFFFFFBBBBB	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:23
          r2	0	M14653_te	240	70	27M	*	0	0	AACAGCTGCGGAATCGCACCGAATGCT	BBBBBFFFFFBFFFFFFFFFFFFFFFF	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:27
        """
		a=line.rstrip("\n").split("\t")

		# discard length
		readlen=len(a[9])
		if readlen<minlen or readlen>maxlen:
			continue
		pirnacount+=1
          
		# discard unmapped
		flag=int(a[1])
		if flag & 0x004 > 0:
			continue 
		#print(flag)
		# discard low mapping quality
		mq=int(a[4])
		if mq< minmq:
			continue

		ref=a[2]
		start=float(a[3])
		tilenumber=int(start/tilesize)
		#print(tilenumber)
		toret["tot"][ref][tilenumber]+=1
		if flag& 0x10:
			toret["ase"][ref][tilenumber]+=1
		else:
			toret["se"][ref][tilenumber]+=1
                         
	return pirnacount,toret



parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Authors
-------
    Robert Kofler
""")
parser.add_argument('--ovary', type=argparse.FileType('r'), default=None,dest="ovary", required=True, help="A sam file; ovary")
parser.add_argument('--embryo', type=argparse.FileType('r'), default=None,dest="embryo", required=True, help="A sam file; embryo")
parser.add_argument('--ts', type=int, default=1000,dest="ts", required=True, help="tile size")
parser.add_argument('--mirna-ovary', type=str,dest="miovary", required=True, help="miRNAs in ovary; use 'pirna' for 1mil piRNA normalization")
parser.add_argument('--mirna-embryo', type=str,dest="miembryo", required=True, help="miRNAs in embryo")
parser.add_argument('--min-mq', type=int,dest="minmq", required=True, help="minimum mapping quality")

minexp=10
args,unknown = parser.parse_known_args()

pcembryo,rawembryo=load_pirna_tiles(args.embryo, args.ts, args.minmq)
pcovary, rawovary=load_pirna_tiles(args.ovary, args.ts, args.minmq)

normembryo=None
normovary=None

# choose NORMALIZATION mode
# either a million piRNAs or miRNAs
if(args.miembryo.lower()=="pirna"):
	normembryo	=	normalize_library(rawembryo,pcembryo)
else:
	normembryo	=	normalize_library(rawembryo,int(args.miembryo))

if(args.miovary.lower()=="pirna"):
	normovary	=	normalize_library(rawovary,pcovary)
else:
	normovary	=	normalize_library(rawovary,int(args.miovary))



for ref,tmp in normovary["tot"].items():
	for tilenumber,normoc in tmp.items():
		start=tilenumber*args.ts
		end=start+args.ts-1
		size=args.ts
		rawoc,rawec=rawovary["tot"][ref][tilenumber],rawembryo["tot"][ref][tilenumber]
		if(rawoc<minexp):
			continue
		
		sec=rawovary["se"][ref][tilenumber]
		asec=rawovary["ase"][ref][tilenumber]
		tsease="{0}/{1}".format(sec,asec)
		strandbias=compute_unfolded_strandbias(sec,asec)
		normcountembryo=normembryo["tot"][ref][tilenumber]
	
		mattrans=compute_maternaltransmission(normcountembryo,normoc)

		toprint=[ref,str(start),str(end),str(size),str(rawoc),tsease,str(rawec),str(normoc),str(normcountembryo),str(strandbias),str(mattrans)]
		print("\t".join(toprint))



