/*
Copyright (c) 2011-2012 Daniel Marbach(1,2)

(1) Massachusetts Institute of Technology, Cambridge MA, USA
(2) Broad Institute of MIT and Harvard
 
We release this software open source under an MIT license (see below). If this
software was useful for your scientific work, please cite our paper available at:
http://compbio.mit.edu/flynet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
 */
package edu.mit.compbio.flynet;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;


/**
 * Compute the recovery rate vs. number of edges considered (Fig. 2B of the paper).
 * Note, self-interactions are not considered.
 */
public class RecoveryRate {

	/** The gold standard network */
	private Network goldStandard_ = null;
	/** The predicted networks */
	private ArrayList<Network> networks_ = null;
	
	
	// ============================================================================
	// PUBLIC METHODS
	    
	/** 
	 * Constructor loading the networks passed as command-line arguments in Settings.
	 * The networks are being filtered for TFs and targets that are in the gold standard.
	 */
	public RecoveryRate() {

		goldStandard_ = new Network();
		goldStandard_.loadEdgesAboveThreshold(Settings.goldStandardNetwork_, 1, false);
		
		if (Settings.tgFile_ != null)
			goldStandard_.loadNodes(Settings.tgFile_, false);
		
		// The other networks will be filtered for the TFs and TGs of the gold standard
		String tffile = Flynet.insertBeforeFileExtension(Settings.goldStandardNetwork_, "_tfs");
		String tgfile = Flynet.insertBeforeFileExtension(Settings.goldStandardNetwork_, "_tgs");
		goldStandard_.saveRegulators(Settings.networkDir_ + "/" + tffile);
		goldStandard_.saveTargets(Settings.networkDir_ + "/" + tgfile);
		
		networks_ = new ArrayList<Network>();
		String[] fileList = Settings.argFileList_;
		
		for (int i=0; i<fileList.length; i++) {
			Network net = new Network();
			net.loadSubnet(fileList[i], tffile, tgfile, false);
			//Collections.shuffle(net.getEdges());
			Collections.sort(net.getEdges());
			networks_.add(net);
		}
	}

	
	// ----------------------------------------------------------------------------

	/** Constructor with pre-loaded networks */
	public RecoveryRate(Network goldStandard, ArrayList<Network> networks) {
		goldStandard_ = goldStandard;
		networks_ = networks;
	}

	// ----------------------------------------------------------------------------

	/** Compute and export the recovery rate curves */
	public void run() {
		
		String filename = Settings.networkDir_ + "/" + Flynet.insertBeforeFileExtension(goldStandard_.getName(), "_recovery");
		Flynet.println("Writing recovery rates to file: " + filename);
		
		FileExport writer = new FileExport(filename);
		writeFileHeader(writer);
		
		// The total number of possible edges in the gold standard "universe"
		// (considering only possible edges between actual TFs and targets of the gold standard)
		int numPossibleEdges = computeNumPossibleEdges();
		
		// The number of recovered gold standard edges at the current point in the list
		int[] numRecovered = new int[networks_.size()];
		for (int i=0; i<networks_.size(); i++)
			numRecovered[i] = 0;
		
		// For each position in the ranked list of edges
		for (int k=0; k<numPossibleEdges; k++) {
			
			// The next line to be written to the file
			String nextLine = "";
			// Percentage considered
			double fractionConsidered = (k+1) / (double)numPossibleEdges;
			nextLine += fractionConsidered + "\t";
			
			// Update the number of recovered edges
			for (int i=0; i<networks_.size(); i++) {
				
				if (networks_.get(i).getNumEdges() > k) {
					Edge edge = networks_.get(i).getEdge(k);
					if (goldStandard_.contains(edge))
						numRecovered[i]++;
				}
				nextLine += numRecovered[i] + "\t";
			}
			double expectedByChance = goldStandard_.getNumEdges() * fractionConsidered;
			nextLine += expectedByChance;
			
			// write the line to the file
			writer.println(nextLine);			
		}
		
		// Close the output stream
		writer.close();
		
		//fractionPhysicallySupported();
	}

		
	// ============================================================================
	// PRIVATE METHODS

	/** Write the header of the file */
	private void writeFileHeader(FileExport writer) {
		
		String header = "#";
		header += "Fraction_considered";
		
		for (int i=0; i<networks_.size(); i++)
			header += "\t" + networks_.get(i).getName();
		header += "\tRandom";
		
		writer.println(header);
	}
	
	
    // ----------------------------------------------------------------------------

	/**
	 * Compute the total number of possible edges in the gold standard "universe" without self-loops
	 */
	private int computeNumPossibleEdges() {
		
		HashSet<Node> regulators = goldStandard_.getRegulators();
		
		// Count the number of regulators that are also targets
		int numRegulatorsThatAreTargets = 0;
		Iterator<Node> iter = regulators.iterator();
		while (iter.hasNext()) {
			Node tf = iter.next();
			if (tf.getIncomingEdges().size() > 0)
				numRegulatorsThatAreTargets++;
		}
		
		// Number of possible edges minus the self-edges
		return regulators.size() * goldStandard_.getTargets().size() - numRegulatorsThatAreTargets;
	}
	
	
	// ----------------------------------------------------------------------------

	/** 
	 * Fraction of edges supported by chip, motif, or both as we go down in the list of edges.
	 * Important, we assume the following order of the networks: unsup, sup, chip, motif, ... (further networks are ignored).
	 */
	@SuppressWarnings("unused")
	private void fractionPhysicallySupported() {
		
		Network unsup = networks_.get(0);
		Network sup = networks_.get(1);
		Network chip = networks_.get(2);
		Network motif = networks_.get(3);

		Flynet.println("Computing fraction of edges supported by chip, motif, or both ...");
		Flynet.println("IMPORTANT: Networks must be given in the right order!");
		Flynet.println("Type\tPassed as argument (should match the type of network)");
		Flynet.println("Unsup\t" + unsup.getName());
		Flynet.println("Sup\t" + sup.getName());
		Flynet.println("Chip\t" + chip.getName());
		Flynet.println("Motif\t" + motif.getName());

		int numEdges = unsup.getNumEdges();
		assert sup.getNumEdges() == numEdges;
		
		// Number of edges supported by chip, chip&motif, motif, and the rest
		int[] unsup_count = new int[4];
		int[] sup_count = new int[4];
		for (int i=0; i<4; i++) {
			unsup_count[i] = 0;
			sup_count[i] = 0;
		}
		
		String filename = "fractionPhysicallySupported.tsv";
		Flynet.println("Writing fraction of physically supported edges to file: " + filename);
		FileExport writer = new FileExport(filename);
		
		String header = "#Fraction_considered\tunsup_chip\tunsup_chipAndMotif\tunsup_motif\tunsup_rest\t" +
				"sup_chip\tsup_chipAndMotif\tsup_motif\tsup_rest";
		writer.println(header);
		
		int numPossibleEdges = computeNumPossibleEdges();
		
		// As we go down the list, write the fraction of edges for each category
		for (int k=0; k<numEdges; k++) {

			Edge edge = unsup.getEdge(k);
			if (goldStandard_.contains(edge)) {
				double chipWeight = chip.getEdge(edge).getWeight();
				double motifWeight = motif.getEdge(edge).getWeight();
				//int category = getCategory(chip.contains(edge), motif.contains(edge));
				int category = getCategory(chipWeight>1-1e-6, motifWeight>0.1+1e-6);
				unsup_count[category]++;
			}
			
			edge = sup.getEdge(k);
			if (goldStandard_.contains(edge)) {
				double chipWeight = chip.getEdge(edge).getWeight();
				double motifWeight = motif.getEdge(edge).getWeight();
				//int category = getCategory(chip.contains(edge), motif.contains(edge));
				int category = getCategory(chipWeight>1-1e-6, motifWeight>0.1+1e-6);
				sup_count[category]++;
			}
			
			// The next line to be written to the file
			String nextLine = "";
			
			// Percentage considered
			double fractionConsidered = (k+1) / (double)numPossibleEdges;
			nextLine += fractionConsidered;

			for (int c=0; c<4; c++)
				nextLine += "\t" + unsup_count[c];
			for (int c=0; c<4; c++)
				nextLine += "\t" + sup_count[c];
			writer.println(nextLine);
		}
		
		writer.close();
	}

	
	// ----------------------------------------------------------------------------

	/** 
	 * Returns:
	 * 0: chip
	 * 1: chip&motif
	 * 2: motif
	 * 3: rest
	 */
	private int getCategory(boolean isChip, boolean isMotif) {
				
		if (isChip && isMotif)
			return 1;
		else if (isChip)
			return 0;
		else if (isMotif)
			return 2;
		else
			return 3;
	}

	
}
