import sys
import time

def ProgressBar(COUNT, LinesNumber):
    sys.stdout.write('\r'+str(COUNT)+'/'+str(LinesNumber))
    sys.stdout.flush()
    
Without_repre_list = False
list_repre_protein_genes = []
if len(sys.argv[1:]) > 3:
    with open(sys.argv[4], 'r') as inp_repre:
        print time.strftime("%H:%M:%S"), "Obtaining wanted gene list..."
        for line in inp_repre.readlines():
            if line.startswith("#") is False:
                list_repre_protein_genes.append(line.strip('\n').strip('\r').split(":")[-1])
else: 
    Without_repre_list = True
    print "Without gene list given"

def ReadGff(PATH):
    DICT = {}        
    countline = 0
    with open(PATH, "r") as inp_gff:
        content = inp_gff.readlines()	
        totallines = len(content)
        for line in content:
            countline += 1
            if line.startswith('#') is False:
                i = line.strip("\n").strip("\r").split("\t")
                if i[2] == "exon":
                    geneid = i[8].split("Parent=")[1].split(':')[-1].split(',')[0]
                    if geneid in list_repre_protein_genes or Without_repre_list:
                        if geneid not in DICT.keys():
                            DICT[geneid] = [i[0], i[6], [[int(i[3]), int(i[4])]], []]
                        else:
                            DICT[geneid][2].append([int(i[3]), int(i[4])])
                elif i[2] == 'protein':
                    geneID = i[8].split('Derives_from=')[-1].split(':')[-1].split(',')[0]
                    if geneID in list_repre_protein_genes or Without_repre_list:
                        if geneID not in DICT.keys():
                            DICT[geneID] = [i[0], i[6], [], [int(i[3]), int(i[4])]]
                        else:
                            DICT[geneID][3] = [int(i[3]), int(i[4])]
                elif i[2] == 'CDS':
                    geneID = i[8].split('Parent=')[-1].split(':')[-1].split(',')[0]
                    if (geneID in list_repre_protein_genes) or (Without_repre_list):
                        if geneID not in DICT.keys():
                            DICT[geneID] = [i[0], i[6], [], [int(i[3]), int(i[4])]]
                        else:
                            if DICT[geneID][3] == []:
                                DICT[geneID][3] = [int(i[3]), int(i[4])]
                            else:
                                if int(i[3]) < DICT[geneID][3][0]:
                                    DICT[geneID][3][0] = int(i[3])
                                if int(i[4]) > DICT[geneID][3][1]:
                                    DICT[geneID][3][1] = int(i[4])
	return DICT

dict_gene_exons = ReadGff(sys.argv[1])
NUMnoCDS = 0
for i in dict_gene_exons.keys():
    dict_gene_exons[i][2] = sorted(dict_gene_exons[i][2]) 
    if dict_gene_exons[i][3] == []:
        NUMnoCDS += 1
        del dict_gene_exons[i] 


if 'TAIR10_GFF3_genes.gff' in sys.argv[1]:
    for gene in dict_gene_exons.keys():
        #print dict_gene_exons[gene][0]
        if dict_gene_exons[gene][0][3] == 'C':
            dict_gene_exons[gene][0] = 'Pt'
        elif dict_gene_exons[gene][0][3] == 'M':
            dict_gene_exons[gene][0] = 'Mt'
        elif 'Chr' in dict_gene_exons[gene][0]:
            dict_gene_exons[gene][0] = dict_gene_exons[gene][0][3][-1]
        else:
            print gene
            

def ReadBGintoDict2(PATH):
    dict_chr_coverage = {}
    for chr in list(set([x[0] for x in dict_gene_exons.values()])):
        chr_terminal = max([x[2][-1][1] for x in dict_gene_exons.values() if x[0] == chr])
        dict_chr_coverage[chr] = [0] * chr_terminal
        
    with open(PATH, "r") as inp_coverage:
        for line in inp_coverage:
            if not line.startswith("track") and not line.startswith("#") and not line.startswith('chrom'):
                i = line.strip("\n").split("\t")
                chr = i[0]
                chr_terminal = dict_chr_coverage
                if int(float(i[3])) == 0:
                    pass
                else:
                    for p in range(int(i[1]), int(i[2])):
                        try:
                            dict_chr_coverage[chr][p] = (int(float(i[3])))
                        except IndexError:
                            break
                            
    return dict_chr_coverage
 
def function_gene_coverage_seq3(Cov_DICT, gene):
    chr = dict_gene_exons[gene][0]
    strand = dict_gene_exons[gene][1]
    exons = dict_gene_exons[gene][2]
    Cov_chr = Cov_DICT[chr]
    cDNA_coverage = []
    for e in exons:
        exon_cov = Cov_chr[e[0]-1:e[1]]
        cDNA_coverage.extend(exon_cov)

    if strand == "+":
        return cDNA_coverage
    elif strand == '-':
        return cDNA_coverage[::-1]
    else:
        print "ERROR: direction is unavailable"
        return ""
        
def theUTRlen(cds, exons, strand = '+'):
    genome_CDS_start = int(cds[0])
    genome_CDS_end = int(cds[1])
    lUTR = 0
    rUTR = 0
    for e in exons:
        if genome_CDS_start in range(e[0], e[1]+1):
            lUTR += (genome_CDS_start-1) - (e[0]-1)
            break
        else:
            lUTR += e[1] - (e[0]-1) 
    for e in reversed(exons):	
        if genome_CDS_end in range(e[0], e[1]+1):
            rUTR += e[1] - genome_CDS_end
            break
        else:
            rUTR += e[1] - (e[0]-1) 
    if strand == "+":
        return (lUTR, rUTR)
    elif strand == '-':
        return (rUTR, lUTR)
    else:
        print "ERROR: direction is not given"
        return ""

Dict_plus = ReadBGintoDict2(sys.argv[2])
Dict_minus = ReadBGintoDict2(sys.argv[3])
                 
                            
with open(sys.argv[2]+"_minus_BedGraph.GeneCovN", "w") as output_gene_coverage:
    output_gene_coverage.write("#"+str(sys.argv)+'\n')
    output_gene_coverage.write("#Gene_id\tcDNA_length\tStrand\tCDS_start\tCDS_end\tRead_density\tCoverage\n")
    for gene in [x for x in dict_gene_exons.keys() if dict_gene_exons[x][1] == '+']:
        strand = dict_gene_exons[gene][1]
        cov =  function_gene_coverage_seq3(Dict_plus, gene)
        length_cDNA = len(cov)
        if length_cDNA != sum([x[1]-x[0]+1 for x in dict_gene_exons[gene][2]]):
            print "ERROR, length calculation error"
            sys.exit()

        length_UTRs = theUTRlen(dict_gene_exons[gene][3], dict_gene_exons[gene][2], strand)
        FiveUTRlen = length_UTRs[0] 
        ThreeUTRlen = length_UTRs[1] 
        CDS_start = FiveUTRlen+1
        CDS_end = length_cDNA - ThreeUTRlen
        
        ReadDensity = sum(cov[:-3])/float(length_cDNA-3)
        output_gene_coverage.write('%s\t%s\t%s\t%s\t%s\t%s\t%s\n'%(gene, length_cDNA, strand, CDS_start, CDS_end, round(ReadDensity,6), cov))
        
            
with open(sys.argv[2]+"_minus_BedGraph.GeneCovN", "a") as output_gene_coverage:
    for gene in [x for x in dict_gene_exons.keys() if dict_gene_exons[x][1] == '-']:
        strand = dict_gene_exons[gene][1]        
        cov =  function_gene_coverage_seq3(Dict_minus, gene)
        length_cDNA = len(cov)
        if length_cDNA != sum([x[1]-x[0]+1 for x in dict_gene_exons[gene][2]]):
            sys.exit()
        length_UTRs = theUTRlen(dict_gene_exons[gene][3], dict_gene_exons[gene][2], strand)
        FiveUTRlen = length_UTRs[0] 
        ThreeUTRlen = length_UTRs[1] 
        CDS_start = FiveUTRlen+1
        CDS_end = length_cDNA - ThreeUTRlen
        
        ReadDensity = sum(cov[:-3])/float(length_cDNA-3)
        output_gene_coverage.write('%s\t%s\t%s\t%s\t%s\t%s\t%s\n'%(gene, length_cDNA, strand, CDS_start, CDS_end, round(ReadDensity,6), cov))

