/*
Copyright (c) 2011-2012 Daniel Marbach(1,2)

(1) Massachusetts Institute of Technology, Cambridge MA, USA
(2) Broad Institute of MIT and Harvard
 
We release this software open source under an MIT license (see below). If this
software was useful for your scientific work, please cite our paper available at:
http://compbio.mit.edu/flynet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
 */
package edu.mit.compbio.flynet;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;


/**
 *  
 */
public class Network {

	/** The name of the network */
	private String name_ = null;
	/** The nodes are stored in a hash map with their labels as key */
	private HashMap<String, Node> nodes_ = null;
	/** The edges */
	private ArrayList<Edge> edges_ = null;

	/** The set of regulators (nodes with at least one target). Regulators may also be targets. */
	private HashSet<Node> regulators_ = null;
	/** The set of targets (nodes with least one regulator). Targets may also be regulators. */
	private HashSet<Node> targets_ = null;
	
	/** 
	 * Threshold for considering edges present (load() only adds edges greater or equal this threshold).
	 * Set to a large negative value to load edges irrespective of their value.
	 */
	private double minEdgeWeight_ = -1e12;
	
	/** Number of redundant (multi) edges that were removed in a loaded network */
	private int numMultiEdgesRemoved_ = 0;
	/** Number of self loops that were removed in a loaded network */
	private int numSelfEdgesRemoved_ = 0;

	
	// ============================================================================
	// PUBLIC METHODS
	
	/** Constructor */
	public Network() {
		initialize();
	}
	
	/** Constructor with network name */
	public Network(String name) {
		initialize();
		name_ = name;
	}

	
	// ----------------------------------------------------------------------------

	/** Initialize an empty network */
	public void initialize() {
		
		nodes_ = new HashMap<String, Node>();
		edges_ = new ArrayList<Edge>();
		regulators_ = new HashSet<Node>();
		targets_ = new HashSet<Node>();
		
		numMultiEdgesRemoved_ = 0;
		numSelfEdgesRemoved_ = 0;
	}

	
	// ----------------------------------------------------------------------------

	/** Calls load(), but adds only edges with weight greater or equal the given threshold */
	public void loadEdgesAboveThreshold(String filename, double minWeight, boolean header) {
		
		minEdgeWeight_ = minWeight;
		load(filename, header);
	}

		
	// ----------------------------------------------------------------------------

	/** 
	 * Load the network from the given file (relative to the network directory specified in Settings).
	 * The format is: <TF> \t <target> \t <weight>
	 * If the flag loadZeroWeightEdges is set, edges with weight 0 are not added, but the nodes of 
	 * these edges are added to the network.
	 * If the file has a header line, set header=true.
	 */
	public void load(String filename, boolean header) {
		
		initialize();
		name_ = filename;
		
		String path = Settings.networkDir_ + "/" + filename;
		Flynet.println("Reading file: " + path);
		
		FileParser parser = new FileParser(path);
		String[] nextLine = parser.readLine();
		
		// Skip the header line of the file
		if (header)
			nextLine = parser.readLine();
		
		while (nextLine != null) {
			//if (nextLine.length != 3)
				//throw new RuntimeException("Parse error at line " + parser.getLineCounter() + ": expected three columns");
			
			if (nextLine.length == 3)
				// Add the edge (also creates the nodes if they are not yet there)
				addEdge(nextLine[0], nextLine[1], Double.parseDouble(nextLine[2]));
			else if (nextLine.length == 2)
				addEdge(nextLine[0], nextLine[1], 1);
			else
				throw new RuntimeException("Parse error at line " + parser.getLineCounter() + ": expected two or three columns");
				
			nextLine = parser.readLine();
		}
		parser.close();
		
		// Sort the edges according to their weight
		Collections.sort(edges_);
		
		Flynet.println(regulators_.size() + "\tTFs");
		Flynet.println(targets_.size() + "\ttarget genes (may also be TFs)");
		if (minEdgeWeight_ == -1e12)
			Flynet.println(edges_.size() + "\tunique edges (no minimal weight specified)");
		else
			Flynet.println(edges_.size() + "\tunique edges (minimal weight considered: " + minEdgeWeight_ + ")");
		
		if (numMultiEdgesRemoved_ > 0)
			Flynet.println(numMultiEdgesRemoved_ + "\tmulti edges have been removed");
		if (numSelfEdgesRemoved_ > 0)
			Flynet.println(numSelfEdgesRemoved_ + "\tself loops have been removed");
		
		Flynet.println("");
	}
	
	
	// ----------------------------------------------------------------------------

	/** 
	 * Same as load(), but filter for the given list of regulators and targets, i.e., only
	 * load edges between the specified regulators and targets.
	 */
	public void loadSubnet(String filename, String regulatorFile, String targetFile, boolean header) {
		
		initialize();
		name_ = filename;
		
		// Load the nodes
		loadNodes(regulatorFile, targetFile);
		
		String path = Settings.networkDir_ + "/" + filename;
		Flynet.println("Reading file: " + path);
		
		FileParser parser = new FileParser(path);
		String[] nextLine = parser.readLine();
		
		// Skip the header line of the file
		if (header)
			nextLine = parser.readLine();
		
		while (nextLine != null) {
			if (nextLine.length == 3)
				// Add the edge (also creates the nodes if they are not yet there)
				addEdgeBetweenExistingNodes(nextLine[0], nextLine[1], Double.parseDouble(nextLine[2]));
			else if (nextLine.length == 2)
				addEdgeBetweenExistingNodes(nextLine[0], nextLine[1], 1);
			else
				throw new RuntimeException("Parse error at line " + parser.getLineCounter() + ": expected two or three columns");
				
			nextLine = parser.readLine();
		}
		parser.close();
		
		// Sort the edges according to their weight
		Collections.sort(edges_);
		
		Flynet.println(regulators_.size() + "\tTFs");
		Flynet.println(targets_.size() + "\ttarget genes (may also be TFs)");
		if (minEdgeWeight_ == -1e12)
			Flynet.println(edges_.size() + "\tunique edges (no minimal weight specified)");
		else
			Flynet.println(edges_.size() + "\tunique edges (minimal weight considered: " + minEdgeWeight_ + ")");
		
		if (numMultiEdgesRemoved_ > 0)
			Flynet.println(numMultiEdgesRemoved_ + "\tmulti edges have been removed");
		if (numSelfEdgesRemoved_ > 0)
			Flynet.println(numSelfEdgesRemoved_ + "\tself loops have been removed");
		
		Flynet.println("");
	}

	
	// ----------------------------------------------------------------------------

	/** Load a list of TFs and targets, add them to nodes_, and to regulators_ and targets_, respectively */
	public void loadNodes(String regulatorFile, String targetFile) {
		
		loadNodes(regulatorFile, true);
		loadNodes(targetFile, false);
		Flynet.println(regulators_.size() + "\tTFs");
		Flynet.println(targets_.size() + "\ttarget genes (may also be TFs)\n");		
	}

	
	// ----------------------------------------------------------------------------

	/** Display number of nodes and edges on the console */
	public void printInfo() {
		
		Flynet.println(regulators_.size() + "\tTFs");
		Flynet.println(targets_.size() + "\ttarget genes (may also be TFs)");
		Flynet.println(edges_.size() + "\tedges");
		Flynet.println("");
	}


	// ----------------------------------------------------------------------------

	/** 
	 * Load a list of nodes, add them to nodes_, and either to regulators_ or targets_
	 * (depending on the flag isRegulator).
	 */
	public void loadNodes(String filename, boolean isRegulator) {
		
		String path = Settings.networkDir_ + "/" + filename;
		Flynet.println("Reading file: " + path);
		
		FileParser parser = new FileParser(path);
		String[] nextLine = parser.readLine();
		
		while (nextLine != null) {
			if (nextLine.length != 1)
				throw new RuntimeException("Parse error at line " + parser.getLineCounter() + ": expected one column");
			
			Node node = nodes_.get(nextLine[0]);
			if (node == null) {
				node = new Node(nextLine[0]);
				nodes_.put(nextLine[0], node);
			}
			if (isRegulator)
				regulators_.add(node);
			else
				targets_.add(node);
			
			nextLine = parser.readLine();
		}
		parser.close();
	}

	
	// ----------------------------------------------------------------------------

	/** 
	 * Filter the network: remove all regulators and targets that are *not* in the
	 * provided lists of regulators and targets (corresponding edges are also removed)
	 */
	public void filter(String regulatorFile, String targetFile) {
		
		// Load the nodes
		Network filter = new Network();
		filter.loadNodes(regulatorFile, targetFile);
		HashMap<String, Node> filterNodes = filter.getNodes();
		HashSet<Node> filterTargets = filter.getTargets();
		HashSet<Node> filterRegulators = filter.getRegulators();
		
		// For each node, check that it is part of the loaded lists
		// Make a copy as an array because as we will be removing nodes I'm not sure if it
		// would mess up the iterators
		Node[] nodeArray = nodes_.values().toArray(new Node[0]);
		for (int i=0; i<nodeArray.length; i++) {
			
			String label = nodeArray[i].getLabel();
			Node filterNode = filterNodes.get(label);
			
			// If the node is not part of the loaded lists
			if (filterNode == null)
				removeNode(label);
			// If the node not part of the loaded targets
			else if (!filterTargets.contains(filterNode))
				removeTarget(label);
			// If the node is not part of the loaded regulators
			else if (!filterRegulators.contains(filterNode))
				removeRegulator(label);
		}
	}

	
	// ----------------------------------------------------------------------------

	/** Remove the given node from the network */
	public void removeNode(String label) {
		
		removeRegulator(label);
		removeTarget(label);
	}

		
	// ----------------------------------------------------------------------------

	/** Remove the given regulator from the network (keep it as target, if it is one) */
	public void removeRegulator(String label) {
		
		Node node = nodes_.get(label);
		if (node == null)
			return;
		
		if (!targets_.contains(node))
			nodes_.remove(label);
		regulators_.remove(node);
		
		// Remove the edges of this regulator
		/*
		ArrayList<Edge> flagedToDelete = new ArrayList<Edge>();
		for (int i=0; i<edges_.size(); i++) {
			Edge edge = edges_.get(i);
			if (edge.getTf() == node)
				flagedToDelete.add(edge);
		}*/
		edges_.removeAll(node.getOutgoingEdges());
	}

	
	// ----------------------------------------------------------------------------

	/** Remove the given target from the network (keep it as regulator, if it is one) */
	public void removeTarget(String label) {
		
		Node node = nodes_.get(label);
		if (node == null)
			return;
		
		if (!regulators_.contains(node))
			nodes_.remove(label);
		targets_.remove(node);
		
		// Remove the edges of this target
		/*
		ArrayList<Edge> flagedToDelete = new ArrayList<Edge>();
		for (int i=0; i<edges_.size(); i++) {
			Edge edge = edges_.get(i);
			if (edge.getTarget() == node)
				flagedToDelete.add(edge);
		}*/
		edges_.removeAll(node.getIncomingEdges());
	}

	
	// ----------------------------------------------------------------------------

	/** 
	 * Write the network to the given file.
	 * The format is: <TF> \t <target> \t <weight>
	 */
	public void save(String filename) {
		
		FileExport writer = new FileExport(filename);
		DecimalFormat fiveDec = new DecimalFormat("0.00000");
		fiveDec.setGroupingUsed(false);
		
		for (int i=0; i<edges_.size(); i++) {
			Edge edge = edges_.get(i);
			String str = edge.getTf().getLabel() + "\t" + edge.getTarget().getLabel() + "\t" + fiveDec.format(edge.getWeight());
			
			// append the features, if any
			double[] features = edge.getFeatures();
			if (features != null)
				for (int f=0; f<features.length; f++)
					str += "\t" + fiveDec.format(features[f]);
			
			writer.println(str);
		}
		writer.close();
	}

		
	// ----------------------------------------------------------------------------

	/** Write the regulators to the given file */
	public void saveRegulators(String filename) {
		
		FileExport writer = new FileExport(filename);
		
		Iterator<Node> iter = regulators_.iterator();
		while (iter.hasNext())
			writer.println(iter.next().getLabel());

		writer.close();
	}


	// ----------------------------------------------------------------------------

	/** Write the targets to the given file */
	public void saveTargets(String filename) {
		
		FileExport writer = new FileExport(filename);
		
		Iterator<Node> iter = targets_.iterator();
		while (iter.hasNext())
			writer.println(iter.next().getLabel());

		writer.close();
	}

	
	// ----------------------------------------------------------------------------

	/** Returns true if this network contains an edge from the given TF to the given target */
	public boolean contains(Edge edge) {
		
		Edge thisEdge = getEdge(edge);
		if (thisEdge == null)
			return false;
		else
			return true;
		
	}
	
	
	// ----------------------------------------------------------------------------

	/** Returns the corresponding edge of this network (same regulator/target labels), null if not found */
	public Edge getEdge(Edge edge) {
		
		return getEdge(edge.getTf().getLabel(), edge.getTarget().getLabel());
	}

	
	// ----------------------------------------------------------------------------

	/** Returns the corresponding edge of this network, null if not found */
	public Edge getEdge(String regulatorId, String targetId) {
		
		Node target = nodes_.get(targetId);
		if (target == null)
			return null;
		
		ArrayList<Edge> incomingEdges = target.getIncomingEdges();
		
		for (int i=0; i<incomingEdges.size(); i++)
			if (incomingEdges.get(i).getTf().getLabel().equals(regulatorId))
				return incomingEdges.get(i);
		
		return null;
	}


	// ----------------------------------------------------------------------------

    /** Add a copy of the given edge (creates a new edge/nodes, see addEdge()) */
	public void addCopyOfEdge(Edge edge) {
	
		addEdge(edge.getTf().getLabel(), edge.getTarget().getLabel(), edge.getWeight());
	}

	
	// ----------------------------------------------------------------------------

    /** Add an edge (and the corresponding nodes, if they are not yet present in the network) */
	public void addEdge(String tfLabel, String targetLabel, double weight) {

		if (tfLabel.equals(targetLabel)) {
			numSelfEdgesRemoved_++;
			return;
		}
		
		// Get (or create) the nodes
		Node tf = getOrCreateNode(tfLabel);
		Node target = getOrCreateNode(targetLabel);
		
		addEdge(tf, target, weight);
	}
	

	// ----------------------------------------------------------------------------

    /** Add an edge, but only if this TF and target are already in the list of regulators and targets */
	public void addEdgeBetweenExistingNodes(String tfLabel, String targetLabel, double weight) {

		if (tfLabel.equals(targetLabel)) {
			numSelfEdgesRemoved_++;
			return;
		}
		
		// Get the nodes
		Node tf = nodes_.get(tfLabel);
		if (tf == null || !regulators_.contains(tf))
			return;
		
		Node target = nodes_.get(targetLabel);
		if (target == null || !targets_.contains(target))
			return;

		addEdge(tf, target, weight);
	}
	

	// ----------------------------------------------------------------------------

    /** Add an edge (and the corresponding nodes, if they are not yet present in the network) */
	public void addEdge(Node tf, Node target, double weight) {

		// Don't add edges with weight smaller than the threshold
		if (weight < minEdgeWeight_)
			return;

		// Create the edge
		Edge edge = new Edge(tf, target, weight);
		
		// If the edge is already present, don't add it and set the weight to the max of the two
		Edge existingEdge = getEdge(edge);
		if (existingEdge != null) {
			
			if (existingEdge.getWeight() < weight)
				existingEdge.setWeight(weight);
			
			//System.out.println(tf.getLabel() + "\t" + target.getLabel());
			numMultiEdgesRemoved_++;
			return;
		}
		
		edges_.add(edge);
		
		// Add it as incoming/outgoing edge of the target/TF
		target.addIncomingEdge(edge);
		tf.addOutgoingEdge(edge);
		
		// Add the TF to the set of regulators, and the target to the set of targets
		regulators_.add(tf);
		targets_.add(target);

	}
	

	// ----------------------------------------------------------------------------

    /** Randomize the network by shuffling the labels of the TFs, and the labels of the genes that are not TFs */
	public void randomize() {
		
		// Make the list of targets that are not regulators
		HashSet<Node> notRegulators = new HashSet<Node>();
		Iterator<Node> iter = targets_.iterator();
		
		while (iter.hasNext()) {
			Node gene = iter.next();
			if (gene.getOutgoingEdges().size() == 0)
				notRegulators.add(gene);
		}
		assert notRegulators.size() + regulators_.size() == nodes_.size();
		
		randomizeNodeLabels(regulators_);
		randomizeNodeLabels(notRegulators);
	}

		
	// ============================================================================
	// PRIVATE METHODS

	/** Get the node if it exists, otherwise create a new one */
	private Node getOrCreateNode(String label) {
		
		Node node = nodes_.get(label);
		// If the node doesn't exist, create a new one
		if (node == null) {
			node = new Node(label);
			nodes_.put(label, node);
		}
	
		return node;
	}
				
	
	// ----------------------------------------------------------------------------

    /** Randomize the labels of the given list of nodes */
	public void randomizeNodeLabels(HashSet<Node> nodes) {
		
		ArrayList<String> randomizedLabels = new ArrayList<String>();
		
		Iterator<Node> iter = nodes.iterator();
		while (iter.hasNext())
			randomizedLabels.add(iter.next().getLabel());
		Collections.shuffle(randomizedLabels);

		int i=0;
		iter = nodes.iterator();
		while (iter.hasNext())
			iter.next().setLabel(randomizedLabels.get(i++));
	}

		
	// ============================================================================
	// SETTERS AND GETTERS

	public void setName(String name) { name_ = name; }
	public String getName() { return name_; }
	
	public HashMap<String, Node> getNodes() { return nodes_; }
	public ArrayList<Edge> getEdges() { return edges_; }
	public HashSet<Node> getRegulators() { return regulators_; }
	public HashSet<Node> getTargets() { return targets_; }
	
	public int getNumEdges() { return edges_.size(); }
	public Edge getEdge(int i) { return edges_.get(i); }
   
}
