#!/usr/bin/env python

"""
From list of trees (PartitionFinder list), calculates all vs all RF distances
"""
import dendropy
from dendropy.calculate import treecompare
import os
import sys

import pandas as pd

import csv

import statistics

from collections import defaultdict
import itertools

FILENAME = sys.argv[1]
OUTPUT = sys.argv[2]
OUTPUT_ALL = sys.argv[3]



"""
takes the filename of the two trees and finds their euclidian distance
"""
def find_euclidian(tree1,tree2):
	tns = dendropy.TaxonNamespace()
	tmp_tree1 = dendropy.Tree.get(data=tree1,schema="newick",taxon_namespace=tns)
	tree1_labels = []
	for taxa in list(tmp_tree1.taxon_namespace):
		tree1_labels.append(taxa.label)

	tns = dendropy.TaxonNamespace()
	tmp_tree2 = dendropy.Tree.get(data=tree2,schema="newick",taxon_namespace=tns)
	tree2_labels = []
	for taxa in list(tmp_tree2.taxon_namespace):
		tree2_labels.append(taxa.label)


	tmp_tree1.retain_taxa_with_labels(tree2_labels)
	tmp_tree2.retain_taxa_with_labels(tree1_labels)

	tmp_str1 = tmp_tree1.as_string("newick")
	tmp_str2 = tmp_tree2.as_string("newick")

	# FIX THISSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS
	tns = dendropy.TaxonNamespace()
	tmp_tree1 = dendropy.Tree.get(data = tmp_str1,schema = "newick",taxon_namespace=tns)
	tmp_tree2 = dendropy.Tree.get(data = tmp_str2,schema = "newick",taxon_namespace=tns)
	##############################################################################

	tmp_tree1.encode_bipartitions()
	tmp_tree2.encode_bipartitions()

	distance = dendropy.calculate.treecompare.euclidean_distance(tmp_tree1, tmp_tree2)

	#now we normalise the distance based on number of taxa
	#divide by 2(n-3)
	num_taxa = len(tmp_tree1.taxon_namespace)

	distance = float(distance) / (2 * (float(num_taxa) - 3))

	return distance




df = pd.read_csv(FILENAME,sep="\t")
tree_dict = defaultdict(list)


for index,row in df.iterrows():
	method = row['Method']
	tree = row['Tree']
	tree_dict[method].append(tree)


with open(OUTPUT,'w') as out_fp, open(OUTPUT_ALL,'w') as out_fp_all:
	tsv_writer = csv.writer(out_fp,delimiter="\t")
	row = ["Method","Euclidean_mean","Euclidean_std"]
	tsv_writer.writerow(row)
	tsv_writer_all = csv.writer(out_fp_all,delimiter="\t")
	row_all = ["Method","Tree1","Tree2","Euclidean"]
	tsv_writer_all.writerow(row_all)
	for method in tree_dict:
		print(method)
		treelist = tree_dict[method]
		euclidean_dist=[]
		for pair in itertools.combinations(treelist,2):
			euclidean_score = find_euclidian(pair[0],pair[1]) 
			euclidean_dist.append(euclidean_score)
			row_all = [method,pair[0],pair[1],euclidean_score]
			tsv_writer_all.writerow(row_all)
		mean_standard = statistics.mean(euclidean_dist)
		std_standard = statistics.stdev(euclidean_dist)
		row = [method,mean_standard,std_standard]
		tsv_writer.writerow(row)





























