import sys, os

import re
import Queue
import pickle
import numpy as np
import multiprocessing

from random import random
from pysam import Fastafile
from time import sleep, clock
from collections import defaultdict
from itertools import product, repeat

COMP_BASES = {'A':'T', 'T':'A', 'C':'G', 'G':'C'}
GENE_BLOCK_SIZE = 50
BLOCK_REPORT = 500
WARNINGS = True
NON_RNA_PAT = re.compile('[^ACGU]')
FILTER_LOW_CMPLX = True

def reverse_complement(seq):
    """ Emulate Biopython reverse_complement method, but faster
    """
    def get_comp_base(base):
        try:
            comp_base = COMP_BASES[base]
        except KeyError:
            # if there is a non-standard base then just add it to the sequence
            comp_base = base

        return comp_base

    rev_comp_seq = ''.join(map(get_comp_base, reversed(seq)))

    return rev_comp_seq

def get_rna_seq(chrm, strand, regions, fasta):
    """ Return the translation of the given genomic coding regions
    """
    # get dna sequence from fasta object
    if not chrm.startswith('chr'):
        chrm = 'chr' + chrm
    dna_seq = ''

    for start, stop in regions:
        # fetch the exon seq and check that
        # the correct number of bases were retrieved
        exon_seq = fasta.fetch(chrm, start - 1, stop)
        if len(exon_seq) != stop - start + 1:
            raise LookupError
        dna_seq += exon_seq

    # lower case letters indicate low complexity
    if not dna_seq.isupper():
        # if filtering low complexity regions remove this region
        if FILTER_LOW_CMPLX:
            raise LookupError
        # if not filtering convert to upper case
        else:
            dna_seq = dna_seq.upper()

    # reverse complement if necessary
    if strand == '-':
        dna_seq = reverse_complement(dna_seq)
    rna_seq = re.sub('T', 'U', dna_seq)

    return rna_seq

def get_mer_seq(fasta, chrm, strand, mer_positions):
    """ Retreive and store DNA sequence from each mer
    """
    gene_mer_dist = defaultdict(int)

    if WARNINGS: failed_mers = []
    for mer_pos in mer_positions:
        try:
            mer_seq = get_rna_seq(chrm, strand, mer_pos, fasta)
            if NON_RNA_PAT.search(mer_seq):
                raise LookupError
        except LookupError:
            if WARNINGS: failed_mers.append(mer_pos)
            continue
        gene_mer_dist[mer_seq] += 1

    err_mes = None
    if WARNINGS and len(failed_mers) > 0:
        min_pos = min(reg[0] for mer_pos in failed_mers for reg in mer_pos)
        max_pos = max(reg[1] for mer_pos in failed_mers for reg in mer_pos)
        err_mes = '{}:{}-{}:{}'.format(
            chrm, str(min_pos), str(max_pos), strand)

    return gene_mer_dist, err_mes

def process_gene(gene_exons, gene_splices, mer_len):
    gene_mer_pos = []
    def add_spliced_mers(exons):
        """ Add mer positions which span one or more splice junction
        """
        # note reversed order so that short first exons are handeled
        for i in reversed(range(min(mer_len - 1, exons[-1][1] - exons[-1][0]))):
            mer_pos = [(exons[-1][0], exons[-1][0] + i), ]
            current_len = i + 1
            next_exon = -2
            while current_len < mer_len:
                # if this is a small first exon
                if -1 * next_exon > len(exons): break

                # if this is a small exon we add the entire region
                if exons[next_exon][1] - exons[next_exon][0] + 1 < \
                        mer_len - current_len:
                    mer_pos.insert(0, exons[next_exon])
                    current_len += exons[next_exon][1] - exons[next_exon][0] + 1
                    next_exon -= 1
                # else finish this spliced mer
                else:
                    mer_pos.insert(
                        0, (exons[next_exon][1] - mer_len + current_len + 1,
                            exons[next_exon][1]))
                    current_len = mer_len

            # if trans has a short first exon mer may not be done here
            if current_len < mer_len: break

            gene_mer_pos.append(tuple(mer_pos))

        return

    for splice_exons in gene_splices:
        # add mers crossing a splice junction
        add_spliced_mers(splice_exons)

    # create set of all mer pos for each mer within a full exon
    exonic_mers = set((tuple([exon[0] + mer_start, exon[0] + mer_start + mer_len - 1])
                       for exon in gene_exons
                       for mer_start in range(exon[1] - exon[0] - mer_len + 2)))
    [gene_mer_pos.append([exon_mer, ]) for exon_mer in exonic_mers]

    return gene_mer_pos

def get_trans_elements(regions, mer_len):
    def get_splice_pos(exons):
        if len(exons) <= 1: return None

        # initialize splice with mer_len positions into the downstream exon
        if exons[-1][1] - exons[-1][0] + 1 >= mer_len:
            splice_pos = [(exons[-1][0], exons[-1][0] + mer_len - 1), ]
        # if this exon is too short then add the whole exon
        else:
            splice_pos = [exons[-1], ]

        # add upstream exons in order to define the splice with mer_len pos's
        current_len = 0
        curr_exon = -2
        while True:
            # if there are no more exons to add breaks
            if -1 * curr_exon > len(exons): break

            # if the next exon has enough length then add it and break
            if (current_len + exons[curr_exon][1] -
                exons[curr_exon][0]) >= mer_len:
                splice_pos.insert(
                    0, (exons[curr_exon][1] - mer_len + current_len + 2, 
                        exons[curr_exon][1]))
                break

            # add the next exon and break
            splice_pos.insert(0, exons[curr_exon])
            current_len += exons[curr_exon][1] - exons[curr_exon][0] + 1
            curr_exon -= 1

        # if there is not at least mer_len to return then return None
        if sum(reg[1] - reg[0] + 1 for reg in splice_pos) < mer_len:
            return None
        return tuple(splice_pos)

    trans_exons = []
    trans_splices = []

    current_exons = [regions[0], ]
    for exon in regions[1:]:
        # if this is a CDS-exon bondary collapse it
        if current_exons[-1][1] == exon[-1] + 1:
            current_exons[-1] = (current_exons[-1][0], exon[1])
            continue

        # add possibly collapsed exons to transcript exons and splices
        trans_exons.append(current_exons[-1])
        splice_pos = get_splice_pos(current_exons)
        if splice_pos is not None:
            trans_splices.append(get_splice_pos(current_exons))

        # update current_exons to include exon and remove those no longer needed
        new_current_exons = [current_exons[-1], exon]
        while (sum(reg[1] - reg[0] + 1
                   for reg in new_current_exons[:-1]) < mer_len - 1 and
               len(current_exons) >= len(new_current_exons)):
            new_current_exons.insert(0, current_exons[
                    len(current_exons) - len(new_current_exons)])
        current_exons = new_current_exons

    # add possibly collapsed exons to transcript exons and splices
    trans_exons.append(current_exons[-1])
    splice_pos = get_splice_pos(current_exons)
    if splice_pos is not None:
        trans_splices.append(get_splice_pos(current_exons))

    return trans_exons, trans_splices

def process_gene_worker(genes_queue, mers_queue, fasta_fn, mer_len):
    # create fasta within sub-processes
    fasta = Fastafile(fasta_fn)

    while not genes_queue.empty():
        # get gene data from queue
        try:
            genes_block = genes_queue.get(block=False)
        except Queue.Empty:
            break

        mers_block = []
        for gene_id, gene_data in genes_block:
            # skip trans-genes
            if (not all(gene_data[0][0] == c for c in zip(*gene_data)[0]) and
                not all(gene_data[0][1] == s for s in zip(*gene_data)[1])):
                sys.stderr.write('Not processing gene ' + gene_id + ' as it ' +
                                 'spans multiple chromosomes or strands.')
                continue

            # group transcripts from this locus
            locus_trans = defaultdict(list)
            for chrm, strand, trans_id, start, stop in gene_data:
                locus_trans[trans_id].append((start, stop))

            # store unique mers from each gene
            gene_exons = set()
            gene_splices = set()
            for regions in locus_trans.itervalues():
                trans_exons, trans_splices = get_trans_elements(
                    sorted(regions), mer_len)
                gene_exons.update(trans_exons)
                gene_splices.update(trans_splices)

            gene_mer_pos = process_gene(gene_exons, gene_splices, mer_len)
            gene_mer_dist, err_mes = get_mer_seq(
                fasta, chrm, strand, gene_mer_pos)
            if err_mes is not None:
                mers_queue.put(err_mes)
            mers_block.append((gene_id, gene_mer_dist))

        mers_queue.put(mers_block)

    return

def process_all_genes(raw_trans, num_threads, fasta_fn, mer_len, save_fn):
    # setup genes queue
    manager = multiprocessing.Manager()
    genes_queue = manager.Queue()
    mers_queue = manager.Queue()
    genes_block = []
    # fill genes_queue with blocks of genes to be processed
    for gene_dat in raw_trans.iteritems():
        if len(genes_block) >= GENE_BLOCK_SIZE:
            genes_queue.put(genes_block)
            genes_block = []
        genes_block.append(gene_dat)
    if len(genes_block) > 0:
        genes_queue.put(genes_block)

    # spawn threads to perform conserved domain searches and psort localizations
    args = (genes_queue, mers_queue, fasta_fn, mer_len)
    processes = []
    for thread_id in xrange(num_threads):
        p = multiprocessing.Process(target=process_gene_worker, args=args)
        p.start()
        processes.append(p)

    if VERBOSE:
        sys.stderr.write('Each dot represents ' + str(BLOCK_REPORT) +
                         ' genes completed: ')
        gene_count = 0

    all_mer_seqs = []
    err_mess = []
    # process output queues while processes are filling it
    while any(p.is_alive() for p in processes):
        try:
            mers_block = mers_queue.get(block=False)
        except Queue.Empty:
            sleep(1)
            continue

        # if this is a true mer block store it
        if isinstance(mers_block, list):
            all_mer_seqs.extend(mers_block)
        # else it is an error message
        else:
            err_mess.append(mers_block)

        if VERBOSE:
            gene_count += len(mers_block) \
                if isinstance(mers_block, list) else 0
            if gene_count >= BLOCK_REPORT:
                sys.stderr.write('.')
                sys.stderr.flush()
                gene_count = gene_count - BLOCK_REPORT

    # empty any remaining blocks after processes have finished
    while not mers_queue.empty():
        mers_block = mers_queue.get(block=False)
        if isinstance(mers_block, list):
            all_mer_seqs.extend(mers_block)
        else:
            err_mess.append(mers_block)

        if VERBOSE:
            gene_count += len(mers_block) \
                if isinstance(mers_block, list) else 0
            if gene_count >= BLOCK_REPORT:
                sys.stderr.write('.')
                sys.stderr.flush()
                gene_count = gene_count - BLOCK_REPORT

    if VERBOSE: sys.stderr.write('\n')

    if WARNINGS and len(err_mess) > 0:
        sys.stderr.write('DNA mers contained invalid genomic sequence:\n')
        sys.stderr.write('\t' + ','.join(err_mess) + '\n')

    if VERBOSE: sys.stderr.write('Saving motif distribution object for later runs...\n')
    # save position of gene names and motif sequences in dictionaries
    gene_indicies = dict(zip(zip(*all_mer_seqs)[0],
                             range(len(all_mer_seqs))))
    mer_indicies = dict(zip((''.join(bases) for bases in 
                             product(*repeat(('A','C','G','U'), mer_len))), 
                            range(4**mer_len)))
    # set up mer matrix to store
    all_mer_dist = np.zeros((len(all_mer_seqs), 4**mer_len), dtype='int16')
    for gene_id, gene_mer_dist in all_mer_seqs:
        # only update those indicies which need to be updated
        gene_mer_indicies = [mer_indicies[mer_seq]
                             for mer_seq in gene_mer_dist.keys()]
        all_mer_dist[gene_indicies[gene_id], gene_mer_indicies] = \
            all_mer_dist[gene_indicies[gene_id], gene_mer_indicies] + \
            [mer_count for mer_count in gene_mer_dist.values()]

    with open(save_fn, 'w') as fp:
        pickle.dump([all_mer_dist, gene_indicies, mer_indicies], fp)

    return all_mer_dist, gene_indicies, mer_indicies

def parse_gtf(gtf_fn):
    raw_trans = defaultdict(list)
    gtf_fp = open(gtf_fn)
    def add_region(line):
        chrm, source, reg_type, start, stop, foo2, strand, foo3, foo4, \
            gene_id, foo5, trans_id = line.split()
        raw_trans[gene_id[1:-2]].append((
                chrm, strand, trans_id[1:-2], int(start), int(stop)))
        return gene_id[1:-2]

    if TRACK:
        track_line = gtf_fp.readline()
    num_loci = len(set(add_region(line) for line in gtf_fp))

    if VERBOSE:
        sys.stderr.write('Num Loci: ' + str(num_loci) + '\n')

    return raw_trans

def parse_arguments():
    global VERBOSE
    global TRACK

    import argparse
    parser = argparse.ArgumentParser(
        description='Run .' )
    parser.add_argument('--gtf',
                        help='GTF annotation file.')
    parser.add_argument('--fasta',
                        help='Fasta file of genome sequence.')

    parser.add_argument('--mer-length', '-l', default=7, type=int,
                        help='Length of mers to search for enrichment. ' +
                        'Default: %(default)i')

    parser.add_argument('--verbose', '-v', default=False, action='store_true',
                        help='Whether or not to print status information.')
    parser.add_argument('--track', '-t', default=False, action='store_true',
                        help='Whether or not gtf contains a track line.')
    parser.add_argument('--processes', '-p', default=1, type=int,
                        help='Number of processes for getting unique mers. ' +
                        'Default: %(default)i')
    args = parser.parse_args()

    if not ((args.gtf is not None and args.fasta is not None) or
            (args.saved_mers is not None)):
        sys.stderr.write('Must provide a gtf and fasta file or ' +
                         'saved mers file.\n\n')
        parser.print_help()
        sys.exit(1)

    VERBOSE = args.verbose
    TRACK = args.track

    return args.gtf, args.fasta, args.mer_length, args.processes

def main():
    gtf_fn, fasta_fn, mer_len, num_threads = parse_arguments()

    if VERBOSE: sys.stderr.write('Parsing GTF... ')
    raw_trans = parse_gtf(gtf_fn)

    if VERBOSE: sys.stderr.write('Getting mer sequences...\n')
    all_mer_dist, gene_indicies, mer_indicies = process_all_genes(
        raw_trans, num_threads, fasta_fn, mer_len,
        '{}.{}.mer_len{}.saved_mers.pickle'.format(
            gtf_fn, fasta_fn, str(mer_len)))

    return

if __name__ == '__main__':
    main()
