/*
 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include "avl.h"
#include "gc.h"

#include <assert.h>

#ifdef remove
#undef remove
#endif

/*
 * Arguments used by a lot of the internal functions.
 *
 * It's an odd bunch, but these were factored out,
 * as they tended to stay the same or be passed right along.
 * Notice how 'AvlNode node' is markedly missing.
 */
typedef struct AvlContext {
	AvlCompare  compare;
	
	const void *key;
	const void *value;
	
	union {
		AvlNode replaced;
		AvlNode removed;
	};
	int height_change;
} AvlContext;

static AVL     mkRoot(AvlCompare compare, AvlNode root, size_t count);
static AvlNode mkNode(const void *key, const void *value, const AvlNode lr[2], int balance);
static AvlNode mkSingleton(const void *key, const void *value);

static AvlNode lookup(AvlNode node, AvlContext *ctx);

static AvlNode insert(AvlNode node, AvlContext *ctx);
static AvlNode remove(AvlNode node, AvlContext *ctx);
static AvlNode replaceWithNeighbor(AvlNode node, AvlContext *ctx);
static AvlNode removeExtremum(AvlNode node, int side, AvlContext *ctx);

static AvlNode setChild(AvlNode node, int side, AvlNode replacement, int *height_change);
static AvlNode balance(const void *key, const void *value, const AvlNode lr[2], int bal);

static bool    checkBalances(AvlNode node, int *height);
static bool    checkOrder(AVL avl);
static size_t  countNode(AvlNode node);


/*** Data constructors ***/

static AVL mkRoot(AvlCompare compare, AvlNode root, size_t count)
{
	// struct AVL *avl = gc_alloc_labeled(sizeof(*avl), "(AVL)");
	struct AVL *avl = gc_alloc(sizeof(*avl));
	
	avl->compare = compare;
	avl->root = root;
	avl->count = count;
	
	return avl;
}

static AvlNode mkNode(const void *key, const void *value, const AvlNode lr[2], int balance)
{
	// struct AvlNode *node = gc_alloc_labeled(sizeof(*node), "(AvlNode)");
	struct AvlNode *node = gc_alloc(sizeof(*node));
	
	node->key = key;
	node->value = value;
	node->lr[0] = lr[0];
	node->lr[1] = lr[1];
	node->balance = balance;
	
	return node;
}



static AvlNode mkSingleton(const void *key, const void *value)
{
	AvlNode lr[2] = {NULL, NULL};
	return mkNode(key, value, lr, 0);
}


/*
 * Utility macros for converting between
 * "balance" values (-1 or 1) and "side" values (0 or 1).
 *
 * bal(0)   == -1
 * bal(1)   == +1
 * side(-1) == 0
 * side(+1) == 1
 */
#define bal(side) ((side) == 0 ? -1 : 1)
#define side(bal) ((bal)  == 1 ?  1 : 0)

static int sign(int cmp)
{
	if (cmp < 0)
		return -1;
	if (cmp == 0)
		return 0;
	return 1;
}


/*** Public functions ***/

AVL avl_new(AvlCompare compare)
{
	return mkRoot(compare, NULL, 0);
}

void *avl_lookup(AVL avl, const void *key)
{
	#ifndef NDEBUG
	if (avl == NULL)
		error("avl_lookup called with NULL tree");
	#endif
	
	AvlNode found = avl_lookup_node(avl, key);
	
	return found ? (void*) found->value : NULL;
}

AvlNode avl_lookup_node(AVL avl, const void *key)
{
	#ifndef NDEBUG
	if (avl == NULL)
		error("avl_lookup_node called with NULL tree");
	#endif
	
	AvlContext ctx;
	
	ctx.compare = avl->compare;
	ctx.key     = key;
	
	return lookup(avl->root, &ctx);
}

AVL avl_insert(AVL avl, const void *key, const void *value)
{
	AvlContext ctx;
	AvlNode    newRoot;
	
	#ifndef NDEBUG
	if (avl == NULL)
		error("avl_insert called with NULL tree");
	#endif
	
	ctx.compare = avl->compare;
	ctx.key     = key;
	ctx.value   = value;
	
	newRoot = insert(avl->root, &ctx);
	
	return mkRoot(avl->compare, newRoot, avl->count + !ctx.replaced);
}

AVL avl_remove(AVL avl, const void *key)
{
	AvlContext ctx;
	AvlNode    newRoot;
	
	#ifndef NDEBUG
	if (avl == NULL)
		error("avl_remove called with NULL tree");
	#endif
	
	ctx.compare = avl->compare;
	ctx.key     = key;
	
	newRoot = remove(avl->root, &ctx);
	
	return mkRoot(avl->compare, newRoot, avl->count - !!ctx.removed);
}


/*** Implementation ***/

static AvlNode lookup(AvlNode node, AvlContext *ctx)
{
	int cmp;
	
	if (node == NULL)
		return NULL;
	
	cmp = ctx->compare(ctx->key, node->key);
	
	if (cmp < 0)
		return lookup(node->lr[0], ctx);
	if (cmp > 0)
		return lookup(node->lr[1], ctx);
	return node;
}

/* Insert a key/value into a subtree, rebalancing if necessary. */
static AvlNode insert(AvlNode node, AvlContext *ctx)
{
	if (node == NULL) {
		ctx->replaced      = NULL;
		ctx->height_change = 1;
		return mkSingleton(ctx->key, ctx->value);
	} else {
		int cmp  = sign(ctx->compare(ctx->key, node->key));
		int side = side(cmp);
		
		if (cmp == 0) {
			ctx->replaced      = node;
			ctx->height_change = 0;
			return mkNode(ctx->key, ctx->value, node->lr, node->balance);
		} else {
			AvlNode subtree = insert(node->lr[side], ctx);
			return setChild(node, side, subtree, &ctx->height_change);
		}
	}
}

/* Remove the node matching ctx->key, rebalancing if necessary. */
static AvlNode remove(AvlNode node, AvlContext *ctx)
{
	if (node == NULL) {
		ctx->removed       = NULL;
		ctx->height_change = 0;
		return node;
	} else {
		int cmp  = sign(ctx->compare(ctx->key, node->key));
		int side = side(cmp);
		
		if (cmp == 0) {
			if (node->lr[0] != NULL && node->lr[1] != NULL)
				return replaceWithNeighbor(node, ctx);
			
			ctx->removed = node;
			ctx->height_change = -1;
			
			/* Replace with left or right child, if either exists. */
			return node->lr[0] ? node->lr[0] : node->lr[1];
		} else {
			AvlNode subtree = remove(node->lr[side], ctx);
			return setChild(node, side, subtree, &ctx->height_change);
		}
	}
}

/*
 * Replace node with its predecessor or successor.
 * It must have children on both sides.
 */
static AvlNode replaceWithNeighbor(AvlNode node, AvlContext *ctx)
{
	/* Pick a subtree to pull the replacement from such that
	 * this node doesn't have to be rebalanced. */
	int side = node->balance <= 0 ? 0 : 1;
	
	/* Remove the extremum. */
	AvlNode subtree = removeExtremum(node->lr[side], 1 - side, ctx);
	
	/*
	 * Manufacture a throw-away node on the stack which represents
	 * the node we're removing, but with key/value substituted
	 * with the extremum we just pulled out.
	 */
	struct AvlNode replacement;
	replacement.key      = ctx->removed->key;
	replacement.value    = ctx->removed->value;
	replacement.lr[0]    = node->lr[0];
	replacement.lr[1]    = node->lr[1];
	replacement.balance  = node->balance;
	
	ctx->removed = node;
	
	return setChild(&replacement, side, subtree, &ctx->height_change);
}

/*
 * Remove either the left-most (if side == 0) or right-most (if side == 1) node in a subtree.
 *
 * The subtree must not be empty (i.e. node must not be NULL).
 */
static AvlNode removeExtremum(AvlNode node, int side, AvlContext *ctx)
{
	if (node->lr[side] == NULL) {
		ctx->removed       = node;
		ctx->height_change = -1;
		return node->lr[1 - side];
	} else {
		AvlNode subtree = removeExtremum(node->lr[side], side, ctx);
		return setChild(node, side, subtree, &ctx->height_change);
	}
}

/*
 * Replace node->lr[side] with a new subtree, rebalancing if necessary.
 * Think of this function as a higher-level interface to balance().
 *
 * *height_change is both an input and output argument:
 *   - As input,  it is height(replacement) - height(node->lr[side]).  Must be -1, 0, or 1.
 *   - As output, it is height(node') - height(node).
 */
static AvlNode setChild(AvlNode node, int side, AvlNode replacement, int *height_change)
{
	AvlNode lr[2] = {node->lr[0], node->lr[1]};
	int     sway;
	
	lr[side] = replacement;
	
	if (*height_change == 0)
		return mkNode(node->key, node->value, lr, node->balance);
	
	sway = *height_change;
	if (side == 0)
		sway = -sway;
	
	if (node->balance != sway)
		node = mkNode(node->key, node->value, lr, node->balance + sway);
	else
		node = balance(node->key, node->value, lr, sway);
	
	if ((*height_change == 1) == (node->balance == 0))
		*height_change = 0;
	
	return node;
}

/*
 * Perform tree rotations on an unbalanced node.
 *
 * bal == -1 means the node's balance is -2 .
 * bal == +1 means the node's balance is +2 .
 */
static AvlNode balance(const void *key, const void *value, const AvlNode lr[2], int bal)
{
	int side      = side(bal);
	int opposite  = 1 - side;
	
	AvlNode child = lr[side];
	
	AvlNode    b,       c;
	AvlNode lr_b[2], lr_c[2];
	
	if (child->balance != -bal) {
		/* Left-left (side == 0) or right-right (side == 1) */
		
		lr_c[side]     = child->lr[opposite];
		lr_c[opposite] = lr[opposite];
		c = mkNode(key, value, lr_c, bal - child->balance);
		
		lr_b[side]     = child->lr[side];
		lr_b[opposite] = c;
		b = mkNode(child->key, child->value, lr_b, child->balance - bal);
		
		return b;
		
	} else {
		/* Left-right (side == 0) or right-left (side == 1) */
		
		AvlNode grandchild = child->lr[opposite];
		
		AvlNode    a;
		AvlNode lr_a[2];
		
		lr_a[side]     = child->lr[side];
		lr_a[opposite] = grandchild->lr[side];
		a = mkNode(child->key, child->value, lr_a, grandchild->balance == -bal ? bal : 0);
		
		lr_c[side]     = grandchild->lr[opposite];
		lr_c[opposite] = lr[opposite];
		c = mkNode(key, value, lr_c, grandchild->balance == bal ? -bal : 0);
		
		lr_b[side]     = a;
		lr_b[opposite] = c;
		b = mkNode(grandchild->key, grandchild->value, lr_b, 0);
		
		return b;
	}
}


/***************************** Debugging *****************************/

bool avl_check_invariants(AVL avl)
{
	int    dummy;
	
	return checkBalances(avl->root, &dummy)
	    && checkOrder(avl)
	    && countNode(avl->root) == avl->count;
}

void avl_print_diff(AVL a, AVL b)
{
	AvlIter ai, bi;
	
	avl_iter_begin(&ai, a, FORWARD);
	avl_iter_begin(&bi, b, FORWARD);
	
	while (ai.node != NULL && bi.node != NULL) {
		int cmp = a->compare(ai.key, bi.key);
		
		if (cmp < 0) {
			printf("< %s\n", (char*)ai.key);
			avl_iter_next(&ai);
		} else if (cmp > 0) {
			printf("> %s\n", (char*)bi.key);
			avl_iter_next(&bi);
		} else {
			avl_iter_next(&ai);
			avl_iter_next(&bi);
		}
	}
	
	if (ai.node == NULL) {
		for (; bi.node != NULL; avl_iter_next(&bi))
			printf("> %s\n", (char*)bi.key);
	} else if (bi.node == NULL) {
		for (; ai.node != NULL; avl_iter_next(&ai))
			printf("< %s\n", (char*)ai.key);
	}
}

static bool checkBalances(AvlNode node, int *height)
{
	if (node) {
		int h0, h1;
		
		if (!checkBalances(node->lr[0], &h0))
			return false;
		if (!checkBalances(node->lr[1], &h1))
			return false;
		
		if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
			return false;
		
		*height = (h0 > h1 ? h0 : h1) + 1;
		return true;
	} else {
		*height = 0;
		return true;
	}
}

static bool checkOrder(AVL avl)
{
	AvlIter     i;
	const void *last  = NULL;
	size_t      count = 0;
	
	avl_foreach(i, avl) {
		if (count++ > 0 && avl->compare(last, i.key) >= 0)
			return false;
		last = i.key;
	}
	
	if (count != avl->count)
		return false;
	
	return true;
}

static size_t countNode(AvlNode node)
{
	if (node)
		return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
	else
		return 0;
}


/************************* Traversal *************************/

void avl_iter_begin(AvlIter *iter, AVL avl, AvlDirection dir)
{
	AvlNode node = avl->root;
	
	iter->stack_index = 0;
	iter->direction   = dir;
	
	if (node == NULL) {
		iter->key      = NULL;
		iter->value    = NULL;
		iter->node     = NULL;
		return;
	}
	
	while (node->lr[dir] != NULL) {
		iter->stack[iter->stack_index++] = node;
		node = node->lr[dir];
	}
	
	iter->key   = (void*) node->key;
	iter->value = (void*) node->value;
	iter->node  = node;
}

void avl_iter_next(AvlIter *iter)
{
	AvlNode      node = iter->node;
	AvlDirection dir  = iter->direction;
	
	if (node == NULL)
		return;
	
	node = node->lr[1 - dir];
	if (node != NULL) {
		while (node->lr[dir] != NULL) {
			iter->stack[iter->stack_index++] = node;
			node = node->lr[dir];
		}
	} else if (iter->stack_index > 0) {
		node = iter->stack[--iter->stack_index];
	} else {
		iter->key      = NULL;
		iter->value    = NULL;
		iter->node     = NULL;
		return;
	}
	
	iter->node  = node;
	iter->key   = (void*) node->key;
	iter->value = (void*) node->value;
}



bool avl_insert_ptr(AVL *avl, const void *key, const void *value)
{
	size_t prevCount = avl_count(*avl);
	*avl = avl_insert(*avl, key, value);
	return (prevCount != avl_count(*avl));
}

bool avl_remove_ptr(AVL *avl, const void *key)
{
	size_t prevCount = avl_count(*avl);
	*avl = avl_remove(*avl, key);
	return (prevCount != avl_count(*avl));
}