/*
Copyright 2007 Daniel Zerbino (zerbino@ebi.ac.uk)

    This file is part of Velvet.

    Velvet is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    Velvet is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Velvet; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

*/

#include <stdlib.h>
#include <stdio.h>

#include "globals.h"
#include "graph.h"
#include "tightString.h"
#include "dfibHeap.h"
#include "recycleBin.h"
#include "passageMarker.h"

#define MAXREADLENGTH 100
#define MAXNODELENGTH 200

static const Time INDEL = 0;
static const Time SIM[4][4] = {
	{1, 0, 0, 0},
	{0, 1, 0, 0},
	{0, 0, 1, 0},
	{0, 0, 0, 1}
};

//Global variables used throughout this procedure(internal use only !)
static Time *times;
static Node **previous;

static DFibHeapNode **dheapNodes;
static DFibHeap *dheap;

static TightString *fastSequence;
static TightString *slowSequence;

static int WORDLENGTH;
static int SELF_LOOP_CUTOFF;
static double cutoff;
static Graph *graph;
static Node *start;
static Node *target;
static IDnum *markedNodes;

static PassageMarker *fastPath;
static PassageMarker *slowPath;

static double Fmatrix[MAXREADLENGTH + 1][MAXREADLENGTH + 1];
//End of global variables;

static void setNodeTime(Node * node, Time time)
{
	times[getNodeID(node) + nodeCount(graph)] = time;
}

static Time getNodeTime(Node * node)
{
	if (getNodeID(node) > nodeCount(graph)
	    || getNodeID(node) < -nodeCount(graph))
		printf("Bizarre node %li %p\n", getNodeID(node), node);

	return times[getNodeID(node) + nodeCount(graph)];
}

static Node *getNodePrevious(Node * node)
{
	return previous[getNodeID(node) + nodeCount(graph)];
}

static boolean isPreviousToNode(Node * previous, Node * target)
{
	Node *currentNode = target;
	Node *previousNode = NULL;
	Time targetTime = getNodeTime(target);

	//printf("Testing if %li is previous to %li\n", getNodeID(previous), getNodeID(target));

	while (true) {
		//printf("CCC %li %f\n", getNodeID(currentNode), getNodeTime(currentNode));

		if (currentNode == previous)
			return true;

		if (currentNode == previousNode)
			return false;

		if (getNodeID(currentNode) > nodeCount(graph)
		    || getNodeID(currentNode) < -nodeCount(graph)) {
			printf("Node ID??? %li %li\n",
			       getNodeID(currentNode),
			       getNodeID(previousNode));
		}

		if (getNodeTime(currentNode) != targetTime)
			return false;

		previousNode = currentNode;
		currentNode = getNodePrevious(currentNode);
	}
}

static boolean
extractSequence(PassageMarker * path, TightString * sequence)
{
	PassageMarker *marker;
	Coordinate seqLength = 0;
	Coordinate writeIndex = 0;

	//printf("Extracting sequence %li ... ", pathLength);

	//Measure length
	for (marker = getNextInSequence(path); !isTerminal(marker);
	     marker = getNextInSequence(marker))
		seqLength += getNodeLength(getNode(marker));

	if (seqLength > MAXREADLENGTH)
		return false;
	else
		setTightStringLength(sequence, seqLength);

	//Copy sequences
	for (marker = getNextInSequence(path); !isTerminal(marker);
	     marker = getNextInSequence(marker)) {
		appendNodeSequence(getNode(marker), sequence, writeIndex);
		writeIndex += getNodeLength(getNode(marker));
	}

	return true;
}

static Time max(Time A, Time B, Time C)
{
	if (A >= B && A >= C)
		return A;
	else if (B >= C)
		return B;
	else
		return C;
}

static boolean
compareSequences(TightString * sequence1, TightString * sequence2)
{
	Coordinate i, j;
	Coordinate length1 = getLength(sequence1);
	Coordinate length2 = getLength(sequence2);
	Coordinate avgLength;
	Time Choice1, Choice2, Choice3;
	Time maxScore;

	if (length1 < WORDLENGTH || length2 < WORDLENGTH)
		return (length1 - length2 < 10 && length1 - length2 > -10);

	avgLength = (length1 + length2) / 2;

	for (i = 0; i <= length1; i++)
		Fmatrix[i][0] = 0;
	for (j = 0; j <= length2; j++)
		Fmatrix[0][j] = 0;

	for (i = 1; i <= length1; i++) {
		for (j = 1; j <= length2; j++) {
			Choice1 =
			    Fmatrix[i - 1][j - 1] +
			    SIM[(int) getNucleotide(i - 1, sequence1)]
			    [(int) getNucleotide(j - 1, sequence2)];
			Choice2 = Fmatrix[i - 1][j] + INDEL;
			Choice3 = Fmatrix[i][j - 1] + INDEL;
			Fmatrix[i][j] = max(Choice1, Choice2, Choice3);
		}
	}

	maxScore = Fmatrix[length1][length2];

	if ((1 - maxScore / avgLength) > cutoff)
		return false;

	return true;
}

static void destroyPaths()
{
	PassageMarker *marker;

	while (slowPath != NULL) {
		marker = slowPath;
		getNodeTime(getNode(marker));
		getNodeTime(getTwinNode(getNode(marker)));

		slowPath = getNextInSequence(marker);
		destroyPassageMarker(marker);
	}

	while (fastPath != NULL) {
		marker = fastPath;
		getNodeTime(getNode(marker));
		getNodeTime(getTwinNode(getNode(marker)));
		fastPath = getNextInSequence(marker);
		destroyPassageMarker(marker);
	}
}

static void cleanUpRedundancy()
{
	PassageMarker *current;

	for (current = getNextInSequence(slowPath); !isTerminal(current);
	     current = getNextInSequence(current))
		setSingleNodeStatus(getNode(current), false);

	destroyPaths();
}

static void comparePaths(Node * destination, Node * origin)
{
	IDnum slowLength, fastLength;
	Node *fastNode, *slowNode;
	IDnum i;
	PassageMarker *marker;
	Coordinate slowSeqLength;

	//Measure lengths
	slowLength = fastLength = 0;
	fastNode = destination;
	slowNode = origin;

	//puts("Looking into separate paths");

	while (fastNode != slowNode) {
		//printf("Fast node %li Slow node %li\n", getNodeID(fastNode), getNodeID(slowNode));

		if (getNodeTime(fastNode) > getNodeTime(slowNode)) {
			fastLength++;
			fastNode = getNodePrevious(fastNode);
		} else if (getNodeTime(fastNode) < getNodeTime(slowNode)) {
			slowLength++;
			slowNode = getNodePrevious(slowNode);
		} else if (isPreviousToNode(slowNode, fastNode)) {
			while (fastNode != slowNode) {
				fastLength++;
				fastNode = getNodePrevious(fastNode);
			}
		} else if (isPreviousToNode(fastNode, slowNode)) {
			while (slowNode != fastNode) {
				slowLength++;
				slowNode = getNodePrevious(slowNode);
			}
		} else {
			fastLength++;
			fastNode = getNodePrevious(fastNode);
			slowLength++;
			slowNode = getNodePrevious(slowNode);
		}

		if (slowLength > MAXNODELENGTH
		    || fastLength > MAXNODELENGTH) {
			//printf("Paths too fragmented %li %li\n", slowLength, fastLength);
			return;
		}
	}

	if (fastLength == 0) {
		//puts("Self-loop");

		if (target == NULL)
			return;

		slowNode = origin;
		slowSeqLength = 0;
		while (slowNode != destination) {
			slowSeqLength += getNodeLength(slowNode);
			slowNode = getNodePrevious(slowNode);
		}

		printf("Self-loop %li\n", slowSeqLength);

		if (slowSeqLength < SELF_LOOP_CUTOFF) {
			//puts("Negligible... simplifying!");
			slowNode = origin;
			while (slowNode != destination) {
				setSingleNodeStatus(slowNode, 3);
				slowNode = getNodePrevious(slowNode);
			}
		}

		return;
	}
	//Backtracking to record actual paths
	fastPath = addUncertainPassageMarker(1, destination);
	setPassageMarkerStatus(fastPath, true);

	for (i = 0; i < fastLength; i++) {
		marker =
		    addUncertainPassageMarker(1,
					      getNodePrevious(getNode
							      (fastPath)));
		setPassageMarkerStatus(marker, true);
		connectPassageMarkers(marker, fastPath, graph);
		fastPath = marker;
	}

	slowPath = addUncertainPassageMarker(2, destination);
	setPassageMarkerStatus(slowPath, true);

	marker = addUncertainPassageMarker(2, origin);
	setPassageMarkerStatus(marker, true);
	connectPassageMarkers(marker, slowPath, graph);
	slowPath = marker;

	for (i = 0; i < slowLength; i++) {
		marker =
		    addUncertainPassageMarker(2,
					      getNodePrevious(getNode
							      (slowPath)));
		setPassageMarkerStatus(marker, true);
		connectPassageMarkers(marker, slowPath, graph);
		slowPath = marker;
	}

	//Extract sequences
	if (!extractSequence(fastPath, fastSequence)
	    || !extractSequence(slowPath, slowSequence)) {
		//puts("Paths too long");
		destroyPaths();
		return;
	}
	//Compare sequences
	if (compareSequences(fastSequence, slowSequence)) {
		//puts("Correcting discrepancy");
		cleanUpRedundancy();
		return;
	}
	//puts("\tFinished comparing paths, changes made");
	destroyPaths();
}

static void tourBusArc(Node * origin, Arc * arc, Time originTime)
{
	Node *destination = getDestination(arc);
	Time arcTime, totalTime, destinationTime;
	IDnum nodeIndex = getNodeID(destination) + nodeCount(graph);
	Node *oldPrevious = previous[nodeIndex];

	//printf("Trying arc from %li -> %li\n", getNodeID(origin), getNodeID(destination)); 

	if (oldPrevious == origin)
		return;

	arcTime =
	    ((Time) getNodeLength(origin)) / ((Time) getMultiplicity(arc));
	totalTime = originTime + arcTime;

	destinationTime = times[nodeIndex];

	if (destinationTime == -1) {
		//puts("New destination");
		setNodeTime(destination, totalTime);
		dheapNodes[nodeIndex] =
		    insertNodeIntoDHeap(dheap, totalTime, destination);
		previous[nodeIndex] = origin;
		return;
	} else if (destinationTime > totalTime) {
		//printf("Previously visited from slower node %li\n", getNodeID(getNodePrevious(destination))); 
		if (dheapNodes[nodeIndex] == NULL) {
			return;
		}

		setNodeTime(destination, totalTime);
		replaceKeyInDHeap(dheap, dheapNodes[nodeIndex], totalTime);
		previous[nodeIndex] = origin;

		comparePaths(destination, oldPrevious);
		return;
	} else {
		//printf("Previously visited by faster node %li\n", getNodeID(getNodePrevious(destination))); 

		comparePaths(destination, origin);
	}
}

static void tourBusNode(Node * node)
{
	Arc *arc;
	Node *destination;
	Time nodeTime = getNodeTime(node);

	//printf("Node %li %f %i %p\n", getNodeID(node),
	//       times[getNodeID(node) + nodeCount(graph)], simpleArcCount(node),
	//       node);

	for (arc = getArc(node); arc != NULL; arc = getNextArc(arc)) {
		destination = getDestination(arc);

		// Node doesn't belong to the marked node area 
		if (!getNodeStatus(getDestination(arc)))
			continue;

		tourBusArc(node, arc, nodeTime);

		if (!getNodeStatus(node))
			break;
	}
}

static boolean isLocalDeadEnd(Node * node)
{
	Arc *arc;

	for (arc = getArc(node); arc != NULL; arc = getNextArc(arc))
		if (getNodeStatus(getDestination(arc)))
			return false;

	return true;
}

static boolean isLocalTwinDeadEnd(Node * node)
{
	Arc *arc;

	for (arc = getArc(getTwinNode(node)); arc != NULL;
	     arc = getNextArc(arc))
		if (getNodeStatus(getTwinNode(getDestination(arc))))
			return false;

	return true;
}

static void clipTipsVeryHardLocally()
{
	IDnum index, nodeID;
	Node *current, *twin;
	boolean modified = true;

	//puts("Clipping short tips off graph HARD");

	while (modified) {
		modified = false;
		index = 0;

		for (nodeID = markedNodes[index++]; nodeID != 0;
		     nodeID = markedNodes[index++]) {
			current = getNodeInGraph(graph, nodeID);
			//printf("Checking node HARD %li %i\n", nodeID, simpleArcCount(current));

			if (current == NULL || !getNodeStatus(current))
				continue;

			if (current == start)
				continue;

			if (target != NULL
			    && current == getTwinNode(target))
				continue;

			twin = getTwinNode(current);

			if (isLocalDeadEnd(current)) {
				setSingleNodeStatus(current, false);
				modified = true;
			} else if (isLocalTwinDeadEnd(current)) {
				setSingleNodeStatus(current, false);
				modified = true;
			}
		}
	}
}

static void tourBus(Node * startingPoint)
{
	Node *currentNode = startingPoint;
	IDnum nodeID = getNodeID(startingPoint) + nodeCount(graph);

	//printf("Tour bus from node %li...\n", getNodeID(startingPoint));

	times[nodeID] = 0;
	previous[nodeID] = currentNode;

	while (currentNode != NULL) {
		dheapNodes[getNodeID(currentNode) + nodeCount(graph)] =
		    NULL;
		tourBusNode(currentNode);
		currentNode = removeNextNodeFromDHeap(dheap);
	}
}

void prepareGraphForLocalCorrections(Graph * argGraph, double argCutoff)
{
	IDnum nodes = nodeCount(argGraph);
	IDnum index;

	//Setting global params
	graph = argGraph;
	cutoff = argCutoff;
	WORDLENGTH = getWordLength(graph);;
	SELF_LOOP_CUTOFF = WORDLENGTH;
	// Done with global params

	printf("Preparing to correct graph with cutoff %f\n", cutoff);

	// Allocating memory
	times = malloc((2 * nodes + 1) * sizeof(Time));
	previous = malloc((2 * nodes + 1) * sizeof(Node *));

	dheapNodes = malloc((2 * nodes + 1) * sizeof(DFibHeapNode *));

	dheap = newDFibHeap();

	fastSequence = newTightString(MAXREADLENGTH);
	slowSequence = newTightString(MAXREADLENGTH);

	for (index = 0; index < (2 * nodeCount(graph) + 1); index++) {
		times[index] = -1;
		dheapNodes[index] = NULL;
		previous[index] = NULL;
	}

	//Done with memory 
}

void
correctGraphLocally(Node * argStart, Node * argTarget,
		    IDnum * argMarkedNodes)
{
	IDnum nodeID, index, nodeIndex;

	//printf("Correcting graph from node %li\n", getNodeID(start));
	start = argStart;
	target = argTarget;
	markedNodes = argMarkedNodes;

	clipTipsVeryHardLocally();

	index = 0;
	for (nodeID = markedNodes[index++]; nodeID != 0;
	     nodeID = markedNodes[index++]) {
		nodeIndex = nodeID + nodeCount(graph);
		times[nodeIndex] = -1;
		dheapNodes[nodeIndex] = NULL;
		previous[nodeIndex] = NULL;
	}

	tourBus(start);

	index = 0;
	for (nodeID = markedNodes[index++]; nodeID != 0;
	     nodeID = markedNodes[index++]) {
		nodeIndex = nodeID + nodeCount(graph);
		times[nodeIndex] = -1;
		dheapNodes[nodeIndex] = NULL;
		previous[nodeIndex] = NULL;
	}

	clipTipsVeryHardLocally();
}

void deactivateLocalCorrectionSettings()
{
	free(times);
	free(previous);
	free(dheapNodes);
	destroyDHeap(dheap);

	free(fastSequence);
	free(slowSequence);
}
