#!/usr/bin/python
# -*- coding: utf-8 -*-

# Copyright or Copr. Centre National de la Recherche Scientifique (CNRS) (2018)
# Contributors:
# - Carine Rey <carine.rey@ens-lyon.org>

# This software is a computer program whose purpose is to provide a set of scripts for pre and post processing of data for
# convergence detection programs.

# This software is governed by the CeCILL-C license under French law and abiding by the rules of distribution of free software.
# You can use, modify and/ or redistribute the software under the terms of the CeCILL-C license as circulated by CEA, CNRS and
# INRIA at the following URL "http://www.cecill.info".

# As a counterpart to the access to the source code and rights to copy, modify and redistribute granted by the license, users
# are provided only with a limited warranty and the software's author, the holder of the economic rights, and the successive
# licensors have only limited liability.

# In this respect, the user's attention is drawn to the risks associated with loading, using, modifying and/or developing or
# reproducing the software by the user in light of its specific status of free software, that may mean that it is complicated
# to manipulate, and that also therefore means that it is reserved for developers and experienced professionals having in-depth
# computer knowledge. Users are therefore encouraged to load and test the software's suitability as regards their requirements
# in conditions enabling the security of their systems and/or data to be ensured and, more generally, to use and operate it in
# the same conditions as regards security.

# The fact that you are presently reading this means that you have had knowledge of the CeCILL-C license and that you accept
# its terms.


import argparse
import sys
import os, re
import logging

from Bio import AlignIO, SeqIO
from Bio.Seq import Seq
from Bio.Alphabet import IUPAC
from Bio.SeqRecord import SeqRecord
from Bio.Align import MultipleSeqAlignment

#===================================================================================================
# inputs
#===================================================================================================

### Option defining
parser = argparse.ArgumentParser(prog="back_translate_ali.py",
                                 description='')
parser.add_argument('--version', action='version', version='%(prog)s 1.0')
parser.add_argument('--debug', action="store_true",
                    help="debug mode",
                    default=False)
##############
requiredOptions = parser.add_argument_group('REQUIRED OPTIONS')
requiredOptions.add_argument('-a', "--aa", type=argparse.FileType('r'),
                             help='AA Alignment fasta filename to back translate', required=True)
requiredOptions.add_argument('-n', "--nt", type=argparse.FileType('r'),
                             help='NT fasta filename to help the back translation', required=True)
requiredOptions.add_argument('-o', '--output', type=str,
                   help="Output name", required=True)

##############


### Option parsing
args = parser.parse_args()

AaAliFile = args.aa
NtFile = args.nt

OutFile = args.output


#===================================================================================================
# Set up output directory and logger
#===================================================================================================
### Set up the logger
# create logger
logger = logging.getLogger("back_translate_ali")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
if args.debug:
    ch.setLevel(logging.DEBUG)
else:
    ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter_ch = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter_ch)
logger.addHandler(ch)

logger.info(sys.argv)


#===================================================================================================
# Read input aa alignment
#===================================================================================================
try:
    ali = AlignIO.read(AaAliFile, "fasta")
    aa_ali_d = {}
    for s in ali:
        aa_ali_d[s.id] = s
except Exception as exc:
    logger.error(str(exc))
    sys.exit(1)

logger.info("Ali aa (%s) ok after checking", AaAliFile.name)

nb_sites = ali.get_alignment_length()



#===================================================================================================
# Read input nt alignment
#===================================================================================================
import time
try:
    start=time.time()
    nt_seq_d = SeqIO.to_dict(SeqIO.parse(NtFile, "fasta"))
    end=time.time()-start
    logger.info("Loading time: %s", end)
except Exception as exc:
    logger.error(str(exc))
    sys.exit(1)
logger.info("Sequences nt (%s) ok after checking", NtFile.name)


#===================================================================================================
# Create output file
#===================================================================================================

from Bio.Data import CodonTable
standard_table = CodonTable.ambiguous_dna_by_name["Standard"]
#logger.debug(dir(standard_table.forward_table))
standard_table.forward_table.forward_table["---"]="-"
standard_table.forward_table.forward_table["NNN"]="X"

def back_translate(aa_seq, nt_seq, alphabet = "Standard"):
    nt_ali_seq = []
    i_nt = 0
    for i_aa in range(len(aa_seq)):
        if aa_seq[i_aa] == "-":
            nt_ali_seq.append("---")
        else:
            nt_ali_seq.append(nt_seq[i_nt:i_nt+3])
            i_nt += 3

    nt_ali_seq = "".join(nt_ali_seq)
    if "N" in nt_ali_seq:
        i_nt = 0
        nt_ali_seq2 = []
        for i_aa in range(len(aa_seq)):
            codon = nt_ali_seq[i_nt:i_nt+3]
            if "N" in codon:
                nt_ali_seq2.append("NNN")
            else:
                nt_ali_seq2.append(codon)
            i_nt += 3
        nt_ali_seq2 = "".join(nt_ali_seq2)
    else:
        nt_ali_seq2 = nt_ali_seq

    if alphabet:
        #check back translation
        aa_nt_ali_seq = Seq(nt_ali_seq2,alphabet=IUPAC.ambiguous_dna).translate(table=standard_table)
        aa_ok = 0
        
        if str(aa_nt_ali_seq) == str(aa_seq):
            logger.debug("Back_translation is okay ! Houra °\_o_/° !")
            aa_ok = len(aa_seq)
        else:
            for i_aa in range(len(aa_seq)):
                if aa_seq[i_aa] == aa_nt_ali_seq [i_aa]:
                    aa_ok += 1
                elif aa_seq[i_aa] == "X":
                    aa_ok += 1
                else:
                    logger.debug("%s %s %s ", i_aa, aa_seq[i_aa],aa_nt_ali_seq [i_aa])
        logger.debug("%s %s ", aa_ok, len(aa_seq))
        if aa_ok != len(aa_seq):
            logger.error("Error in the back_translation, check alphabet !")
            logger.error(str(aa_nt_ali_seq))
            logger.error(str(aa_seq))
            for i_aa in range(len(aa_seq)):
                if aa_seq[i_aa] != aa_nt_ali_seq [i_aa]:
                    logger.error("%s %s %s ", i_aa, aa_seq[i_aa],aa_nt_ali_seq [i_aa])
            #logger.error("%s is not present in %s", name, NtFile.name)
            sys.exit(1)
    return("".join(nt_ali_seq))

out_seq_l = []
for name in aa_ali_d.keys():
    logger.debug("Proccess sequence: %s", name)
    try:
       aa_record = aa_ali_d[name]
       aa_ali_seq=str(aa_ali_d[name].seq)
       nt_seq=str(nt_seq_d[name].seq)
    except Exception as exc:
       logger.error(str(exc))
       logger.error("Ici")
       sys.exit(42)
    try:
       nt_ali_seq = back_translate(aa_ali_seq, nt_seq)
       record = SeqRecord(Seq(nt_ali_seq),id=aa_record.id,description=aa_record.description)
       out_seq_l.append(record)
    except Exception as exc:
       logger.error("nt: %s", nt_seq)
       logger.error("aa: %s", aa_ali_seq)
       logger.error(str(exc))
       logger.error("La")
       #logger.error("%s is not present in %s", name, NtFile.name)
       sys.exit(42)

ali_nt=MultipleSeqAlignment(out_seq_l)
logger.info("Write ouptut in %s (%s seq, %s pos)", OutFile, len(ali_nt),ali_nt.get_alignment_length())

try:
    SeqIO.write(ali_nt, OutFile, "fasta")
except Exception as exc:
    logger.error(str(exc))
    sys.exit(42)
