#!/usr/bin/env python

from Bio.Seq import Seq
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Align import MultipleSeqAlignment
from Bio.Seq import Seq
from Bio.Alphabet import  IUPAC
from Bio import AlignIO
import random
import os
import sys
import glob

out_site_garder=sys.argv[1] #fichier ou l'on va enregister les site que l'on garde
nb_rd_site_keep=int(sys.argv[2])
nb_rd_gene_keep=int(sys.argv[3])
output_conca_file=str(sys.argv[4])

#show argument who was giving
print("output file kept sites : %s" %(out_site_garder))
print("output concat: %s" %(output_conca_file))
print("number of sites to keep per gene: %s" %(nb_rd_site_keep))
print("number of genes to keep: %s" %(nb_rd_gene_keep))


#create list and dict who are use later
liste_site_garder=[]
dict_seq_conca={}

# looking for all fasta on the directory
fich_fasta=glob.glob('*.fasta')
if output_conca_file in fich_fasta:
    fich_fasta.remove(output_conca_file)

#take randomly 100 files
nb_fasta=len(fich_fasta)
if nb_fasta > nb_rd_gene_keep:
    fich_fasta_random = random.sample(fich_fasta, min(3*nb_rd_gene_keep,nb_fasta)) # 3*nb_rd_gene_keep fasta file select randomly more for check if long enough for after 

else:
    print("WARNING: less than %s fasta for sampling and concatenation" %(nb_rd_gene_keep))
    fich_fasta_random=fich_fasta

# Check enough species and pos
fich_fasta_random_size_dict = {}

list_sp_ok = []
##Detect the number of species
for fasta_file in fich_fasta_random:
    alignment = AlignIO.read(fasta_file, "fasta")
    length_ali=alignment.get_alignment_length()
    nb_sp = len(alignment)
    list_sp_ok.extend([str(record.id) for record in alignment])
    fich_fasta_random_size_dict[fasta_file]=[nb_sp,length_ali]

list_sp_ok=list(set(list_sp_ok))
nb_sp_ok = len(list_sp_ok)
print("\nNumber of detected species: %s" %(nb_sp_ok))

#try to remove bad fasta_file
list_sp_ok = []
fich_fasta_random_ok = []
fich_fasta_random_bad = []
for fasta_file in fich_fasta_random:
    if fich_fasta_random_size_dict[fasta_file][0] < nb_sp_ok:
        fich_fasta_random_bad.append(fasta_file)
    elif fich_fasta_random_size_dict[fasta_file][1] < nb_rd_site_keep:
        fich_fasta_random_bad.append(fasta_file)
    else:
        fich_fasta_random_ok.append(fasta_file)

if len(fich_fasta_random_ok) < nb_rd_gene_keep and fich_fasta_random_bad:
    print("Not enought good ali (nb_sp and nb pos) we will used not complete fasta")
    fich_fasta_random_ok.extend(random.sample(fich_fasta_random_bad, min(len(fich_fasta_random_bad), nb_rd_gene_keep-len(fich_fasta_random_ok))))

fich_fasta_random_final = random.sample(fich_fasta_random_ok, min(len(fich_fasta_random_ok), nb_rd_gene_keep)) # 100 fasta file select randomly


print("Nb of used genes %s/%s" %(len(fich_fasta_random_final),nb_rd_gene_keep))

#take randomly x sites in each fasta
for fich in fich_fasta_random_final:
    base=os.path.basename(fich)
    nom=os.path.splitext(base)[0]
    alignment = AlignIO.read(fich, "fasta")
    #length of aligment
    length_ali=alignment.get_alignment_length()
    nb_sp = len(alignment)
    #print("ali loading ok (%s sp, %s pos)" %(nb_sp,length_ali))

    #generate random number depend of what user want
    list_random_site=[]
    #check if this is more site to choose than site on the aligment
    if nb_rd_site_keep >= length_ali:
        print("warning for %s ali length < to number to keep so it's all keep " %(nom))
        list_random_site = range(0,length_ali)
    else:
        list_random_site = random.sample( range(0,length_ali), nb_rd_site_keep)

    list_random_site.sort()
    garder=[nom,  str(list_random_site) ]
    #add name of file and site we keep for check if needed
    liste_site_garder.append(garder)

    #subseq4ali=[]
    #open aligment for extract each  part on sequence and add it on dictionary
    sp_list = []
    for record in SeqIO.parse(fich, "fasta"):
        seq=list(record.seq)
        # it's extract for each sequence of fasta the same part of aligment
        new_seq=[x[1] for x in enumerate(seq) if x[0] in list_random_site]
        new_seq="".join(new_seq)
        nom_sp=str(record.id)
        sp_list.append(nom_sp)
        #add sequence in dictionary
        if nom_sp in dict_seq_conca:
            dict_seq_conca[nom_sp]+=new_seq
        else:
            dict_seq_conca[nom_sp]=new_seq

    #add empty seq if absent species
    if nb_sp < nb_sp_ok:
        list_absent_sp = [sp for sp in list_sp_ok if not sp in sp_list]
        for nom_sp in list_absent_sp:
            new_seq =  "-" * length_ali
            #add sequence in dictionary
            if nom_sp in dict_seq_conca:
                dict_seq_conca[nom_sp]+=new_seq
            else:
                dict_seq_conca[nom_sp]=new_seq


subseq4ali=[]
print("\nWrite concat in %s" %(output_conca_file))
#parse the doctionary ans create sequence for each specie concatenation
for sp in dict_seq_conca:
    sub_seq=SeqRecord(Seq(dict_seq_conca[sp] , IUPAC.protein), id=sp, description="")
    subseq4ali.append(sub_seq)
    print("%s: %s pos" %(sp,len(dict_seq_conca[sp])))

#create the aligment and write it in a file
align = MultipleSeqAlignment(subseq4ali, IUPAC.protein)
SeqIO.write(align, output_conca_file , "fasta")

#write file with the file and part of aligment who was use for concanetation
with open(out_site_garder, 'w') as out_file:
    for i in liste_site_garder:
        out_file.write(i[0] + "\t" + i[1] + "\n")
