import sys, time

def coverage_gene(inp_coverage, gene_list):
    coverage_read={}
    for line in inp_coverage:
        if not line.startswith("#"):
            i=line.strip("\n").split("\t")
            if i[0] in gene_list:
                coverage_read[i[0]]=eval(i[6])

    return coverage_read

def frame_po(ge_i, f_i, seq, f_frame_ge, iid):
    f_frame_ge[iid][f_i]=[]
    for i_r in range(f_i+3, len(seq), 3):
        co=seq[i_r:i_r+3]
        if co == "TAG" or co == "TAA" or co == "TGA":
            f_frame_ge[iid][f_i].extend((ge_i, f_i, i_r+2, seq[f_i:i_r+3]))
            break
    if not co == "TAG" and not co == "TAA" and not co == "TGA":
            f_frame_ge[iid][f_i].extend((ge_i, f_i, i_r+2, seq[f_i:i_r+3]))
            print "NOT FOUND STOP CODON", iid, f_i
    return f_frame_ge


def get_TIS_codon(cDNA_seq, ltm_po, cds_po, f_frame):
    for ge in cDNA_seq.keys():
        if ge in ltm_po.keys():
            ge_seq=cDNA_seq[ge]
            ltn_tis=ltm_po[ge]
            
            for i in ltn_tis:
                id_TIS=ge+"_"+str(i)
                if id_TIS  in f_frame:
                    print "duplicated in identified TIS:", id_TIS
                f_frame[id_TIS]={}
            for i in ltn_tis:
                id_TIS=ge+"_"+str(i)
                f_frame=frame_po(ge, i-1, ge_seq, f_frame, id_TIS)
                
                if i>1:
                    f_frame=frame_po(ge, i-2, ge_seq, f_frame, id_TIS)
                else:
                    f_frame=frame_po(ge, i+1, ge_seq, f_frame, id_TIS)
                f_frame=frame_po(ge, i, ge_seq, f_frame, id_TIS)

    return f_frame

def get_frame_index(CHX_gene, f_frame, cDNA_seq):
    for ge_TIS in f_frame:
        ge_id=ge_TIS.split("_")[0]
        c_signal=CHX_gene[ge_id]
        seq_c=cDNA_seq[ge_id]
        for i in f_frame[ge_TIS].keys():
            Frame0_RC=0
            Frame1_RC=0
            Frame2_RC=0
            total_CDS_RC=0

            da=f_frame[ge_TIS][i]
            print da
            if (da[2]-da[1]+1)<=9:
                CDS_COVseq=c_signal[da[1]: da[2]-2]
            else:
                CDS_COVseq=c_signal[da[1]: da[2]-5]
            Frame0_RC += sum(CDS_COVseq[::3])
            Frame1_RC += sum(CDS_COVseq[1::3])
            Frame2_RC += sum(CDS_COVseq[2::3])
            if len(CDS_COVseq) % 3 != 0:
                print "===============gene annotation in CDS seems error", f_frame[ge_TIS][i]
            total_CDS_RC=Frame0_RC+Frame1_RC+Frame2_RC
            if total_CDS_RC ==0:
                f_frame[ge_TIS][i].append(str(0)+"_"+seq_c[da[1]:da[1]+3])
            else:
                f_frame[ge_TIS][i].append(str(Frame0_RC/float(total_CDS_RC))+"_"+seq_c[da[1]:da[1]+3])
    return f_frame

def output_TIS_frame(f_frame_in, ltm_po, cds_po, out1, near):
    for i in ltm_po.keys():
        fram_d={}
        po=ltm_po[i]
        for ii in sorted(map(int,po)):
            id_i=i+"_"+str(ii)
            fram_d[ii]=[]
            fram_d_pi={}
            for iii in sorted(map(int, f_frame_in[id_i].keys())):
                fr=f_frame_in[id_i][iii][4].split("_")
                if fr[1] == "ATG":
                    fram_d_pi[iii]=float(1.2)
                elif fr[1] in near:
                    fram_d_pi[iii]=float(1.1)
                else:
                    fram_d_pi[iii]=float(fr[0])
             max_fr=max(fram_d_pi.iteritems(), key = lambda x: x[1])
            if round(max_fr[1], 2)>=0.5:
                fram_d[ii].append((str(max_fr[0]+1)+":"+str(f_frame_in[id_i][max_fr[0]][4])))
            
        out1.write(("%s\t%s\t%s\t%s\t%s\n")%(i,cds_po[i].split("_")[0], cds_po[i].split("_")[1], ltm_po[i], (sorted(fram_d.items(), key=lambda t: t[0]))))


with open(sys.argv[1], "r") as inp_coverage:
    print time.strftime("%H:%M:%S"), "Loading LTM-TIS..."
    cds_po={}
    ltm_po={}
    ltm_ge=0
    for l1 in inp_coverage:
        if not l1.startswith("#") and not "NA" in l1 and l1 !="\n":
            ltm_ge+=1
            i=l1.strip("\n").split("\t")
            if i[0] in ltm_po:
                print "ERROR in LTM..............................."
            ltm_po[i[0]]=eval(i[3])
            cds_po[i[0]]=i[1]+"_"+i[2]


with open(sys.argv[2], "r") as in2:
    cDNA_seq={}
    for l2 in in2:
        if not l2.startswith("#") and l2 !="\n":
            i2=l2.strip().split("\t")
            cDNA_seq[i2[0]]=i2[1]

with open(sys.argv[3], "r") as inp_coverage:
    CHX_gene=coverage_gene(inp_coverage, ltm_po.keys())

atcg=["A", "T", "C", "G"]
near_co=[]
for i_n in atcg:
    c1=i_n+"TG"
    c2="A"+i_n+"G"
    c3="AT"+i_n
    near_co.extend((c1, c2, c3))


f_frame={}
f_frame=get_TIS_codon(cDNA_seq, ltm_po, cds_po, f_frame)
f_frame_index=get_frame_index(CHX_gene, f_frame, cDNA_seq)

print time.strftime("%H:%M:%S"), "outwrite TIS_frame index ..."
out1=open(sys.argv[1]+".TIS_ATG_near_frame.filtered", "w")
out1.write("#python %s\n"%(" ".join(sys.argv)))
out1.write(("%s\t%s\t%s\t%s\t%s\n")%("Gene", "CDS_start", "CDS_end", "TIS", "TIS_frame_index"))

output_TIS_frame(f_frame_index, ltm_po, cds_po, out1, near_co)
out1.close()





