#!/usr/bin/env python

"""
From list of trees (PartitionFinder list), calculates all vs all RF distances
"""

from ete3 import Tree
import os
import sys

import pandas as pd

import csv

import statistics

from collections import defaultdict
import itertools



def count_internal(tree):
	tree.unroot()
	edges=-1
	for edge in tree.traverse():
		if not edge.is_leaf():
			edges+=1
	return edges


def rf_distance(tree1,tree2,option=False):
	t1 = Tree(tree1)
	t2 = Tree(tree2)
	t1.unroot()
	t2.unroot()

	leaves1 = [sp.name for sp in t1.get_leaves()]
	leaves2 = [sp.name for sp in t2.get_leaves()]

	common_leaves = set(leaves1) & set(leaves2)
	t1.prune(common_leaves)
	t2.prune(common_leaves)

	t1.unroot()
	t2.unroot()

	rf = t1.robinson_foulds(t2,unrooted_trees=True)

	rf_dist = rf[0]
	max_rf = rf[1]
	num_leaves=len(rf[2])

	max_resolved_score = (2*num_leaves)-6

	internal_branches_t1 = count_internal(t1)
	internal_branches_t2 = count_internal(t2)

	total_internal = internal_branches_t1 + internal_branches_t2

	num_missing_splits_t1 = num_leaves - 3 - internal_branches_t1
	num_missing_splits_t2 = num_leaves - 3 - internal_branches_t2




	rf_dist_upper = rf_dist + num_missing_splits_t1 + num_missing_splits_t2


	#If normalise by (num intenral branches)
	if option=='reduced':
		normalised_rf = rf_dist/total_internal

	elif option=='upper':
	#If add score to create upper bound due to being polytomy
		normalised_rf = rf_dist_upper/max_resolved_score

	else:
		normalised_rf = rf_dist/max_resolved_score

	return normalised_rf








METHOD = "HaMStR"


FILENAME = '/SAN/biosciences/oma/research/jeremy_levy/OMAStandalone/Scripts/GENE_TEST_AUG2018/ALLGroups_GENETREES_OMAID.tsv'
OUTPUT_ALL = '/SAN/biosciences/oma/research/jeremy_levy/OMAStandalone/Scripts/GENE_TEST_AUG2018/RF/All/{}/{}_NormalisedRF.tsv'.format(METHOD,METHOD)





df_sample = pd.read_csv(FILENAME,sep="\t")

tree_dict = defaultdict(list)
for index,row in df_sample.iterrows():
	method = row['Method']
	tree = row['Tree']
	if tree=="":
		continue
	else:
		tree_dict[method].append(tree)


with open(OUTPUT_ALL,'w') as out_fp_all:
	tsv_writer_all = csv.writer(out_fp_all,delimiter="\t")
	row_all = ["Method","Tree1","Tree2","NormalisedRF"]
	tsv_writer_all.writerow(row_all)
	treelist = tree_dict[METHOD]
	rf_standard=[]
	rf_reduced=[]
	rf_upper=[]
	for pair in itertools.combinations(treelist,2):
		rf_standard_score = rf_distance(pair[0],pair[1]) 
		rf_standard.append(rf_standard_score)
		rf_reduced_score = rf_distance(pair[0],pair[1],'reduced') 
		rf_reduced.append(rf_reduced_score)
		rf_upper_score = rf_distance(pair[0],pair[1],'upper')
		rf_upper.append(rf_upper_score)
		row_all = [method,pair[0],pair[1],rf_upper_score]
		tsv_writer_all.writerow(row_all)




























