From a116c24e8a53e11860e29ac8b189d52b8c9ec192 Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Wed, 14 Aug 2024 13:31:11 +0000 Subject: [PATCH] ready to debug --- include/ast/astnode_visitor.h | 1 + include/ast/scope.hpp | 29 ++++ include/tools.h | 26 +++- src/ast/semanticvisitor.cpp | 245 ++++++++++++++++++++++++++++------ src/semantic/semantic.cpp | 32 ++++- src/semantic/visitor.cpp | 2 +- 6 files changed, 293 insertions(+), 42 deletions(-) diff --git a/include/ast/astnode_visitor.h b/include/ast/astnode_visitor.h index c0c30ae..5f91a61 100644 --- a/include/ast/astnode_visitor.h +++ b/include/ast/astnode_visitor.h @@ -57,6 +57,7 @@ class ASTNodeVirturalVisitor : public ASTNodeVisitorBase { class ASTSemanticCheckVisitor : public ASTNodeVirturalVisitor { bool is_in_func; FunctionSchema cur_func_schema; + std::string cur_class_name; size_t loop_level; std::shared_ptr global_scope; friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); diff --git a/include/ast/scope.hpp b/include/ast/scope.hpp index 71ecc4e..7cb38b6 100644 --- a/include/ast/scope.hpp +++ b/include/ast/scope.hpp @@ -71,6 +71,7 @@ class FunctionScope : public ScopeBase { friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); friend class Visitor; friend class ASTSemanticCheckVisitor; + friend class GlobalScope; FunctionSchema schema; bool add_variable([[maybe_unused]] const std::string &name, [[maybe_unused]] const ExprTypeInfo &type) override { throw std::runtime_error("FunctionScope does not support add_variable"); @@ -99,6 +100,8 @@ class FunctionScope : public ScopeBase { }; class ClassDefScope : public ScopeBase { friend class Visitor; + friend class GlobalScope; + friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); std::unordered_map member_variables; std::unordered_map> member_functions; bool add_variable(const std::string &name, const ExprTypeInfo &type) override { @@ -156,6 +159,32 @@ class GlobalScope : public ScopeBase { std::unordered_map global_variables; std::unordered_map> global_functions; std::unordered_map> classes; + ExprTypeInfo FetchClassMemberVariable(const std::string &class_name, const std::string &var_name) { + if (classes.find(class_name) == classes.end()) { + throw SemanticError("Class " + class_name + " not found", 1); + } + auto ptr = classes[class_name]; + if (ptr->member_variables.find(var_name) == ptr->member_variables.end()) { + throw SemanticError("Variable " + var_name + " not found in class " + class_name, 1); + } + return ptr->member_variables[var_name]; + } + FunctionSchema FetchClassMemberFunction(const std::string &class_name, const std::string &func_name) { + if (classes.find(class_name) == classes.end()) { + throw SemanticError("Class " + class_name + " not found", 1); + } + auto ptr = classes[class_name]; + if (ptr->member_functions.find(func_name) == ptr->member_functions.end()) { + throw SemanticError("Function " + func_name + " not found in class " + class_name, 1); + } + return ptr->member_functions[func_name]->schema; + } + FunctionSchema FetchFunction(const std::string &name) { + if (global_functions.find(name) == global_functions.end()) { + throw SemanticError("Function " + name + " not found", 1); + } + return global_functions[name]->schema; + } bool add_class(const std::string &name, std::shared_ptr ptr) { if (IsKeyWord(name)) return false; if (classes.find(name) != classes.end()) { diff --git a/include/tools.h b/include/tools.h index e01f4a1..b1d67b0 100644 --- a/include/tools.h +++ b/include/tools.h @@ -67,4 +67,28 @@ struct ArrayType { IdentifierType basetype; size_t level; }; -using ExprTypeInfo = std::variant; \ No newline at end of file +inline bool operator==(const ArrayType &l, const ArrayType &r) { + return l.has_base_type == r.has_base_type && l.basetype == r.basetype && l.level == r.level; +} +using ExprTypeInfo = std::variant; + +inline bool operator==(const ExprTypeInfo &l, const ExprTypeInfo &r) { + if (std::holds_alternative(r) && std::get(r) == "null") { + if (std::holds_alternative(l)) { + std::string l_type = std::get(l); + if (l_type == "int" || l_type == "bool" || l_type == "string") { + return false; + } + return true; + } + return true; + } + if (std::holds_alternative(l)) { + return std::holds_alternative(r) && std::get(l) == std::get(r); + } + if (std::holds_alternative(l)) { + return std::holds_alternative(r) && std::get(l) == std::get(r); + } + throw std::runtime_error("something strange happened"); +} +inline bool operator!=(const ExprTypeInfo &l, const ExprTypeInfo &r) { return !(l == r); } \ No newline at end of file diff --git a/src/ast/semanticvisitor.cpp b/src/ast/semanticvisitor.cpp index 2e770c4..d92d29e 100644 --- a/src/ast/semanticvisitor.cpp +++ b/src/ast/semanticvisitor.cpp @@ -1,3 +1,4 @@ +#include #include #include "astnode_visitor.h" #include "scope.hpp" @@ -17,6 +18,7 @@ void ASTSemanticCheckVisitor::ActuralVisit(ClassDef_ASTNode *node) { for (auto var : node->member_variables) { var->accept(this); } + cur_class_name = node->class_name; for (auto ch : node->sorted_children) { if (std::dynamic_pointer_cast(ch) == nullptr) { ch->accept(this); @@ -59,7 +61,10 @@ void ASTSemanticCheckVisitor::ActuralVisit(ExprStatement_ASTNode *node) { node-> void ASTSemanticCheckVisitor::ActuralVisit(IfStatement_ASTNode *node) { node->condition->accept(this); - // TODO type check + const static ExprTypeInfo standard = "bool"; + if (node->condition->expr_type_info != standard) { + throw SemanticError("If condition must be bool", 1); + } node->if_clause->accept(this); if (node->has_else_clause) { node->else_clause->accept(this); @@ -68,7 +73,10 @@ void ASTSemanticCheckVisitor::ActuralVisit(IfStatement_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(WhileStatement_ASTNode *node) { node->condition->accept(this); - // TODO type check + const static ExprTypeInfo standard = "bool"; + if (node->condition->expr_type_info != standard) { + throw SemanticError("While condition must be bool", 1); + } loop_level++; node->loop_body->accept(this); loop_level--; @@ -80,7 +88,10 @@ void ASTSemanticCheckVisitor::ActuralVisit(ForStatement_ASTNode *node) { } if (node->condition) { node->condition->accept(this); - // TODO type check + const static ExprTypeInfo standard = "bool"; + if (node->condition->expr_type_info != standard) { + throw SemanticError("For condition must be bool", 1); + } } if (node->update) { node->update->accept(this); @@ -93,7 +104,16 @@ void ASTSemanticCheckVisitor::ActuralVisit(ForStatement_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(JmpStatement_ASTNode *node) { if (loop_level == 0 && node->jmp_type > 0) throw SemanticError("Jump statement outside loop", 1); if (node->jmp_type == 0) { - // TODO : return type check + if (node->return_value) { + node->return_value->accept(this); + if (node->return_value->expr_type_info != cur_func_schema.return_type) { + throw SemanticError("Return type mismatch", 1); + } + } else { + if (cur_func_schema.return_type != "void") { + throw SemanticError("Return type mismatch", 1); + } + } } } @@ -105,16 +125,20 @@ void ASTSemanticCheckVisitor::ActuralVisit(SuiteStatement_ASTNode *node) { // Expression AST Nodes void ASTSemanticCheckVisitor::ActuralVisit(NewArrayExpr_ASTNode *node) { - // TODO: Implement this method for (size_t i = 0; i < node->dim_size.size(); i++) { if (node->dim_size[i]) { node->dim_size[i]->accept(this); - // TODO type check + const static ExprTypeInfo standard = "int"; + if (node->dim_size[i]->expr_type_info != standard) { + throw SemanticError("Array dimension must be int", 1); + } } } if (node->has_initial_value) { node->initial_value->accept(this); - // TODO type check + if (node->expr_type_info != node->initial_value->expr_type_info) { + throw SemanticError("Array type mismatch", 1); + } } } @@ -125,82 +149,161 @@ void ASTSemanticCheckVisitor::ActuralVisit(NewExpr_ASTNode *node) {} void ASTSemanticCheckVisitor::ActuralVisit(AccessExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: member check - if (node->is_function) { - // TODO arg number check - for (auto arg : node->arguments) { - arg->accept(this); - // TODO type check + if (std::holds_alternative(node->base->expr_type_info)) { + if (node->is_function && node->member == "size" && node->arguments.size() == 0) { + node->expr_type_info = "int"; + return; } + throw SemanticError("Access on non-class", 1); + } + std::string base_type; + try { + base_type = std::get(node->base->expr_type_info); + } catch (...) { + throw SemanticError("Access on non-class", 1); + } + if (node->is_function) { + auto schema = global_scope->FetchClassMemberFunction(base_type, node->member); + if (schema.arguments.size() != node->arguments.size()) { + throw SemanticError("Argument number mismatch", 1); + } + for (auto &arg : node->arguments) { + arg->accept(this); + if (arg->expr_type_info != schema.arguments[&arg - &node->arguments[0]].first) { + throw SemanticError("Argument type mismatch", 1); + } + } + } else { + node->expr_type_info = global_scope->FetchClassMemberVariable(base_type, node->member); } } void ASTSemanticCheckVisitor::ActuralVisit(IndexExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: dimension check + if (std::holds_alternative(node->base->expr_type_info)) { + throw SemanticError("Indexing on non-array", 1); + } + const auto &tp = std::get(node->base->expr_type_info); + if (tp.level < node->indices.size()) { + throw SemanticError("Indexing on non-array", 1); + } for (auto idx : node->indices) { idx->accept(this); - // TODO type check + const static ExprTypeInfo standard = "int"; + if (idx->expr_type_info != standard) { + throw SemanticError("Index must be int", 1); + } + } + if (tp.level == node->indices.size()) { + node->expr_type_info = tp.basetype; + } else { + node->expr_type_info = ArrayType{true, tp.basetype, tp.level - node->indices.size()}; } } void ASTSemanticCheckVisitor::ActuralVisit(SuffixExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->base->expr_type_info != standard) { + throw SemanticError("Suffix operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(PrefixExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->base->expr_type_info != standard) { + throw SemanticError("Prefix operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(OppositeExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->base->expr_type_info != standard) { + throw SemanticError("Opposite operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(LNotExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "bool"; + if (node->base->expr_type_info != standard) { + throw SemanticError("Logical not operation on non-bool", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(BNotExpr_ASTNode *node) { // TODO: Implement this method node->base->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->base->expr_type_info != standard) { + throw SemanticError("Bitwise not operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(MDMExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("MDM operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(PMExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo STRING = "string"; + if (node->left->expr_type_info == STRING && node->right->expr_type_info == STRING) { + node->expr_type_info = STRING; + return; + } + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("PM operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(RLExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("RL operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(GGLLExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo STRING = "string"; + if (node->left->expr_type_info == STRING && node->right->expr_type_info == STRING) { + node->expr_type_info = "bool"; + return; + } + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("GGLL operation on non-int", 1); + } + node->expr_type_info = "bool"; } void ASTSemanticCheckVisitor::ActuralVisit(NEExpr_ASTNode *node) { @@ -208,63 +311,91 @@ void ASTSemanticCheckVisitor::ActuralVisit(NEExpr_ASTNode *node) { node->left->accept(this); node->right->accept(this); // TODO: type check + node->expr_type_info = "bool"; } void ASTSemanticCheckVisitor::ActuralVisit(BAndExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("BAnd operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(BXorExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("BXor operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(BOrExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "int"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("BOr operation on non-int", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(LAndExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "bool"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("LAnd operation on non-bool", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(LOrExpr_ASTNode *node) { // TODO: Implement this method node->left->accept(this); node->right->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "bool"; + if (node->left->expr_type_info != standard || node->right->expr_type_info != standard) { + throw SemanticError("LOr operation on non-bool", 1); + } + node->expr_type_info = standard; } void ASTSemanticCheckVisitor::ActuralVisit(TernaryExpr_ASTNode *node) { // TODO: Implement this method node->condition->accept(this); - // TODO: type check + const static ExprTypeInfo standard = "bool"; node->src1->accept(this); node->src2->accept(this); - // TODO: type check + if (node->src1->expr_type_info != node->src2->expr_type_info) { + throw SemanticError("Ternary operation on different type", 1); + } + node->expr_type_info = node->src1->expr_type_info; } void ASTSemanticCheckVisitor::ActuralVisit(AssignExpr_ASTNode *node) { // TODO: Implement this method node->dest->accept(this); node->src->accept(this); - // TODO: check type and assignability + if (node->dest->expr_type_info != node->src->expr_type_info) { + throw SemanticError("Assign operation on different type", 1); + } } -void ASTSemanticCheckVisitor::ActuralVisit(ThisExpr_ASTNode *node) {} +void ASTSemanticCheckVisitor::ActuralVisit(ThisExpr_ASTNode *node) { + // TODO + node->expr_type_info = cur_class_name; +} void ASTSemanticCheckVisitor::ActuralVisit(ParenExpr_ASTNode *node) { - // TODO: Implement this method node->expr->accept(this); node->expr_type_info = node->expr->expr_type_info; } @@ -278,18 +409,32 @@ void ASTSemanticCheckVisitor::ActuralVisit(IDExpr_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(FunctionCallExpr_ASTNode *node) { // TODO: Implement this method // TODO: check function existence and arg number - for (auto arg : node->arguments) { - arg->accept(this); - // TODO: type check + auto schema = global_scope->FetchFunction(node->func_name); + std::cerr << "function to call is " << node->func_name << std::endl; + if (schema.arguments.size() != node->arguments.size()) { + throw SemanticError("Argument number mismatch", 1); } + for (auto &arg : node->arguments) { + arg->accept(this); + int idx = &arg - &node->arguments[0]; // for debug; + if (arg->expr_type_info != schema.arguments[&arg - &node->arguments[0]].first) { + throw SemanticError("Argument type mismatch", 1); + } + } + node->expr_type_info = schema.return_type; } void ASTSemanticCheckVisitor::ActuralVisit(FormattedStringExpr_ASTNode *node) { // TODO: Implement this method for (auto arg : node->exprs) { arg->accept(this); - // TODO: type check + const static ExprTypeInfo valid_types[] = {"int", "bool", "string"}; + if (arg->expr_type_info != valid_types[0] && arg->expr_type_info != valid_types[1] && + arg->expr_type_info != valid_types[2]) { + throw SemanticError("Invalid type in formatted string", 1); + } } + node->expr_type_info = "string"; } void ASTSemanticCheckVisitor::ActuralVisit(ConstantExpr_ASTNode *node) { @@ -298,6 +443,28 @@ void ASTSemanticCheckVisitor::ActuralVisit(ConstantExpr_ASTNode *node) { if (std::holds_alternative(node->expr_type_info)) { return; } else { - ; + std::string base_type; + bool found_base_type = false; + size_t found_level = 0; + std::function search = [&](ConstantExpr_ASTNode *node, int depth) { + if (std::holds_alternative(node->expr_type_info)) { + if (!found_base_type) { + found_base_type = true; + base_type = std::get(node->expr_type_info); + found_level = depth; + } else { + if (base_type != std::get(node->expr_type_info) || found_level != depth) { + throw SemanticError("Invalid const array type", 1); + } + } + } else { + const auto &sub_nodes = std::get>>(node->value); + for (auto sub_node : sub_nodes) { + search(sub_node.get(), depth + 1); + } + } + }; + search(node, 0); + node->expr_type_info = ArrayType{true, base_type, found_level}; } } diff --git a/src/semantic/semantic.cpp b/src/semantic/semantic.cpp index 2a199bf..89bbdea 100644 --- a/src/semantic/semantic.cpp +++ b/src/semantic/semantic.cpp @@ -36,7 +36,37 @@ std::shared_ptr CheckAndDecorate(std::shared_ptrclasses["string"] = nullptr; // TODO: add string class + global_scope->classes["string"] = std::make_shared(); + global_scope->classes["string"]->member_functions["length"] = std::make_shared(); + global_scope->classes["string"]->member_functions["length"]->schema.return_type = "int"; + global_scope->classes["string"]->member_functions["substring"] = std::make_shared(); + global_scope->classes["string"]->member_functions["substring"]->schema.return_type = "string"; + global_scope->classes["string"]->member_functions["substring"]->schema.arguments = {{"int", "left"}, + {"int", "right"}}; + global_scope->classes["string"]->member_functions["parseInt"] = std::make_shared(); + global_scope->classes["string"]->member_functions["parseInt"]->schema.return_type = "int"; + global_scope->classes["string"]->member_functions["ord"] = std::make_shared(); + global_scope->classes["string"]->member_functions["ord"]->schema.return_type = "int"; + global_scope->classes["string"]->member_functions["ord"]->schema.arguments = {{"int", "pos"}}; + global_scope->global_functions["print"] = std::make_shared(); + global_scope->global_functions["print"]->schema.return_type = "void"; + global_scope->global_functions["print"]->schema.arguments = {{"string", "str"}}; + global_scope->global_functions["println"] = std::make_shared(); + global_scope->global_functions["println"]->schema.return_type = "void"; + global_scope->global_functions["println"]->schema.arguments = {{"string", "str"}}; + global_scope->global_functions["printInt"] = std::make_shared(); + global_scope->global_functions["printInt"]->schema.return_type = "void"; + global_scope->global_functions["printInt"]->schema.arguments = {{"int", "n"}}; + global_scope->global_functions["printlnInt"] = std::make_shared(); + global_scope->global_functions["printlnInt"]->schema.return_type = "void"; + global_scope->global_functions["printlnInt"]->schema.arguments = {{"int", "n"}}; + global_scope->global_functions["getString"] = std::make_shared(); + global_scope->global_functions["getString"]->schema.return_type = "string"; + global_scope->global_functions["getInt"] = std::make_shared(); + global_scope->global_functions["getInt"]->schema.return_type = "int"; + global_scope->global_functions["toString"] = std::make_shared(); + global_scope->global_functions["toString"]->schema.return_type = "string"; + global_scope->global_functions["toString"]->schema.arguments = {{"int", "n"}}; visitor.visit(src.get()); return src; } diff --git a/src/semantic/visitor.cpp b/src/semantic/visitor.cpp index 2cc3766..597aaba 100644 --- a/src/semantic/visitor.cpp +++ b/src/semantic/visitor.cpp @@ -257,7 +257,7 @@ std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *con auto cur_scope = std::make_shared(); cur_scope->parent = nodetype_stk.back().second.get(); construct_func->current_scope = cur_scope; - cur_scope->schema.return_type = "null"; + cur_scope->schema.return_type = "void"; nodetype_stk.push_back({ASTNodeType::Constructor, construct_func->current_scope}); construct_func->func_body = std::dynamic_pointer_cast(