#!/usr/bin/env python
from __future__ import print_function
import sys
import os
import re
import argparse
import pandas as pd
from collections import defaultdict, namedtuple
my_bin_dir = os.path.dirname(os.path.realpath(sys.argv[0]))
sys.path.insert(0, os.path.join(my_bin_dir, "../tools"))
from transcripts import gene_pred_iterator
from defaultOrderedDict import DefaultOrderedDict
from intervals import ChromosomeInterval

verbose = False

def parse_args():
    desc = """convert a CAT generated genePred and associated information to an NCBI submission table file
      The following FATAL table2asn errors in the report file can be ignored:
         FATAL: JOINED_FEATURES_NO_EXCEPTION'
         FATAL: CONTAINED_CDS
         FATAL: BACTERIAL_PARTIAL_NONEXTENDABLE_PROBLEMS
      The following table2asn report message can be ignored
         FEATURE_LOCATION_CONFLICT: Coding region location does not match gene location
         GENE_PARTIAL_CONFLICT: 1 feature locations conflict with partialness of overlapping gene (and similar messages)
         SHORT_INTRON: 1 intron is shorter than 10 nt
         SUSPECT_PRODUCT_NAMES: 12 features are all capital letters
    """
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--verbose', action="store_true", default=False,
                        help="run on some tracing")
    parser.add_argument("cat_genepred",
                        help="input CAT genePred results")
    parser.add_argument("cat_genepred_info",
                        help="associated genePred info")
    parser.add_argument("locus_tag_prefix",
                        help="locus tag prefix to use")
    parser.add_argument("ncbi_tbl_file",
                        help="output NCBI table file")
    args = parser.parse_args()
    global verbose
    verbose = args.verbose
    return args

# this maps Ensembl/GENCODE biotypes to feature-level identifiers
# non-transcribed pseudogenes are called mRNAs as they are mostly
# annotated from protein alignments.  Transcribed pseudogenes are called
# ncRNAs, because they actually produce a transcript that could
# have non-coding function.
biotype_map = {
    '3prime_overlapping_ncRNA': 'ncRNA',
    'IG_C_gene': 'mRNA',
    'IG_C_pseudogene': 'mRNA',
    'IG_D_gene': 'mRNA',
    'IG_J_gene': 'mRNA',
    'IG_V_gene': 'mRNA',
    'IG_V_pseudogene': 'mRNA',
    'IG_pseudogene': 'mRNA',
    'Mt_rRNA': 'rRNA',
    'Mt_tRNA': 'tRNA',
    'TEC': 'mRNA',
    'TR_C_gene': 'mRNA',
    'TR_J_gene': 'mRNA',
    'TR_J_pseudogene': 'mRNA',
    'TR_V_gene': 'mRNA',
    'TR_V_pseudogene': 'ncRNA',
    'antisense': 'ncRNA',
    'antisense_RNA': 'ncRNA',
    'bidirectional_promoter_lncRNA': 'ncRNA',
    'lincRNA': 'ncRNA',
    'macro_lncRNA': 'ncRNA',
    'miRNA': 'ncRNA',
    'misc_RNA': 'ncRNA',
    'non_coding': 'ncRNA',
    'non_stop_decay': 'ncRNA',
    'nonsense_mediated_decay': 'mRNA',
    'polymorphic_pseudogene': 'mRNA',
    'processed_pseudogene': 'ncRNA',
    'processed_transcript': 'ncRNA',
    'protein_coding': 'mRNA',
    'pseudogene': 'mRNA',
    'rRNA': 'rRNA',
    'retained_intron': 'ncRNA',
    'ribozyme': 'ncRNA',
    'sRNA': 'ncRNA',
    'scaRNA': 'ncRNA',
    'scRNA': 'ncRNA',
    'sense_intronic': 'ncRNA',
    'sense_overlapping': 'ncRNA',
    'snRNA': 'ncRNA',
    'snoRNA': 'ncRNA',
    'transcribed_processed_pseudogene': 'ncRNA',
    'transcribed_unitary_pseudogene': 'ncRNA',
    'transcribed_unprocessed_pseudogene': 'ncRNA',
    'translated_unprocessed_pseudogene': 'ncRNA',
    'translated_processed_pseudogene': 'ncRNA',
    'unitary_pseudogene': 'mRNA',
    'unknown_likely_coding': 'mRNA',
    'unprocessed_pseudogene': 'ncRNA',
    'vaultRNA': 'ncRNA'}


# for non-coding RNAs, this defines the ncrna_class key/value pair
ncrna_class = {
    '3prime_overlapping_ncrna': 'other',
    'antisense': 'antisense_RNA',
    'bidirectional_promoter_lncrna': 'lncRNA',
    'lincRNA': 'lncRNA',
    'macro_lncRNA': 'lncRNA',
    'miRNA': 'miRNA',
    'non_stop_decay': 'other',
    'processed_pseudogene': 'other',
    'processed_transcript': 'other',
    'pseudogene': 'other',
    'retained_intron': 'other',
    'ribozyme': 'ribozyme',
    'sRNA': 'other',
    'scRNA': 'other',
    'scaRNA': 'other',
    'sense_intronic': 'other',
    'sense_overlapping': 'other',
    'snRNA': 'snRNA',
    'snoRNA': 'snoRNA',
    'transcribed_processed_pseudogene': 'other',
    'transcribed_unitary_pseudogene': 'other',
    'transcribed_unprocessed_pseudogene': 'other',
    'translated_unprocessed_pseudogene': 'other',
    'unitary_pseudogene': 'other',
    'unprocessed_pseudogene': 'other',
    'vaultRNA': 'vault_RNA'}


# for pseudogenes, this defines the pseudogene key/value pair
pseudo_map = {
    'IG_C_pseudogene': 'unknown',
    'IG_V_pseudogene': 'unknown',
    'TR_J_pseudogene': 'unknown',
    'TR_V_pseudogene': 'unknown',
    'IG_pseudogene': 'unknown',
    'processed_pseudogene': 'processed',
    'pseudogene': 'unknown',
    'transcribed_processed_pseudogene': 'processed',
    'transcribed_unprocessed_pseudogene': 'unprocessed',
    'translated_unprocessed_pseudogene': 'unprocessed',
    'unprocessed_pseudogene': 'unprocessed',
    'translated_processed_pseudogene': 'processed'}


class TranscriptData(namedtuple("TranscriptData", ("tx", "attrs"))):
    """transcript annotation and attributes"""
    __slots__ = ()

def load_annotations(gp_file, gp_info_file):
    """load all transcript into OrderedDict of (chrom, gene) of TranscriptAttrs"""
    txd_by_chrom_gene = DefaultOrderedDict(lambda: DefaultOrderedDict(list))
    attrs = pd.read_csv(gp_info_file, sep='\t')
    attrs = attrs.set_index(['transcript_id'])
    for tx in gene_pred_iterator(gp_file):
        txd_by_chrom_gene[tx.chromosome][tx.name2].append(TranscriptData(tx, attrs.ix[tx.name]))
    return txd_by_chrom_gene
        

class NcbiTblWriter(object):
    """class to write NCBI tables.  Due to the need to examine transcripts
    before genes,  this queues entries to be written.  The push function
    saves existing entries and starts a new ones.  The pop() function
    write current entries and removes them from the top of stack."""

    def __init__(self, fh):
        self.fh = fh
        self.recs = []
        self.saved = []

    def push(self):
        self.saved.append(self.recs)
        self.recs = []

    def pop(self):
        "replace top with pushed"
        assert len(self.recs) == 0
        self.recs = self.saved.pop()

    def flush(self):
        "write to file"
        for rec in self.recs:
            print(rec, file=self.fh)
        self.recs = []

    def start_seq(self, seqname):
        self.recs.append(">Features {}".format(seqname))
        
    def __write_feature_region(self, region, strand, feature_key, start_incmpl, end_incmpl):
        "write one region range from list, feature_key should be None for continuations"
        start, stop = str(region[0] + 1), str(region[1])
        if strand == '-':
            start, stop = stop, start
        if start_incmpl:
            start = "<{}".format(start)
        if end_incmpl:
            stop = ">{}".format(stop)
        row = [start, stop]
        if feature_key is not None:
            row.append(feature_key)
        self.recs.append("\t".join(row))

    def write_feature(self, regions, strand, feature_key, start_incmpl=False, end_incmpl=False):
        """Write feature definition rows.  Ranges is a list in the form
        ((start1, end1), (start2, end2), ...), for each regions of the feature.
        The range are zero-based, half-open, and coordinates will be swapped on negative strand."""

        # correct order for strand
        regions = list(sorted(regions, key=lambda r: r[0] if strand == '+' else -r[1]))
        if strand == '-':
            start_incmpl, end_incmpl = end_incmpl, start_incmpl
        self.__write_feature_region(regions[0], strand, feature_key,
                                    start_incmpl=start_incmpl, end_incmpl=False)
        for region in regions[1:]:
            self.__write_feature_region(region, strand, None,
                                        start_incmpl=None,
                                        end_incmpl=(end_incmpl and (region == regions[-1])))

    def write_qualifier(self, qualifier_key, qualifier_value=None):
        """write a qualifier with optional value"""
        row = ["", "", "", qualifier_key]
        if qualifier_value is not None:
            row += [qualifier_value]
        self.recs.append("\t".join(row))

    def write_note(self, note):
        self.write_qualifier("note", note)


class CdsSpec(object):
    """Data collected defining CDS"""
    def __init__(self, strand):
        self.strand = strand
        self.regions = []     # CDS regions, frame adjusted
        self.no_stop_regions = []  # CDS regions up to the first in-frame stop
        self.gaps = []
        # start/end completeness (start DNA + strand start)
        self.start_incmpl = self.end_incmpl = None

    @property
    def cds_len(self):
        return sum(([e[1]-e[0] for e in self.regions]))
        
    @property
    def no_stop_len(self):
        return sum(([e[1]-e[0] for e in self.no_stop_regions]))

    def is_mult_three(self):
        "is CDS length a multiple of three?"
        return (self.cds_len % 3) == 0
        
def frame_incr(frame, amt=1):
    """increment frame by positive or negative amount"""
    if frame >= 0:
        return (frame + amt) % 3
    else:
        amt3 = (-amt) % 3
        return (frame - (amt - amt3)) % 3

def adjust_cds_start(cds_interval, expected_frame, frame):
    """adjust cds_interval to match the expected frame.  It is possible
    for the cds_interval to become zero"""
    amt = 0
    # this could be calculated rather than increment by in a loop, however
    # this is easier for the feeble minded
    while frame != expected_frame:
        frame = frame_incr(frame)
        amt += 1
    # min/max here avoids going negative, making a zero-length block
    if cds_interval.strand == '+':
        gap = (cds_interval.start, cds_interval.start + amt)
        start = min(cds_interval.start + amt, cds_interval.stop)
        stop = cds_interval.stop
    else:
        gap = (cds_interval.stop - amt, cds_interval.stop)
        start = cds_interval.start
        stop = max(cds_interval.stop - amt, cds_interval.start)
    cds_interval = ChromosomeInterval(cds_interval.chromosome, start, stop, cds_interval.strand)
    if verbose:
        print("  adj", cds_interval, "[{}]".format(len(cds_interval)), "frame", "{}..{}".format(frame, frame_incr(frame, len(cds_interval))), amt, "gap", gap, file=sys.stderr)
    return gap, cds_interval


def inframe_stop_cds(cds_interval, ifs_position):
    """truncate if at or after an in-frame start"""
    if cds_interval.strand == '+':
        # don't allow stop to go past ifs_position
        start = cds_interval.start
        stop = max(min(cds_interval.stop, ifs_position), cds_interval.start)
    else:
        start = min(max(cds_interval.start, ifs_position), cds_interval.stop)
        stop = cds_interval.stop
    return ChromosomeInterval(cds_interval.chromosome, start, stop, cds_interval.strand)


def get_gene_locus_tag(locus_tag_prefix, txd):
    # remove species identifier from CAT names (Clint_Chim, etc)
    return txd.tx.name2.split('_')[-1]

def get_transcript_base_id(txd):
    "remove species identifier from CAT names (Clint_Chim, etc)"
    return txd.tx.name.split('_')[-1]

    # transcript ID is the gnl|${locus_tag_prefix}|${tx_name} format

def get_transcript_gnl_id(locus_tag_prefix, txd):
    # transcript ID is the gnl|${locus_tag_prefix}|${tx_name} format
    return 'gnl|{}|{}'.format(locus_tag_prefix, get_transcript_base_id(txd))

def get_transcript_feature(txd):
    if txd.tx.cds_size > 0:
        return 'mRNA'
    else:
        # for non-coding, make use of translation table above to get feature type
        return biotype_map[txd.attrs.transcript_biotype]

def is_non_coding(txd):
    return txd.attrs.transcript_biotype in ncrna_class
    
def trans_has_cds(txd):
    return txd.tx.cds_size > 0
    
def get_gene_name(txd):
    # GENCODE uses clone derived names for some genomic annotations, which NCBI doesn't
    # (understandably) like.  Drop these.
    if re.match('^A[CL][0-9]+\\.[0-9]+$', txd.attrs.source_gene_common_name):
        return None
    else:
        return txd.attrs.source_gene_common_name

def get_transcript_product(locus_tag_prefix, txd, trans_num):
    name = get_gene_name(txd)
    if name is None:
        name = get_transcript_base_id(txd)
    return "{}_{}".format(name, trans_num)

def get_transcript_protein_id(locus_tag_prefix, txd, trans_num):
    return "{}_{}_prot".format(get_transcript_base_id(txd), trans_num)

def write_protein_id(locus_tag_prefix, txd, trans_num, tblwr):
    tblwr.write_qualifier('protein_id', get_transcript_protein_id(locus_tag_prefix, txd, trans_num))

def write_product(locus_tag_prefix, txd, trans_num, tblwr):
    "write product qualifier as needed"
    # if we have a source gene name, use it
    # product tag must match the gene-level gene tag or the conversion program gets mad
    if isinstance(txd.attrs.source_gene, str):
        tblwr.write_qualifier("product", get_transcript_product(locus_tag_prefix, txd, trans_num))

def make_exon_idx_iter(txd):
    "make iterator exon indexes in order of transcriptions"
    if txd.tx.strand  == '+':
        return xrange(0, len(txd.tx.exon_intervals))
    else:
        return xrange(len(txd.tx.exon_intervals) - 1, -1, -1)
        
def add_cds_region(cds_interval, frame, expected_frame, ifs_position, cds_spec):
    """Add the next CDS region to the lists, adding gaps to if there are indels,
    as NCBI doesn't explicitly support frame shifts. Returns next expected frame"""
    # FIXME: /start_codon could be used rather adjust CDS at start
    if verbose:
        print("  cds", cds_interval, "[{}]".format(len(cds_interval)), "frame", frame, "exframe", expected_frame, file=sys.stderr)

    # adjust for frame
    if frame != expected_frame:
        gap, cds_interval = adjust_cds_start(cds_interval, expected_frame, frame)
        cds_spec.gaps.append(gap)

    if ifs_position is not None:
        no_stop_cds_interval = inframe_stop_cds(cds_interval, ifs_position)
    else:
        no_stop_cds_interval = cds_interval

    if len(cds_interval) != 0:
        cds_spec.regions.append((cds_interval.start, cds_interval.stop))
    if len(no_stop_cds_interval) != 0:
        cds_spec.no_stop_regions.append((no_stop_cds_interval.start, no_stop_cds_interval.stop))
    
    return frame_incr(expected_frame, len(cds_interval))

def compute_cds_completeness(txd, trans_start_incmpl, cds_spec):
    cds_spec.start_incmpl = cds_spec.end_incmpl = False

    # incompleteness checks based on frame, source gene and start/stop codons in genome
    trans_end_incmpl = not cds_spec.is_mult_three()

    if txd.tx.strand  == '+':
        if trans_start_incmpl or not txd.attrs.valid_start:
            cds_spec.start_incmpl = True
        if trans_end_incmpl or not txd.attrs.valid_start:
            cds_spec.end_incmpl = True
    else:
        if trans_start_incmpl or not txd.attrs.valid_start:
            cds_spec.end_incmpl = True
        if trans_end_incmpl or not txd.attrs.valid_start:
            cds_spec.start_incmpl = True

    # completeness based on source
    if txd.tx.cds_start_stat == "incmpl":
        cds_spec.start_incmpl = True
    if txd.tx.cds_end_stat == "incmpl":
        cds_spec.end_incmpl = True

def get_cds_regions(txd):
    """Get list of CDS regions and list of partial codons.  If there is a frame
    shift, adjust the CDS accordingly, inserting a gap.
    """
    cds_spec = CdsSpec(txd.tx.strand)
    ifs_position = None if pd.isnull(txd.attrs["ifs_position"]) else int(txd.attrs["ifs_position"])
    expected_frame = 0
    # traverse in transcription order
    trans_start_incmpl = None
    for iexon in make_exon_idx_iter(txd):
        cds_interval = txd.tx.exon_intervals[iexon].intersection(txd.tx.coding_interval)
        if cds_interval is not None:
            if trans_start_incmpl is None:
                trans_start_incmpl = (txd.tx.exon_frames[iexon] != 0)
            expected_frame = add_cds_region(cds_interval, txd.tx.exon_frames[iexon], expected_frame, ifs_position, cds_spec)

    compute_cds_completeness(txd, trans_start_incmpl, cds_spec)
    return cds_spec


def is_possible_unitary_pseudo(txd, cds_spec):
    """call as pseudo if reduced in length by > 30bp or 20% or more of
    the CDS lost."""
    cds_len = cds_spec.cds_len
    no_stop_len = cds_spec.no_stop_len
    assert no_stop_len <= cds_len
       
    return (cds_len > 0) and ((cds_len - no_stop_len) > 30) or (no_stop_len <= (0.80 * cds_len))

def write_cds_regions(locus_tag_prefix, txd, trans_num, cds_spec, tblwr):
    tx = txd.tx
    txa = txd.attrs
    tblwr.write_feature(cds_spec.no_stop_regions, tx.strand, "CDS",
                        cds_spec.start_incmpl, cds_spec.end_incmpl)

    # FIXME: codon_start could be set rather than advancing to first codon
    tblwr.write_qualifier('codon_start', "1")
    
    # product and protein id  must be same on mRNA and CDS
    write_product(locus_tag_prefix, txd, trans_num, tblwr)
    write_protein_id(locus_tag_prefix, txd, trans_num, tblwr)
 
    # coding transcripts must have a protein_id tag -- just use the name with _prot suffix
    tblwr.write_qualifier('transcript_id', get_transcript_gnl_id(locus_tag_prefix, txd))

    if len(cds_spec.gaps) > 0:
        tblwr.write_note("gaps were added to CDS to maintain frame")
        

def write_cds_gaps(locus_tag_prefix, txd, cds_spec, tblwr):
    tx = txd.tx
    txa = txd.attrs
    for cds_gap in cds_spec.gaps:
        tblwr.write_feature([cds_gap], tx.strand, "misc_feature")
        tblwr.write_note("gap added in CDS to maintain frame, possibly due to error in genome")


def write_cds(locus_tag_prefix, txd, trans_num, cds_spec, tblwr):
    if len(cds_spec.no_stop_regions):
        write_cds_regions(locus_tag_prefix, txd, trans_num, cds_spec, tblwr)
    if len(cds_spec.gaps) > 0:
        write_cds_gaps(locus_tag_prefix, txd, cds_spec, tblwr)

def get_transcript_exons(txd):
    return [(e.start, e.stop) for e in txd.tx.exon_intervals]

def write_transcript_features(locus_tag_prefix, txd, trans_num, cds_spec, unitary_pseudo, tblwr):
    tx = txd.tx
    txa = txd.attrs
    
    # feature with correct type and all exons
    tblwr.write_feature(get_transcript_exons(txd), tx.strand, get_transcript_feature(txd))
    
    tblwr.write_qualifier('transcript_id', get_transcript_gnl_id(locus_tag_prefix, txd))

    # if non-coding, fill ncRNA_class tag
    if is_non_coding(txd):
        tblwr.write_qualifier('ncRNA_class', ncrna_class[txa.transcript_biotype])

    # check for gaps
    if any(len(x) <= 50 for x in txd.tx.intron_intervals):
        tblwr.write_qualifier('exception', 'low-quality sequence region')

    # product and protein id  must be same on mRNA and CDS
    write_product(locus_tag_prefix, txd, trans_num, tblwr)
    write_protein_id(locus_tag_prefix, txd, trans_num, tblwr)

    if unitary_pseudo:
        tblwr.write_qualifier('pseudogene', "unitary")
        tblwr.write_note("Putative inactivated transcript (unitary pseudogene)")

    # record cat pipeline information
    tblwr.write_note("CAT transcript id: {}".format(get_transcript_base_id(txd)))
    tblwr.write_note("CAT alignment id: {}".format(txa.alignment_id))
    tblwr.write_note("CAT source transcript id: {}".format(txa.source_transcript))
    if txa.transcript_biotype != 'unknown_likely_coding':
        tblwr.write_note("CAT source GENCODE transcript biotype: {}".format(txa.transcript_biotype))
    else:
        tblwr.write_note('CAT novel prediction: {}'.format(txa.transcript_modes))

def convert_transcript(locus_tag_prefix, txd, trans_num, tblwr):
    if verbose:
        print("converting", txd.tx.name, file=sys.stderr)
    if trans_has_cds(txd):
        cds_spec = get_cds_regions(txd)
        unitary_pseudo = is_possible_unitary_pseudo(txd, cds_spec)
    else:
        cds_spec = None
        unitary_pseudo = False
    write_transcript_features(locus_tag_prefix, txd, trans_num, cds_spec, unitary_pseudo, tblwr)

    if trans_has_cds(txd) and not unitary_pseudo:
        write_cds(locus_tag_prefix, txd, trans_num, cds_spec, tblwr)
    return (trans_has_cds(txd), unitary_pseudo)

def write_gene_features(locus_tag_prefix, gene_txds, unitary_pseudo, tblwr):
    txd = gene_txds[0]  # arbitrary transcript
    tx = txd.tx
    txa = txd.attrs

    # coordinates
    start = min([txd.tx.start for txd in gene_txds])
    end = max([txd.tx.stop for txd in gene_txds])
    tblwr.write_feature([(start, end)], tx.strand, "gene")

    # locus tag is ${locus_tag_prefix}_${gene_name}
    tblwr.write_qualifier("locus_tag", '{}_{}'.format(locus_tag_prefix, get_gene_locus_tag(locus_tag_prefix, txd)))

    # if this is a pseudogene, record the pseudogene tag on the gene-level feature
    if txa.gene_biotype in pseudo_map:
        tblwr.write_qualifier("pseudogene", pseudo_map[txa.gene_biotype])
    elif unitary_pseudo:
        tblwr.write_qualifier('pseudogene', "unitary")
    
def convert_gene_transcripts(locus_tag_prefix, gene_txds, tblwr):
    trans_num = 1 
    num_with_cds = 0
    num_unitary_pseudo = 0
    for txd in gene_txds:
        has_cds, unitary_pseudo = convert_transcript(locus_tag_prefix, txd, trans_num, tblwr)
        trans_num += 1
        if has_cds:
            num_with_cds += 1
        if unitary_pseudo:
            num_unitary_pseudo += 1
    return num_with_cds, num_unitary_pseudo

def convert_gene(locus_tag_prefix, gene_txds, tblwr):
    num_with_cds, num_unitary_pseudo = convert_gene_transcripts(locus_tag_prefix, gene_txds, tblwr)
    tblwr.push()  # save transcripts
    all_coding_unitary_pseudo = (num_with_cds > 0) and (num_unitary_pseudo == num_with_cds)
    write_gene_features(locus_tag_prefix, gene_txds, all_coding_unitary_pseudo, tblwr)
    tblwr.flush()  # write genes
    tblwr.pop()
    tblwr.flush()  # write transcripts

def convert_chrom(locus_tag_prefix, chrom, chrom_txds, tblwr):
    """write features for a chromosome"""
    tblwr.start_seq(chrom)
    tblwr.flush()
    for gene_txds in chrom_txds.values():
        convert_gene(locus_tag_prefix, gene_txds, tblwr)

def cat_to_ncbi_tbl(locus_tag_prefix, txd_by_chrom_gene, tblwr):
    for chrom in txd_by_chrom_gene.keys():
        convert_chrom(locus_tag_prefix, chrom, txd_by_chrom_gene[chrom], tblwr)
        
def main(args):
    txd_by_chrom_gene = load_annotations(args.cat_genepred, args.cat_genepred_info)
    with open(args.ncbi_tbl_file, "w") as tblwr:
        cat_to_ncbi_tbl(args.locus_tag_prefix, txd_by_chrom_gene, NcbiTblWriter(tblwr))


main(parse_args())
