/*
 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include "gc.h"

#include "gc/hook.h"
#include "gc/private.h"

GC            gc_state;
unsigned int  gc_verbose;
volatile bool memoryError_reported;

#define gc gc_state

void gc_main(void func(void))
{
	TRY
		gc_init();
		
		TRY
			{
				get_stack_ptr(gc.stack_base);
				func();
			}
		FINALLY
			gc_finish();
		ENDFINAL
	
	ONERR
		if (!(errCode == ER_MEMORY && memoryError_reported)) {
			if (errCode == ER_NO_MSG) {
				GKeyFlush();
				ngetchx();
			}
			PASS;
		}
	ENDTRY
}

void *gc_alloc_type(size_t size, BlockType type)
{
	// Adjust size to be divisible by 4 and have space for the 4-byte header.
	if (size > MAX_SIZE)
		memoryError("Cannot allocate %lu-byte block", size);
	else if (size < 8)
		size = 12;
	else
		size = (size + 7) & ~3;
	
	return gc_alloc_(size, type)->data;
}

Block *gc_alloc_(unsigned short size, BlockType type)
{
	Block *b;
	
	assert(size >= 12 && (size & 3) == 0);
	
	b = get_from_free_list(size);
	if (__builtin_expect(b != NULL, 1))
		return claim(b, size, type);
	
	performGC();
	
	if (gc.bytes_freed <= gc.bytes_total * 1 / 3)
		add_page();
	
	b = get_from_free_list(size);
	if (b != NULL)
		return claim(b, size, type);
	
	add_page();
	
	b = get_from_free_list(size);
	if (b != NULL)
		return claim(b, size, type);
	
	memoryError("Out of memory");
}

void gc_free_(void *ptr, bool run_finalizer)
{
	// Can't do this because the jump table
	// is invalidated by block allocations.
	#if 0
	Block *b = lookup(ptr);
	if (b == NULL)
		memoryError("gc_free%s: invalid pointer 0x%05lX",
		            run_finalizer ? "" : "_no_finalize",
		            (unsigned long)ptr);
	#endif
	
	Block *b = (Block*)((char*)ptr - BLOCK_HEADER_SIZE);
	
	if (!b->used)
		memoryError("gc_free%s: double free or memory corruption (0x%05lX)",
		             run_finalizer ? "" : "_no_finalize",
		             (unsigned long)ptr);
	
	if (run_finalizer)
		finalize(b->data, b->type);
	
	b->bits = 0;
	add_to_free_list(b);
}

void *gc_alloc_labeled(size_t size, const char *label)
{
	assert(label != NULL);
	
	// Adjust size to be divisible by 4 and have space for the label and header.
	if (size + 4 > MAX_SIZE)
		memoryError("Cannot allocate %lu-byte block", size);
	else if (size < 8)
		size = 16;
	else
		size = (size + 11) & ~3;
	
	Block *b = gc_alloc_(size, GC_LABELED);
	
	const char **lptr = (const char**)((char*)b + b->size - 4);
	*lptr = label;
	
	return b->data;
}

char *gc_show_ptr(void *ptr, char buffer[64])
{
	if (ptr == NULL) {
		strcpy(buffer, "NULL");
		return buffer;
	}
	
	if ((unsigned long)ptr < 0x400 ||
	    (unsigned long)ptr > 0x3FFFF)
		goto invalid;
	
	if (IN_STACK(ptr)) {
		sprintf(buffer, "(stack)0x%05lX", (unsigned long)ptr);
		return buffer;
	}
	if (IN_BSS(ptr)) {
		sprintf(buffer, "(bss)0x%05lX", (unsigned long)ptr);
		return buffer;
	}
	
	Block      *b = ptr - 4;
	const char *label = NULL;
	
	if (!b->used)
		error("gc_show_ptr: %s0x%05lX is free", label ? label : "", (unsigned long)ptr);
	
	if (b->type == GC_LABELED) {
		label = *(const char**)((char*)b + b->size - 4);
		if (label == NULL)
			error("gc_show_ptr: 0x%05lX has null label", (unsigned long)ptr);
	} else if (b->type == GC_ATOMIC) {
		label = "(char*)";
	}
	
	if (b->size > MAX_SIZE + 4) {
		error("gc_show_ptr: invalid pointer 0x%05lX (size is %u)", (unsigned long)ptr, b->size);
		goto invalid;
	}
	
	sprintf(buffer, "%s0x%05lX", label ? label : "", (unsigned long)ptr);
	
	return buffer;

invalid:
	error("gc_show_ptr: invalid pointer 0x%05lX", (unsigned long)ptr);
}

static Block *claim(Block *b, unsigned short request, BlockType type)
{
	assert(b->size >= request && b->used == 0);
	
	b->type = type;
	b->used = 1;
	
	/* If block is extra-large, cut off the part we don't need. */
	if (b->size - request >= 12) {
		Block *extra = (Block*)((char*)b + request);
		
		memzero4(extra, BLOCK_HEADER_SIZE);
		extra->size = b->size - request;
		add_to_free_list(extra);
		
		b->size = request;
	}
	
	memzero4(b->data, b->size - BLOCK_HEADER_SIZE);
	
	return b;
}

static void gc_init(void)
{
	assert(BLOCK_HEADER_SIZE == 4);
	assert((char*)&((Block*)0)->data - (char*)0 == BLOCK_HEADER_SIZE);
	assert((char*)&gc.free_lists[0] == (char*)&gc);
	
	/* Set the GC root to the BSS, but cut out the gc structure itself. */
	char   *bss_start = __ld_bss_start;
	size_t  bss_size  = __ld_bss_size;
	char   *gc_start  = (char*)&gc;
	size_t  gc_size   = sizeof(gc);
	
	assert(gc_start >= bss_start && gc_start + gc_size <= bss_start + bss_size);
	
	gc.root_count = 2;
	gc.roots[0].base = bss_start;
	gc.roots[0].size = gc_start - bss_start;
	gc.roots[1].base = gc_start + gc_size;
	gc.roots[1].size = (bss_start + bss_size) - (gc_start + gc_size);
	
	/* Allocate one page to start out. */
	gc.page_count = 1;
	if (!alloc_page(&gc.pages[0]))
		memoryError("No memory available");
	
	#ifndef NDEBUG
	gc_check_invariants();
	#endif
}

static bool add_page(void)
{
	Page         page;
	unsigned int i;
	
	assert(gc.page_count > 0);
	
	/* Allocate the new page */
	if (gc.page_count >= PAGE_MAX ||
	    !alloc_page(&page))
		return false;
	
	/* Insert it into the page array, keeping start addresses in order. */
	for (i = 0; i < gc.page_count; i++)
		if ((char*)page.first < (char*)gc.pages[i].first)
			break;
	memmove(&gc.pages[i+1], &gc.pages[i], (gc.page_count - i) * sizeof(Page));
	gc.pages[i] = page;
	gc.page_count++;
	
	/* Update sentinel links so that all blocks are in one big list starting at gc.pages[0].first */
	if (i > 0) {
		assert(gc.pages[i-1].first < page.first);
		gc.pages[i-1].sentinel->nextPage = page.first;
	}
	if (i + 1 < gc.page_count) {
		assert(gc.pages[i+1].first > page.first);
		page.sentinel->nextPage = gc.pages[i+1].first;
	}
	
	return true;
}

static bool alloc_page(Page *page)
{
	size_t  max, avail, size;
	
	max   = HeapMax();
	avail = HeapAvail();
	if (max   < MIN_PAGE_SIZE ||
	    avail < (RESERVE_SIZE + 2) + MIN_PAGE_SIZE)
		return false;
	
	size = min(max, avail - (RESERVE_SIZE + 2));
	
	if (size > MAX_PAGE_SIZE)
		size = MAX_PAGE_SIZE;
	
	HANDLE handle = HeapAllocHigh(size);
	if (handle == H_NULL)
		return false;
	
	char *ptr = HeapDeref(handle);
	
	page->first    = (Block*) (((unsigned long)ptr + 3) & ~3);
	page->size     = (size - ((char*)page->first - ptr)) & ~3;
	page->sentinel = (Block*) ((char*)page->first + page->size - 8);
	page->handle   = handle;
	
	assert((char*)page->first + page->size <= ptr + size);
	assert((char*)page->sentinel - (char*)page->first >= 12);
	assert((char*)page->sentinel + 8 == (char*)page->first + page->size);
	assert((char*)&page->sentinel->nextPage + 4 == (char*)page->first + page->size);
	
	memzero4(page->first, BLOCK_HEADER_SIZE);
	page->first->size = page->size - 8;
	
	memzero4(page->sentinel, BLOCK_HEADER_SIZE);
	page->sentinel->nextPage = NULL;
	
	/* Make allocator aware of the new memory. */
	add_to_free_list(page->first);
	
	return true;
}

static void add_to_free_list(Block *b)
{
	assert((b->size & 3) == 0 && b->size >= 12);
	assert(b->used == 0);
	
	Block **list = &gc.free_lists[FREE_LIST_INDEX(b->size)];
	
	b->nextFree = *list;
	*list = b;
}

static Block *get_from_free_list(unsigned short size)
{
	assert((size & 3) == 0 && size >= 12);
	
	unsigned int i = FREE_LIST_INDEX(size);
	
	/*
	 * If we can't get a block exactly as big as we want,
	 * see if there are any in the larger sizes.
	 */
	while (i < FREE_LIST_COUNT && gc.free_lists[i] == NULL)
		i++;
	
	if (i >= FREE_LIST_COUNT)
		return NULL;
	
	Block **list = &gc.free_lists[i];
	
	/*
	 * If the head of the free list we chose is too small, go to the next one.
	 * Note that this only applies to the final free list,
	 * which has blocks of multiple sizes.
	 */
	while ((*list)->size < size) {
		list = &((*list)->nextFree);
		if (*list == NULL)
			return NULL;
	}
	
	Block *ret = *list;
	*list = ret->nextFree;
	
	assert(ret->used == 0);
	
	return ret;
}

static void build_jump_table(void)
{
	Block          *b = gc.pages[0].first;
	unsigned short  i = 0;
	
	if (gc_verbose & GC_VERBOSE_TASKS)
		puts("Building jump table...\n");
	
	do {
		if (b->used) {
			unsigned short start = (unsigned long)b->data >> JUMP_TABLE_SHIFT,
			               end   = ((unsigned long)b + b->size) >> JUMP_TABLE_SHIFT;
			
			assert(start <= end && end < JUMP_TABLE_COUNT);
			
			while (i < start)
				gc.jump_table[i++] = 0;
			while (i <= end)
				gc.jump_table[i++] = toNear(b);
		}
		
		b = next_block(b);
	} while (b != NULL);
	
	while (i < JUMP_TABLE_COUNT)
		gc.jump_table[i++] = 0;
}

static Block *lookup(void *ptr)
{
	if ((char*)ptr >= HEAP_END)
		return NULL;
	
	Block *b = fromNear(gc.jump_table[(unsigned long)ptr >> JUMP_TABLE_SHIFT]);
	
	#define step()                           \
		if (b == NULL || (char*)ptr < b->data) \
			return NULL;                         \
		                                       \
		if ((char*)ptr <= (char*)b + b->size)  \
			return b;                            \
		                                       \
		b = (Block*)((char*)b + b->size);      \
		if (b->size == 0)                      \
			b = b->nextPage;
	
	/*
	 * A window of size N can span at most (N+19) / 12 blocks.
	 * The worst case scenario is when 12-byte blocks are packed together.
	 * For example, we could have a situation like this:
	 *
	 * XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	 *
	 * *---12345678*---12345678*---12345678*---12345678*---12345678*---
	 *
	 * This window spans 6 blocks.  The * is the end of a block,
	 * which counts as being "inside" the block.
	 */
	
	/*
	unsigned int i = ((1 << JUMP_TABLE_SHIFT) + 19) / 12 - 1;
	do {
		step();
	} while (i--);
	*/
	
	step();
	step();
	step();
	step();
	step();
	step();
	
	#undef step
	
	return NULL;
}

static void mark(void *ptr)
{
	Block *b = lookup(ptr);
	
	if (b) {
		if (b->marked || !b->used)
			return;
		
		b->marked = 1;
		
		if (b->type != 0)
			defer_scan(b);
	}
}

static void defer_scan(Block *b)
{
	if (gc.defer_count >= DEFER_MAX && !gc.need_to_scan) {
		if (gc_verbose & GC_VERBOSE_NEED_TO_SCAN)
			puts("GC stack overflow: rescanning heap\n");
		gc.need_to_scan = true;
	}
	
	if (gc.need_to_scan)
		b->need_to_scan = true;
	else
		gc.defer[gc.defer_count++] = toNear(b);
}

static void scan(Block *b)
{
	char *s = b->data, *e = (char*)b + b->size - 2;
	
	for (; s < e; s += 2)
		mark(*(void**)s);
}

static void flush(void)
{
	Block *b = NULL;
	
flush_stack:
	while (gc.defer_count > 0)
		scan(fromNear(gc.defer[--gc.defer_count]));
	
	if (gc.need_to_scan) {
		gc.need_to_scan = false;
		b = gc.pages[0].first;
	}
	
	for (; b != NULL; b = next_block(b)) {
		if (b->need_to_scan) {
			b->need_to_scan = false;
			scan(b);
			goto flush_stack;
		}
	}
}

static void mark_range(void *base, size_t size)
{
	char *s = base, *e = s + size;
	
	s = (char*) ((unsigned long)(s+1) & ~1);
	
	if (gc_verbose & GC_VERBOSE_RANGES)
		printf("Scanning 0x%05lX..0x%05lX\n", (unsigned long)s, (unsigned long)e);
	
	for (; s+4 <= e; s += 2)
		mark(*(void**)s);
}

static void sweep(void)
{
	Block *b         = gc.pages[0].first;
	Block *freeClump = NULL; /* accumulator block.  Adjacent free blocks are clumped together. */
	unsigned long bytes_freed = 0, bytes_retained = 0, bytes_total = 0;
	
	if (gc_verbose & GC_VERBOSE_TASKS)
		puts("Sweeping...\n");
	
	#ifndef NDEBUG
	gc_check_invariants();
	#endif
	
	/* We'll be rebuilding the free lists from scratch. */
	memzero4(gc.free_lists, sizeof(gc.free_lists));
	
	for (;;) {
		bytes_total += b->size;
		
		if (b->used && !b->marked && !b->dont_gc) {
			bytes_freed += b->size;
			finalize(b->data, b->type);
			b->used = 0;
		}
		
		/*
		 * If this block is used, punctuate the free clump (if present).
		 * Otherwise, add this block to it, or start a new free clump with it.
		 */
		if (b->used) {
			bytes_retained += b->size;
			if (freeClump) {
				add_to_free_list(freeClump);
				freeClump = NULL;
			}
		} else {
			if (freeClump)
				freeClump->size += b->size;
			else
				freeClump = b;
		}
		
		/* While we're at it, clear marked bits of blocks. */
		b->marked = 0;
		
		/*
		 * Proceed to the next block.
		 *
		 * If we encounter a page break or end of heap, punctuate the free clump.
		 * We don't want to create blocks spanning the no man's land between two pages.
		 */
		b = (Block*)((char*)b + b->size);
		if (b->size == 0) {
			b = b->nextPage;
			
			if (freeClump != NULL) {
				add_to_free_list(freeClump);
				freeClump = NULL;
			}
			
			if (b == NULL)
				break;
		}
	}
	
	if (gc_verbose & GC_VERBOSE_STATS)
		printf("Freed %lu bytes\nRetained %lu bytes\n", bytes_freed, bytes_retained);
	if (gc_verbose & GC_VERBOSE_TERSE) {
		char buffer[64];
		sprintf(buffer, "%lu freed / %lu used / %lu total", bytes_freed, bytes_retained, bytes_total);
		ST_helpMsg(buffer);
	}
	
	gc.bytes_freed    = bytes_freed;
	gc.bytes_retained = bytes_retained;
	gc.bytes_total    = bytes_total;
}

void performGC(void)
{
	void *sp;
	get_stack_ptr(sp);
	
	// setjmp gives us the registers that parent calls might still be using.
	JMP_BUF regs;
	setjmp(regs);
	regs->PC = 0;
	
	if (gc.collecting)
		memoryError("performGC entered twice (bug)");
	gc.collecting = true;
	
	#ifndef NDEBUG
	gc_check_invariants();
	#endif
	
	assert((char*)sp < (char*)gc.stack_base);
	assert(gc.defer_count == 0 && gc.need_to_scan == false);
	
	if ((gc_verbose & ~GC_VERBOSE_TERSE) != 0) {
		putchar('\n');
		
		if (gc_verbose & GC_VERBOSE_BORDER)
			puts("---------------------\n");
		if (gc_verbose & GC_VERBOSE_COLLECTING)
			puts("Garbage collecting...\n");
	}
	
	build_jump_table();
	
	if (gc_verbose & GC_VERBOSE_TASKS)
		puts("Marking...\n");
	
	mark_range(regs, sizeof(regs));
	mark_range(sp, (char*)gc.stack_base - (char*)sp);
	
	unsigned int i;
	for (i = 0; i < gc.root_count; i++)
		mark_range(gc.roots[i].base, gc.roots[i].size);
	
	flush();
	sweep();
	
	if (gc_verbose & GC_VERBOSE_BORDER)
		puts("---------------------\n");
	
	#ifndef NDEBUG
	gc_check_invariants();
	#endif
	
	if (gc_verbose & GC_VERBOSE_PAUSE)
		ngetchx();
	
	gc.collecting = false;
}

static void gc_finish(void)
{
	Block        *b = gc.pages[0].first;
	unsigned int  i;
	
	do {
		if (b->used)
			finalize(b->data, b->type);
		
		b = next_block(b);
	} while (b != NULL);
	
	for (i = 0; i < gc.page_count; i++)
		HeapFree(gc.pages[i].handle);
}

void gc_check_invariants(void)
{
	#define errorIf(pred) do { \
			if (pred) \
				error("GC check: %s", #pred); \
		} while (0)
	
	Page         *page;
	Block        *b;
	unsigned int  i;
	
	errorIf(gc.page_count < 1);
	errorIf(gc.pages[0].first == NULL);
	errorIf(gc.pages[0].size == 0);
	errorIf(gc.pages[0].sentinel == NULL);
	errorIf(gc.pages[0].handle == H_NULL);
	
	for (i = 1; i < gc.page_count; i++)
		errorIf((char*)gc.pages[i-1].first >= (char*)gc.pages[i].first);
	
	/* Pass 1: Make sure page array and block links agree with each other. */
	page = &gc.pages[0];
	b    = page->first;
	for (i = 0; i < gc.page_count; i++, page++) {
		char *page_start = (char*)page->first;
		char *page_end   = page_start + page->size;
		
		errorIf(page->handle == H_NULL);
		
		errorIf(page_start >= page_end);
		errorIf(page_start < (char*)HeapDeref(page->handle));
		errorIf(page_end > (char*)HeapDeref(page->handle) + HeapSize(page->handle));
		
		errorIf((char*)page->sentinel + 8 != page_end);
		
		errorIf(b == NULL);
		errorIf(b->size == 0);
		errorIf((char*)b != page_start);
		
		while (b->size != 0) {
			errorIf(b->size > MAX_SIZE + 12);
			
			b->in_free_list = 0;
			
			b = (Block*)((char*)b + b->size);
			errorIf((char*)b + 8 > page_end);
		}
		
		errorIf((char*)b + 8 != page_end);
		
		b = b->nextPage;
	}
	errorIf(b != NULL);
	
	/*
	 * Pass 2: Make sure that, for every block, it is in the free list if and only if it is free.
	 *         Also make sure that blocks in the fixed-width free lists all have the right size.
	 */
	for (i = 0; i < FREE_LIST_COUNT; i++) {
		for (b = gc.free_lists[i]; b != NULL; b = b->nextFree) {
			if (b->used)
				error("GC check: used block in free list");
			if (i < FREE_LIST_COUNT - 1)
				errorIf(b->size != i*4 + 12);
			else
				errorIf(b->size < i*4 + 12);
			b->in_free_list = 1;
		}
	}
	
	b = gc.pages[0].first;
	do {
		errorIf(b->used && b->in_free_list);
		errorIf(!b->used && !b->in_free_list);
		
		b = next_block(b);
	} while (b != NULL);
	
	#undef errorIf
}

char *gc_strdup_len(const char *str, size_t len)
{
	char *ret = gc_alloc_atomic(len + 1);
	memcpy(ret, str, len);
	ret[len] = 0;
	return ret;
}
