#include "symtab.h"
#include <string.h>
#include <stdlib.h>
#include <stdio.h>

HashNode* symbolTable[HASH_SIZE + 100];
HashNode* scopeStack[2048]; 
int currentDepth = 0;

unsigned int hash_pjw(char* name) {
    if (!name) return 0;
    unsigned int val = 0, i;
    for (; *name; ++name) {
        val = (val << 2) + *name;
        if ((i = val & ~HASH_SIZE)) val = (val ^ (i >> 12)) & HASH_SIZE;
    }
    return val % HASH_SIZE; 
}

void initSymbolTable() {
    for (int i = 0; i < HASH_SIZE + 100; i++) symbolTable[i] = NULL;
    currentDepth = 0;
    scopeStack[0] = NULL;
}

void enterScope() {
    currentDepth++;
    if(currentDepth >= 2048) return;
    scopeStack[currentDepth] = NULL;
}

void exitScope() {
    if (currentDepth < 0) return;
    HashNode* node = scopeStack[currentDepth];
    while (node != NULL) {
        unsigned int index = hash_pjw(node->name);
        if (index >= HASH_SIZE + 100) { node = node->stack_next; continue; }

        HashNode* head = symbolTable[index];
        if (head == node) {
            symbolTable[index] = node->next;
        } else {
            HashNode* prev = head;
            while (prev && prev->next != node) prev = prev->next;
            if (prev) prev->next = node->next;
        }
        
        // HashNode* temp = node;
        node = node->stack_next;
        // free(temp);  <-- DISABLED TO PREVENT CRASH
    }
    currentDepth--;
}

int insertSymbol(char* name, Type* type) {
    if (!name) return 0;
    if (inCurrentScope(name)) return 0;
    unsigned int index = hash_pjw(name);
    if (index >= HASH_SIZE + 100) index = 0; 
    
    HashNode* node = (HashNode*)malloc(sizeof(HashNode));
    node->name = name; 
    node->type = type;
    node->depth = currentDepth;
    
    node->next = symbolTable[index];
    symbolTable[index] = node;
    
    node->stack_next = scopeStack[currentDepth];
    scopeStack[currentDepth] = node;
    return 1;
}

HashNode* lookupSymbol(char* name) {
    if (!name) return NULL;
    unsigned int index = hash_pjw(name);
    if (index >= HASH_SIZE + 100) return NULL;

    HashNode* node = symbolTable[index];
    while (node != NULL) {
        if (strcmp(node->name, name) == 0) return node;
        node = node->next;
    }
    return NULL;
}

int inCurrentScope(char* name) {
    if (!name) return 0;
    unsigned int index = hash_pjw(name);
    if (index >= HASH_SIZE + 100) return 0;

    HashNode* node = symbolTable[index];
    while (node != NULL) {
        if (strcmp(node->name, name) == 0) {
            if (node->depth == currentDepth) return 1;
        }
        node = node->next;
    }
    return 0;
}
