/**
 * SIFTER project. This implements the high level core
 * functionality for the inference and learning methods.
 * $Id: PFun.java,v 1.2 2005/05/27 06:56:01 sprite Exp $
 * 
 * Copyright (C) 2010, Barbara Engelhardt (bee@compbio.berkeley.edu)
 *
 * @author Barbara Engelhardt (primary investigator and author)
 * @author Steven R. Chan (later code hacking, documentation, GUI)
 * @version 1.0
 */

import java.util.Enumeration;
import java.util.Vector;
import java.util.Hashtable;
import java.io.*;

import stat.*;
import util.*;

/**
 * A ton of helper functions for PFun.
 * 
 * $Id: PFunLibrary.java,v 1.2 2005/05/27 06:56:01 sprite Exp $
 * @author Barbara Engelhardt, primary author
 * @author (some modifications by Steven Chan)
 */
public class PFunLibrary {
    /**
     * Indexed by swissprot numbers. 
     * KEY: swissprotID (not name); VALUE: PFunprotein
     */
    private Hashtable proteinList;  
    /**
     * Indexed by pfam ID's.
     */
    private PfunFamily family;
    /**
     * Holds user's preferences.
     */
    private Settings settings;
    private boolean verbose;
    private boolean em;
    private String familyFilename;
    private String scaleParamsFilename;
    private String alphaParamsFilename;
    private int setfolds;
    private int truncation;

    /**
     * Constructor. Initializes variables.
     */
    public PFunLibrary()
    {
	proteinList = null;	
	family = null;
	settings = new Settings(); // Use a default new Settings object.
    }
    
    /*****************************************************************
        ONLY USED BY -lib
    ******************************************************************/   
    
    // Only called from testLibrary() and testLibrarySmall().
    // Otherwise also seems useless.
    public void printOutProteinList(String outFileName, Hashtable pList)
    {
        PrintStream fout;		
        try {
            fout =  new PrintStream(new FileOutputStream(new File(outFileName)));
            
            Enumeration pnames = pList.keys();
            while(pnames.hasMoreElements()) {
                PfunProtein pfp = (PfunProtein)pList.get(pnames.nextElement());
                pfp.printProteinToFile(fout);
            }
            fout.close();
        }
        catch (Exception ioe) {
            System.err.println("PrintOutProteinList: " + outFileName + " " +
                               ioe.getMessage());
            System.exit(1);
        }	
    }
    	
    //////////////////////////////////////////////////
    // Ontology building section                    //
    //////////////////////////////////////////////////
    
    /**
     * Builds DAG: Most of the time she uses this (string-calling).
     * All this does is take an input file "inFiles", stick it in a 
     * vector, and pass it on like a hot potato to another overloaded 
     * buildDatasetGODAG function (in this case, completely 
     * passing all of its own parameters to the overloaded f(x) 
     * except ignoring the "outfileName" argument.)
     *
     * STEP 1.
     * TODO: Check to see if this calls other overloaded 
     * buildDatasetGODAG() functions.
     * 
     * @param inFamilyFile
     * @param goFile
     * @param noIEA
     * @return
     */
    
    /**
     * STEP 2 in the buildDatasetGODAG potato-passing process.
     * does the actual work of building a model from the files.
     * In general, she uses the single file ... calls PFunGO
     * 
     * @param files
     * @param goFile
     * @param noIEA
     * @return
     */
    public PFunGODAG buildDatasetGODAG(String file, String goFile, 
				       boolean noIEA) {
        PFunGODAG pfgodag = new PFunGODAG(goFile, true, true, true);
        pfgodag.setSettingsObject(this.settings);
        pfgodag.fillOutList();   
	// Reads in the molecular function file from constructor 
	// on previous line and compiles this humongous data structure.
	// I wonder if it can cache this to speed up future retrieval...
        Vector fns = null; // evidence functions;
        Vector methods = null; // methods list
        proteinList = new Hashtable();
	// counting the xml file with GO dag stats
	// print out at the end.
	PfunFamily proteinFam = new PfunFamily(file);
	Vector p = proteinFam.readInFromXMLFile(file);
	if(p != null) {
	    if (verbose) System.out.println(file);
	    fns = new Vector();
	    methods = new Vector();
	    for(int i = 0; i < p.size(); i++) {
		PfunProtein pro = (PfunProtein)p.elementAt(i);
		//if(pro.getGONumber() != null) { // bee added for BMC experiments
		if(pro.getGONumber() != null) {
		    fns.addAll(pro.getGONumber());
		    methods.addAll(pro.getMOC());
		}
		if(!proteinList.contains(pro.getID())) {
		    proteinList.put(pro.getID(), pro);
		}
	    }
	    // should be noIEA, but tally is only IDA
	    pfgodag.tallyFunctions(fns, methods); 
	}
	family = proteinFam;
	pfgodag.pruneZeroHitsAndLeaves();
	pfgodag.padSingletonLeaves();
	pfgodag.findRValue();
	
	Enumeration proteinIDs = proteinList.keys();
	while(proteinIDs.hasMoreElements()) {
	    // iterate through protein list,
	    // incorporating evidence
	    String pID = (String)proteinIDs.nextElement();
	    PfunProtein pro = (PfunProtein)proteinList.get(pID);
	    if(pro.getFunctionRatio() != null) {
		pro.addRatioLikelihoods();
	    } else if(pro.getGONumber() != null) {
		methods = pro.getMOC();
		Vector functions = pro.getGONumber();
		pfgodag.incorporateEvidenceShort(functions, 
						 methods, noIEA);
		if(verbose) {
		    for(int i = 0; i < methods.size(); i++) {
			System.out.println("Evidence "
					   +proteinFam.getProteinNameFromID(pID).toUpperCase()+" "
					   +methods.elementAt(i)+" "
					   +functions.elementAt(i));
		    }
		}
		double[] leaflike = pfgodag.pullOutLeafLikelihoods();
		pfgodag.clearDAGProbabilities();
		pro.addLikelihoods(leaflike);
	    }
	}	
	if (verbose) System.out.println("Finished building GO graphs");
	return pfgodag;
    }
    
    /**
     * Called by parseTestNetworkInputExact() and a WHOLE bunch of 
     * other functions!
     * 
     * Tries comparing the reconciled tree with the one in GODAG.
     * They're never merged. With the set of likelihoods from the GODAG, 
     * associate each protein with the leaf in the tree. 
     * 
     * TODO: Could have error checking ...
     **/
    public void buildGraphicalModel(PFunGODAG pfgodag, String nexfile, 
				    boolean noIEA)
    {
	// Read in their reconciled tree (not the GODAG)
	// 
	PFunTree t = new PFunTree(family.getMaxAlignment());
	String reconciledFilename = nexfile;
	t.createReconciled(reconciledFilename);
	if(verbose) t.printTree();
	// Add evidence at leaves with evidence
	// PFam-type names (not swiss prot ids)
	Vector proteins = family.getProteinNames();
	for(int i = 0; i < proteins.size(); i++) {   
	    // src: Does comparison. Does not do error-checking.
	    //match protein with reconciled tree leaf
	    String currentProtein = (String)proteins.elementAt(i);
	    PfunProtein pfp = (PfunProtein) proteinList
		.get(family.getProteinName(currentProtein));
	    if(t.hasNode(currentProtein) 
	       && pfp.getLeafLikelihoods() != null &&
	       (pfp.hasProperEvidence(settings))) {
		t.setNodeEvidenceProbabilities(currentProtein, 
					       pfp.getLeafLikelihoods());
	    }
	}
	// and attach them to the family.
	// src: makes sure that each protein in this reconciled tree 
	// has the appropriate likelihoods from the GODAG.
	t.setSingleNodeSampleLikelihoods(pfgodag.getLeafSubsetPrior(), 
					 pfgodag.getSingleLeafPrior());
	family.addTree(t);    
    }

    /** Accessor for familyList. 
     * @return Hashtable: a list of families
     */
    public PfunFamily getFamily()
    {
        return family;
    }
    
    /**
     * This is the main function for inferences that she uses; 
     * not the other inferWithTransitionMatrix() functions.
     * learnTransitionMatrix().
     *
     * Actually exponential in terms of the number of functions in the tree 
     * (even though it's linear in the number of proteins). 
     * TODO: To run it on larger datasets: figure out how to speed this up, 
     *
     * @param pfgodag
     * @param hashtable
     */

    public Hashtable inferWithTransitionMatrixExact(PFunGODAG pfgodag, 
						    PfunFamily fam) 
    {
	// iterate through protein list,
	// incorporating evidence
	
	PFunTransMatrix pfx = new PFunTransMatrix(familyFilename);
	pfx.readInAlpha(alphaParamsFilename);
	pfx.readInScale(scaleParamsFilename);
	pfx.setMaxFunctions(truncation); // truncation here
	System.out.println("Setting truncation level to "+truncation);
	PFunLearn pfl = new PFunLearn(pfx, proteinList,
				      pfgodag, settings);
	pfl.setSettingsObject(settings);
	pfl.setFamily(fam);
	return(pfl.inferPosteriorsExact(fam));
	//fam.getTree().getLogLikelihood(pfx, scaleParams);
    }
    
    /**
     * @param pfgodag
     * @param hashtable
     */
    public double estimateParameters(PFunGODAG pfgodag, PfunFamily fam) 
    {
	// iterate through protein list,
	// incorporating evidence
	PFunTransMatrix pfx = new PFunTransMatrix(familyFilename);
	pfx.readInScale(scaleParamsFilename);
	pfx.readInAlpha(alphaParamsFilename);

	PFunLearn pfl = new PFunLearn(pfx, proteinList,
				      pfgodag, settings);
	pfl.setSettingsObject(settings);
	pfl.setFamily(fam);
	pfl.ExpectationMaximization(pfx, fam);
	return 0.0;
    }
    
    /**
     * crossValidation
     *
     */

    public double crossValidation(PFunGODAG pfgodag, PfunFamily fam)
    {
	return (crossValidation(pfgodag, fam, setfolds));
    }

    public double crossValidation(PFunGODAG pfgodag, PfunFamily fam, int folds)
    {
	// iterate through protein list,
	// incorporating evidence
	int correctGO = 0;
	int totalGO = 0;
	Vector toRemove = new Vector();
	int lpsize = 0;

	Vector proteinNames = fam.getProteinNames();
	if(folds > 0) PFunUtil.randomizeVector(proteinNames); 
	// Build the list of proteins to remove
	PFunTree tree = fam.getTree();
	for(int i = 0; i < proteinNames.size(); i++) {
	    //match protein with reconciled tree leaf
	    String currentProtein = (String)proteinNames.elementAt(i);
	    PfunProtein p = (PfunProtein)
		proteinList.get(fam.getProteinName(currentProtein));
	    if(p != null && tree.hasNode(currentProtein)) {
		PFunTree.Node n = tree.getNode(currentProtein);
		if(n.hasLocalProbabilities() && n.isLeaf()) {
		    toRemove.addElement(n);
		    lpsize = n.getLocalProbabilities().length;
		}
	    }
	}
	if(folds == 0) folds = toRemove.size();
	int foldSize = (int)Math.floor((double)toRemove.size()/(double)folds);
	int foldSizeRemainder = toRemove.size()%folds;
	System.out.println("Fold size: "+foldSize);
	System.out.println("Number of holdouts: "+toRemove.size());
	int startIndex = 0;
	int endIndex = 0;
	double[][] lps = new double[toRemove.size()][lpsize];
	for(int i = 0; i < folds; i++) {
	    startIndex = endIndex;
	    endIndex = (endIndex+foldSize > toRemove.size())?
		(toRemove.size()):(endIndex+foldSize);
	    if(i < foldSizeRemainder) ++endIndex;
	    for(int j = startIndex; j < endIndex; j++) {
		PFunTree.Node n = (PFunTree.Node)toRemove.elementAt(j);
		lps[j] = n.getLocalProbabilities();
		n.removeLocalProbabilities();
		if(verbose)
		    System.out.println("X Val: removing evidence for "
				       +n.getNodeID());
	    }
	    // Run EM like usual
	    PFunTransMatrix pfx = new PFunTransMatrix(familyFilename);
	    pfx.setMaxFunctions(truncation); // truncation here
	    System.out.println("Setting truncation level to "+truncation);

	    pfx.readInScale(scaleParamsFilename);
	    pfx.readInAlpha(alphaParamsFilename);
	    
	    PFunLearn pfl = 
		new PFunLearn(pfx, proteinList,
			      pfgodag, settings);
	    pfl.setSettingsObject(settings);
	    pfl.setFamily(fam);
	    Hashtable posteriors; 
	    if(em) {
		posteriors = pfl.ExpectationMaximization(pfx, fam);
	    } else {
		posteriors = inferWithTransitionMatrixExact(pfgodag, fam);
	    }
	    // Check to see if we got held-out ones correct
	    for(int k = startIndex; k < endIndex; k++) {
		PFunTree.Node n = (PFunTree.Node)toRemove.elementAt(k);
		double[] lp = lps[k];
		if(lp == null) 
		    System.out.println("Error: don't have local probabilities for "+n.getNodeID());
		if(posteriors.containsKey(n)) {
		    double[] d = (double[])posteriors.get(n);
		    System.out.print("x-val ("+n.getNodeID()+") ");
		    for(int di = 0; di < d.length; di++) {
			System.out.print(d[di]+" ");
		    }
		    System.out.println();
		    double maxPrior = 0.0;
		    double maxPosterior = 0.0;
		    int[] maxPriorIndex = new int[1];
		    int maxPosteriorIndex = 0;
		    for(int j = 0; j < d.length; j++) {
			//System.out.println("LP["+j+"] = "+lp[j]);
			double lpj = lp[j];
			try {
			    if(d[j] > maxPosterior) {
				maxPosterior = d[j];
				maxPosteriorIndex = j;
			    }
			    else if(PFunUtil.areEqual(d[j],maxPosterior)) {
				maxPosteriorIndex = -1;
			    }
			    if(lpj < 0)  lpj = Math.exp(lpj);
			    if(lpj > maxPrior) {
				maxPrior = lpj;
				maxPriorIndex[0] = j;
			    }
			    else if(PFunUtil.areEqual(lpj, maxPrior)) {
				int[] maxPriorIndexT = 
				    new int[maxPriorIndex.length+1];
				for(int pi = 0; pi < maxPriorIndex.length; pi++) 
				    maxPriorIndexT[pi] = maxPriorIndex[pi];
				maxPriorIndexT[maxPriorIndex.length] = j;
			    }
			} catch (ArrayIndexOutOfBoundsException e) {
			    continue;
			}
		    }
		    System.out.println("In x-validation: "
				       +n.getNodeID()+" node name, "
				       +d.length+" = d length, "
				       +maxPosteriorIndex
				       +" = max posterior index "
				       +maxPriorIndex.length+" = max prior index "
				       +maxPosterior+" = max posterior "
				       +maxPrior+" = max prior");
		    
		    boolean thisCorrect = false;
		    for(int pi = 0; pi < maxPriorIndex.length; pi++) {
			if(maxPosteriorIndex == maxPriorIndex[pi] && 
			   (maxPosteriorIndex != -1)) {
			    correctGO++;
			    thisCorrect = true;
			}
		    } 
		    if(!thisCorrect) {
			System.out.print("Missed "+(n.getNodeID())
					 +", predicted "
					 +maxPosteriorIndex
					 +", real");
			for(int pi = 0; pi < maxPriorIndex.length; pi++) {
			    System.out.print(" "+maxPriorIndex[pi]);
			}
			System.out.println();
		    }
		    totalGO++;
		}
		tree.setNodeEvidenceProbabilities((String)n.getNodeID(), 
						  lp);
	    }
	}
	System.out.println("Cross-validation results: "
			   +((double)correctGO/(double)totalGO)
			   +" ("+correctGO+" out of "+totalGO+")");
	return ((double)correctGO/(double)totalGO);
    }
    
    /** Should we ever want to use an already existing Settings object,
     * we can direct our pointer to that.
     * @param settings Settings object to set to.
     * @see PFun.main()
     */
    public void setSettingsObject(Settings settings) {
        this.settings = settings;
        initSettings();
    }
    

    /** Once we get a new settings object, initialize any variables 
     * that we care about.
     * For example, I know PFunLibrary cares about verbose, family name, etc.
     */
    private void initSettings() {
        this.verbose = ((Boolean)this.settings
			.getSetting("verbose")).booleanValue();
        this.em = ((Boolean)this.settings
		   .getSetting("em")).booleanValue();
        this.familyFilename = (String) this.settings
	    .getSetting("familyFilename");
        this.scaleParamsFilename = (String) this.settings
	    .getSetting("scaleParamsFilename");
        this.alphaParamsFilename = (String) this.settings
	    .getSetting("alphaParamsFilename");
	this.setfolds = 
	    ((Integer)this.settings.getSetting("folds")).intValue();
	this.truncation = 
	    ((Integer)this.settings.getSetting("truncation")).intValue();
    }


}
