import sys, time
from ast import literal_eval

def get_ge_info(inp_coverage):
    ge_5UTR={}
    ge_aTIS={}
    ge_dtis={}
    ge_3UTR={}
    ge_l={}
    for line in inp_coverage:
        if line.startswith("Solyc"):
            data=line.strip().split("\t")
            cds_s=int(data[1])
            cds_t=int(data[2])
            tis_f=literal_eval(data[4])
            ge_l[data[0]]=[]
            ge_l[data[0]].extend((cds_s, cds_t))

            cds_1_3=(cds_s)+(((cds_t)-(cds_s)+1)/3.0)
            for i in tis_f:
                if len(i[1])>0:
                    codon_po=int(i[1][0].split(":")[0])
                    codon_tis=i[1][0].split("_")[1]
                    if codon_po< cds_s:
                        if data[0] not in ge_5UTR.keys():
                            ge_5UTR[data[0]]=[]
                            ge_5UTR[data[0]].append(str(codon_po)+"_"+codon_tis)
                        else:
                            ge_5UTR[data[0]].append(str(codon_po)+"_"+codon_tis)
                
                    elif codon_po==cds_s:
                        if data[0] not in ge_aTIS.keys():
                            ge_aTIS[data[0]]=[]
                            ge_aTIS[data[0]].append(str(codon_po)+"_"+codon_tis)
                        else:
                            ge_aTIS[data[0]].append(str(codon_po)+"_"+codon_tis)
    
                    elif cds_s<codon_po<cds_1_3:
                        if data[0] not in ge_dtis.keys():
                            ge_dtis[data[0]]=[]
                            ge_dtis[data[0]].append(str(codon_po)+"_"+codon_tis)
                        else:
                            ge_dtis[data[0]].append(str(codon_po)+"_"+codon_tis)
                                
                                
                    elif codon_po>cds_t:
                        if data[0] not in ge_3UTR.keys():
                            ge_3UTR[data[0]]=[]
                            ge_3UTR[data[0]].append(str(codon_po)+"_"+codon_tis)
                        else:
                            ge_3UTR[data[0]].append(str(codon_po)+"_"+codon_tis)
                    else:
                        print "error of finding TIS position on a transcript:", i, line.strip()

    return ge_l, ge_5UTR, ge_aTIS, ge_dtis, ge_3UTR


def get_TIScodon(inp_coverage):
    codon_5UTR={}
    codon_aTIS={}
    codon_dtis={}
    codon_3UTR={}

    ltm_ge=0
    ltm_num=0
    for line in inp_coverage:
        if line.startswith("Solyc"):
            data=line.strip().split("\t")
            cds_s=int(data[1])
            cds_t=int(data[2])
            ge=0
            tis_f=literal_eval(data[4])
            
            
            cds_1_3=(cds_s)+(((cds_t)-(cds_s)+1)/3.0)
            
            ge_po=[]
            for i in tis_f:
                if len(i[1])>0:
                    ge=1
                    codon_tis=i[1][0].split("_")[1]
                    codon_po=int(i[1][0].split(":")[0])
                    if codon_po not in ge_po:
                        ge_po.append(codon_po)
                        ltm_num+=1
                        if codon_po< cds_s:
                            if codon_tis not in codon_5UTR.keys():
                                codon_5UTR[codon_tis]=1
                            else:
                                codon_5UTR[codon_tis]+=1

                        elif codon_po==cds_s:
                            if codon_tis not in codon_aTIS.keys():
                                codon_aTIS[codon_tis]=1
                            else:
                                codon_aTIS[codon_tis]+=1

                        elif cds_s<codon_po<cds_1_3:
                            #print tis_f, codon_tis, codon_po
                            if codon_tis not in codon_dtis.keys():
                                codon_dtis[codon_tis]=1
                            else:
                                codon_dtis[codon_tis]+=1

                        elif codon_po>cds_t:
                            if codon_tis not in codon_3UTR.keys():
                                codon_3UTR[codon_tis]=1
                            else:
                                codon_3UTR[codon_tis]+=1
                        else:
                            print "error of finding TIS position on a transcript:", i, line.strip()
                    else:
                        print "redundant TIS site:", line.strip()

            ltm_ge=ltm_ge+ge
    return codon_5UTR, codon_aTIS, codon_dtis, codon_3UTR


def outwrite_codon(out, list_co, UTR5, aTIS, dtis, UTR3, near):
    for i in list_co:
        #print i
        if i not in UTR5.keys():
            UTR5[i]=0
        if i not in aTIS.keys():
            aTIS[i]=0

        if i not in dtis.keys():
            dtis[i]=0
        if i not in UTR3.keys():
            UTR3[i]=0
                
        #label near cognate codons
        if i in near:
            i_o=i+"*"
        else:
            i_o=i
        out.write("%s\t%s\t%s\t%s\t%s\n"%(i_o, UTR5[i], aTIS[i], dtis[i], UTR3[i]))
    out.close()

def outwrite_ge(out, list_co, UTR5, aTIS, dtis, UTR3):
    for i in list_co.keys():
        #print i
        if i not in UTR5.keys():
            UTR5[i]=[]
        if i not in aTIS.keys():
            aTIS[i]=[]
        
        if i not in dtis.keys():
            dtis[i]=[]
        if i not in UTR3.keys():
            UTR3[i]=[]
        out.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n"%(i, list_co[i][0], list_co[i][1], list(set(UTR5[i])), list(set(aTIS[i])), list(set(dtis[i])), list(set(UTR3[i])), len(set(UTR5[i])), len(set(aTIS[i])), len(set(dtis[i])), len(set(UTR3[i]))))
    out.close()



atcg=["A", "T", "C", "G"]
all_co=[]
for i_n in atcg:
    for ii_n in atcg:
        for iii_n in atcg:
            all_co.append(i_n+ii_n+iii_n)

atcg=["A", "T", "C", "G"]
near_co=[]
for i_n in atcg:
    c1=i_n+"TG"
    c2="A"+i_n+"G"
    c3="AT"+i_n
    near_co.extend((c1, c2, c3))


with open(sys.argv[1], "r") as inp_coverage:
    cod_5UTR, cod_aTIS, cod_dtis, cod_3UTR=get_TIScodon(inp_coverage)

with open(sys.argv[1], "r") as inp_coverage:                         #
    ge_info, g_5UTR, g_aTIS, g_dtis, g_3UTR=get_ge_info(inp_coverage)



with open(sys.argv[1]+".TIS_codon_CDS1.3", "w") as out1:
    out1.write("#python %s\n" %(" ".join(sys.argv)))
    out1.write("#%s\t%s\t%s\t%s\t%s\n" %("codon", "TIS_number_5UTR", "TIS_number_aTIS", "TIS_number_dTIS", "TIS_number_3UTR"))
    outwrite_codon(out1, all_co, cod_5UTR, cod_aTIS, cod_dtis, cod_3UTR, near_co)


with open(sys.argv[1]+".TIS_gene_CDS1.3", "w") as out2:
    out2.write("#python %s\n" %(" ".join(sys.argv)))
    out2.write("#%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %("GENE_id", "cds_start", "cds_end", "TIS_5UTR", "TIS_aTIS", "TIS_dTIS", "TIS_3UTR", "TIS_number_5UTR", "TIS_number_aTIS", "TIS_number_dTIS", "TIS_number_3UTR"))
    outwrite_ge(out2, ge_info, g_5UTR, g_aTIS, g_dtis, g_3UTR)






