/*
 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include "lambda.h"

static void unexpected(const char *s, char e)
{
	if (e)
		error("Syntax error: unexpected '%c' (expected '%c') at: %s", *s, e, s);
	else
		error("Syntax error: unexpected '%c' at: %s", *s, s);
}

// Skip whitespace, expect a string, then skip it
static void expect(const char **sptr, const char *str)
{
	const char *s = *sptr;
	const char *e = str;
	const char *s_orig;
	
	skipSpace(&s);
	s_orig = s;
	
	for (; *e != 0; s++, e++) {
		if (*s != *e)
			error("Syntax error: expected %s at: %s", str, s_orig);
	}
	
	*sptr = s;
}

static Expression *mkVarExpr(char *var)
{
	Expression *expr = mkExpression(E_VAR);
	expr->var = var;
	return expr;
}

static Expression *mkValueExpr(Value *value)
{
	Expression *expr = mkExpression(E_VALUE);
	expr->value = value;
	return expr;
}

static Expression *mkLambda(char *var, Expression *expr)
{
	Expression *ret = mkExpression(E_LAMBDA);
	ret->lambda.var  = var;
	ret->lambda.expr = expr;
	return ret;
}

static Expression *mkAp(Expression *f, Expression *x)
{
	Expression *expr = mkExpression(E_AP);
	expr->ap.f = f;
	expr->ap.x = x;
	return expr;
}

static Expression *mkIfThen(Expression *pred, Expression *on_true, Expression *on_false)
{
	Expression *expr = mkExpression(E_IF_THEN);
	
	expr->if_then.pred     = pred;
	expr->if_then.on_true  = on_true;
	expr->if_then.on_false = on_false;
	
	return expr;
}

#define parseLowest parseInfix

static Expression *parseSection(const char **sptr);
static Expression *parseInfix(const char **sptr);


static Expression *parsePrimary(const char **sptr)
{
	#define     s (*sptr)
	Expression *e;
	char       *var;
	
	skipSpace(&s);
	
	if (*s == ')' || *s == ']' || *s == ',' || *s == '\0'
	    || startsWithKeyword(s, "then")
	    || startsWithKeyword(s, "else"))
		return NULL;
	
	if (*s == '(') {
		s++;
		skipSpace(&s);
		
		if (*s == ')') {
			s++;
			return mkValueExpr(mkUnit());
		}
		
		e = parseSection(&s);
		expect(&s, ")");
		
		assert(e->tag != E_INFIX);
		if (e->tag == E_INFIX_CLOSED)
			e->tag = E_SECTION;
		
		return e;
	}
	
	if ((*s == '\\'   && !is_symbol(*(s+1))) ||
	    (*s == '\x89')) {
		const char *lambda_s = s;
		
		s++;
		skipSpace(&s);
		
		if (!is_name_start(*s))
			error("Syntax error: expected variable name after \\ at: %s", lambda_s);
		
		var = parseSymbol(&s);
		e   = parseLowest(&s);
		return mkLambda(var, e);
	}
	
	if (*s == '[') {
		s++;
		skipSpace(&s);
		
		if (*s == ']') {
			s++;
			return mkValueExpr(nil_v);
		}
		
		List(Expression) *list = NULL;
		
		for (;;) {
			list = listCons(parseLowest(&s), list);
			
			skipSpace(&s);
			if (*s == ']') {
				s++;
				break;
			}
			if (*s != ',')
				unexpected(s, ',');
			s++;
		}
		
		e = mkExpression(E_LIST);
		e->list = list;
		return e;
	}
	
	if (*s == '\"')
		return mkValueExpr(parseString(&s));
	
	if (*s == '\'')
		return mkValueExpr(parseChar(&s));
	
	if (startsWithKeyword(s, "if")) {
		s += 2;
		
		Expression *pred = parseLowest(&s);
		
		skipSpace(&s);
		if (!startsWithKeyword(s, "then"))
			error("Missing 'then' after 'if'");
		s += 4;
	
		Expression *on_true = parseLowest(&s);
		
		skipSpace(&s);
		if (!startsWithKeyword(s, "else"))
			error("Missing 'else' in if-then statement");
		s += 4;
		
		Expression *on_false = parseLowest(&s);
		
		return mkIfThen(pred, on_true, on_false);
	}
	
	if (is_int_start(*s)) {
		Value *v = mkValue(V_INT);
		v->i = parseInt(&s);
		return mkValueExpr(v);
	}
	
	if (is_name_start(*s)) {
		var = parseSymbol(&s);
		return mkVarExpr(var);
	}
	
	error("Syntax error: unexpected %c at: %s\n", *s, s);
	
	#undef s
}

static void pushInfix(Expression **root, Expression *prev, Expression *x)
{
	if (prev == NULL) {
		x->parent = NULL;
		*root = x;
	} else if (x->tag == E_INFIX) {
		Expression *p, *c;
		
		if (prev->tag == E_INFIX)
			error("Syntax error: nothing between operators %s and %s", prev->infix.op, x->infix.op);
		
		for (c = prev, p = prev->parent; p != NULL; c = p, p = p->parent) {
			int cmp;
			
			assert(p->tag == E_INFIX);
			cmp = compareFixity(p->infix.fixity, x->infix.fixity);
			
			if (cmp == 0)
				error("Precedence parsing error: cannot mix %s and %s in the same infix expression",
				      p->infix.op, x->infix.op);
			
			if (cmp < 0)
				break;
		}
		
		if (p == NULL) {
			x->parent = NULL;
			*root = x;
		} else {
			assert(p->infix.b == c);
			x->parent = p;
			p->infix.b = x;
		}
		
		c->parent = x;
		x->infix.a = c;
	} else if (prev->tag == E_INFIX) {
		x->parent = prev;
		prev->infix.b = x;
	} else {
		error("Internal error: pushInfix() passed two consecutive nullary terms");
	}
}

static void checkInfix(Expression *expr)
{
	if (expr->tag == E_INFIX) {
		if (expr->infix.a == NULL)
			error("Missing left operand to operator %s", expr->infix.op);
		if (expr->infix.b == NULL)
			error("Missing right operand to operator %s", expr->infix.op);
		
		checkInfix(expr->infix.a);
		checkInfix(expr->infix.b);
	}
}

static Expression *parseSection(const char **sptr)
{
	Expression *root = NULL,
	           *prev = NULL,
	           *ap = NULL,
	           *x;
	
	while ((x = parsePrimary(sptr)) != NULL) {
		if (x->tag == E_VAR) {
			char         *var    = x->var;
			const Fixity *fixity = getFixity(var);
			
			if (fixity->fixity != PREFIX) {
				x->tag = E_INFIX;
				x->infix.op     = var;
				x->infix.fixity = fixity;
				x->infix.a      = NULL;
				x->infix.b      = NULL;
			}
		}
		
		if (x->tag == E_INFIX) {
			if (ap != NULL) {
				pushInfix(&root, prev, ap);
				prev = ap;
				ap = NULL;
			}
			pushInfix(&root, prev, x);
			prev = x;
		} else {
			if (ap != NULL)
				ap = mkAp(ap, x);
			else
				ap = x;
		}
	}
	
	if (ap != NULL)
		pushInfix(&root, prev, ap);
	
	if (root == NULL)
		error("Missing expression");
	
	if (root->tag == E_INFIX) {
		root->tag = E_INFIX_CLOSED;
		
		if (root->infix.a != NULL)
			checkInfix(root->infix.a);
		if (root->infix.b != NULL)
			checkInfix(root->infix.b);
	}
	
	return root;
}

static Expression *parseInfix(const char **sptr)
{
	Expression *expr = parseSection(sptr);
	
	if (expr->tag == E_INFIX_CLOSED && (expr->infix.a == NULL || expr->infix.b == NULL))
		error("Syntax error: A section must be enclosed in parenthesis");
	
	return expr;
}

static void cleanupExpression(Expression *expr)
{
	switch (expr->tag) {
		case E_VAR:
		case E_VALUE:
			break;
		
		case E_LIST:
			{
				List(Expression) *i;
				for (i = expr->list; i != NULL; i = i->next)
					cleanupExpression(i->item);
			}
			break;
		
		case E_LAMBDA:
			cleanupExpression(expr->lambda.expr);
			break;
		
		case E_AP:
			cleanupExpression(expr->ap.f);
			cleanupExpression(expr->ap.x);
			break;
		
		case E_INFIX:
		case E_INFIX_CLOSED:
		case E_SECTION:
			expr->tag = E_INFIX;
			if (expr->infix.a != NULL)
				cleanupExpression(expr->infix.a);
			if (expr->infix.b != NULL)
				cleanupExpression(expr->infix.b);
			break;
		
		case E_IF_THEN:
			cleanupExpression(expr->if_then.pred);
			cleanupExpression(expr->if_then.on_true);
			cleanupExpression(expr->if_then.on_false);
			break;
		
		default:
			error("Corrupt data structure passed to cleanupExpression");
	}
}

Expression *parseExpression(const char *s)
{
	Expression *ret = parseLowest(&s);
	
	skipSpace(&s);
	if (*s != '\0')
		unexpected(s, 0);
	
	cleanupExpression(ret);
	return ret;
}
