/*
Copyright (c) 2011-2012 Daniel Marbach(1,2)

(1) Massachusetts Institute of Technology, Cambridge MA, USA
(2) Broad Institute of MIT and Harvard
 
We release this software open source under an MIT license (see below). If this
software was useful for your scientific work, please cite our paper available at:
http://compbio.mit.edu/flynet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
 */
package edu.mit.compbio.flynet;

import java.util.HashMap;
import java.util.Iterator;

import jsc.util.Rank;


/**
 * GO or ImaGO annotation for genes
 */
public class GeneExpression implements FunctionalDataset {

	/** The name of this dataset */
	private String name_ = null;
	/** The expression profile for each gene */
	private HashMap<String, double[]> profiles_ = null;
	/** The number of measurements for each gene (length of the profiles) */
	private int numMeasurements_ = -1;

	/** The sum and the sum of squares for each gene (used to efficiently compute correlations) */
	private HashMap<String, double[]> sums_ = null;
	
	/** Set true *before* calling load() to rank transform the loaded data */
	private boolean rankTransform_ = false;
	
	
	// ============================================================================
	// PUBLIC METHODS
	
	/** Constructor, set the flag to rank transform the loaded data */
	public GeneExpression(boolean rankTransform) {
		rankTransform_ = rankTransform;
	}
	
	
	// ----------------------------------------------------------------------------

	/** 
	 * Load the expression profiles from the given file and initialize sums_, sumOfSquares_.
	 * The data will be rank transformed if the flag is set
	 */
	public void load(String filename) {

		loadData(filename);
		
		if (rankTransform_)
			rankTransformData();
		initializeSums();
		
		Flynet.println(profiles_.size() + "\tgene profiles");
		Flynet.println(numMeasurements_ + "\tvalues each\n");
	}

	
	// ----------------------------------------------------------------------------

	/** Returns true if the given gene is part of the expression dataset */
	public boolean contains(Node gene) {
		return profiles_.containsKey(gene.getLabel());
	}


	// ----------------------------------------------------------------------------

	/** 
	 * Returns either the correlation or the squared correlation (coefficient of determination R^2),
	 * depending on the flag useSquaredCorrelation_. 
	 */
	public double functionalRelation(Node gene1, Node gene2) {

		double[] profile1 = profiles_.get(gene1.getLabel());
		double[] profile2 = profiles_.get(gene2.getLabel());
		double[] sum1 = sums_.get(gene1.getLabel());
		double[] sum2 = sums_.get(gene2.getLabel());
		
		if (profile1 == null || profile2 == null)
			throw new RuntimeException("One of the given genes is not part of this dataset");

		double sumOfProducts = 0;
		for (int i=0; i<numMeasurements_; i++)
			sumOfProducts += profile1[i] * profile2[i];
		
		double num =  sumOfProducts - (sum1[0]*sum2[0])/numMeasurements_;
		double den1 = sum1[1] - (sum1[0]*sum1[0])/numMeasurements_;
		double den2 = sum2[1] - (sum2[0]*sum2[0])/numMeasurements_;
		
		if (den1 == 0 || den2 == 0)
			return 0;
		//else if (!useSquaredCorrelation_)
			//return num / (Math.sqrt(den1)*Math.sqrt(den2));
		else
			return (num*num) / (den1*den2);
	}
	
	
	// ============================================================================
	// PRIVATE FUNCTIONS
	
	/** Load the expression profiles from the given file (first line is assumed to be the header and skipped) */
	private void loadData(String filename) {

		name_ = filename.substring(0, filename.indexOf("."));
		profiles_ = new HashMap<String, double[]>();
		
		String path = Settings.networkDir_ + "/" + filename;
		Flynet.println("Reading file: " + path);
		
		FileParser parser = new FileParser(path);
		String[] nextLine = parser.readLine();
		
		// Skip the header
		nextLine = parser.readLine();
		numMeasurements_ = nextLine.length - 1; // the first col is the gene name 
		
		while (nextLine != null) {
			if (numMeasurements_ != nextLine.length-1)
				throw new IllegalArgumentException("The expression profiles must all have the same length");
			
			String gene = nextLine[0];
			// Every gene should be listed only once in the file
			if (profiles_.get(gene) != null)
				throw new IllegalArgumentException("A gene must have a single profile");
			
			// The expression profile of the gene
			double[] expr = new double[numMeasurements_];

			for (int i=0; i<numMeasurements_; i++)
				expr[i] = Double.parseDouble(nextLine[i+1]);
			
			profiles_.put(gene, expr);
			nextLine = parser.readLine();
		}	
		parser.close();		
	}


	// ----------------------------------------------------------------------------

	/** Initialize the sum and sum of squares for each gene (sums_) */
	private void initializeSums() {

		sums_ = new HashMap<String, double[]>();
		
		// For all genes
		Iterator<String> iter = profiles_.keySet().iterator();
		while (iter.hasNext()) {
			
			String gene = iter.next();
			double[] expr = profiles_.get(gene);
			
			double[] sum = new double[2];
			sum[0] = 0; // The sum
			sum[1] = 0; // The sum of squares
			
			for (int i=0; i<numMeasurements_; i++) {
				sum[0] += expr[i];
				sum[1] += expr[i]*expr[i];
			}
			sums_.put(gene, sum);
		}
	}
    
	
	// ----------------------------------------------------------------------------

	/** Rank transform the profile of each gene */
	private void rankTransformData() {

		Iterator<double[]> iter = profiles_.values().iterator();
		while (iter.hasNext()) {
			double[] profile = iter.next();
			
			Rank rank = new Rank(profile, 0);
			double[] rankedValues = rank.getRanks();
			assert rankedValues.length == numMeasurements_;
			
			for (int i=0; i<numMeasurements_; i++)
				profile[i] = rankedValues[i];
		}
	}

	// ============================================================================
	// SETTERS AND GETTERS
	
	public void setName(String name) { name_ = name; }
	public String getName() { return name_; }
		
}
