/*
Copyright 2007 Daniel Zerbino (zerbino@ebi.ac.uk)

    This file is part of Velvet.

    Velvet is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    Velvet is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Velvet; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

*/

#include <stdlib.h>
#include <stdio.h>

#include "globals.h"
#include "graph.h"
#include "concatenatedGraph.h"
#include "recycleBin.h"
#include "locallyCorrectedGraph.h"
#include "passageMarker.h"

#define BLOCK_SIZE  1000
#define UNRELIABLE_CONNECTION_CUTOFF 5

static Graph *graph;
static RecycleBin *connectionMemory = NULL;
static Coordinate LONG_NODE_CUTOFF = 250;
static boolean *isPaired;

typedef struct connection_st Connection;

struct connection_st {
	Node *destination;
	IDnum count;
	Connection *next;
};

static boolean isAnchor(Node * node)
{
	return getNodeLength(node) > LONG_NODE_CUTOFF;
}

static Connection *allocateConnection()
{
	if (connectionMemory == NULL)
		connectionMemory =
		    newRecycleBin(sizeof(Connection), BLOCK_SIZE);

	return allocatePointer(connectionMemory);
}

static void deallocateConnection(Connection * connect)
{
	deallocatePointer(connectionMemory, connect);
}

static IDnum *computeReadToNodeCounts(IDnum ** nodeReads,
				      IDnum * nodeReadCounts)
{
	IDnum index, nodeIndex;
	IDnum maxNodeIndex = 2 * nodeCount(graph) + 1;
	IDnum maxReadIndex = sequenceCount(graph) + 1;
	IDnum *readNodeCounts = calloc(maxReadIndex, sizeof(IDnum));
	IDnum *nodeArray;

	for (nodeIndex = 0; nodeIndex < maxNodeIndex; nodeIndex++) {
		nodeArray = nodeReads[nodeIndex];

		for (index = 0; index < nodeReadCounts[nodeIndex]; index++) {
			readNodeCounts[nodeArray[index]]++;
		}
	}

	return readNodeCounts;
}

static IDnum **computeReadToNodeMappings(IDnum ** nodeReads,
					 IDnum * nodeReadCounts,
					 IDnum * readNodeCounts)
{
	IDnum index, nodeIndex, readIndex;
	IDnum nodeID;
	IDnum maxNodeIndex = 2 * nodeCount(graph) + 1;
	IDnum maxReadIndex = sequenceCount(graph) + 1;

	IDnum **readNodes = calloc(maxReadIndex, sizeof(IDnum *));

	IDnum *nodeArray;
	IDnum *readArray;

	for (readIndex = 1; readIndex < maxReadIndex; readIndex++) {
		if (readNodeCounts[readIndex] != 0) {
			readNodes[readIndex] =
			    malloc(readNodeCounts[readIndex] *
				   sizeof(IDnum));
			readNodeCounts[readIndex] = 0;
		}
	}

	for (nodeIndex = 0; nodeIndex < maxNodeIndex; nodeIndex++) {
		nodeArray = nodeReads[nodeIndex];
		nodeID = nodeIndex - nodeCount(graph);

		for (index = 0; index < nodeReadCounts[nodeIndex]; index++) {
			readIndex = nodeArray[index];
			readArray = readNodes[readIndex];
			readArray[readNodeCounts[readIndex]] = nodeID;
			readNodeCounts[readIndex]++;
		}
	}

	return readNodes;
}

static void createConnection(IDnum nodeID, IDnum node2ID,
			     Connection ** connections, IDnum count)
{
	Node *node2 = getNodeInGraph(graph, node2ID);
	IDnum nodeIndex = nodeID + nodeCount(graph);
	Connection *connect = connections[nodeIndex];

	while (connect != NULL) {
		if (connect->destination == node2) {
			connect->count += count;
			return;
		}
		connect = connect->next;
	}

	connect = allocateConnection();
	connect->destination = node2;
	connect->count = count;
	connect->next = connections[nodeIndex];

	connections[nodeIndex] = connect;
}

static IDnum destroyConnection(IDnum nodeID, IDnum node2ID,
			       Connection ** connections)
{
	Node *node2 = getNodeInGraph(graph, node2ID);
	IDnum nodeIndex = nodeID + nodeCount(graph);
	Connection *connect = connections[nodeIndex];
	IDnum count;
	Connection *tmp;

	if (connect == NULL)
		return 0;

	if (connect->destination == node2) {
		count = connect->count;
		connections[nodeIndex] = connect->next;
		deallocateConnection(connect);
		return count;
	}

	for (; connect->next != NULL; connect = connect->next) {
		tmp = connect->next;
		if (tmp->destination == node2) {
			count = tmp->count;
			connect->next = tmp->next;
			deallocateConnection(tmp);
			return count;
		}
	}

	return 0;
}

static Connection **computeNodeToNodeMappings(IDnum ** readNodes,
					      IDnum * readNodeCounts,
					      IDnum ** nodeReads,
					      IDnum * nodeReadCounts,
					      IDnum * readPairs)
{
	IDnum maxNodeIndex = 2 * nodeCount(graph) + 1;
	IDnum index, nodeIndex, readIndex;
	IDnum index2, node2ID, nodeID;
	Connection **nodeNodes =
	    calloc(maxNodeIndex, sizeof(Connection *));
	IDnum *nodeArray, *readArray;
	IDnum readPairIndex;
	Node *node;
	isPaired = calloc(maxNodeIndex, sizeof(boolean));

	// Standard connections
	for (nodeIndex = 0; nodeIndex < maxNodeIndex; nodeIndex++) {
		nodeID = nodeIndex - nodeCount(graph);
		node = getNodeInGraph(graph, nodeID);

		if (node == NULL || !isAnchor(node))
			continue;

		nodeArray = nodeReads[nodeIndex];
		for (index = 0; index < nodeReadCounts[nodeIndex]; index++) {
			readIndex = nodeArray[index];
			readPairIndex = readPairs[readIndex - 1] + 1;

			if (readPairIndex == 0)
				continue;


			readArray = readNodes[readPairIndex];

			for (index2 = 0;
			     index2 < readNodeCounts[readPairIndex];
			     index2++) {
				node2ID = readArray[index2];
				createConnection(nodeID, node2ID,
						 nodeNodes, 1);
				if (isAnchor
				    (getNodeInGraph(graph, node2ID))
				    && nodeID != -node2ID)
					isPaired[nodeIndex] = true;
			}
		}
	}

	return nodeNodes;
}

static Node *markInterestingNodes(Node * node, Connection ** scaffold,
				  IDnum * markedNodes)
{
	IDnum nodeIndex = getNodeID(node) + nodeCount(graph);
	Connection *connect;
	Node *destination;
	Node *oppositeNode = NULL;
	boolean conflict = false;
	IDnum nodeCounter = 0;

	setSingleNodeStatus(node, true);
	markedNodes[nodeCounter] = getNodeID(node);
	nodeCounter++;

	while (scaffold[nodeIndex] != NULL) {
		connect = scaffold[nodeIndex];
		scaffold[nodeIndex] = connect->next;

		destination = connect->destination;
		setSingleNodeStatus(getTwinNode(destination), true);
		markedNodes[nodeCounter] = -getNodeID(destination);
		nodeCounter++;
		if (connect->count > 1 && isAnchor(destination)
		    && destination != getTwinNode(node)) {
			if (oppositeNode == NULL)
				oppositeNode = destination;
			else {
				conflict = true;
			}
		}

		deallocateConnection(connect);
		destroyConnection(getNodeID(destination), getNodeID(node),
				  scaffold);
	}

	if (conflict) {
		markedNodes[nodeCounter] = 0;
		return NULL;
	}

	if (oppositeNode == NULL) {
		markedNodes[nodeCounter] = 0;
		return NULL;
	}

	nodeIndex = getNodeID(oppositeNode) + nodeCount(graph);
	for (connect = scaffold[nodeIndex]; connect != NULL;
	     connect = connect->next) {
		destination = connect->destination;
		if (isAnchor(destination)
		    && destination != getTwinNode(oppositeNode)
		    && destination != node)
			conflict = true;
	}

	if (conflict) {
		markedNodes[nodeCounter] = 0;
		return NULL;
	}

	while (scaffold[nodeIndex] != NULL) {
		connect = scaffold[nodeIndex];
		scaffold[nodeIndex] = connect->next;

		destination = connect->destination;
		setSingleNodeStatus(destination, true);
		markedNodes[nodeCounter] = getNodeID(destination);
		nodeCounter++;

		deallocateConnection(connect);
		destroyConnection(getNodeID(destination),
				  getNodeID(oppositeNode), scaffold);
	}

	markedNodes[nodeCounter] = 0;
	return oppositeNode;
}

static void absorbExtension(Node * node, Node * extension)
{
	Arc *arc;


	appendDescriptors(node, extension);

	// Destroy old nodes    
	while (getArc(node) != NULL)
		destroyArc(getArc(node), graph);

	// Create new
	for (arc = getArc(extension); arc != NULL; arc = getNextArc(arc))
		createAnalogousArc(node, getDestination(arc), arc, graph);
}

static void absorbExtensionInScaffold(Node * node, Node * source,
				      Connection ** scaffold)
{
	IDnum nodeID = getNodeID(node);
	IDnum nodeIndex = nodeID + nodeCount(graph);
	IDnum sourceID = getNodeID(source);
	IDnum sourceIndex = sourceID + nodeCount(graph);
	Node *twinSource = getTwinNode(source);
	IDnum twinSourceIndex = getNodeID(twinSource) + nodeCount(graph);
	Connection *tmp, *connect;
	Node *destination;
	IDnum destinationID;
	IDnum count;
	Category cat;

	// Inward pointers
	for (connect = scaffold[sourceIndex]; connect != NULL;
	     connect = connect->next) {
		destination = connect->destination;

		if (!isAnchor(destination)
		    && destination != getTwinNode(node))
			continue;

		destinationID = getNodeID(destination);
		count =
		    destroyConnection(destinationID, sourceID, scaffold);
		createConnection(destinationID, -nodeID, scaffold, count);
	}

	for (connect = scaffold[twinSourceIndex]; connect != NULL;
	     connect = connect->next) {
		destination = connect->destination;

		if (!isAnchor(destination) && destination != node)
			continue;

		destinationID = getNodeID(destination);
		count =
		    destroyConnection(destinationID, -sourceID, scaffold);
		createConnection(destinationID, nodeID, scaffold, count);
	}

	// Outward pointers
	while (scaffold[nodeIndex] != NULL) {
		tmp = scaffold[nodeIndex];
		scaffold[nodeIndex] = tmp->next;
		deallocateConnection(tmp);
	}

	while (scaffold[sourceIndex] != NULL) {
		tmp = scaffold[sourceIndex];
		scaffold[sourceIndex] = tmp->next;
		deallocateConnection(tmp);
	}

	scaffold[nodeIndex] = scaffold[twinSourceIndex];
	scaffold[twinSourceIndex] = NULL;

	// Read coverage
	for (cat = 0; cat < CATEGORIES; cat++)
		incrementVirtualCoverage(node, cat,
					 getVirtualCoverage(source, cat));

	// Read starts
	shareReadStarts(node, getTwinNode(source), graph);

	destroyNode(source, graph);

}

static boolean pushNeighbours(Node * node, Node * oppositeNode,
			      Connection ** scaffold, boolean force_jumps)
{
	Arc *arc;
	Node *candidate = NULL;
	Node *dest;
	IDnum counter = 0;
	boolean doubt;

	Node *lastCandidate;

	setSingleNodeStatus(node, 2);

	lastCandidate = node;

	while (true) {
		candidate = NULL;
		doubt = false;

		// First round for priority nodes
		for (arc = getArc(node); arc != NULL;
		     arc = getNextArc(arc)) {
			dest = getDestination(arc);

			if (getNodeStatus(dest) != 1)
				continue;


			if (candidate != NULL) {
				doubt = true;
				candidate = NULL;
				break;
			}

			candidate = dest;
		}

		// Second round for second level nodes
		if (candidate == NULL && doubt == false) {
			for (arc = getArc(node); arc != NULL;
			     arc = getNextArc(arc)) {
				dest = getDestination(arc);

				if (getNodeStatus(dest) != 3)
					continue;

				if (candidate != NULL) {
					doubt = true;
					candidate = NULL;
					break;
				}

				candidate = dest;
			}
		}
		// In case of failure   
		if (candidate == NULL) {
			break;
		}
		// In case of previously visited node
		if (getNodeStatus(candidate) == 2) {
			break;
		}
		// If valid decision, step forwarde
		absorbExtension(node, candidate);

		// If actual destination, celebrate
		if (oppositeNode != NULL
		    && candidate == getTwinNode(oppositeNode)) {
			absorbExtensionInScaffold(node, oppositeNode,
						  scaffold);
			return true;
		}
		// Prepare next step
		setSingleNodeStatus(candidate, 2);
		lastCandidate = candidate;
		counter++;
	}

	if (oppositeNode == NULL)
		return false;

	if (force_jumps) {
		addBufferToDescriptor(node, LONG_NODE_CUTOFF);
		absorbExtension(node, getTwinNode(oppositeNode));
		absorbExtensionInScaffold(node, oppositeNode, scaffold);
		return true;
	}

	while (true) {
		candidate = NULL;
		for (arc = getArc(oppositeNode); arc != NULL;
		     arc = getNextArc(arc)) {
			dest = getDestination(arc);
			if (getNodeStatus(getTwinNode(dest)) != 1
			    && getNodeStatus(getTwinNode(dest)) != 2)
				continue;

			if (candidate != NULL) {
				candidate = NULL;
				break;
			}

			candidate = getTwinNode(dest);
		}

		if (candidate == NULL) {
			break;
		}

		if (getNodeStatus(candidate) == 2) {
			return false;
		}

		absorbExtension(oppositeNode, getTwinNode(candidate));
		setSingleNodeStatus(candidate, 3);
	}

	return false;
}

static void unmarkInterestingNodes(IDnum * markedNodes)
{
	IDnum index = 0;
	IDnum nodeIndex = markedNodes[index];
	Node *node;

	while (nodeIndex != 0) {
		node = getNodeInGraph(graph, nodeIndex);
		if (node != NULL)
			setNodeStatus(node, false);
		index++;
		nodeIndex = markedNodes[index];
	}
}

static void expandLongNode(Node * node, Connection ** scaffold,
			   IDnum * markedNodes, boolean force_jumps)
{
	Node *oppositeLongNode;
	boolean hit = true;

	while (hit) {
		hit = false;
		oppositeLongNode =
		    markInterestingNodes(node, scaffold, markedNodes);

		// DEBUG
		if (oppositeLongNode == node) {
			puts("Hairpin pointing to oneself!");
			break;
		}
		// END OF DEBUG

		// DEBUG
		if (oppositeLongNode != NULL) {
			// END OF DEBUG
			correctGraphLocally(node, oppositeLongNode,
					    markedNodes);
			hit =
			    pushNeighbours(node, oppositeLongNode,
					   scaffold, force_jumps);
			// DEBUG
		}
		// END OF DEBUG
		unmarkInterestingNodes(markedNodes);
	}
}

static void expandLongNodes(Connection ** scaffold, IDnum * markedNodes,
			    boolean force_jumps)
{
	IDnum nodeID;
	Node *node;

	for (nodeID = 1; nodeID <= nodeCount(graph); nodeID++) {
		node = getNodeInGraph(graph, nodeID);

		if (node != NULL && isAnchor(node)) {
			expandLongNode(node, scaffold, markedNodes,
				       force_jumps);
			expandLongNode(getTwinNode(node), scaffold,
				       markedNodes, force_jumps);
		}
	}
}

static void cleanMemory(IDnum ** readNodes, IDnum * readNodeCounts,
			Connection ** scaffold, IDnum * markedNodes)
{
	IDnum index;
	Connection *connect;

	for (index = 1; index < sequenceCount(graph); index++)
		free(readNodes[index]);

	free(readNodes);
	free(readNodeCounts);

	for (index = 0; index <= nodeCount(graph); index++)
		while (scaffold[index] != NULL) {
			connect = scaffold[index];
			scaffold[index] = connect->next;
			deallocateConnection(connect);
		}

	free(scaffold);

	free(markedNodes);
}

static void removeUnreliableConnections(Connection ** scaffold)
{
	IDnum maxNodeIndex = nodeCount(graph) * 2 + 1;
	IDnum index;
	Connection *connect, *tmp;

	for (index = 0; index < maxNodeIndex; index++) {
		if (scaffold[index] == NULL)
			continue;

		connect = scaffold[index];

		while (connect != NULL && connect->next != NULL) {
			if (connect->next->count >
			    UNRELIABLE_CONNECTION_CUTOFF
			    || !isAnchor(connect->next->destination))
				connect = connect->next;
			else {
				tmp = connect->next;
				connect->next = tmp->next;
				deallocateConnection(tmp);
			}
		}

		connect = scaffold[index];
		if (connect->count <= UNRELIABLE_CONNECTION_CUTOFF
		    && isAnchor(connect->destination)) {
			scaffold[index] = connect->next;
			deallocateConnection(connect);
		}
	}
}

void exploitShortReadPairs(Graph * argGraph, IDnum * readPairs,
			   Coordinate argCutoffLength, boolean force_jumps)
{
	graph = argGraph;
	LONG_NODE_CUTOFF = argCutoffLength;
	resetNodeStatus(graph);

	IDnum **nodeReads = getNodeToReadMappings(graph);

	if (nodeReads == NULL) {
		puts("No short reads to work with!");
		return;
	}

	IDnum *nodeReadCounts = getNodeReadCounts(graph);
	IDnum *readNodeCounts =
	    computeReadToNodeCounts(nodeReads, nodeReadCounts);
	IDnum **readNodes =
	    computeReadToNodeMappings(nodeReads, nodeReadCounts,
				      readNodeCounts);
	Connection **scaffold =
	    computeNodeToNodeMappings(readNodes, readNodeCounts, nodeReads,
				      nodeReadCounts, readPairs);
	IDnum *markedNodes =
	    malloc((2 * nodeCount(graph) + 1) * sizeof(IDnum));

	removeUnreliableConnections(scaffold);

	prepareGraphForLocalCorrections(graph, 0.2);
	expandLongNodes(scaffold, markedNodes, force_jumps);
	deactivateLocalCorrectionSettings();

	//cleanUpLostNodes(graph, scaffold);

	cleanMemory(readNodes, readNodeCounts, scaffold, markedNodes);

	concatenateGraph(graph);
}
