#include "semantic.h"
#include "symtab.h"
#include <string.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>

Type* TYPE_INT;
Type* TYPE_FLOAT;

/* 函数声明 */
void Program(struct TreeNode* node);
void ExtDefList(struct TreeNode* node);
void ExtDef(struct TreeNode* node);
Type* Specifier(struct TreeNode* node);
void ExtDecList(struct TreeNode* node, Type* type);
void CompSt(struct TreeNode* node, Type* retType);
void DefList(struct TreeNode* node, FieldList** structFields);
void Def(struct TreeNode* node, FieldList** structFields);
void DecList(struct TreeNode* node, Type* type, FieldList** structFields);
void Dec(struct TreeNode* node, Type* type, FieldList** structFields);
void StmtList(struct TreeNode* node, Type* retType);
void Stmt(struct TreeNode* node, Type* retType);
Type* Exp(struct TreeNode* node);
void Args(struct TreeNode* node, FieldList* paramDef);

void semantic_error(int type, int lineno, const char* format, ...) {
    va_list args;
    va_start(args, format);
    printf("Error type %d at Line %d: ", type, lineno);
    vprintf(format, args);
    printf("\n");
    va_end(args);
    fflush(stdout); 
}

/* 变量名提取 */
char* check_VarDec(struct TreeNode* node, Type* type, Type** outType) {
    if (!node) return "unknown";
    struct TreeNode* cur = node;
    if (cur->child_num >= 4 && strcmp(cur->children[1]->name, "LB") == 0) {
        Type* newArrType = newArrayType(type, cur->children[2]->value.int_val);
        return check_VarDec(cur->children[0], newArrType, outType);
    }
    if (cur->child_num >= 1 && strcmp(cur->children[0]->name, "ID") == 0) {
        *outType = type;
        return cur->children[0]->value.str_val;
    }
    return "unknown";
}

/* 参数提取 */
FieldList* get_Params(struct TreeNode* node) {
    FieldList* head = NULL;
    FieldList* tail = NULL;
    struct TreeNode* cur = node;

    while (cur) {
        if (cur->child_num == 0) break;
        struct TreeNode* paramDec = cur->children[0]; 
        
        if (paramDec->child_num < 2) break;
        struct TreeNode* spec = paramDec->children[0];
        struct TreeNode* vdec = paramDec->children[1];
        
        Type* specType = Specifier(spec);
        Type* finalType = NULL;
        char* name = check_VarDec(vdec, specType, &finalType);
        
        FieldList* f = newFieldList(name, finalType);
        if (!head) head = f; 
        else tail->next = f;
        tail = f;
        
        if (cur->child_num >= 3) cur = cur->children[2]; 
        else break; 
    }
    return head;
}

void semantic_analysis(struct TreeNode* root) {
    if (!root) return;
    TYPE_INT = newBasicType(0);
    TYPE_FLOAT = newBasicType(1);
    initSymbolTable();
    Program(root);
}

void Program(struct TreeNode* node) {
    if (!node) return;
    if (node->child_num > 0) ExtDefList(node->children[0]);
}

void ExtDefList(struct TreeNode* node) {
    if (!node) return;
    if (node->child_num > 0) ExtDef(node->children[0]);
    if (node->child_num > 1) ExtDefList(node->children[1]);
}

void ExtDef(struct TreeNode* node) {
    if (!node) return;
    Type* type = Specifier(node->children[0]);
    
    for (int i = 1; i < node->child_num; i++) {
        struct TreeNode* child = node->children[i];
        if (strcmp(child->name, "ExtDecList") == 0) {
            ExtDecList(child, type);
        }
        else if (strcmp(child->name, "FunDec") == 0) {
            struct TreeNode* funDec = child;
            if (funDec->child_num < 1) return;
            char* funcName = funDec->children[0]->value.str_val;

            if (lookupSymbol(funcName)) {
                semantic_error(4, funDec->lineno, "redefine function: %s", funcName);
                return; 
            }

            Type* funcType = newType(FUNCTION);
            funcType->u.function.returnType = type;
            funcType->u.function.params = NULL;
            funcType->u.function.paramNum = 0;

            if (funDec->child_num == 4) { 
                funcType->u.function.params = get_Params(funDec->children[2]);
                FieldList* p = funcType->u.function.params;
                while(p) { funcType->u.function.paramNum++; p = p->next; }
            }

            insertSymbol(funcName, funcType);

            enterScope();
            FieldList* p = funcType->u.function.params;
            while (p) {
                if (strcmp(p->name, "unknown") != 0) {
                    if (!insertSymbol(p->name, p->type)) {
                        semantic_error(3, funDec->lineno, "redefine variable: %s", p->name);
                    }
                }
                p = p->next;
            }
            if (i + 1 < node->child_num) CompSt(node->children[i+1], type);
            exitScope();
            return; 
        }
    }
}

void CompSt(struct TreeNode* node, Type* retType) {
    if (!node) return;
    for (int i = 0; i < node->child_num; i++) {
        struct TreeNode* child = node->children[i];
        if (strcmp(child->name, "DefList") == 0) DefList(child, NULL);
        else if (strcmp(child->name, "StmtList") == 0) StmtList(child, retType);
    }
}

Type* Specifier(struct TreeNode* node) {
    if (!node || node->child_num == 0) return NULL;
    struct TreeNode* child = node->children[0];
    if (strcmp(child->name, "TYPE") == 0) {
        if (child->value.str_val && strcmp(child->value.str_val, "float") == 0) return TYPE_FLOAT;
        return TYPE_INT;
    } 
    else if (strcmp(child->name, "StructSpecifier") == 0) {
        if (child->child_num == 5) {
            char* name = child->children[1]->value.str_val;
            if (lookupSymbol(name)) {
                semantic_error(15, child->lineno, "redefine the same structure type");
                return NULL;
            }
            Type* t = newStructType(name, NULL);
            insertSymbol(name, t);
            
            FieldList* fields = NULL;
            DefList(child->children[3], &fields);
            t->u.structure.member = fields;
            return t;
        } 
        else {
            char* name = child->children[1]->value.str_val;
            HashNode* sym = lookupSymbol(name);
            if (!sym || sym->type->kind != STRUCTURE) {
                semantic_error(17, child->lineno, "undefined structure: %s", name);
                return NULL;
            }
            return sym->type;
        }
    }
    return NULL;
}

void ExtDecList(struct TreeNode* node, Type* type) {
    if (!node) return;
    Type* finalType = NULL;
    char* name = check_VarDec(node->children[0], type, &finalType);
    
    if (strcmp(name, "unknown") != 0) {
        if (!insertSymbol(name, finalType)) {
            semantic_error(3, node->lineno, "redefine variable: %s", name);
        }
    }
    if (node->child_num > 1) ExtDecList(node->children[2], type);
}

void DefList(struct TreeNode* node, FieldList** structFields) {
    if (!node) return;
    if (node->child_num > 0) Def(node->children[0], structFields);
    if (node->child_num > 1) DefList(node->children[1], structFields);
}

void Def(struct TreeNode* node, FieldList** structFields) {
    if (!node) return;
    Type* type = Specifier(node->children[0]);
    DecList(node->children[1], type, structFields);
}

void DecList(struct TreeNode* node, Type* type, FieldList** structFields) {
    if (!node) return;
    Dec(node->children[0], type, structFields);
    if (node->child_num > 1) DecList(node->children[2], type, structFields);
}

void Dec(struct TreeNode* node, Type* type, FieldList** structFields) {
    if (!node) return;
    Type* finalType = NULL;
    char* name = check_VarDec(node->children[0], type, &finalType);
    if (strcmp(name, "unknown") == 0) return;

    if (structFields) {
        FieldList* cur = *structFields;
        while(cur) {
            if (strcmp(cur->name, name) == 0) {
                semantic_error(15, node->lineno, "redefined field: %s", name);
                return;
            }
            cur = cur->next;
        }
        FieldList* f = newFieldList(name, finalType);
        if (*structFields == NULL) *structFields = f;
        else {
            FieldList* tail = *structFields;
            while(tail->next) tail = tail->next;
            tail->next = f;
        }
        if (node->child_num > 1) semantic_error(15, node->lineno, "cannot initialize field in struct");
    } else {
        if (!insertSymbol(name, finalType)) {
            semantic_error(3, node->lineno, "redefine variable: %s", name);
        }
        if (node->child_num > 1) {
            Type* expType = Exp(node->children[2]);
            if (expType && !checkType(finalType, expType)) {
                semantic_error(5, node->lineno, "unmatching type on both sides of assignment");
            }
        }
    }
}

void StmtList(struct TreeNode* node, Type* retType) {
    if (!node) return;
    if (node->child_num > 0) Stmt(node->children[0], retType);
    if (node->child_num > 1) StmtList(node->children[1], retType);
}

void Stmt(struct TreeNode* node, Type* retType) {
    if (!node || node->child_num == 0) return;
    char* name = node->children[0]->name;
    
    if (strcmp(name, "Exp") == 0) Exp(node->children[0]);
    else if (strcmp(name, "CompSt") == 0) {
        enterScope();
        CompSt(node->children[0], retType);
        exitScope();
    }
    else if (strcmp(name, "RETURN") == 0) {
        Type* actual = Exp(node->children[1]);
        if (actual && !checkType(retType, actual))
            semantic_error(8, node->lineno, "return value type mismatches the declared type");
    }
    else if (strcmp(name, "IF") == 0) {
        Exp(node->children[2]);
        Stmt(node->children[4], retType); 
        if (node->child_num > 5) Stmt(node->children[6], retType);
    }
    else if (strcmp(name, "WHILE") == 0) {
        Exp(node->children[2]);
        Stmt(node->children[4], retType);
    }
}

Type* Exp(struct TreeNode* node) {
    if (!node || node->child_num == 0) return NULL;
    
    // ID
    if (strcmp(node->children[0]->name, "ID") == 0 && node->child_num == 1) {
        char* name = node->children[0]->value.str_val;
        HashNode* sym = lookupSymbol(name);
        if (!sym) {
            semantic_error(1, node->lineno, "undefined variable: %s", name);
            return NULL;
        }
        return sym->type;
    }
    if (strcmp(node->children[0]->name, "INT") == 0) return TYPE_INT;
    if (strcmp(node->children[0]->name, "FLOAT") == 0) return TYPE_FLOAT;
    
    // Assign
    if (node->child_num == 3 && strcmp(node->children[1]->name, "ASSIGN") == 0) {
        struct TreeNode* lhs = node->children[0];
        int isLVal = 0;
        if (lhs->child_num >= 1 && lhs->children[0]) {
            if (lhs->child_num == 1 && strcmp(lhs->children[0]->name, "ID") == 0) isLVal = 1;
            else if (lhs->child_num == 4 && strcmp(lhs->children[1]->name, "LB") == 0) isLVal = 1;
            else if (lhs->child_num == 3 && strcmp(lhs->children[1]->name, "DOT") == 0) isLVal = 1;
        }
        
        if (!isLVal) {
            semantic_error(6, node->lineno, "rvalue on the left side of assignment operator");
            return NULL;
        }
        Type* t1 = Exp(lhs);
        Type* t2 = Exp(node->children[2]);
        if (t1 && t2 && !checkType(t1, t2)) {
            semantic_error(5, node->lineno, "unmatching type on both sides of assignment");
            return NULL;
        }
        return t1;
    }
    
    // Func Call
    if (node->child_num >= 3 && strcmp(node->children[0]->name, "ID") == 0) {
        char* name = node->children[0]->value.str_val;
        HashNode* sym = lookupSymbol(name);
        if (!sym) {
            semantic_error(2, node->lineno, "undefined function: %s", name);
            return NULL;
        }
        if (sym->type->kind != FUNCTION) {
            semantic_error(11, node->lineno, "applying function invocation operator on non-function names");
            return NULL;
        }
        if (node->child_num == 4) Args(node->children[2], sym->type->u.function.params);
        else if (sym->type->u.function.params != NULL) {
            semantic_error(9, node->lineno, "arguments mismatch the declared parameters");
        }
        return sym->type->u.function.returnType;
    }
    
    // Operators
    if (node->child_num == 3 && (strcmp(node->children[1]->name, "PLUS") == 0 || strcmp(node->children[1]->name, "MUL") == 0 || strcmp(node->children[1]->name, "DIV") == 0 || strcmp(node->children[1]->name, "MINUS") == 0)) {
        Type* t1 = Exp(node->children[0]);
        Type* t2 = Exp(node->children[2]);
        if (t1 && t2) {
            if (t1->kind != BASIC || t2->kind != BASIC || t1->u.basic != t2->u.basic) {
                semantic_error(7, node->lineno, "unmatching operands");
                return t1; 
            }
            return t1;
        }
        return NULL;
    }
    
    // Array
    if (node->child_num == 4 && strcmp(node->children[1]->name, "LB") == 0) {
        Type* base = Exp(node->children[0]);
        Type* idx = Exp(node->children[2]);
        if (!base || base->kind != ARRAY) {
            semantic_error(10, node->lineno, "applying indexing operator on non-array type variables");
            return NULL;
        }
        if (!idx || idx->kind != BASIC || idx->u.basic != 0) {
            semantic_error(12, node->lineno, "array indexing with non-integer type expression");
            return base->u.array.elem;
        }
        return base->u.array.elem;
    }
    
    // Struct
    if (node->child_num == 3 && strcmp(node->children[1]->name, "DOT") == 0) {
        Type* base = Exp(node->children[0]);
        char* field = node->children[2]->value.str_val;
        if (!base || base->kind != STRUCTURE) {
            semantic_error(13, node->lineno, "accessing member of non-structure variable");
            return NULL;
        }
        FieldList* f = base->u.structure.member;
        while(f) {
            if (strcmp(f->name, field) == 0) return f->type;
            f = f->next;
        }
        semantic_error(14, node->lineno, "accessing an undefined structure member");
        return NULL;
    }
    
    if (strcmp(node->children[0]->name, "LP") == 0) return Exp(node->children[1]);
    
    return NULL;
}

void Args(struct TreeNode* node, FieldList* paramDef) {
    if (!node) return;
    Type* t = Exp(node->children[0]);
    if (!paramDef) {
        semantic_error(9, node->lineno, "arguments mismatch the declared parameters");
        return;
    }
    if (t && !checkType(t, paramDef->type)) {
        semantic_error(9, node->lineno, "arguments mismatch the declared parameters");
        return;
    }
    if (node->child_num == 3) Args(node->children[2], paramDef->next);
    else if (paramDef->next != NULL) semantic_error(9, node->lineno, "arguments mismatch the declared parameters");
}
