"""
Basic biology related functions
"""
import string
import array
import os
from pyfasta import Fasta, NpyFastaRecord
from fileOps import opengz


class UpperNpyFastaRecord(NpyFastaRecord):
    """
    Used when we want only upper case records.
    If as_string is False, will no longer return a memmap object but instead a list.
    """
    def __getitem__(self, islice):
        d = self.getdata(islice)
        return d.tostring().decode().upper() if self.as_string else map(string.upper, d)


def read_fasta(path_or_handle, validate='DNA'):
    """iteratively yields a sequence for each '>' it encounters, ignores '#' lines
    if validate is true, will ensure that each row contains valid DNA fasta characters
    """
    assert validate in ['DNA', 'protein', None], "Valid options for validate are DNA, protein or None"
    if isinstance(path_or_handle, str):
        fh = opengz(path_or_handle)
    else:
        fh = path_or_handle
    line = fh.readline()
    chars_to_remove = "\n "
    if validate is 'DNA':
        valid_chars = set('ACGTUYSWKMBDHVNacgtuyswkmbdhvn.-*')
    elif validate is 'protein':
        valid_chars = set('ABCDEFGHIKLMPQSRTVWXYZUabcdefghiklmpqsrtvwxyzuNn.-*')
    else:
        valid_chars = set()
    while line != '':
        if line[0] == '>':
            name = line[1:-1]
            line = fh.readline()
            seq = array.array('c')
            while line != '' and line[0] != '>':
                line = line.translate(None, chars_to_remove)
                if len(line) > 0 and line[0] != '#':
                    seq.extend(line)
                line = fh.readline()
            if validate is not None:
                try:
                    assert all(x in valid_chars for x in seq)
                except AssertionError:
                    bad_chars = {x for x in seq if x not in valid_chars}
                    raise RuntimeError("Invalid FASTA character(s) seen in fasta sequence: {}".format(bad_chars))
            yield name, seq.tostring()
        else:
            line = fh.readline()
    if isinstance(path_or_handle, str):
        fh.close()


def write_fasta(path_or_handle, name, seq, chunk_size=100, validate=None):
    """Writes out fasta file. if path ends in gz, will be gzipped.
    """
    if isinstance(path_or_handle, str):
        fh = opengz(path_or_handle, 'w')
    else:
        fh = path_or_handle
    if validate is 'DNA':
        valid_chars = set('ACGTUYSWKMBDHVNacgtuyswkmbdhvn.-*')
    elif validate is 'protein':
        valid_chars = set('ABCDEFGHIKLMPQSRTVWXYZUabcdefghiklmpqsrtvwxyzuNn.-*')
    else:
        valid_chars = set()
    try:
        assert any([isinstance(seq, unicode), isinstance(seq, str)])
    except AssertionError:
        raise RuntimeError("Sequence is not unicode or string")
    if validate is not None:
        try:
            assert all(x in valid_chars for x in seq)
        except AssertionError:
            bad_chars = {x for x in seq if x not in valid_chars}
            raise RuntimeError("Invalid FASTA character(s) seen in fasta sequence: {}".format(bad_chars))
    fh.write(">%s\n" % name)
    for i in xrange(0, len(seq), chunk_size):
        fh.write("%s\n" % seq[i:i+chunk_size])
    if isinstance(path_or_handle, str):
        fh.close()


def complement(seq, comp=string.maketrans("ATGCatgc", "TACGtacg")):
    """
    given a sequence, return the complement.
    """
    return str(seq).translate(comp)


def reverse_complement(seq):
    """
    Given a sequence, return the reverse complement.
    """
    return complement(seq)[::-1]


_codon_table = {
    'ATG': 'M',
    'TAA': '*', 'TAG': '*', 'TGA': '*', 'TAR': '*', 'TRA': '*',
    'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', 'GCN': 'A',
    'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', 'AGA': 'R',
    'AGG': 'R', 'CGN': 'R', 'MGR': 'R',
    'AAT': 'N', 'AAC': 'N', 'AAY': 'N',
    'GAT': 'D', 'GAC': 'D', 'GAY': 'D',
    'TGT': 'C', 'TGC': 'C', 'TGY': 'C',
    'CAA': 'Q', 'CAG': 'Q', 'CAR': 'Q',
    'GAA': 'E', 'GAG': 'E', 'GAR': 'E',
    'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', 'GGN': 'G',
    'CAT': 'H', 'CAC': 'H', 'CAY': 'H',
    'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATH': 'I',
    'TTA': 'L', 'TTG': 'L', 'CTT': 'L', 'CTC': 'L', 'CTA': 'L',
    'CTG': 'L', 'YTR': 'L', 'CTN': 'L',
    'AAA': 'K', 'AAG': 'K', 'AAR': 'K',
    'TTT': 'F', 'TTC': 'F', 'TTY': 'F',
    'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', 'CCN': 'P',
    'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', 'AGT': 'S',
    'AGC': 'S', 'TCN': 'S', 'AGY': 'S',
    'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', 'ACN': 'T',
    'TGG': 'W',
    'TAT': 'Y', 'TAC': 'Y', 'TAY': 'Y',
    'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', 'GTN': 'V',
    '': ''
    }


def codon_to_amino_acid(c):
    """
    Given a codon C, return an amino acid or ??? if codon unrecognized.
    Codons could be unrecognized due to ambiguity in IUPAC characters.
    """
    assert len(c) == 3, c
    if c is None:
        return None
    if c in _codon_table:
        return _codon_table[c]
    return '?'


def translate_sequence(sequence):
    """
    Translates a given DNA sequence to single-letter amino acid
    space. If the sequence is not a multiple of 3 and is not a unique degenerate codon it will be truncated silently.
    """
    result = []
    sequence = sequence.upper()
    i = 0
    for i in xrange(0, len(sequence) - len(sequence) % 3, 3):
        result.append(codon_to_amino_acid(sequence[i: i + 3]))
    if len(sequence) % 3 == 2:
        c = codon_to_amino_acid(sequence[i + 3:] + 'N')
        if c != '?':
            result.append(c)
    return ''.join(result)


def read_codons(seq, offset=0, skip_last=True):
    """
    Provides an iterator that reads through a sequence one codon at a time.
    """
    l = len(seq)
    if skip_last:
        l -= 3
    for i in xrange(offset,  l - l % 3, 3):
            yield seq[i:i + 3]


def read_codons_with_position(seq, offset=0, skip_last=True):
    """
    Provides an iterator that reads through a sequence one codon at a time,
    returning both the codon and the start position in the sequence.
    """
    l = len(seq)
    if skip_last:
        l -= 3
    for i in xrange(offset, l - l % 3, 3):
            yield i, seq[i:i + 3]


def get_sequence_dict(file_path, upper=True):
    """
    Returns a dictionary of fasta records. If upper is true, all bases will be uppercased.
    """
    assert os.path.exists(file_path), ('Error: FASTA file {} does not exist'.format(file_path))
    gdx_path = file_path + ".gdx"
    assert os.path.exists(gdx_path), ("Error: gdx does not exist for this fasta. We need the fasta files to be "
                                      "flattened in place prior to running the pipeline because of concurrency issues.")
    flat_path = file_path + '.flat'
    assert os.path.exists(flat_path), ("Error: flat file does not exist for this fasta. We need the fasta files to be "
                                       "flattened in place prior to running the pipeline because of concurrency issues.")
    if upper is True:
        return Fasta(file_path, record_class=UpperNpyFastaRecord)
    else:
        return Fasta(file_path)
