diff --git a/Makefile b/Makefile index f50091d..52c646b 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ build: # 运行目标,运行生成的可执行文件 run: - @cd $(BUILD_DIR) && ./zmxcc /dev/stdin -o /dev/null + @cd $(BUILD_DIR) && ./zmxcc /dev/stdin -o /dev/null 2>/dev/null # 清理目标 clean: diff --git a/include/ast/astnode.h b/include/ast/astnode.h index fbd9e97..2411067 100644 --- a/include/ast/astnode.h +++ b/include/ast/astnode.h @@ -4,15 +4,11 @@ #include #include #include "../semantic/visitor.h" +#include "scope.hpp" #include "tools.h" -using IdentifierType = std::string; -struct ArrayType { - bool has_base_type; - IdentifierType basetype; - size_t level; -}; -using ExprTypeInfo = std::variant; class ASTNodeVisitorBase { + friend Visitor; + public: virtual ~ASTNodeVisitorBase() = default; virtual void visit(class ASTNodeBase *context) = 0; @@ -20,6 +16,9 @@ class ASTNodeVisitorBase { class ASTNodeBase { friend Visitor; + + protected: + std::shared_ptr current_scope; ASTNodeType type; // std::vector> children; size_t start_line, start_char_pos, end_line, end_char_pos; diff --git a/include/ast/scope.hpp b/include/ast/scope.hpp new file mode 100644 index 0000000..aa3a57b --- /dev/null +++ b/include/ast/scope.hpp @@ -0,0 +1,165 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "tools.h" +class ScopeBase { + friend class Visitor; + friend class LocalScope; + friend class MemberFunctionScope; + friend class FunctionScope; + friend class ClassDefScope; + friend class GlobalScope; + + protected: + ScopeBase *parent; // cannot use std::shared_ptr because of circular dependency + virtual bool VariableNameAvailable(const std::string &name, int ttl) = 0; + virtual bool add_variable(const std::string &name, const ExprTypeInfo &type) = 0; + static inline bool IsKeyWord(const std::string &name) { + static const std::unordered_set keywords = {"void", "bool", "int", "string", "new", "class", + "null", "true", "false", "this", "if", "else", + "for", "while", "break", "continue", "return"}; + return keywords.find(name) != keywords.end(); + } +}; +class LocalScope : public ScopeBase { + friend class Visitor; + std::unordered_map local_variables; + bool add_variable(const std::string &name, const ExprTypeInfo &type) override { + if (!VariableNameAvailable(name, 0)) { + throw std::runtime_error("Variable name " + name + " is not available"); + } + local_variables[name] = type; + return true; + } + bool VariableNameAvailable(const std::string &name, int ttl) override { + if (ttl == 0 && IsKeyWord(name)) { + return false; + } + if (ttl == 0) { + if (local_variables.find(name) != local_variables.end()) { + return false; + } + } + return parent->VariableNameAvailable(name, ttl + 1); + } +}; +struct FunctionSchema { + friend class Visitor; + ExprTypeInfo return_type; + std::vector> arguments; +}; +class FunctionScope : public ScopeBase { + friend class Visitor; + FunctionSchema schema; + bool add_variable([[maybe_unused]] const std::string &name, [[maybe_unused]] const ExprTypeInfo &type) override { + throw std::runtime_error("FunctionScope does not support add_variable"); + } + bool VariableNameAvailable(const std::string &name, int ttl) override { + if (ttl == 0 && IsKeyWord(name)) { + return false; + } + if (ttl == 1) { + for (const auto &arg : schema.arguments) { + if (arg.second == name) { + return false; + } + } + } + return parent->VariableNameAvailable(name, ttl + 1); + } +}; +class ClassDefScope : public ScopeBase { + friend class Visitor; + std::unordered_map member_variables; + std::unordered_map> member_functions; + bool add_variable(const std::string &name, const ExprTypeInfo &type) override { + if (!VariableNameAvailable(name, 0)) { + throw std::runtime_error("Variable name " + name + " is not available"); + } + member_variables[name] = type; + return true; + } + bool add_function(const std::string &name, std::shared_ptr ptr) { + if (IsKeyWord(name)) return false; + if (member_variables.find(name) != member_variables.end()) { + return false; + } + if (member_functions.find(name) != member_functions.end()) { + return false; + } + member_functions[name] = ptr; + return true; + } + bool VariableNameAvailable(const std::string &name, int ttl) override { + if (ttl == 0 && IsKeyWord(name)) { + return false; + } + if (member_functions.find(name) != member_functions.end()) { + return false; + } + return parent->VariableNameAvailable(name, ttl + 1); + } +}; +class GlobalScope : public ScopeBase { + friend class Visitor; + std::unordered_map global_variables; + std::unordered_map> global_functions; + std::unordered_map> classes; + bool add_class(const std::string &name, std::shared_ptr ptr) { + if (IsKeyWord(name)) return false; + if (classes.find(name) != classes.end()) { + return false; + } + if (global_functions.find(name) != global_functions.end()) { + return false; + } + if (global_variables.find(name) != global_variables.end()) { + return false; + } + classes[name] = ptr; + return true; + } + bool add_function(const std::string &name, std::shared_ptr ptr) { + if (IsKeyWord(name)) return false; + if (classes.find(name) != classes.end()) { + return false; + } + if (global_functions.find(name) != global_functions.end()) { + return false; + } + if (global_variables.find(name) != global_variables.end()) { + return false; + } + global_functions[name] = ptr; + return true; + } + bool add_variable(const std::string &name, const ExprTypeInfo &type) override { + if (!VariableNameAvailable(name, 0)) { + throw std::runtime_error("Variable name " + name + " is not available"); + } + global_variables[name] = type; + return true; + } + bool VariableNameAvailable(const std::string &name, [[maybe_unused]] int ttl) override { + if (ttl == 0 && IsKeyWord(name)) { + return false; + } + if (global_variables.find(name) != global_variables.end()) { + return false; + } + if (classes.find(name) != classes.end()) { + return false; + } + if (classes.find(name) != classes.end()) { + return false; + } + return true; + } + + public: + GlobalScope() { parent = nullptr; } +}; \ No newline at end of file diff --git a/include/semantic/visitor.h b/include/semantic/visitor.h index 0ccfe3f..e4b2c80 100644 --- a/include/semantic/visitor.h +++ b/include/semantic/visitor.h @@ -4,9 +4,10 @@ #include #include #include "MXParserVisitor.h" +#include "ast/scope.hpp" #include "tools.h" class Visitor : public MXParserVisitor { - std::vector nodetype_stk; + std::vector>> nodetype_stk; public: std::any visitMxprog(MXParser::MxprogContext *context) override; diff --git a/include/tools.h b/include/tools.h index 60ac182..e01f4a1 100644 --- a/include/tools.h +++ b/include/tools.h @@ -1,5 +1,6 @@ #pragma once #include +#include enum class ASTNodeType { // Expression nodes NewArrayExpr, @@ -58,4 +59,12 @@ class SemanticError : public std::exception { SemanticError(const std::string &msg, int error_code) : msg(msg), error_code(error_code) {} const char *what() const noexcept override { return msg.c_str(); } int GetErrorCode() const { return error_code; } -}; \ No newline at end of file +}; + +using IdentifierType = std::string; +struct ArrayType { + bool has_base_type; + IdentifierType basetype; + size_t level; +}; +using ExprTypeInfo = std::variant; \ No newline at end of file diff --git a/src/semantic/semantic.cpp b/src/semantic/semantic.cpp index 564aa7c..6ddcdaf 100644 --- a/src/semantic/semantic.cpp +++ b/src/semantic/semantic.cpp @@ -27,14 +27,16 @@ std::shared_ptr CheckAndDecorate(std::shared_ptr &ast_out) { antlr4::ANTLRInputStream input(fin); MXLexer lexer(&input); + MXErrorListener error_listener; + lexer.removeErrorListeners(); + lexer.addErrorListener(&error_listener); antlr4::CommonTokenStream tokens(&lexer); tokens.fill(); MXParser parser(&tokens); parser.removeErrorListeners(); - MXErrorListener error_listener; parser.addErrorListener(&error_listener); antlr4::tree::ParseTree *tree = parser.mxprog(); - if (!error_listener.IsOk()) throw SemanticError("Fatal error: syntax error", 1); + if (!error_listener.IsOk()) throw SemanticError("Invalid Identifier", 1); Visitor visitor; std::shared_ptr ast = BuildAST(&visitor, tree); ast_out = CheckAndDecorate(ast); diff --git a/src/semantic/visitor.cpp b/src/semantic/visitor.cpp index 4dc7a4f..90835d2 100644 --- a/src/semantic/visitor.cpp +++ b/src/semantic/visitor.cpp @@ -4,6 +4,7 @@ #include "MXParser.h" #include "MXParserVisitor.h" #include "ast/ast.h" +#include "ast/scope.hpp" #include "tools.h" std::any Visitor::visitMxprog(MXParser::MxprogContext *context) { @@ -13,7 +14,9 @@ std::any Visitor::visitMxprog(MXParser::MxprogContext *context) { program->start_char_pos = context->getStart()->getCharPositionInLine(); program->end_line = context->getStop()->getLine(); program->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::Program); + auto cur_scope = std::make_shared(); + program->current_scope = cur_scope; + nodetype_stk.push_back({ASTNodeType::Program, cur_scope}); for (auto def : context->children) { if (auto classDefContext = dynamic_cast(def)) { @@ -46,7 +49,10 @@ std::any Visitor::visitFunction_def(MXParser::Function_defContext *context) { func->start_char_pos = context->getStart()->getCharPositionInLine(); func->end_line = context->getStop()->getLine(); func->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::FuncDef); + auto cur_scope = std::make_shared(); + cur_scope->parent = nodetype_stk.back().second.get(); + func->current_scope = cur_scope; + nodetype_stk.push_back({ASTNodeType::FuncDef, func->current_scope}); std::string return_type_str; if (auto type_context = dynamic_cast(context->children[0])) { @@ -73,6 +79,7 @@ std::any Visitor::visitFunction_def(MXParser::Function_defContext *context) { std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded return type is [array]" << return_type_str << " with dimensions=" << return_dimensions << std::endl; } + cur_scope->schema.return_type = func->return_type; func->func_name = context->children[1 + 2 * return_dimensions]->getText(); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "func_name=" << func->func_name << std::endl; size_t cur = 3 + 2 * return_dimensions; @@ -115,8 +122,19 @@ std::any Visitor::visitFunction_def(MXParser::Function_defContext *context) { << std::endl; } func->params.push_back(std::make_pair(cur_para_name, cur_para_type)); + cur_scope->schema.arguments.push_back(std::make_pair(cur_para_type, cur_para_name)); cur++; } + if (ClassDefScope *cparent = dynamic_cast(cur_scope->parent)) { + if (!cparent->add_function(func->func_name, cur_scope)) { + throw SemanticError("Function name " + func->func_name + " is not available", 1); + } + } else if (GlobalScope *gparent = dynamic_cast(cur_scope->parent)) { + if (!gparent->add_function(func->func_name, cur_scope)) { + throw SemanticError("Function name " + func->func_name + " is not available", 1); + } + } else + throw std::runtime_error("unknown parent scope type"); func->func_body = std::dynamic_pointer_cast( std::any_cast>(visit(context->suite()))); nodetype_stk.pop_back(); @@ -129,7 +147,15 @@ std::any Visitor::visitClass_def(MXParser::Class_defContext *context) { class_def->start_char_pos = context->getStart()->getCharPositionInLine(); class_def->end_line = context->getStop()->getLine(); class_def->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ClassDef); + auto cur_scope = std::make_shared(); + cur_scope->parent = nodetype_stk.back().second.get(); + class_def->current_scope = cur_scope; + GlobalScope *gparent = dynamic_cast(cur_scope->parent); + assert(gparent != nullptr); + if (!gparent->add_class(context->ID()->getText(), cur_scope)) { + throw SemanticError("Class name " + context->ID()->getText() + " is not available", 1); + } + nodetype_stk.push_back({ASTNodeType::ClassDef, class_def->current_scope}); class_def->class_name = context->ID()->getText(); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "building a class named " << class_def->class_name @@ -160,7 +186,8 @@ std::any Visitor::visitClass_var_def(MXParser::Class_var_defContext *context) { member_var_def->start_char_pos = context->getStart()->getCharPositionInLine(); member_var_def->end_line = context->getStop()->getLine(); member_var_def->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ClassVariable); + member_var_def->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::ClassVariable, member_var_def->current_scope}); std::string define_type_base; if (auto type_context = dynamic_cast(context->children[0])) { @@ -190,6 +217,9 @@ std::any Visitor::visitClass_var_def(MXParser::Class_var_defContext *context) { member_var_def->vars.push_back(std::make_pair(id->getText(), nullptr)); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded member variable name is " << id->getText() << std::endl; + if (!member_var_def->current_scope->add_variable(id->getText(), member_var_def->var_type)) { + throw SemanticError("Variable name " + id->getText() + " is not available", 1); + } } nodetype_stk.pop_back(); @@ -203,7 +233,11 @@ std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *con construct_func->start_char_pos = context->getStart()->getCharPositionInLine(); construct_func->end_line = context->getStop()->getLine(); construct_func->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::Constructor); + auto cur_scope = std::make_shared(); + cur_scope->parent = nodetype_stk.back().second.get(); + construct_func->current_scope = cur_scope; + cur_scope->schema.return_type = "null"; + nodetype_stk.push_back({ASTNodeType::Constructor, construct_func->current_scope}); construct_func->func_body = std::dynamic_pointer_cast( std::any_cast>(visit(context->suite()))); @@ -219,7 +253,12 @@ std::any Visitor::visitSuite(MXParser::SuiteContext *context) { suite_node->start_char_pos = context->getStart()->getCharPositionInLine(); suite_node->end_line = context->getStop()->getLine(); suite_node->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::SuiteStatement); + auto cur_scope = std::make_shared(); + cur_scope->parent = nodetype_stk.back().second.get(); + suite_node->current_scope = cur_scope; + assert(nodetype_stk.size() > 0); + cur_scope->parent = nodetype_stk.back().second.get(); + nodetype_stk.push_back({ASTNodeType::SuiteStatement, cur_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding suite statements" << std::endl; std::vector stmts = context->statement(); @@ -250,7 +289,9 @@ std::any Visitor::visitExpr_statement(MXParser::Expr_statementContext *context) expr_stmt->start_char_pos = context->getStart()->getCharPositionInLine(); expr_stmt->end_line = context->getStop()->getLine(); expr_stmt->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ExprStatement); + assert(nodetype_stk.size() > 0); + expr_stmt->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::ExprStatement, expr_stmt->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding an expression statement" << std::endl; expr_stmt->expr = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); @@ -263,7 +304,9 @@ std::any Visitor::visitIf_statement(MXParser::If_statementContext *context) { if_stmt->start_char_pos = context->getStart()->getCharPositionInLine(); if_stmt->end_line = context->getStop()->getLine(); if_stmt->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::IfStatement); + assert(nodetype_stk.size() > 0); + if_stmt->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::IfStatement, if_stmt->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding an if statement" << std::endl; if_stmt->condition = std::any_cast>(visit(context->expr())); std::vector sub_stmts = context->statement(); @@ -285,7 +328,9 @@ std::any Visitor::visitWhile_statement(MXParser::While_statementContext *context while_stmt->start_char_pos = context->getStart()->getCharPositionInLine(); while_stmt->end_line = context->getStop()->getLine(); while_stmt->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::WhileStatement); + assert(nodetype_stk.size() > 0); + while_stmt->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::WhileStatement, while_stmt->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding a while statement" << std::endl; while_stmt->condition = std::any_cast>(visit(context->expr())); while_stmt->loop_body = std::any_cast>(visit(context->statement())); @@ -299,7 +344,10 @@ std::any Visitor::visitFor_statement(MXParser::For_statementContext *context) { for_stmt->start_char_pos = context->getStart()->getCharPositionInLine(); for_stmt->end_line = context->getStop()->getLine(); for_stmt->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ForStatement); + auto cur_scope = std::make_shared(); + cur_scope->parent = nodetype_stk.back().second.get(); + for_stmt->current_scope = cur_scope; + nodetype_stk.push_back({ASTNodeType::ForStatement, for_stmt->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding a for statement" << std::endl; size_t cur = 2; if (dynamic_cast(context->children[cur]) != nullptr && @@ -371,7 +419,9 @@ std::any Visitor::visitJmp_statement(MXParser::Jmp_statementContext *context) { jmp_stmt->start_char_pos = context->getStart()->getCharPositionInLine(); jmp_stmt->end_line = context->getStop()->getLine(); jmp_stmt->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::JmpStatement); + assert(nodetype_stk.size() > 0); + jmp_stmt->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::JmpStatement, jmp_stmt->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding a jmp statement" << std::endl; if (context->RETURN() != nullptr) { jmp_stmt->jmp_type = 0; @@ -401,7 +451,9 @@ std::any Visitor::visitDefine_statement(MXParser::Define_statementContext *conte def_stmt->start_char_pos = context->getStart()->getCharPositionInLine(); def_stmt->end_line = context->getStop()->getLine(); def_stmt->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::DefinitionStatement); + assert(nodetype_stk.size() > 0); + def_stmt->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::DefinitionStatement, def_stmt->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding a definition statement" << std::endl; std::string define_type_base; if (auto type_context = dynamic_cast(context->children[0])) { @@ -436,6 +488,8 @@ std::any Visitor::visitDefine_statement(MXParser::Define_statementContext *conte if (dynamic_cast(context->children[i]) != nullptr && dynamic_cast(context->children[i])->getSymbol()->getType() == MXParser::ID) { def_stmt->vars.push_back(std::make_pair(context->children[i]->getText(), nullptr)); + if (!def_stmt->current_scope->add_variable(context->children[i]->getText(), def_stmt->var_type)) + throw SemanticError("Variable " + context->children[i]->getText() + " already defined", 1); } else throw std::runtime_error("unknown subnode occurred in visitDefine_statement"); i++; @@ -456,7 +510,9 @@ std::any Visitor::visitGgll_expression(MXParser::Ggll_expressionContext *context ggll_expr->start_char_pos = context->getStart()->getCharPositionInLine(); ggll_expr->end_line = context->getStop()->getLine(); ggll_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::GGLLExpr); + assert(nodetype_stk.size() > 0); + ggll_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::GGLLExpr, ggll_expr->current_scope}); ggll_expr->op = context->children[1]->getText(); ggll_expr->left = std::any_cast>(visit(context->children[0])); ggll_expr->right = std::any_cast>(visit(context->children[2])); @@ -470,7 +526,8 @@ std::any Visitor::visitBxor_expression(MXParser::Bxor_expressionContext *context bxor_expr->start_char_pos = context->getStart()->getCharPositionInLine(); bxor_expr->end_line = context->getStop()->getLine(); bxor_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::BXorExpr); + bxor_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::BXorExpr, bxor_expr->current_scope}); bxor_expr->op = context->children[1]->getText(); bxor_expr->left = std::any_cast>(visit(context->children[0])); bxor_expr->right = std::any_cast>(visit(context->children[2])); @@ -484,7 +541,8 @@ std::any Visitor::visitSuffix_expression(MXParser::Suffix_expressionContext *con suffix_expr->start_char_pos = context->getStart()->getCharPositionInLine(); suffix_expr->end_line = context->getStop()->getLine(); suffix_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::SuffixExpr); + suffix_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::SuffixExpr, suffix_expr->current_scope}); suffix_expr->op = context->children[1]->getText(); suffix_expr->base = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); @@ -497,7 +555,8 @@ std::any Visitor::visitLand_expression(MXParser::Land_expressionContext *context land_expr->start_char_pos = context->getStart()->getCharPositionInLine(); land_expr->end_line = context->getStop()->getLine(); land_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::LAndExpr); + land_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::LAndExpr, land_expr->current_scope}); land_expr->op = context->children[1]->getText(); land_expr->left = std::any_cast>(visit(context->children[0])); land_expr->right = std::any_cast>(visit(context->children[2])); @@ -511,7 +570,8 @@ std::any Visitor::visitPm_expression(MXParser::Pm_expressionContext *context) { pm_expr->start_char_pos = context->getStart()->getCharPositionInLine(); pm_expr->end_line = context->getStop()->getLine(); pm_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::PMExpr); + pm_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::PMExpr, pm_expr->current_scope}); pm_expr->op = context->children[1]->getText(); pm_expr->left = std::any_cast>(visit(context->children[0])); pm_expr->right = std::any_cast>(visit(context->children[2])); @@ -525,7 +585,8 @@ std::any Visitor::visitIndex_expression(MXParser::Index_expressionContext *conte idx_expr->start_char_pos = context->getStart()->getCharPositionInLine(); idx_expr->end_line = context->getStop()->getLine(); idx_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::IndexExpr); + idx_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::IndexExpr, idx_expr->current_scope}); auto sub_exprs = context->expr(); idx_expr->base = std::any_cast>(visit(sub_exprs[0])); for (size_t i = 1; i < sub_exprs.size(); i++) { @@ -541,7 +602,8 @@ std::any Visitor::visitOpposite_expression(MXParser::Opposite_expressionContext oppsite_expr->start_char_pos = context->getStart()->getCharPositionInLine(); oppsite_expr->end_line = context->getStop()->getLine(); oppsite_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::OppositeExpr); + oppsite_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::OppositeExpr, oppsite_expr->current_scope}); oppsite_expr->base = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); return std::static_pointer_cast(oppsite_expr); @@ -553,7 +615,8 @@ std::any Visitor::visitNew_array_expression(MXParser::New_array_expressionContex new_array->start_char_pos = context->getStart()->getCharPositionInLine(); new_array->end_line = context->getStop()->getLine(); new_array->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::NewArrayExpr); + new_array->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::NewArrayExpr, new_array->current_scope}); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "Adding a new array expression" << std::endl; size_t total_dimensions = context->LBRACKET().size(); @@ -596,7 +659,8 @@ std::any Visitor::visitAccess_expression(MXParser::Access_expressionContext *con access_expr->type = ASTNodeType::MemberFunctionAccessExpr; access_expr->is_function = true; } - nodetype_stk.push_back(access_expr->type); + access_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({access_expr->type, access_expr->current_scope}); auto sub_exprs = context->expr(); access_expr->base = std::any_cast>(visit(sub_exprs[0])); access_expr->member = context->ID()->getText(); @@ -613,7 +677,8 @@ std::any Visitor::visitBand_expression(MXParser::Band_expressionContext *context band_expr->start_char_pos = context->getStart()->getCharPositionInLine(); band_expr->end_line = context->getStop()->getLine(); band_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::BAndExpr); + band_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::BAndExpr, band_expr->current_scope}); band_expr->op = context->children[1]->getText(); band_expr->left = std::any_cast>(visit(context->children[0])); band_expr->right = std::any_cast>(visit(context->children[2])); @@ -627,7 +692,8 @@ std::any Visitor::visitNew_construct_expression(MXParser::New_construct_expressi new_construct_expr->start_char_pos = context->getStart()->getCharPositionInLine(); new_construct_expr->end_line = context->getStop()->getLine(); new_construct_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::NewConstructExpr); + new_construct_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::NewConstructExpr, new_construct_expr->current_scope}); new_construct_expr->expr_type_info = context->type()->getText(); nodetype_stk.pop_back(); return std::static_pointer_cast(new_construct_expr); @@ -639,7 +705,8 @@ std::any Visitor::visitTernary_expression(MXParser::Ternary_expressionContext *c ternary_expr->start_char_pos = context->getStart()->getCharPositionInLine(); ternary_expr->end_line = context->getStop()->getLine(); ternary_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::TernaryExpr); + ternary_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::TernaryExpr, ternary_expr->current_scope}); auto expr_subnodes = context->expr(); ternary_expr->condition = std::any_cast>(visit(expr_subnodes[0])); ternary_expr->src1 = std::any_cast>(visit(expr_subnodes[1])); @@ -654,7 +721,8 @@ std::any Visitor::visitBnot_expression(MXParser::Bnot_expressionContext *context bnot_expr->start_char_pos = context->getStart()->getCharPositionInLine(); bnot_expr->end_line = context->getStop()->getLine(); bnot_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::BNotExpr); + bnot_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::BNotExpr, bnot_expr->current_scope}); bnot_expr->base = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); return std::static_pointer_cast(bnot_expr); @@ -666,7 +734,8 @@ std::any Visitor::visitLnot_expression(MXParser::Lnot_expressionContext *context lnot_expr->start_char_pos = context->getStart()->getCharPositionInLine(); lnot_expr->end_line = context->getStop()->getLine(); lnot_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::LNotExpr); + lnot_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::LNotExpr, lnot_expr->current_scope}); lnot_expr->base = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); return std::static_pointer_cast(lnot_expr); @@ -678,7 +747,8 @@ std::any Visitor::visitPrefix_expression(MXParser::Prefix_expressionContext *con prefix_expr->start_char_pos = context->getStart()->getCharPositionInLine(); prefix_expr->end_line = context->getStop()->getLine(); prefix_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::PrefixExpr); + prefix_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::PrefixExpr, prefix_expr->current_scope}); prefix_expr->op = context->children[0]->getText(); prefix_expr->base = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); @@ -691,7 +761,8 @@ std::any Visitor::visitRl_expression(MXParser::Rl_expressionContext *context) { rl_expr->start_char_pos = context->getStart()->getCharPositionInLine(); rl_expr->end_line = context->getStop()->getLine(); rl_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::RLExpr); + rl_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::RLExpr, rl_expr->current_scope}); rl_expr->op = context->children[1]->getText(); rl_expr->left = std::any_cast>(visit(context->children[0])); rl_expr->right = std::any_cast>(visit(context->children[2])); @@ -705,7 +776,8 @@ std::any Visitor::visitAssign_expression(MXParser::Assign_expressionContext *con assign_expr->start_char_pos = context->getStart()->getCharPositionInLine(); assign_expr->end_line = context->getStop()->getLine(); assign_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::AssignExpr); + assign_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::AssignExpr, assign_expr->current_scope}); assign_expr->dest = std::any_cast>(visit(context->expr(0))); assign_expr->src = std::any_cast>(visit(context->expr(1))); nodetype_stk.pop_back(); @@ -718,7 +790,8 @@ std::any Visitor::visitMdm_expression(MXParser::Mdm_expressionContext *context) mdm_expr->start_char_pos = context->getStart()->getCharPositionInLine(); mdm_expr->end_line = context->getStop()->getLine(); mdm_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::MDMExpr); + mdm_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::MDMExpr, mdm_expr->current_scope}); mdm_expr->op = context->children[1]->getText(); mdm_expr->left = std::any_cast>(visit(context->children[0])); mdm_expr->right = std::any_cast>(visit(context->children[2])); @@ -732,7 +805,8 @@ std::any Visitor::visitNew_expression(MXParser::New_expressionContext *context) new_expr->start_char_pos = context->getStart()->getCharPositionInLine(); new_expr->end_line = context->getStop()->getLine(); new_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::NewExpr); + new_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::NewExpr, new_expr->current_scope}); new_expr->expr_type_info = context->type()->getText(); nodetype_stk.pop_back(); return std::static_pointer_cast(new_expr); @@ -744,7 +818,8 @@ std::any Visitor::visitNe_expression(MXParser::Ne_expressionContext *context) { ne_expr->start_char_pos = context->getStart()->getCharPositionInLine(); ne_expr->end_line = context->getStop()->getLine(); ne_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::NEExpr); + ne_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::NEExpr, ne_expr->current_scope}); ne_expr->op = context->children[1]->getText(); ne_expr->left = std::any_cast>(visit(context->children[0])); ne_expr->right = std::any_cast>(visit(context->children[2])); @@ -758,7 +833,8 @@ std::any Visitor::visitBor_expression(MXParser::Bor_expressionContext *context) bor_expr->start_char_pos = context->getStart()->getCharPositionInLine(); bor_expr->end_line = context->getStop()->getLine(); bor_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::BOrExpr); + bor_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::BOrExpr, bor_expr->current_scope}); bor_expr->op = context->children[1]->getText(); bor_expr->left = std::any_cast>(visit(context->children[0])); bor_expr->right = std::any_cast>(visit(context->children[2])); @@ -772,7 +848,8 @@ std::any Visitor::visitLor_expression(MXParser::Lor_expressionContext *context) lor_expr->start_char_pos = context->getStart()->getCharPositionInLine(); lor_expr->end_line = context->getStop()->getLine(); lor_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::LOrExpr); + lor_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::LOrExpr, lor_expr->current_scope}); lor_expr->op = context->children[1]->getText(); lor_expr->left = std::any_cast>(visit(context->children[0])); lor_expr->right = std::any_cast>(visit(context->children[2])); @@ -786,7 +863,8 @@ std::any Visitor::visitThis_expr(MXParser::This_exprContext *context) { this_expr->start_char_pos = context->getStart()->getCharPositionInLine(); this_expr->end_line = context->getStop()->getLine(); this_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ThisExpr); + this_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::ThisExpr, this_expr->current_scope}); nodetype_stk.pop_back(); return std::static_pointer_cast(this_expr); } @@ -797,7 +875,8 @@ std::any Visitor::visitParen_expr(MXParser::Paren_exprContext *context) { paren_expr->start_char_pos = context->getStart()->getCharPositionInLine(); paren_expr->end_line = context->getStop()->getLine(); paren_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ParenExpr); + paren_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::ParenExpr, paren_expr->current_scope}); paren_expr->expr = std::any_cast>(visit(context->expr())); nodetype_stk.pop_back(); return std::static_pointer_cast(paren_expr); @@ -809,7 +888,8 @@ std::any Visitor::visitId_expr(MXParser::Id_exprContext *context) { id_expr->start_char_pos = context->getStart()->getCharPositionInLine(); id_expr->end_line = context->getStop()->getLine(); id_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::IDExpr); + id_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::IDExpr, id_expr->current_scope}); id_expr->id = IdentifierType{context->ID()->getText()}; nodetype_stk.pop_back(); return std::static_pointer_cast(id_expr); @@ -821,7 +901,8 @@ std::any Visitor::visitFunction_call_expr(MXParser::Function_call_exprContext *c func_call_expr->start_char_pos = context->getStart()->getCharPositionInLine(); func_call_expr->end_line = context->getStop()->getLine(); func_call_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::FunctionCallExpr); + func_call_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::FunctionCallExpr, func_call_expr->current_scope}); func_call_expr->func_name = IdentifierType{context->ID()->getText()}; auto expr_subnodes = context->expr(); for (auto expr_subnode : expr_subnodes) { @@ -841,7 +922,8 @@ std::any Visitor::visitFormatted_string(MXParser::Formatted_stringContext *conte fmt_expr->start_char_pos = context->getStart()->getCharPositionInLine(); fmt_expr->end_line = context->getStop()->getLine(); fmt_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::FormattedStringExpr); + fmt_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::FormattedStringExpr, fmt_expr->current_scope}); if (context->FORMAT_STRING_WHOLE() != nullptr) { fmt_expr->literals.push_back(context->FORMAT_STRING_WHOLE()->getText()); @@ -868,7 +950,8 @@ std::any Visitor::visitConstant(MXParser::ConstantContext *context) { constant_expr->start_char_pos = context->getStart()->getCharPositionInLine(); constant_expr->end_line = context->getStop()->getLine(); constant_expr->end_char_pos = context->getStop()->getCharPositionInLine(); - nodetype_stk.push_back(ASTNodeType::ConstantExpr); + constant_expr->current_scope = nodetype_stk.back().second; + nodetype_stk.push_back({ASTNodeType::ConstantExpr, constant_expr->current_scope}); if (context->TRUE() != nullptr || context->FALSE() != nullptr) { constant_expr->expr_type_info = "bool";