#!/usr/bin/env python

"""
Finds the normalised RF distances between all trees in a list
all vs all
"""

import os
import sys
import dendropy
from dendropy.calculate import treecompare

import csv

##################################################

"""
takes the filename of the two trees and finds their euclidian distance
"""
def find_normalised_RF(tree1,tree2):
	tns = dendropy.TaxonNamespace()
	tmp_tree1 = dendropy.Tree.get_from_path(tree1,"newick",taxon_namespace=tns)
	tree1_labels = []
	for taxa in list(tmp_tree1.taxon_namespace):
		tree1_labels.append(taxa.label)

	tns = dendropy.TaxonNamespace()
	tmp_tree2 = dendropy.Tree.get_from_path(tree2,"newick",taxon_namespace=tns)
	tree2_labels = []
	for taxa in list(tmp_tree2.taxon_namespace):
		tree2_labels.append(taxa.label)


	tmp_tree1.retain_taxa_with_labels(tree2_labels)
	tmp_tree2.retain_taxa_with_labels(tree1_labels)

	tmp_str1 = tmp_tree1.as_string("newick")
	tmp_str2 = tmp_tree2.as_string("newick")

	# FIX THISSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS
	tns = dendropy.TaxonNamespace()
	tmp_tree1 = dendropy.Tree.get(data = tmp_str1,schema = "newick",taxon_namespace=tns)
	tmp_tree2 = dendropy.Tree.get(data = tmp_str2,schema = "newick",taxon_namespace=tns)
	##############################################################################

	tmp_tree1.encode_bipartitions()
	tmp_tree2.encode_bipartitions()

	distance = dendropy.calculate.treecompare.unweighted_robinson_foulds_distance(tmp_tree1, tmp_tree2)

	#now we normalise the distance based on number of taxa
	#divide by 2(n-3)
	num_taxa = len(tmp_tree1.taxon_namespace)
	if num_taxa!=0:
		distance = float(distance) / (2 * (float(num_taxa)) - 3)
	else:
		distance = -1

	return distance
##################################################


"""
Returns list of all files given .txt with filenames in it
"""
def get_all_files(filename):
	file_list = []
	with open(filename,'r') as f:
		line = f.readline()
		while(line):
			line = line.rstrip()
			file_list.append(line)
			line = f.readline()
	return file_list


###################################################


def main(filename,output):
	tree_list = get_all_files(filename)
	with open(output,'w') as out_fp:
		tsv_writer = csv.writer(out_fp,delimiter="\t")
		row = ['Tree1','Tree2','NormalisedRFDistance']
		tsv_writer.writerow(row)
		for i in range(len(tree_list)-1):
			tree1 = tree_list[i]
			for j in range(i+1,len(tree_list),1):
				tree2 = tree_list[j]
				dist = find_normalised_RF(tree1,tree2)
				row = [tree1, tree2, dist]
				tsv_writer.writerow(row)


if __name__=="__main__":
	main(sys.argv[1],sys.argv[2])

