#!usr/bin/Rscript
# simulate random positions for MO tree based on rates inferred by gainLoss (GL)
# 2/17/14 by max
# updated 2/22/14- now collects functions of simulating, postprocessing, and analyzing
# the distribution and/or relevance of various null distributions.

#data_dir = 'MOtree_GLrun'

require(ape)
require(MASS)

stopifnot(exists('working_dir'))
stopifnot(exists('data_dir'))

outfile_stub = 'rep10'
ngenes = 10000	# num genes to sim
transform_shape_gain = 3
transform_shape_loss = 1.5

# turned off the seed for making the whole 10K null, otherwise defeats the purpose.
#set.seed(311)	# what can i say, the 90s were my childhood.
treefile = file.path(data_dir,'TheTree.INodes.ph')
print(treefile)
tree = read.tree(treefile)	# GL-refitted MO tree

gainrates = read.table(file.path(data_dir,'gain4site.txt'),header=T,comment.char='#',sep='\t')	# gain rates from GL
lossrates = read.table(file.path(data_dir,'loss4site.txt'),header=T,comment.char='#',sep='\t')	# loss rates from GL
nrates = nrow(gainrates)

# ML fitting of gamma distribution to know how to sample rates
gainparams = fitdistr(gainrates[,'Rate'],'gamma')
lossparams = fitdistr(lossrates[,'Rate'],'gamma')

allgainrate = rgamma(ngenes,shape = transform_shape_gain*gainparams$estimate['shape'],rate=gainparams$estimate['rate'])
alllossrate = rgamma(ngenes,shape = transform_shape_loss*lossparams$estimate['shape'],rate=lossparams$estimate['rate'])

simed_genes = matrix(rep(NA,ngenes*(length(tree$tip.label)+length(tree$node.label))),ngenes)	# +1 is for root # NO MORE
colnames(simed_genes) = append(tree$tip.label,tree$node.label)
#colnames(simed_genes)[colnames(simed_genes)=='[N1]'] = 'N1'
for (i in 1:ngenes) {
	if (i%%100 == 0) {
		print(i)
		}
	# select a rate pair at random from rate list - NO LONGER- NOW SAMPLING INDEPENDENTLY
	# TO FLESH OUT THE RIGHT END OF THE NULL DISTRIBUTION!!!
#	pos_rates = sample(nrates,size=1)
#	gainrate = gainrates[pos_rates,'Rate']
#	lossrate = lossrates[pos_rates,'Rate']

	gainrate = allgainrate[i]
	lossrate = alllossrate[i]
	
	# NOW JUST USING ML TO ESTIMATE RATE DISTRIBUTIONS, AND WORKING FROM THERE!
	# TRIED A VARIETY OF PARAMETERIZATIONS BY CHANGING SHAPE (~MEAN) PARAMETER
	# THIS IS WHAT WORKED BEST (LOOKED MOST LIKE REAL GENES)
	model = matrix(c(0,lossrate,gainrate,0),2)	# for first pos in fake_results
	
#	print('root sample')
	root.prob = c(lossrate/(gainrate+lossrate),gainrate/(gainrate+lossrate))
	# estimate root probabilities from rates assuming stationary model
	rooting = sample(c(1,2),prob=root.prob,size=1)
	
	pattern = rTraitDisc(tree,model=model,k=2,states=c(0,1),root.value=rooting,ancestor=TRUE)
	
	# 2.15.3 to 3.0.0 update required this explicit numericization of the reconstruction 
	simed_genes[i,] = as.numeric(as.vector(pattern)) 
	
	# not necessary.
	if (i%%ngenes == 0) {	# write out data in manageable-sized chunks
	#	simed_genes[simed_genes==1] = 0
	#	simed_genes[simed_genes==2] = 1
		colnames(simed_genes)[colnames(simed_genes) =='[N1]'] = 'N1'
	#	write.table(simed_genes,'MOtree_ape_simgenes_2fitgain2loss_021914.txt',quote=FALSE)
		}
	}
	
# to infer changes in ape-simulated genes on a phylogeny
parent_maps = as.matrix(read.table(file.path(data_dir,'TheTree.INodes.ph.parentage')))

descs = parent_maps[,2]
# infer gain/loss whatever
infer_change = function(desc,pos,mat) {
	# desc is a vector of species at the end of branches
	# note that N1 (root) IS NOT in desc vector
	# therefore length of desc = (total # nodes) - 1.
	anc = parent_maps[parent_maps[,2]==desc,1]	
#	print(anc)
	anc_poses = mat[pos,anc]

	desc_poses = mat[pos,desc]
	changed = desc_poses-anc_poses
	outvector = rep(5,length(anc))
	
	
	outvector[which(desc_poses==1) %in% which(anc_poses==1)] = 2
	outvector[which(desc_poses==0) %in% which(anc_poses==0)] = 1
	outvector[which(changed==0) %in% which(anc_poses==0)] = 1
	outvector[which(changed==0) %in% which(anc_poses==1)] = 2
	outvector[changed==1] = 3
	outvector[changed==-1] = 4

	if (length(which(outvector==5)) > 0) {
		print(anc_poses)
		print(desc_poses)
		print(changed)
		print(anc_poses[outvector==5])
		print(desc_poses[outvector==5])
		print(changed[outvector==5])
		print(outvector)
		}
	stopifnot(length(which(outvector==5)) == 0) 
	return(outvector)
	}

classed = matrix(rep(5,nrow(simed_genes)*length(descs)),nrow(simed_genes))

cat('there are ',nrow(simed_genes),'simulated genes\n')

for (pos in 1:nrow(simed_genes)) {
	if (pos%%100 == 0) { 
		print(pos)
		}
	if (length(which(is.na(simed_genes[pos,])))>0)
		{next}
	classed[pos,] = infer_change(descs,pos,simed_genes)
	}

#print('testing for classification errors')	
stopifnot(length(which(classed==5))==0)

rates = cbind(allgainrate,alllossrate)
colnames(rates) = c('Gain_rate','Loss_rate')

#UNCOMMENT
save(rates,file=paste(paste(working_dir,'MOtree_simrates',sep='/'),outfile_stub,sep='.'))

# figure out gain/presence
present = classed
present[classed==2] = 1
present[classed!=2] = 0

#presence = rowSums(present)
#print('calced presence')
#write.table(present,paste('simed_genes_new/MOtree_apesim_presence.txt',outfile_stub,sep='.'),col.name=FALSE,row.name=FALSE,quote=FALSE)
#UNCOMMENT
save(present,file=paste(paste(working_dir,'apesim_presence.Rdat',sep='/'),outfile_stub,sep='.'))

gain = classed
gain[classed==3] = 1
gain[classed!=3] = 0

#gained = rowSums(gain)
print('calced gain')
#UNCOMMENT
save(gain,file=paste(paste(working_dir,'apesim_gain.Rdat',sep='/'),outfile_stub,sep='.'))

# wipe some big chunks of data from memory
remove(classed)
remove(gain)
remove(present)
remove(simed_genes)
remove(rates)
remove(allgainrate)
remove(alllossrate)


# next, multiply the two- big step.
