#!/usr/bin/env python

"""
Monophyly test for OMA Standalone paper
1) Check how many samples we need by picking the method with the minimum number of groups for each OG size
2) Pick the groups with the highest overlap from the other methods
3) Copy groups into separate folder

Then we can do following steps (not in this script)
1) Calculate monophyly score for each group
2) Find average score - one score for each method
"""

"""
This is hardcoded for picking groups of at least 15 species (from the 30)
- for the OMA Standalone paper
"""

from functools import lru_cache
import sys
import os


from random import shuffle

from Bio import SeqIO

from collections import defaultdict

import csv


SEQ_DIR='/SAN/biosciences/oma/research/jeremy_levy/OMAStandalone/Sequences/'
TREE_DIR='/home/jerelevy/oma_research/jeremy_levy/OMAStandalone/GeneTrees'


"""
First find how many species are present in each group
"""

"""
Given the sequence directory, return a dictionary of all the sequences for each group size
"""

def get_dict_list(SEQ_DIR,TREE_DIR):
	method_list = get_method_list(TREE_DIR)
	dict_list=[]
	for method in method_list:
		seq_directory = os.path.join(SEQ_DIR,method)
		seq_dict = get_seq_dict(seq_directory)
		dict_list.append([method,seq_dict])
	return dict_list


def get_method_list(gene_tree_dir):
	method_list = [method for method in os.listdir(gene_tree_dir)]
	return method_list



def get_seq_dict(seq_directory):
	seq_dict = defaultdict(list)
	file_list = [os.path.join(seq_directory,fn) for fn in os.listdir(seq_directory) if fn.endswith(".out")]
	for filename in file_list:
		num_seqs = check_num_sp(filename)
		seq_dict[num_seqs].append(filename)
	return seq_dict



"""
Counts the number of species present in an MSA
"""
def check_num_sp(group):
	species_list=[]
	for record in SeqIO.parse(group, "fasta"):
		species_list.append(record.name)
	return len(species_list)




#########################################

"""
Find out how many species we need for each group size
Find this by counting minimum number of species for each group
"""


def get_min_count_dict(dict_list):
	minimum_count_dict = defaultdict(list)
	for i in range(15,31,1):
		groupsize=i
		method,min_group_size = find_minimum_groups(dict_list,groupsize)
		minimum_count_dict[groupsize] = [method,min_group_size]
	return minimum_count_dict



"""
Given list of the dictionaries which contain the sequences
Find the minimum number for a given size
"""

def find_minimum_groups(dict_list,groupsize):
	minimum = 1000
	for pair in dict_list:
		dictionary = pair[1]
		if dictionary[groupsize]:
			num_species = len(dictionary[groupsize])
		else:
			num_species=0
		if num_species < minimum:
			minimum = num_species
			method = pair[0]
	return method,minimum

###########################################




"""
Now we want to get seqs for each group size
"""
def get_seqs_certain_size(dict_list,sub_method,groupsize):
	for pair in dict_list:
		method = pair[0]
		if method==sub_method:
			tmplist = pair[1]
			seqs = pair[1][groupsize]
		else:
			continue
	return seqs


"""
This function checks which species are present in a group
- uses SeqIO
"""
@lru_cache(1500)
def check_species_presence(group):
	species_list=[]
	for record in SeqIO.parse(group, "fasta"):
		species_list.append(record.name)
	return frozenset(species_list)


"""
This function checks percentage overlap of two lists
"""
def find_percent_overlap(set1,set2):
	common_sp = set1 & set2
	percent = float(len(common_sp) / len(set1))
	return percent

dict_list = get_dict_list(SEQ_DIR,TREE_DIR)
minimum_count_dict = get_min_count_dict(dict_list)

output = sys.argv[1]

with open(output,'w') as out_fp:
	tsv_writer = csv.writer(out_fp,delimiter="\t")
	#Now cycle through each groupsize
	for i in range(15,31,1):
		groupsize = i
		for pair in dict_list:
			method = pair[0]
			dictionary = pair[1]
			for seq in dictionary[groupsize]:
				row = [method, groupsize, seq]
				tsv_writer.writerow(row)
































































