#!/usr/bin/env python

"""
Give root directory: enters into subdirectory
- picks 50 OG at random with no repeats
- concatenates
- writes concatenated sequence (and .txt) in out directory
"""

from numpy import random
import os


from Bio import SeqIO, AlignIO
from Bio.Seq import Seq, UnknownSeq
from Bio.Align import MultipleSeqAlignment
from Bio.Alphabet import IUPAC, Gapped
from Bio.SeqRecord import SeqRecord

import sys

DIR = "/SAN/biosciences/oma/research/jeremy_levy/OMAStandalone/Sequences"
#OUT_DIR = "/SAN/biosciences/oma/research/jeremy_levy/OMAStandalone/Subsamples/400_OG/Alignments"

OUT_DIR = sys.argv[1]

#NUM_OG = 400

NUM_OG = int(sys.argv[2])

NUM_SAMPLES = 50
MIN_SIZE = 15
MIN_TAXA = 30


def get_all_seqs(directory,MIN_SIZE):
	file_list = [os.path.join(directory,fn) for fn in os.listdir(directory) if fn.endswith(".out")]
	filtered_files=[]
	for filename in file_list:
		seqsize = count_seqs(filename)
		if seqsize < MIN_SIZE:
			continue
		else:
			filtered_files.append(filename)
	return filtered_files




def concatenate(msa_list,all_species):
	tmp_dict = {}

	def append_seq(species, seq):
		if species in tmp_dict:
			tmp_dict[species] += seq
		else:
			tmp_dict[species] = seq
	
	for msa in msa_list:
		species_list=set()
		for sequence in SeqIO.parse(msa,'fasta'):
			species_list.add(sequence.id)
			append_seq(sequence.id, sequence.seq)
		seq_length = len(sequence.seq)
		unknown_seq = UnknownSeq(seq_length, alphabet=sequence.seq.alphabet)
		missing_species = all_species - species_list
		for species in missing_species:
			append_seq(species, unknown_seq)
	return MultipleSeqAlignment(SeqRecord(v, description=k, id=k) for (k, v) in tmp_dict.items())

"""
Requires SeqIO
"""
def count_seqs(filename):
	return sum(1 for _ in SeqIO.parse(filename,'fasta'))

def find_all_species(msa_list):
	all_seqs=set()
	for msa in msa_list:
		for seq in SeqIO.parse(msa,'fasta'):
			all_seqs.add(seq.id)
	return all_seqs


def convert_to_nex(filename,start_format,end_format,output_extension):
#find extension
	for i in range(len(filename)):
		if filename[i] == '.':
			start_index = i
	length_extension = len(filename) - start_index - 1
	input_handle = open(filename, "r")
	output_handle = open(filename[:-length_extension]+output_extension, "w")
	alignments = AlignIO.parse(input_handle, start_format,alphabet=Gapped(IUPAC.protein))
	AlignIO.write(alignments, output_handle, end_format)
	output_handle.close()
	input_handle.close()





def write_ogfile(subsample_og_list,output):
	with open(output+".txt",'w') as out_fp:
		for filename in subsample_og_list:
			out_fp.write(filename)
			out_fp.write("\n")



def create_supermatrix(subsample_og_list,output):
	all_species = find_all_species(subsample_og_list)
	concatenated_seqs = concatenate(subsample_og_list,all_species)
	AlignIO.write(concatenated_seqs, output + '.phy', 'phylip-relaxed')
	#AlignIO.write(msa, output + '.nex', "nexus")
	AlignIO.write(concatenated_seqs, output + '.fas', 'fasta')
	convert_to_nex(output + '.fas','fasta','nexus','nex')



def count_taxa(og_list):
	full_taxa_list=[]
	for og in og_list:
		taxa_list = [record.id for record in SeqIO.parse(og, "fasta")]
		full_taxa_list += taxa_list
	taxa_set = set(full_taxa_list)
	return len(taxa_set)


for method in os.listdir(DIR):
	method_dir = os.path.join(DIR,method)
	file_list = get_all_seqs(method_dir,MIN_SIZE)
	if len(file_list) < NUM_OG:
		continue
	method_out = os.path.join(OUT_DIR,method)
	if not os.path.exists(method_out):
		os.makedirs(method_out)
	for SAMPLE_NUM in range(1,NUM_SAMPLES+1):
		num_taxa = -1
		while(True):
			if num_taxa < MIN_TAXA:
				subsample_og_list = random.choice(file_list,NUM_OG,replace=False)
				num_taxa = count_taxa(subsample_og_list)
				print("Method: {}".format(method))
				print("Sample Number: {}".format(SAMPLE_NUM))
				print("Taxa Coverage: {}".format(num_taxa))
			else:
				break
		output = "{}_Supermatrix_{}".format(method,SAMPLE_NUM)
		output = os.path.join(method_out,output)
		create_supermatrix(subsample_og_list,output)
		write_ogfile(subsample_og_list,output)









