import sys
import time
import numpy as np

min_UTRlen = int(sys.argv[2])
RC_criteria = int(sys.argv[3])
NormMethod = sys.argv[4]
if NormMethod not in ['RawReadCount', 'NormToTotal', 'NormToGeneCodonAVG']:
    print "Error: normalization method is not specified"
    sys.exit()
    
    



dict_gene_coverage = {}
genelist_5utr = []
genelist_3utr = []
total_RC = 0

Frame0_RC = 0
Frame1_RC = 0
Frame2_RC = 0

with open(sys.argv[1], 'r') as inp_coverage:
    for line in inp_coverage:
        if line.startswith('[') == False and line.startswith('#') == False:
            i = line.strip("\n").split("\t")
            if i[0] not in ["Solyc11g022605.1.1", "Solyc12g043160.2.1", "Solyc07g009330.3.1", "Solyc03g095550.2.1", "Solyc06g043000.2.1", "Solyc09g018440.2.1", "Solyc03g115190.3.1"]:
                fUTRlen = int(i[3])-1
                tUTRlen = int(i[1]) - int(i[4])
                COVseq = eval(i[6])

                try:
                    CDS_COVseq = COVseq[fUTRlen:-1*(tUTRlen+3)]
                    CDSReadCount = sum(CDS_COVseq) #not include stop codon
                    CDS_density = CDSReadCount/((len(COVseq)-fUTRlen-(tUTRlen+3))/1.0)
                except ZeroDivisionError:
                    CDS_density= "NA"
                
                
                total_RC += sum(COVseq)
                if CDSReadCount >= RC_criteria:
                    dict_gene_coverage[i[0]] = [int(i[3]), int(i[4]), tuple(COVseq), CDS_density]
                    if fUTRlen >= min_UTRlen:
                        genelist_5utr.append(i[0])
                    if tUTRlen >= min_UTRlen:
                        genelist_3utr.append(i[0])
                Frame0_RC += sum(CDS_COVseq[::3])
                Frame1_RC += sum(CDS_COVseq[1::3])
                Frame2_RC += sum(CDS_COVseq[2::3])
                if len(CDS_COVseq) % 3 != 0:
                    print "===============gene annotation in CDS seems error", i[0]
total_CDS_RC = float(Frame0_RC + Frame1_RC + Frame2_RC)
            
            
total_RC = total_RC/1000000.0


dict_ATG_distance_number = {}
dict_STOP_distance_number = {}
for i in range(-300, 301):
    dict_ATG_distance_number[i] = []
    dict_STOP_distance_number[i] = []

for gene in genelist_5utr:
    cov_seq = dict_gene_coverage[gene][2]
    mean_cov = dict_gene_coverage[gene][3]
    for s in range(len(cov_seq)):
        if NormMethod == 'RawReadCount':
            cov = int(cov_seq[s])
        elif NormMethod == 'NormToTotal':
            cov = int(cov_seq[s]) / float(total_RC)
        elif NormMethod == 'NormToGeneCodonAVG':
            cov = int(cov_seq[s]) / mean_cov
        Dtis = (s+1) - dict_gene_coverage[gene][0]
        try:
            dict_ATG_distance_number[Dtis].append(cov)
            if cov>5000:
                print ("high values_TIS_gene_position:", gene, Dtis, cov)
        except KeyError:
            pass
for gene in genelist_3utr:
    cov_seq = dict_gene_coverage[gene][2]
    mean_cov = dict_gene_coverage[gene][3]
    for s in range(len(cov_seq)):
        if NormMethod == 'RawReadCount':
            cov = int(cov_seq[s])
        elif NormMethod == 'NormToTotal':
            cov = int(cov_seq[s]) / float(total_RC)
        elif NormMethod == 'NormToGeneCodonAVG':
            cov = int(cov_seq[s]) / mean_cov
        Dtts = (s+1) - (dict_gene_coverage[gene][1]-2)
        try:
            dict_STOP_distance_number[Dtts].append(cov)
        except KeyError:
            pass

with open(sys.argv[1]+'.'+NormMethod+'_COV_toTIS_MJ.txt', 'w') as oup_distance_count_A:
    oup_distance_count_A.write('#python %s\n'%' '.join(sys.argv))
    oup_distance_count_A.write('Distance\tReads_accumulation\tGene_number\tNon-zero_number\tAverage\tStdev\tCV\tmaximum\tPercentile99\tPercentile90\tQ3\tMedian\tQ1\tMinimum\n')
    for d in sorted(dict_ATG_distance_number.keys()):
        oup_distance_count_A.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %(d,\
                                                                    round(np.sum(dict_ATG_distance_number[d]),4),\
                                                                    len(dict_ATG_distance_number[d]),\
                                                                    len([x for x in dict_ATG_distance_number[d] if x > 0]),\
                                                                    round(np.mean(dict_ATG_distance_number[d]), 4),\
                                                                    round(np.std(dict_ATG_distance_number[d]), 4),\
                                                                    round((np.std(dict_ATG_distance_number[d])/np.mean(dict_ATG_distance_number[d]))*100, 4),\
                                                                         round(max(dict_ATG_distance_number[d]), 4),\
                                                                         round(np.percentile(dict_ATG_distance_number[d], 99), 4),\
                                                                         round(np.percentile(dict_ATG_distance_number[d], 90), 4),\
                                                                         round(np.percentile(dict_ATG_distance_number[d], 75), 4),\
                                                                         round(np.percentile(dict_ATG_distance_number[d], 50), 4),\
                                                                         round(np.percentile(dict_ATG_distance_number[d], 25), 4),\
                                                                         round(min(dict_ATG_distance_number[d]), 4),\
                                                                         ))
with open(sys.argv[1]+'.'+NormMethod+'_COV_toTTS_MJ.txt', 'w') as oup_distance_count_U:
    oup_distance_count_U.write('#python %s\n'%' '.join(sys.argv))
    oup_distance_count_U.write('Distance\tReads_accumulation\tGene_number\tNon-zero_number\tAverage\tStdev\tCV\tmaximum\tPercentile99\tPercentile90\tQ3\tMedian\tQ1\tMinimum\n')
    for d in sorted(dict_STOP_distance_number.keys()):
        oup_distance_count_U.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %(d,\
                                                                    round(np.sum(dict_STOP_distance_number[d]), 4),\
                                                                    len(dict_STOP_distance_number[d]),\
                                                                    len([x for x in dict_STOP_distance_number[d] if x > 0]),\
                                                                    round(np.mean(dict_STOP_distance_number[d]), 4),\
                                                                    round(np.std(dict_STOP_distance_number[d]), 4),\
                                                                    round((np.std(dict_STOP_distance_number[d])/np.mean(dict_STOP_distance_number[d]))*100, 4),\
                                                                         round(max(dict_STOP_distance_number[d]), 4),\
                                                                         round(np.percentile(dict_STOP_distance_number[d], 99), 4),\
                                                                         round(np.percentile(dict_STOP_distance_number[d], 90), 4),\
                                                                         round(np.percentile(dict_STOP_distance_number[d], 75), 4),\
                                                                         round(np.percentile(dict_STOP_distance_number[d], 50), 4),\
                                                                         round(np.percentile(dict_STOP_distance_number[d], 25), 4),\
                                                                         round(min(dict_STOP_distance_number[d]), 4),\
                                                                         ))
