# to make a dataset with probabilistic ancestral reconstructions of gene presence/absence
# from GLOOME output
from collections import defaultdict
import sys
import re
import math

#print 'i have been opened and run'
data_dir = sys.argv[2]

parent_maps = open(data_dir+'/TheTree.INodes.ph.parentage').readlines()

parentage = {}

for map in parent_maps:
	if map == '\n':
		continue
	fields = map.strip().split()
#	print map
	parentage[fields[1]] = fields[0]

# important, or else you get a key error
parentage['N1'] = 'NA'

print 'opening reconstruction'
anc = open(sys.argv[1])

# sadly, i couldn't come up with anything better than a nested dict in 15 minutes
# to hold all the genome content info across species

# coding for branches:
# 1 = absent both
# 2 = present both
# 3 = 'gain'
# 4 = 'loss'
# 5 = other

print 'mapping to ancestors'
genomes = defaultdict(list)
ancestors = defaultdict(list)
positions = []	# really unique set of kos
for ko in anc:
#	print ko
	if ko.startswith('POS'):
		continue
	else:
		fields = ko.split('\t')
#		print fields
#		genomes[fields[1]][int(fields[0])] = fields[3].strip()
		genomes[fields[1]].append( fields[3].strip() )
		positions.append(int(fields[0]))
#		ancestors[parentage[fields[1]]][int(fields[0])] = fields[3].strip()
		ancestors[parentage[fields[1]]].append( fields[3].strip() )
#		if (len(ancestors) % 128) == 0:
#			print len(ancestors)
anc.close()



metafile = sys.argv[3]
#pos_maps = open('gainLoss_results'+'/'+data_dir + '/' + metafile).readlines()
pos_maps = open(data_dir + '/' + metafile).readlines()

# print header of positions (kos)
poss = sorted(set(positions))
allkos = [pos_maps[int(ko)-1].strip() for ko in poss]

# MODIFIED to grab states both ancestor and descendant nodes for each branch (labeling each branch
# with the name of the descendant node)- to ultimately label each branch for each gene with
# a presence/absence label describing the various possible configurations of presence/absence (see above)

print 'writing'
out_pres = open(sys.argv[1] + '.pres_probs','w')
out_abs = open(sys.argv[1] + '.abs_probs','w')
out_gain = open(sys.argv[1] + '.gain_probs','w')
out_loss = open(sys.argv[1] + '.loss_probs','w')

out_pres.write('\t'+'\t'.join(allkos)+'\n')
out_abs.write('\t'+'\t'.join(allkos)+'\n')
out_gain.write('\t'+'\t'.join(allkos)+'\n')
out_loss.write('\t'+'\t'.join(allkos)+'\n')

genomes['NA'] = 'NA'

#print '\t'+'\t'.join(allkos)

# function to calculate relevant probabilities
def probability(anc,desc):
	ances = float(anc)
	descen = float(desc)
	out_dict = {}
	out_dict['present'] = ances*descen
	out_dict['absent'] = (1-ances)*(1-descen)
	out_dict['gain'] = (1-ances)*descen
	out_dict['loss'] = ances*(1-descen)
	return(out_dict)

print 'writing probs'
counter = 0
for org in sorted(genomes.keys()):
	counter += 1
	if (counter % 128) ==0:
		print counter
	if org == 'NA' or parentage[org] =='NA':
		continue
	genome = genomes[org]
	anc_genome = genomes[parentage[org]]
	pres_prof = org
	abs_prof = org
	gain_prof = org
	loss_prof = org

#	print org
#	print parentage[org]
#	if org=='N73':
		#print parentage[org]
		#print anc_genome
		#print genome

	# making into a list --> better?
#	for ko in sorted(genome.keys()):
	for ko in range(len(genome)):
		prob_dict = probability(anc_genome[ko],genome[ko])
		pres_prof += '\t'+str(prob_dict['present'])
		abs_prof += '\t'+str(prob_dict['absent'])
		gain_prof += '\t'+str(prob_dict['gain'])
		loss_prof += '\t'+str(prob_dict['loss'])

	out_pres.write(pres_prof+'\n')
	out_abs.write(abs_prof+'\n')
	out_gain.write(gain_prof+'\n')
	out_loss.write(loss_prof+'\n')

print 'finished parsing'
out_pres.close()
out_abs.close()
out_gain.close()
out_loss.close()
