#!/usr/bin/env python
import os
import sys
import re
import argparse
import math
import collections



class Gene:
    def __init__(self,geneid,gtfs):
        self.geneid=geneid
        self.gtfs=gtfs
        
        smallest=1000000000
        largest=0
        for g in gtfs:
            if g.start < smallest:
                smallest=g.start
            if g.end>largest:
                largest=g.end
        self.start=smallest
        self.end=largest
        
    def get_feat(self,feat):
        toret=[]
        for g in self.gtfs:
            if g.feat==feat:
                toret.append(g)
        return toret
    
    def get_code(self):
        tr={}
        for i in range(self.start,self.end+1):
            tr[i]=1
        
        for t in self.get_feat("exon"):
            for i in range(t.start,t.end+1):
                tr[i]=2

        for t in self.get_feat("three_prime_UTR"):
            for i in range(t.start,t.end+1):
                tr[i]=3     

        for t in self.get_feat("five_prime_UTR"):
            for i in range(t.start,t.end+1):
                tr[i]=4               


        for t in self.get_feat("CDS"):
            for i in range(t.start,t.end+1):
                tr[i]=5
        
        return tr
              
        
        



def get_genelist(gtfes):
    geneh=collections.defaultdict(lambda:[])
    for g in gtfes:
        gid=g.geneid
        geneh[gid].append(g)
    
    toret=[]
    for geneid,gtfs in geneh.items():
        toret.append(Gene(geneid,gtfs))
    return toret


class TEentry:
    def __init__(self,sid,gen,chr,pos,popfreq):
        self.sid=sid
        self.gen=gen
        self.chr=chr
        self.pos=pos
        self.popfreq=popfreq

class GTFentry:

    def __init__(self,chr,feat,start,end,geneid):
        self.chr=chr
        self.feat=feat
        self.start=start
        self.end=end
        self.geneid=geneid
    
def load_te(file,euchr):
    """
    1	10	X	6709950	0.072
    1	10	X	16577419	0.093
    1	10	X	18325999	0.069
    """
    teh=collections.defaultdict(lambda:[])
    for l in open(file):
        a=l.rstrip("\n").split("\t")
        sid=a[0]
        gen=a[1]
        chr=a[2]
        if chr not in euchr:
            continue
        pos=int(float(a[3]))
        popfreq=a[4]
        t=TEentry(sid,gen,chr,pos,popfreq)
        teh[chr].append(t)
    return teh
    
def load_gtf(file,euchr):
    """
    3L	Cufflinks	exon	13569022	13571086	566	-	.	gene_id "FBgn0026376"; transcript_id "CUFF.6507.2"; exon_number "1"; parent_feature "mRNA";
    3L	Cufflinks	exon	13586900	13587252	566	-	.	gene_id "FBgn0026376"; transcript_id "CUFF.6507.2"; exon_number "7"; parent_feature "mRNA";
    3L	Cufflinks	exon	13571146	13571298	566	-	.	gene_id "FBgn0026376"; transcript_id "CUFF.6507.2"; exon_number "2"; parent_feature "mRNA";
    """
    gtfh=collections.defaultdict(lambda:[])
    for l in open(file):
        a=l.rstrip("\n").split("\t")
        chr  =a[0]
        if chr not in euchr:
            continue
        feat =a[2]
        start=int(a[3])
        end  =int(a[4])
        t=a[8]
        b=t.split(" ")
        geneid=b[1]
        geneid=geneid[1:-2]
        e=GTFentry(chr,feat,start,end,geneid)
        gtfh[chr].append(e)
    return gtfh
        
    


parser = argparse.ArgumentParser(description="""           
Description
-----------
Summary statistics
""",formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
""")
parser.add_argument("--gtf", type=str, required=True, dest="gtf", default=None, help="annotation" )
parser.add_argument("--te", type=str, required=True, dest="te", default=None, help="annotation" )
args = parser.parse_args()

trans={1:"intron",2:"exon",3:"3p",4:"5p",5:"cds"}

euchr=set(["X","2L","2R","3L","3R","4"])

gtfchrh=load_gtf(args.gtf,euchr)
techrh=load_te(args.te,euchr)
techrh.keys()

for chr in euchr:
    chrann=gtfchrh[chr]
    genelist=get_genelist(chrann)
    telist=techrh[chr]
    
    # Generate the hash
    tr={}
    tr=collections.defaultdict(lambda:0)
    for g in genelist:
        c=g.get_code()
        for k,v in c.items():
            if(v>tr[k]):
                tr[k]=v
    
    for te in telist:
        pos=te.pos
        code="inter"
        if pos in tr:
            tmp=tr[pos]
            code=trans[tmp]
        #sid,gen,chr,pos,popfreq    
        topr=[te.sid,te.gen,te.chr,str(te.pos),te.popfreq,code]
        print "\t".join(topr)
            
            
        
    
    
    
    
    
    
    



