# to make a dataset with probabilistic ancestral reconstructions of gene presence/absence
# from GLOOME output
from collections import defaultdict
import sys
import re
import math
import gzip


parent_maps = open('gainLoss_results/MOtree_GLrun/TheTree.INodes.ph.parentage').readlines()

#parent_maps = open('/Users/maximilianpress/Dropbox/LGT_project/gain_loss_inference/parent_cicc_maps.txt').readlines()

parentage = {}
#nodes = []
#ancestors = []
for map in parent_maps:
	if map == '\n':
		continue
	fields = map.strip().split()
#	print map
	parentage[fields[1]] = fields[0]

parentage['N1'] = 'NA'
#	ancestors.append(fields[1])

#nodes = parentage.keys()
#nodes.sort()
#ancestors.sort()
#print len(ancestors)
#ancestors = set(ancestors)

anc = gzip.open(sys.argv[1],'rb')

# sadly, i couldn't come up with anything better than a nested dict in 15 minutes
# to hold all the genome content info across species

# code for branches:
# 1 = absent both
# 2 = present both
# 3 = 'gain'
# 4 = 'loss'
# 5 = other

genomes = defaultdict(dict)
ancestors = defaultdict(dict)
positions = []	# really unique set of kos
for ko in anc:
	if ko.startswith('POS'):
		continue
	else:
		fields = ko.split('\t')
#		print fields
		genomes[fields[1]][int(fields[0])] = fields[3].strip()
		positions.append(int(fields[0]))
		ancestors[parentage[fields[1]]][int(fields[0])] = fields[3].strip()
anc.close()
pos_maps = open('/Users/maximilianpress/Dropbox/LGT_project/gain_loss_inference/MOtree_GLrun/MOtree.GL.011514.meta').readlines()
#pos_maps = open('cicc_newkegg_msa.gl.meta').readlines()

# print header of positions (kos)
poss = sorted(set(positions))
#for ko in poss:
#	print ko
#	print pos_maps[ko].strip()
allkos = [pos_maps[int(ko)-1].strip() for ko in poss]
#allkos = [str(ko) for ko in poss]

# MODIFIED to grab states both ancestor and descendant nodes for each branch (labeling each branch
# with the name of the descendant node)- to ultimately label each branch for each gene with
# a presence/absence label describing the various possible configurations of presence/absence (see above)

#print sorted(genomes.keys())
#out_desc = open(sys.argv[1] + '.desc_profile','w')
#out_anc = open(sys.argv[1] + '.anc_profile','w')
# now step through genomes and print out genome content matrix 'kolist'
# maybe also make a presence/absence profile by majority vote??
#out_desc.write('\t'+'\t'.join(allkos)+'\n')
#out_anc.write('\t'+'\t'.join(allkos)+'\n')
#out_pres = open(sys.argv[1] + '.pres_probs','w')
#out_abs = open(sys.argv[1] + '.abs_probs','w')
#out_gain = open(sys.argv[1] + '.gain_probs','w')
#out_loss = open(sys.argv[1] + '.loss_probs','w')
out_anc = open(sys.argv[1] + '.anc_pres','w')

out_anc.write('\t'+'\t'.join(allkos)+'\n')
#out_abs.write('\t'+'\t'.join(allkos)+'\n')
#out_gain.write('\t'+'\t'.join(allkos)+'\n')
#out_loss.write('\t'+'\t'.join(allkos)+'\n')

genomes['NA'] = 'NA'

#print '\t'+'\t'.join(allkos)

# function to calculate relevant probabilities
def probability(anc,desc):
	ances = float(anc)
	descen = float(desc)
	out_dict = {}
	out_dict['present'] = ances*descen
	out_dict['absent'] = (1-ances)*(1-descen)
	out_dict['gain'] = (1-ances)*descen
	out_dict['loss'] = ances*(1-descen)
	return(out_dict)

for org in sorted(genomes.keys()):
#	print org
#	print len(genomes[org])
	if org == 'NA' or parentage[org] =='NA':
		continue
#	print parentage[org]
#	print genomes[parentage[org]]
	genome = genomes[org]
	anc_genome = genomes[parentage[org]]
#	pres_prof = org
#	abs_prof = org
#	gain_prof = org
#	loss_prof = org
	anc_prof = org

#	outstring = org
	for ko in sorted(genome.keys()):
#		print ko
#		_prof += '\t'+genome[ko]
		anc_prof += '\t'+anc_genome[ko]
#		prob_dict = probability(anc_genome[ko],genome[ko])
#		outstring += '\t'+classifier(anc_genome[ko],genome[ko])
#		pres_prof += '\t'+str(prob_dict['present'])
#		abs_prof += '\t'+str(prob_dict['absent'])
#		gain_prof += '\t'+str(prob_dict['gain'])
#		anc_prof += '\t'+str(prob_dict['loss'])

#	out_desc.write(phyl_prof+'\n')
	out_anc.write(anc_prof+'\n')
#	out_pres.write(pres_prof+'\n')
#	out_abs.write(abs_prof+'\n')
#	out_gain.write(gain_prof+'\n')
#	out_loss.write(loss_prof+'\n')
#out_desc.close()
out_anc.close()

#out_pres.close()
#out_abs.close()
#out_gain.close()
#out_loss.close()
