#!/usr/bin/env python

"""
Calculates monophyly score given a list of PartitionPaper trees



Output into table
"""



import sys
import os

from collections import defaultdict
import csv

from ete3 import Tree
import statistics

import pandas as pd

FILENAME = sys.argv[1]

CLADE_FILE = sys.argv[2]
OUTGROUP = sys.argv[3]

OUTPUT = sys.argv[4]


def find_clade_classifications(filename):
	clade_dict = defaultdict(list)
	with open(filename,'r') as in_fp:
		line = in_fp.readline()
		while(line):
			line = line.rstrip()
			line = line.split()
			clade = line[0]
			species = set(line[1:])
			clade_dict[clade] = species
			line = in_fp.readline()
	return clade_dict


def find_monophyly(clade,tree):
    for node in tree.traverse():
    # No others, All in.
        leaves_at_node = set(node.get_leaf_names())
        support = node.support
        if len(leaves_at_node - clade) == 0 and len(leaves_at_node) > 1:
            yield len(leaves_at_node & clade),support



def get_monophyly_score(tree_file,clade_file,outgroup):
	clade_dict = find_clade_classifications(clade_file)
	clade_order = sorted(clade_dict.keys())
	tree = Tree(tree_file)
	leaves = tree.get_leaves()
	leaf_list=[]
	for leaf in leaves:
		leaf_list.append(leaf.name)
	if outgroup in leaf_list:
		tree.set_outgroup(outgroup)
	else:
		alt = ['AMPQE','TRIAD','HYDM','CIOIN','STRPU','SACKO','HUMAN','ACYPI','CAEEL','PRIPA','DROME']
		for outgroup in alt:
			if outgroup in leaf_list:
				tree.set_outgroup(outgroup)
	weighted_scores=[]
	ave_scores=[]
	total_sp=0
	for clade in clade_order:
		clade_list = clade_dict[clade]
		max_score = len(clade_dict[clade])
		total_sp+=max_score
		#This returns the max score and the support for that clade
		mono_scores_support = list(find_monophyly(clade_list,tree))
		if max_score>1:
			denominator = max_score-1
			if len(mono_scores_support)>0:
				score_support = max(mono_scores_support)
				score = score_support[0]-1
				support = score_support[1]
				if (support > 1):
					support = support/100
				numerator = score*support
			else:
				numerator=0
			monophyly = numerator/denominator
			weighted = monophyly*max_score
			ave_scores.append(monophyly)
			weighted_scores.append(weighted)
	mono_score = statistics.mean(ave_scores)
	weighted_score = sum(weighted_scores)/total_sp
	return mono_score,weighted_score





df = pd.read_csv(FILENAME,sep="\t")

with open(OUTPUT,'w') as out_fp:
	tsv_writer = csv.writer(out_fp,delimiter="\t")
	row = ["Method","Tree","Monophyly","WeightedMonophyly"]
	tsv_writer.writerow(row)
	for index,row in df.iterrows():
		method = row["Method"]
		tree = row["Tree"]
		mono_score,weighted_score = get_monophyly_score(tree,CLADE_FILE,OUTGROUP)
		row = [method,tree,mono_score,weighted_score]
		tsv_writer.writerow(row)










