/*
Copyright (c) 2011-2012 Daniel Marbach(1,2)

(1) Massachusetts Institute of Technology, Cambridge MA, USA
(2) Broad Institute of MIT and Harvard
 
We release this software open source under an MIT license (see below). If this
software was useful for your scientific work, please cite our paper available at:
http://compbio.mit.edu/flynet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
 */
package edu.mit.compbio.flynet;

import java.util.ArrayList;
import java.util.Random;

import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;


/**
 * Learns the weights of the different features using logistic regression on the RedFly
 * network, as described in the paper
 */
public class SupervisedLearning {

	/** The list of edges with the features */
	private ArrayList<Edge> edges_ = null;
	/** The number of edges */
	private int numEdges_ = -1;
	/** The RedFly network used for the class labels */
	private Network redfly_ = null;
	
	/** The weka instances (one for each edges, contains features and class according to redfly) */
	private Instances instances_ =  null;
	/** The number of positive instances (redfly edges) */
	private int numPositives_ = -1;
	/** The weight assigned to the positive instances to balance the classes */
	private double posWeight_;
	/** The weight assigned to the negative instances to balance the classes */
	private double negWeight_;
	
	/** The random seed used for CV */
	private int randomSeed_ = 2;
	/** The random number generator used for CV */
	private Random random_ = null;
	/** The number of folds in cross validation (CV) */
	private int folds_ = 10;
	/** The number of times we repeat cross validation */
	private int numRepeatsCV_ = 1000;
	/** The coefficients learned in each fold of the CV and for each repeat */
	ArrayList<double[]> coeffsCV_ = null;

//	/** The predicted networks (sup using CV, unsup, followed by the feature-specific networks) */
//	private ArrayList<Network> testSetPredictions_ = null;
//	/** The class labels for the test set predictions */
//	private Network testSetPositives_ = null;
	
	
	// ============================================================================
	// PUBLIC METHODS
		
	public SupervisedLearning() {
		random_ = new Random(randomSeed_);
	}
	
	
	// ----------------------------------------------------------------------------

	/** Performs supervised learning with cross-validation, returns the average coefficients */
	public double[] run(ArrayList<Edge> edges) {

		redfly_ = new Network();
		redfly_.load("../networks/redfly.txt", false);
		redfly_.setName("redfly.txt");

		edges_ = edges;
		numEdges_ = edges_.size();

		coeffsCV_ = new ArrayList<double[]>();

		Flynet.println("Running " + numRepeatsCV_ + " iteration(s) of " + folds_ + "-fold cross validation ...");

		// Create the weka instances based on edges_ and redfly_
		createInstances();

		// Perform CV
		for (int i=0; i<numRepeatsCV_; i++)
			crossValidation(i);
		
		// Save the learned coeffs
		writeCoeffsCV();
		
		// Return the average coeffs
		int numCoeffs = coeffsCV_.get(0).length;
		double[] avgCoeffs = new double[numCoeffs];
		for (int c=0; c<numCoeffs; c++)
			avgCoeffs[c] = 0;
		
		for (int i=0; i<coeffsCV_.size(); i++)
			for (int c=0; c<numCoeffs; c++)
				avgCoeffs[c] += coeffsCV_.get(i)[c];

		for (int c=0; c<numCoeffs; c++)
			avgCoeffs[c] /= coeffsCV_.size();
		
		Flynet.println("Avgerage coefficients from cross validation:");
		Flynet.println("Intercept  \t" + avgCoeffs[0]);
		Flynet.println("Motif      \t" + avgCoeffs[1]);
		Flynet.println("Chip       \t" + avgCoeffs[2]);
		Flynet.println("ChromatinTC\t" + avgCoeffs[3]);
		Flynet.println("ChromatinCL\t" + avgCoeffs[4]);
		Flynet.println("Microarray \t" + avgCoeffs[5]);
		Flynet.println("RnaSeq     \t" + avgCoeffs[6]);
		Flynet.println("");
		
		return avgCoeffs;
	}

	
	// ============================================================================
	// PRIVATE METHODS
        
	/** Create the set of weka instances based on edges_ and redfly_ */
	private void createInstances() {

		// Following example from: http://weka.wikispaces.com/Programmatic+Use
		Attribute motif = new Attribute("motif");
		Attribute chip = new Attribute("chip");
		Attribute chromatinTC = new Attribute("chromatinTC");
		Attribute chromatinCL = new Attribute("chromatinCL");
		Attribute microarray = new Attribute("microarray");
		Attribute rnaSeq = new Attribute("rnaSeq");

		// Declare the class attribute along with its values
		FastVector classVal = new FastVector(2);
		classVal.addElement("1");
		classVal.addElement("0");
		Attribute classLabel = new Attribute("RedFly", classVal);

		// Declare the feature vector
		FastVector attributes = new FastVector(7);
		attributes.addElement(motif);    
		attributes.addElement(chip);    
		attributes.addElement(chromatinTC);
		attributes.addElement(chromatinCL);
		attributes.addElement(microarray);
		attributes.addElement(rnaSeq);
		attributes.addElement(classLabel);    

		// Create an empty training set with initial capacity edges.size()
		instances_ = new Instances("RedFly", attributes, numEdges_);           
		instances_.setClassIndex(attributes.size()-1);
		numPositives_ = 0;

		// Create the instances
		for (int i=0; i<numEdges_; i++) {
			double[] features = edges_.get(i).getFeatures();
			assert features.length == attributes.size() - 1;
			
			// Set the features
			Instance instance = new Instance(attributes.size());
			for (int f=0; f<features.length; f++)
				instance.setValue((Attribute)attributes.elementAt(f), features[f]);
			
			// Set the class label
			if (redfly_.contains(edges_.get(i))) {
				instance.setValue((Attribute)attributes.elementAt(attributes.size()-1), "1");
				numPositives_++;
			} else {
				instance.setValue((Attribute)attributes.elementAt(attributes.size()-1), "0");
			}
			// Add the instance
			instances_.add(instance);	
		}
		assert instances_.numInstances() == numEdges_;
		
		// Set the weights of the instances to balance the classes
		posWeight_ = 1 - numPositives_/(double)numEdges_;
		negWeight_ = numPositives_/(double)numEdges_;
		
		for (int i=0; i<instances_.numInstances(); i++) {
			Instance inst = instances_.instance(i);
			if (isPositive(inst))
				inst.setWeight(posWeight_);
			else
				inst.setWeight(negWeight_);
		}
	}

	
	// ----------------------------------------------------------------------------

	/** Run cross-validation */
	private void crossValidation(int iteration) {
		
//		testSetPredictions_ = new ArrayList<Network>();
//		testSetPredictions_.add(new Network("motif"));
//		testSetPredictions_.add(new Network("chip"));
//		testSetPredictions_.add(new Network("chromatinTC"));
//		testSetPredictions_.add(new Network("chromatinCL"));
//		testSetPredictions_.add(new Network("microarray"));
//		testSetPredictions_.add(new Network("rnaseq"));
//		testSetPredictions_.add(new Network("unsup"));
//		testSetPredictions_.add(new Network("sup"));
//		
//		testSetPositives_ = new Network("labels");
		
		// Following example: http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29
		// Create copy of original data
		Instances learningSet = new Instances(instances_);
		// Randomize
		learningSet.randomize(random_);
		// Stratify for CV (makes sure approximately same number of positives is in each fold)
		learningSet.stratify(folds_);

		// Perform CV
		for (int n=0; n<folds_; n++) {
			// Get the training and test sets
			Instances train = learningSet.trainCV(folds_, n);
			Instances test = learningSet.testCV(folds_, n);
			// Checks that stratification and weighting of instances is done to balance classes
			assert isBalanced(test);
			assert isBalanced(train);
			 
			// Train the classifier
			Logistic classifier = new Logistic();
			classifier.setRidge(0);
			try {
				classifier.buildClassifier(train);
			} catch (Exception e) {
				Flynet.error(e);
			}
			//System.out.println(classifier.toString());

			// Predict the test set
			double[] coeffs = coeffsAsVector(classifier.coefficients());
			coeffsCV_.add(coeffs);
//			predictTestSet(test, coeffs);
		}
		
//		// Do the recovery rates only for the first iteration
//		// (ideally, we would compute an avg recovery rate and 95% confidence intervals
//		// over all iterations)
//		if (iteration == 0)
//			testSetRecoveryRates();
	}
		 
	// ----------------------------------------------------------------------------

	/** 
	 * For some reason the weka learned coefficients are stored in a matrix with a 
	 * single column, let's convert to a vector 
	 */
	private double[] coeffsAsVector(double[][] coeffs) {

		double[] vect = new double[coeffs.length];
		for (int i=0; i<coeffs.length; i++) {
			assert coeffs[i].length == 1;
			vect[i] = coeffs[i][0];
		}
		return vect;
	}
		
	
	// ----------------------------------------------------------------------------

	/** Returns true if the instance is a positive (has class label "1") */
	private boolean isPositive(Instance instance) {

		String label = instance.stringValue(instance.classIndex()); 
		return label.equals("1");
	}
	

	// ----------------------------------------------------------------------------

	/** Returns true if the set of instances is balanced (sanity check) */
	private boolean isBalanced(Instances set) {
		
		double expectedNumPositives = set.numInstances() * numPositives_/(double)numEdges_;

		// Count the number of positives and the sum of the weights
		int numP = 0;
		double weightP = 0;
		double weightN = 0;

		for (int i=0; i<set.numInstances(); i++) {
			if (isPositive(set.instance(i))) {
				numP++;
				weightP += set.instance(i).weight();
			} else {
				weightN += set.instance(i).weight();
			}
		}
		
		double delta = Math.abs(weightP - weightN);
		if (Math.abs(expectedNumPositives - numP) >= 1)
			return false;
		if (delta >= posWeight_)
			return false;
		
		return true;
	}
	

	// ----------------------------------------------------------------------------

//	/**  */
//	private void predictTestSet(Instances testSet, double[] coeffs) {
//		
//		// A fictional TF
//		String tf = "NULL";
//		int numFeatures = testSet.numAttributes() - 1; // -1 because they include class labels
//		
//		assert numFeatures == 6; // to be removed, sanity check
//		assert numFeatures == coeffs.length-1;
//		
//		for (int i=0; i<testSet.numInstances(); i++) {
//			Instance instance = testSet.instance(i);
//			
//			// A unique, fictional target
//			String target = "EDGE_" + (testSetPredictions_.get(0).getNumEdges()+1);
//			
//			// Compute the supervised, unsupervised, and feature-specific prediction for this instance/edge
//			double sum = 0;			// Sum of weights for unsup
//			double z = coeffs[0];	// Weighted sum of weights for sup, coeffs[0] is the intercept
//			
//			// For each feature
//			int k;
//			for (k=0; k<numFeatures; k++) {
//				// The value of this feature for this instance
//				double x_k = instance.value(k);
//
//				// Update the sums
//				sum += x_k;
//				z += coeffs[k+1]*x_k;
//				
//				// Add the feature-specific prediction
//				testSetPredictions_.get(k).addEdge(tf, target, x_k);
//			}
//
//			// Add the unsupervised prediction
//			testSetPredictions_.get(k++).addEdge(tf, target, sum/numFeatures);
//
//			// Add the supervised prediction
//			double confidence = 1 / (1+Math.exp(-z));
//			testSetPredictions_.get(k++).addEdge(tf, target, confidence);
//			
//			// Add to the gold standard if it is a positive
//			if (isPositive(instance))
//				testSetPositives_.addEdge(tf, target, 1);
//		}
//	}

	
	// ----------------------------------------------------------------------------

	/** Write the CV coefficients to a file */
	private void writeCoeffsCV() {
		
		FileExport writer = new FileExport("coeffsCV.txt");
		for (int i=0; i<coeffsCV_.size(); i++) {
			double[] coeffs = coeffsCV_.get(i);
			
			String line = Double.toString(coeffs[0]);
			for (int c=1; c<coeffs.length; c++)
				line += "\t" + Double.toString(coeffs[c]);
			
			writer.println(line);
		}
		writer.close();
	}

	
	// ----------------------------------------------------------------------------

//	/** Compute and write the recovery rates of the feature specific and integrated networks */
//	private void testSetRecoveryRates() {
//		
//		RecoveryRate recovery = new RecoveryRate(testSetPositives_, testSetPredictions_);
//		recovery.run();
//	}

		
	// ----------------------------------------------------------------------------

	/** Some tests for the logistic regression */
	@SuppressWarnings("unused")
	public void tests() {
		
		// Following example from: http://weka.wikispaces.com/Programmatic+Use
		// Declare numeric attributes
		Attribute one = new Attribute("one");
		//Attribute two = new Attribute("two");

		// Declare the class attribute along with its values
		FastVector fvClassVal = new FastVector(2);
		fvClassVal.addElement("0");
		fvClassVal.addElement("1");
		Attribute classLabel = new Attribute("Class", fvClassVal);

		// Declare the feature vector
		FastVector attributes = new FastVector();
		attributes.addElement(one);    
		attributes.addElement(classLabel);    

		// Create an empty training set with initial capacity edges.size()
		Instances learningSet = new Instances("Learning", attributes, 10);           
		// Set class index
		learningSet.setClassIndex(attributes.size()-1);
		
		folds_ = 10;
		numPositives_ = 10;
		int numInstances = 100;
		double posWeight = 1 - numPositives_ / (double)numInstances;
		double negWeight = numPositives_ / (double)numInstances;
		
		// Create the instances
		for (int i=0; i<90; i++) {
			Instance negInstance = new Instance(attributes.size());
			negInstance.setValue((Attribute)attributes.elementAt(0), -(i+1));
			negInstance.setValue((Attribute)attributes.elementAt(attributes.size()-1), "0");
			negInstance.setWeight(negWeight);
			learningSet.add(negInstance);
		}
		
		for (int i=0; i<10; i++) {
			Instance posInstance = new Instance(attributes.size());
			posInstance.setValue((Attribute)attributes.elementAt(0), i+1);
			posInstance.setValue((Attribute)attributes.elementAt(attributes.size()-1), "1");
			posInstance.setWeight(posWeight);
			learningSet.add(posInstance);	
		}

		learningSet.randomize(new Random(randomSeed_));         // randomize data with number generator
		learningSet.stratify(folds_);

		Logistic classifier = new Logistic();
		classifier.setRidge(0);
		
		 for (int n=0; n<folds_; n++) {
			 Instances train = learningSet.trainCV(folds_, n);
			 Instances test = learningSet.testCV(folds_, n);

			 assert isBalanced(test);
			 assert isBalanced(train);
		 }
		 
		try {
			classifier.buildClassifier(learningSet);

			double result = classifier.classifyInstance(learningSet.instance(0));
			
			double[][] coeffs = classifier.coefficients();
			System.out.println(classifier.toString());
			System.out.println();
			
		} catch (Exception e) {
			Flynet.error(e);
		}
	}
    
}
