#!/usr/bin/env python

"""
Finds the RF distances between all trees in a directory
all vs all
"""

import os
import sys
import dendropy
from dendropy.calculate import treecompare
import itertools


import csv

##################################################

"""
takes the filename of the two trees and finds their euclidian distance
"""
def find_normalised_RF(tree1,tree2):
	tns = dendropy.TaxonNamespace()
	tmp_tree1 = dendropy.Tree.get_from_path(tree1,"newick",taxon_namespace=tns)
	tree1_labels = []
	for taxa in list(tmp_tree1.taxon_namespace):
		tree1_labels.append(taxa.label)

	tns = dendropy.TaxonNamespace()
	tmp_tree2 = dendropy.Tree.get_from_path(tree2,"newick",taxon_namespace=tns)
	tree2_labels = []
	for taxa in list(tmp_tree2.taxon_namespace):
		tree2_labels.append(taxa.label)


	tmp_tree1.retain_taxa_with_labels(tree2_labels)
	tmp_tree2.retain_taxa_with_labels(tree1_labels)

	tmp_str1 = tmp_tree1.as_string("newick")
	tmp_str2 = tmp_tree2.as_string("newick")

	# FIX THISSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS
	tns = dendropy.TaxonNamespace()
	tmp_tree1 = dendropy.Tree.get(data = tmp_str1,schema = "newick",taxon_namespace=tns)
	tmp_tree2 = dendropy.Tree.get(data = tmp_str2,schema = "newick",taxon_namespace=tns)
	##############################################################################

	tmp_tree1.encode_bipartitions()
	tmp_tree2.encode_bipartitions()

	distance = dendropy.calculate.treecompare.unweighted_robinson_foulds_distance(tmp_tree1, tmp_tree2)

	#now we normalise the distance based on number of taxa
	#divide by 2(n-3)
	num_taxa = len(tmp_tree1.taxon_namespace)

	distance = float(distance) / (2 * (float(num_taxa) - 3))

	return distance

##################################################


"""
Returns list of all files given .txt with filenames in it
"""
def get_tree_list(directory):
	treelist = [os.path.join(directory,tree) for tree in os.listdir(directory) if tree.endswith(".treefile") or "RAxML_bestTree" in tree]
	return treelist

###################################################





###################################################
def main(directory,output):
	tree_list = get_tree_list(directory)
	pairs = itertools.combinations(tree_list,2)
	with open(output,'w') as out_fp:
		tsv_writer = csv.writer(out_fp,delimiter="\t")
		row = ["Tree1","Tree2","NormalisedRF"]
		tsv_writer.writerow(row)
		for i,(t1,t2) in enumerate(pairs):
			distance = find_normalised_RF(t1,t2)
			tree1 = os.path.basename(t1)
			tree2 = os.path.basename(t2)
			row = [tree1,tree2,distance]
			tsv_writer.writerow(row)





if __name__=="__main__":
	directory = sys.argv[1]
	output = sys.argv[2]
	main(directory,output)

