# cython: profile=True

"""Module to compare GTF files and anotate each transcripts
with the number of time it was found in the GTF"""

import re
import sys

GENE_ID_REG = re.compile('gene_id "(.*?)"')
TRANSCRIPT_ID_REG = re.compile('transcript_id "(.*?)"')
EXON_NUMBER_REG = re.compile('exon_number "(.*?)"')
SAMPLE_REG = re.compile('(.*)\.\d+\.\d+')

_CHR2NUM = {
    'chr1': 1,
    'chr2': 2,
    'chr3': 3,
    'chr4': 4,
    'chr5': 5,
    'chr6': 6,
    'chr7': 7,
    'chr8': 8,
    'chr9': 9,
    'chr10': 10,
    'chr11': 11,
    'chr12': 12,
    'chr13': 13,
    'chr14': 14,
    'chr15': 15,
    'chr16': 16,
    'chr17': 17,
    'chr18': 18,
    'chr19': 19,
    'chr20': 20,
    'chr21': 21,
    'chr22': 22,
    'chr23': 23,
    'chrX': 24,
    'chrY': 25,
    'chrM': 26}

_NUM2CHR = {v: k for k, v in _CHR2NUM.items()}

_STANRD2NUM = {'+': 1, '-': 0, '.': 3}
_NUM2STRAND = {1: '+', 0: '-', 3: '.'}

cdef class Exon:
    """Contains GTF infos about xons"""
    cdef unsigned long start
    cdef unsigned long end
    cdef int chrom
    cdef object gene_id
    cdef object transcript_id
    cdef unsigned int strand
    cdef unsigned int exon_number

    def __cinit__(Exon self, gtf_line):
        """Init and Exon from a gtf line"""
        sp = gtf_line.split('\t')
        attr = sp[8]
        self.gene_id = GENE_ID_REG.search(attr).group(1)
        self.transcript_id = TRANSCRIPT_ID_REG.search(attr).group(1)
        self.exon_number = int(EXON_NUMBER_REG.search(attr).group(1))
        self.chrom = _CHR2NUM[sp[0]]
        self.strand = _STANRD2NUM[sp[6]]
        self.start = long(sp[3])
        self.end = long(sp[4])

    def __str__(Exon self):
        p = "{c}:{s}-{e}({strand})\t{gid}\t{tid}\t{exon_number}".format(
            c=_NUM2CHR[self.chrom],
            s=self.start,
            e=self.end,
            strand=_NUM2STRAND[self.strand],
            gid=self.gene_id,
            tid=self.transcript_id,
            exon_number=self.exon_number)
        return p

    def __repr__(Exon self):
        return self.__str__().replace('\t', ' ')

    def __richcmp__(Exon self, Exon other, int op):
        if op == 2:  # 2 is for ==
            if (self.chrom == other.chrom and
                abs(long(self.start) - long(other.start)) < 10 and  # start withih 100bp
                abs(long(self.end) - long(other.end)) < 10):  # end withih 100bp
                return True
            else:
                return False
        else:
            print("WARNING: ONLY USE == WHAT YOU DO IS WRONG")

    def get_start(Exon self):
        return self.start

    def get_end(Exon self):
        return self.end

    def to_bed(Exon self):
        return "\t".join([_NUM2CHR[self.chrom], str(self.start), str(self.end),
                          self.transcript_id, str(self.exon_number),
                          _NUM2STRAND[self.strand]])


cdef class Transcript:
    """Transcript is a collection of exons"""

    cdef unsigned int found
    cdef unsigned int n
    cdef object exons

    cdef unsigned long start
    cdef unsigned long end
    cdef unsigned int chrom
    cdef object gene_id
    cdef object transcript_id
    cdef object similar_ls
    cdef object sample_ls
    cdef unsigned int strand
    cdef unsigned int exon_number


    def __cinit__(self):
        self.exons = []
        self.transcript_id = ''
        self.found = 0
        self.similar_ls = []
        self.sample_ls = []
        self.n = len(self.exons)

    def set_found(Transcript self, unsigned int found):
        self.found = found

    def get_found(Transcript self):
        return self.found

    def add_similar(Transcript self, str name):
        self.similar_ls.append(name)

    def add_sample(Transcript self, str name):
        if name not in self.sample_ls:
            self.sample_ls.append(name)

    def get_similar(Transcript self):
        return self.similar_ls

    def get_tid(Transcript self):
        return self.transcript_id

    def get_sample(Transcript self):
        return SAMPLE_REG.match(self.transcript_id).group(1)

    def get_tss(Transcript self):
        if self.strand == 0:
            return self.end
        elif self.strand == 1:
            return self.start
        else:
            raise ValueError('strand is neither 1 or 0')

    def get_end(Transcript self):
        return self.end

    def get_strand(Transcript self):
        return _NUM2STRAND[self.strand]

    def __str__(self):
        tr_p = "Transcript {tid} with {n} exons:\n".format(
            tid=self.transcript_id,
            n=self.n)
        exons_p = "\n".join([str(ex) for ex in self.exons])
        return tr_p + exons_p

    def __getitem__(Transcript self, key):
        return self.exons[key]

    def __iadd__(Transcript self, Exon exon):
        """Add an exon to the Transcript"""
        if self.exons == []:
            self.chrom = exon.chrom
            self.start = exon.start
            self.end = exon.end
            self.strand = exon.strand
            self.gene_id = exon.gene_id
            self.transcript_id = exon.transcript_id
        elif self.end < exon.end:
            self.end = exon.end
        self.exons.append(exon)
        self.n = len(self.exons)
        return self

    def __len__(Transcript self):
        return self.n

    def __repr__(Transcript self):
        return " ".join([self.transcript_id, str(self.found),
                         "|".join(self.similar_ls)])

    def __richcmp__(Transcript self, Transcript other, int op):
        if op == 2:  # 2 is ==
            i = 0
            #  to counts 0 for the same and self.gene_id != other.gene_id:
            if self.n == other.n:
                for i in range(self.n):
                    if not self.exons[i] == other.exons[i]:
                        return False
                return True
            else:
                return False
        else:
            print("WARNING: ONLY USE == WHAT YOU ARE DOING IS WRONG")

    def to_bed(Transcript self):
        return "\t".join([_NUM2CHR[self.chrom], str(self.start), str(self.end),
                          self.transcript_id, str(self.n),
                          _NUM2STRAND[self.strand]])

    def to_list(Transcript self):
        ret = [_NUM2CHR[self.chrom], self.start, self.end,
               self.transcript_id, self.n, _NUM2STRAND[self.strand],
               self.found, "|".join(self.similar_ls), "|".join(self.sample_ls)]
        return ret


def get_transcr_from_gtf(filename):
    """From a GTF file, extract all the transcripts and exons and return
    a list of all transcripts objects.
    """
    my_transcripts = []
    n = 0
    old_start = 0
    for line in open(filename):
        if line.startswith('#'):
            continue
        else:
            sp = line.split('\t')
            old_start = long(sp[3])
            if sp[2] == 'transcript':
                tr = Transcript()
                my_transcripts.append(tr)
                n += 1
            elif sp[2] == 'exon':
                tr += Exon(line)
    return my_transcripts, n


cdef void transcripts_counter(object transcripts_ls, unsigned long n):
    """Counts occurence of the transcripts in the list"""
    cdef unsigned long found
    cdef long i, j
    for i in range(n):
        found = 0
        for j in range(n):
            if transcripts_ls[i] == transcripts_ls[j]:
                found += 1
                transcripts_ls[i].add_similar(transcripts_ls[j].get_tid())
                transcripts_ls[i].add_sample(transcripts_ls[j].get_sample())
        transcripts_ls[i].set_found(found)


def get_transc_counts_from_gtf(gtffile):
    """Counts occurence of the transcripts in the gtf
    and return the resutls in a list"""
    transcripts, n = get_transcr_from_gtf(gtffile)
    transcripts_counter(transcripts, n)
    return [t.to_list() for t in transcripts]


def tests():
    f = '../tests/all_chim.gtf'
    a, n = get_transcr_from_gtf(f)
    print(a[:5])
    transcripts_counter(a, n)
    for i in range(5):
        print(a[i].get_found())
        print(a[i].get_similar())


if __name__ == "__main__":
    tests()
