#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 30 22:28:42 2021

@author: christospapadopoulos
"""


import os,sys,random,re
from Bio.Seq import Seq
from Bio import SeqIO



def read_multiFASTA(fasta_file):
    dico = {}
    with open(fasta_file,'r') as fasta:
        for line in fasta:
            if line.startswith('>'):
                name = str(line.split()[0])[1:]
                dico[name] = ''
            elif line == '\n':
                continue
            else:
                seq = line.strip()
                dico[name] = dico[name] + seq
    return(dico) 

def find_all(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]

def repare_STOPs(AA,nt):
    stops = find_all(ch="*", s=AA)
    
    if stops == []:
        return(nt,AA)
    
    else:
        codons = [nt[i:i + 3] for i in range(0, len(nt), 3)]
        for j in stops:
            my_codon = codons[j]
            while my_codon in ["TAA","TAG","TGA"]:
                my_codon = "".join(random.sample(my_codon,k=len(my_codon)))
            codons[j] = my_codon
            
        return "".join(codons) , str(Seq.translate("".join(codons)))

def calculate_NT_freqs(fasta):
    seq_cum = ""
    for i in fasta:
        seq_cum = seq_cum + fasta[i]
        
    freq_A = round(seq_cum.count("A")/len(seq_cum),3)
    freq_T = round(seq_cum.count("T")/len(seq_cum),3)
    freq_C = round(seq_cum.count("C")/len(seq_cum),3)
    freq_G = round(seq_cum.count("G")/len(seq_cum),3)
    
    return([freq_A,freq_T,freq_C,freq_G])            
    

fasta_file = sys.argv[sys.argv.index("-fasta")+1]
fasta = read_multiFASTA(fasta_file)

outputs   = sys.argv[sys.argv.index("-out")+1]

nt_freqs = calculate_NT_freqs(fasta = fasta)

print("A : {} | T : {} | C : {} | G : {} | ----> {} %GC".format(str(nt_freqs[0]),
                                                                str(nt_freqs[1]),
                                                                str(nt_freqs[2]),
                                                                str(nt_freqs[3]),
                                                                str(nt_freqs[2] + nt_freqs[3])))

with open(outputs+".nfasta","w") as fn ,open(outputs+".pfasta","w") as fp:
    for i in fasta:
        randomized_nt = ''.join(random.sample(fasta[i],k=len(fasta[i])))
        randomized_aa = str(Seq.translate(randomized_nt))
        stop_freq = round(randomized_aa.count("*") / len(randomized_aa),2)
        
        randomized_nt  , randomized_aa = repare_STOPs(AA= randomized_aa,nt=randomized_nt )
    
        if len(randomized_aa) >= 20:
            fp.write(">{}_random\n{}\n".format(i,randomized_aa))
            fn.write(">{}_random\n{}\n".format(i,randomized_nt))
        #print(">{}_random\n{}".format(i,randomized_aa ))
    

nt_freqs = calculate_NT_freqs(fasta = read_multiFASTA(outputs+".nfasta"))

print("A : {} | T : {} | C : {} | G : {} | ----> {} %GC".format(str(nt_freqs[0]),
                                                                str(nt_freqs[1]),
                                                                str(nt_freqs[2]),
                                                                str(nt_freqs[3]),
                                                                str(nt_freqs[2] + nt_freqs[3])))

