#!/usr/bin/env python
import os
import sys
import inspect
import re
import argparse
import random
import math
import collections
import fileinput
prog = re.compile(r"(\d+)([MISDHN])")

def filter_key(hash,mincount):
     toret={}
     for key,val in hash.items():
          if val>=mincount:
               toret[key]=val
     return toret


def compute_venn(s1,s2):
     set1=set(s1)
     set2=set(s2)
     union=set1.union(set2)
     specific1,shared,specific2=0,0,0
     for u in union:
          if u in set1 and u in set2:
               shared+=1
          elif u in set1:
               specific1+=1
          elif u in set2:
               specific2+=1
          else:
               raise Exception("impossible")

     sum=float(specific1+shared+specific2)
     toret=[specific1,shared,specific2,specific1/sum,shared/sum,specific2/sum]
     return toret


def get_end(start,cig):
     
     result = re.findall(prog,cig)
     alignmentleng=0
     for count,cigchar in result:
          count=int(count)
          if cigchar=="M" or cigchar=="D":
               alignmentleng+=count
     
     end=start+alignmentleng-1
     return end

def load_5p_position(file,minmq,maxmm,minlen,maxlen):
     ps= collections.defaultdict(lambda:0)
     pas=collections.defaultdict(lambda:0)
     tes=set([])
     for line in file:
          """
          0         1         2              3    4         5    6         7      8            9                        10                  11
          r1	16	M14653_te	172	70	23M	*	0	0	ATGTCGAGTTTCGTGCCGAATAA	FFFFFFFFFFFFFFFFFFBBBBB	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:23
          r2	0	M14653_te	240	70	27M	*	0	0	AACAGCTGCGGAATCGCACCGAATGCT	BBBBBFFFFFBFFFFFFFFFFFFFFFF	PG:Z:novoalign	AS:i:0	UQ:i:0	NM:i:0	MD:Z:27
          """
          a=line.rstrip("\n").split("\t")

          # Discard non-Pele
          ref=a[2]
          if ref!= "PPI251_te":  # only interesting for the P-element; is it? maybe as parameter later on
               continue
          
          # discard unmapped
          flag=int(a[1])
          if flag & 0x004 > 0:
               continue 
          # discard low mapping quality
          mq=int(a[4])
          if mq< minmq:
               continue
          
          # discard mismatch
          mm=0
          tmp=a[11]
          b=tmp.split(" ")
          for bt in b:
               if bt.startswith("NM:i:"):
                    mm=int(bt[5:])
          if(mm>maxmm):
               continue  
          
          # discard readlen
          readlen=len(a[9])
          if readlen<minlen or readlen>maxlen:
               continue

          # get direction and enter into dict
          start=int(a[3])
          if flag& 0x10:
               # reverse complement if flag 0x10 is set
               end=get_end(start,a[5]) # for reverse complements get the end
               pas[end]+=1
          else:
               ps[start]+=1
     return ps,pas





def normalizecounts(sig):
     sigsum=float(sum(sig))
     for i in range(0,len(sig)):
          if sigsum>0.0:
               sig[i]/=sigsum
     return sig


     

parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
miRNA: 21-23nt
piRNA: 23-28nt


Authors
-------
    Robert Kofler
""")
parser.add_argument('--si-sam', type=argparse.FileType('r'), default=None,dest="sisam", required=True, help="A sam file for the siRNAs")
parser.add_argument('--pi-sam', type=argparse.FileType('r'), default=None,dest="pisam", required=True, help="A sam file for the piRNAs")

parser.add_argument("--min-mq", type=int, required=False, dest="minmq", default=1, help="min mapping quality")
parser.add_argument("--max-mm", type=int, required=False, dest="maxmm", default=2, help="max mismatches")
parser.add_argument("--mc-si", type=int, required=False, dest="mcsi", default=1, help="min count for sirna")
parser.add_argument("--mc-pi", type=int, required=False, dest="mcpi", default=5, help="min count for pirna")

parser.add_argument("--sample-id", type=str, required=True, dest="sid", default=10, help="the sample id")


args = parser.parse_args()
minmq=args.minmq
maxmm=args.maxmm
mcsi=args.mcsi
mcpi=args.mcpi

# compute position
pi_se,pi_as = load_5p_position(args.pisam,args.minmq,args.maxmm,23,29)
si_se,si_as = load_5p_position(args.sisam,args.minmq,args.maxmm,20,22)

# filter low freq crap
pi_se,pi_as=filter_key(pi_se,mcpi),filter_key(pi_as,mcpi)
si_se,si_as=filter_key(si_se,mcsi),filter_key(si_as,mcsi)

sid=args.sid


# sense - sense
topr=[sid,"si_se","pi_se"]
#print("si_se",si_se)
#print("si_as",si_as)
#print("pi_se",pi_se)
#print("pi_as",pi_as)
topr.extend(compute_venn(si_se.keys(),pi_se.keys()))
print("\t".join([str(i) for i in topr]))

# antisene - antisense
topr=[sid,"si_as","pi_as"]
topr.extend(compute_venn(si_as.keys(),pi_as.keys()))
print("\t".join([str(i) for i in topr]))

# sense - antisense
topr=[sid,"si_se","pi_se"]
topr.extend(compute_venn(si_se.keys(),pi_as.keys()))
print("\t".join([str(i) for i in topr]))

# antisense - sense
topr=[sid,"si_as","pi_se"]
topr.extend(compute_venn(si_as.keys(),pi_se.keys()))
print("\t".join([str(i) for i in topr]))



