/*
Copyright (c) 2011-2012 Daniel Marbach(1,2)

(1) Massachusetts Institute of Technology, Cambridge MA, USA
(2) Broad Institute of MIT and Harvard
 
We release this software open source under an MIT license (see below). If this
software was useful for your scientific work, please cite our paper available at:
http://compbio.mit.edu/flynet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
 */
package edu.mit.compbio.flynet;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;


/**
 * Generate the unsupervised and the supervised integrative networks.
 */
public class FlynetInference {

	/** Flag indicating whether inference is supervised or unsupervised */
	private boolean supervised_ = false;
	/** The ChIP network */
	private Network chip_ = null;
	/** The motif network */
	private Network motif_ = null;
	/** The chromatin time-course */
	private GeneExpression chromatinTC_ = null;
	/** The chromatin cell-lines */
	private GeneExpression chromatinCL_ = null;
	/** The RNAseq data */
	private GeneExpression rnaSeq_ = null;
	/** The microarray data */
	private GeneExpression microarray_ = null;
	/** The flyatlas data */
	//private GeneExpression flyatlas_ = null;
	
	/** The inferred network */
	private Network network_ = null;
	/** 
	 * The coefficients learned using the logistic regression with cross-validation.
	 * The first entry is the bias (beta_0), followed by the weights for:
	 * motif, chip, chromatinTC, chromatinCL, microarray, rnaSeq.
	 * 
	 * The weights that Patrick learned using the R script were (the last entry was for
	 * flyatlas, which was later removed; I suspect that microarray and rnaSeq were switched):
	 * -1.5790, 2.2950, 1.0125, 2.049, 1.109, 0.9751, 0.3797, 0.4526
	 */
	private double[] coeffs_ = null;
	
	
	// ============================================================================
	// PUBLIC METHODS

	/** Run network inference */
	public void run() {
		
		// Flag indicates whether we generate the sup or unsup network
		supervised_ = Settings.generateSup_;
		
		// Learn the weights for the supervised network
		if (supervised_) {
			// Load the input datasets, restricted to redfly TFs and target genes
			loadInputDatasets(true);
			// Generate the features using the loaded datasets (compute correlations etc.)
			generateFeatures();
			// Run the cross validation, computes coeffs_
			runSupervisedLearning();
			// Generate the integrative network, restricted to redfly TFs and targets
			generateNetwork("integrativeSupervised_redfly.txt"); 
		}
		
		// Generate the network for all TFs and targets
		loadInputDatasets(false);
		generateFeatures();
		if (supervised_)
			generateNetwork("integrativeSupervised.txt");
		else
			generateNetwork("integrativeUnsupervised.txt");
	}
	  	

	// ============================================================================
	// PRIVATE METHODS

	/** 
	 * Load the seven input datasets / feature-specific networks, as well as the list of
	 * TFs and TGs to be considered (if Settings.runSupervisedLearning_ is set, only the
	 * RedFly TFs and TGs are loaded)
	 */
	private void loadInputDatasets(boolean redflyUniverseOnly) {
		
		Settings.networkDir_ = "resources/input_data";

		// The files with the list of TFs and TGs to be considered
		String tfFile;
		String tgFile;
		
		if (redflyUniverseOnly) {
			tfFile = "../networks/redfly_tfs.txt";
			tgFile = "../networks/redfly_tgs.txt";
		} else {
			tfFile = "tf.txt";
			tgFile = "tg.txt";
		}
		
		// Load the list of regulators and targets to be considered
		network_ = new Network();
		network_.loadNodes(tfFile, tgFile);

		// Load the input datasets
		chip_ = new Network();
		chip_.load("chip_net.txt", true);
		chip_.setName("Chip");
		
		motif_ = new Network();
		motif_.load("motif_regnet_highestconf.txt", true);
		motif_.setName("Motif");
		
		chromatinTC_ = new GeneExpression(true);
		chromatinTC_.load("chrom-tc.txt");
		chromatinTC_.setName("ChromatinTC");
		
		chromatinCL_ = new GeneExpression(true);
		chromatinCL_.load("chrom-cl.txt");
		chromatinCL_.setName("ChromatinCL");
		
		rnaSeq_ = new GeneExpression(true);
		rnaSeq_.load("rnaseq.txt");
		rnaSeq_.setName("RNAseq");
		
		microarray_ = new GeneExpression(true);
		microarray_.load("microarray.txt");
		microarray_.setName("Microarray");
		
		/*flyatlas_ = new GeneExpression(true);
		//flyatlas_.load("flyatlas.txt");
		flyatlas_.load("flyatlas_full.txt");
		flyatlas_.setName("Flyatlas");*/
	}

	
	// ----------------------------------------------------------------------------
	    
	/** Compute the input features based on the loaded datasets (correlations etc.) */
	private void generateFeatures() {
		
		Flynet.println("Computing input features...\n");
		
		HashSet<Node> regulators = network_.getRegulators();
		HashSet<Node> targets = network_.getTargets();
		
		// Loop over all TF-target gene pairs
		Iterator<Node> regIter = regulators.iterator();
		
		// For all TFs
		while (regIter.hasNext()) {
			Node tf = regIter.next();
			Iterator<Node> targetIter = targets.iterator();
		
			// For all targets
			while (targetIter.hasNext()) {
				Node target = targetIter.next();
				if (tf == target)
					continue;
				
				// The feature vector
				int numFeatures = 6;
				double[] features = new double[numFeatures];
				features[0] = motifWeight(tf, target);
				features[1] = chipWeight(tf, target); 
				features[2] = correlation(tf, target, chromatinTC_);
				features[3] = correlation(tf, target, chromatinCL_);
				features[4] = correlation(tf, target, microarray_);
				features[5] = correlation(tf, target, rnaSeq_);
				//features[6] = correlation(tf, target, flyatlas_);
				
				Edge edge = new Edge(tf, target, -1); // the weight of the edge is not yet defined (-1)
				edge.setFeatures(features);
				network_.getEdges().add(edge);
			}
		}
	}
	
	
	// ----------------------------------------------------------------------------
    
	/** Perform supervised learning and save the predictions and coefficients */
	private void runSupervisedLearning() {

		SupervisedLearning learner = new SupervisedLearning();
		coeffs_ = learner.run(network_.getEdges());
		
	}

	
	// ----------------------------------------------------------------------------
    
	/** 
	 * Generate the integrative network based on the list of TFs and targets that has been previously 
	 * loaded (NOTE, the logistic regression for the supervised network is not run here, we just use
	 * the feature weights that have been previously learned using runSupervisedLearning())
	 */
	private void generateNetwork(String filename) {
		
		Flynet.println("Generating integrative network ...");
		
		ArrayList<Edge> edges = network_.getEdges();

		for (int i=0; i<edges.size(); i++) {
			double[] features = edges.get(i).getFeatures();
			double weight = 0;
			
			if (!supervised_) {
				for (int f=0; f<features.length; f++)
					weight += features[f];
				weight /= features.length;
				
			} else {
				double z = coeffs_[0];
				for (int f=0; f<features.length; f++) {
					features[f] *= coeffs_[f+1];
					z += features[f];
				}
				weight = 1 / (1+Math.exp(-z));
			}
			edges.get(i).setWeight(weight);
		}
		
		// Sort the edges by their weight
		Collections.sort(network_.getEdges());
		
		// Save the networks
		Flynet.println("Saving network: " + filename);
		network_.save(filename);
	}


	// ----------------------------------------------------------------------------
    
	/** Return the weight of this edge from the chip network */
	private double chipWeight(Node tf, Node target) {
		
		// Check if the TF and target are in the feature-specific network
		Node chipTf = chip_.getNodes().get(tf.getLabel());
		if (chipTf == null || chipTf.getOutgoingEdges().size() == 0)
			return Settings.missingValue_;
		
		//Node chipTarget = chip_.getNodes().get(target.getLabel());
		//if (chipTarget == null)
		//	return Settings.missingValue_;
		
		Edge edge = chip_.getEdge(tf.getLabel(), target.getLabel());
		if (edge != null)
			return 0.95;
		else
			return 0.05;
	}

	
	// ----------------------------------------------------------------------------
    
	/** Return the weight of this edge from the motif network */
	private double motifWeight(Node tf, Node target) {
		
		// Check if the TF and target are in the feature-specific network
		Node motifTf = motif_.getNodes().get(tf.getLabel());
		if (motifTf == null)
			return Settings.missingValue_;
		
		Node motifTarget = motif_.getNodes().get(target.getLabel());
		if (motifTarget == null)
			return Settings.missingValue_;
		
		Edge edge = motif_.getEdge(tf.getLabel(), target.getLabel());
		if (edge != null)
			return edge.getWeight();
		else
			return Settings.missingValue_;
	}


	// ----------------------------------------------------------------------------
    
	/** The correlation between the two profiles in the given dataset */
	private double correlation(Node tf, Node target, GeneExpression dataset) {
		
		// Check if the TF and target are in the dataset
		if (!dataset.contains(tf) || !dataset.contains(target))
			return Settings.missingValue_;

		return dataset.functionalRelation(tf, target);
	}

}
