#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 17 17:16:21 2021

@author: christospapadopoulos
"""

import os,sys,random,re
from Bio.Seq import Seq
from Bio import SeqIO

def read_multiFASTA(fasta_file):
    dico = {}
    with open(fasta_file,'r') as fasta:
        for line in fasta:
            if line.startswith('>'):
                name = str(line.split()[0])[1:]
                dico[name] = ''
            elif line == '\n':
                continue
            else:
                seq = line.strip()
                dico[name] = dico[name] + seq
    return(dico) 


def calculate_NT_freqs(fasta):
    seq_cum = ""
    for i in fasta:
        seq_cum = seq_cum + fasta[i]
        
    freq_A = round(seq_cum.count("A")/len(seq_cum),3)
    freq_T = round(seq_cum.count("T")/len(seq_cum),3)
    freq_C = round(seq_cum.count("C")/len(seq_cum),3)
    freq_G = round(seq_cum.count("G")/len(seq_cum),3)
    
    return([freq_A,freq_T,freq_C,freq_G])
        
def find_all(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]

def repare_STOPs(AA,nt):
    stops = find_all(ch="*", s=AA)
    
    if stops == []:
        return(nt,AA)
    
    else:
        codons = [nt[i:i + 3] for i in range(0, len(nt), 3)]
        for j in stops:
            my_codon = codons[j]
            while my_codon in ["TAA","TAG","TGA"]:
                my_codon = "".join(random.sample(my_codon,k=len(my_codon)))
            codons[j] = my_codon
            
        return "".join(codons) , str(Seq.translate("".join(codons)))
    

size_file = sys.argv[sys.argv.index("-size")+1]
size = read_multiFASTA(size_file)

freq_file = sys.argv[sys.argv.index("-freq")+1]
freq      = read_multiFASTA(freq_file)
nt_freqs  = calculate_NT_freqs(fasta = freq )

print(nt_freqs)

outputs   = sys.argv[sys.argv.index("-out")+1]

with open(outputs+".nfasta","w") as fn ,open(outputs+".pfasta","w") as fp:
    for seq in size:
        randomized_nt = ''.join(random.choices(population = ["A","T","C","G"], weights=nt_freqs, k=len(size[seq])))
        randomized_aa = str(Seq.translate(randomized_nt))
        
        randomized_nt  , randomized_aa = repare_STOPs(AA= randomized_aa,nt=randomized_nt )
        
        if len(randomized_aa) >= 20:
            fp.write(">{}_random\n{}\n".format(seq,randomized_aa))
            fn.write(">{}_random\n{}\n".format(seq,randomized_nt))
            #print(">{}_random\n{}".format(seq,randomized_aa))
    


print(calculate_NT_freqs(fasta = read_multiFASTA(outputs+".nfasta")))





