#!usr/bin/Rscript
# predictions for HGT

# clean up+


# not actually firmicutes, but inherited var names from hgt_predict_firm.R script
#firmadj = as.matrix(read.table(gzfile('processed_data/hgt_net_treduce_adjmat_071614.txt.gz'),header=T))
load('processed_data/hgt_net_dag_transreduced_repro.Rdat')
firmadj = adjmat
rm(adjmat)

edgenum = sum(firmadj)
# this file is just a parsed version of the gainloss reconstruction, except that species are mapped not to the node at the end of the branch, but rather to the ancestor.  done using the parse_reconstruction_foranc.py script.
# this is because we are trying to predict based on the ancestor, otherwise it is circular.  
allanc = read.table('processed_data/021214_MOtree_ancestralrecon.txt.gz.anc_pres',header=T)

firmanc = read.table('processed_data/AncestralReconstructPosterior.txt.pres_probs.firm',header=T)
realgains = read.table('gainLoss_results/MOtree_GLrun/AncestralReconstructPosterior.txt.gain_probs',header=T)

# get all the firmicutes branches
# note that in this script i am using a model trained on full dataset to predict, but only predicting firms
firmspecs = rownames(allanc)[!(rownames(allanc) %in% rownames(firmanc))]
firmgains = colSums(round(realgains[firmspecs,]))

# get relevant genes - in pgce net
coled = colnames(firmadj)[which(colSums(firmadj)>0)]
rowed = colnames(firmadj)[which(rowSums(firmadj)>0)]
both = unique(append(rowed,coled))
both = names(na.omit(colSums(allanc[,both])))

firmpres = allanc[firmspecs,both]
firmmodel = as.matrix(firmadj[both,both])

# make it so that branches where gene is present will not get predicted
# there are actually better ways
#diag(firmmodel) = -999

# check that everything is accounted for- suprisingly irritating!
stopifnot(length(which(is.na(firmpres)))==0)
stopifnot(length(which(is.na(firmmodel)))==0)

# make predictions
firmpredict = as.matrix(firmpres) %*% as.matrix(firmmodel)

# restrict to genes that can even be predicted
predictable = names(which(colSums(firmadj)>0))
firmpredict = firmpredict[,predictable]

firmtrue = as.matrix(realgains[firmspecs,predictable])
firmtrue[firmtrue > .5] = 1
firmtrue[firmtrue < .5] = 0

all_for_roc = c()

for (gene in predictable) {
#	print(gene)
	max_in = sum(firmadj[,gene])

	for (branch in firmspecs) {
		if (firmpres[branch,gene] > .4) {
			next
			} else {
			
			comparison = c(firmpredict[branch,gene]/max_in,firmtrue[branch,gene],branch,gene)	
			all_for_roc = rbind(all_for_roc,comparison)
			}
		
		}
	}

# draw a roc curve????
source('code/figure_scripts/roc.R')
print('ROC for using whole dataset as predictor (overfit, probably)')
ROC(all_for_roc)
#rm(firmadj,firmanc,firmpres,firmmodel)

cat('number of edges:',edgenum,'\n')
cat('number of predictable genes:',length(predictable),'\n')
cat('number of gained genes:',length(which(firmgains>=1)),'\n')
