diff --git a/include/ast/astnode.h b/include/ast/astnode.h index 57ac448..fbd9e97 100644 --- a/include/ast/astnode.h +++ b/include/ast/astnode.h @@ -3,6 +3,8 @@ #include #include #include +#include "../semantic/visitor.h" +#include "tools.h" using IdentifierType = std::string; struct ArrayType { bool has_base_type; @@ -15,17 +17,15 @@ class ASTNodeVisitorBase { virtual ~ASTNodeVisitorBase() = default; virtual void visit(class ASTNodeBase *context) = 0; }; -enum class ASTNodeType { - -}; class ASTNodeBase { + friend Visitor; ASTNodeType type; // std::vector> children; size_t start_line, start_char_pos, end_line, end_char_pos; public: virtual ~ASTNodeBase() = default; - virtual void accept(class ASTNodeVisitorBase *visitor) = 0; + // virtual void accept(class ASTNodeVisitorBase *visitor) = 0; }; #endif \ No newline at end of file diff --git a/include/ast/expr_astnode.h b/include/ast/expr_astnode.h index 0b5da73..71922b6 100644 --- a/include/ast/expr_astnode.h +++ b/include/ast/expr_astnode.h @@ -14,112 +14,147 @@ class Expr_ASTNode : public ASTNodeBase { class BasicExpr_ASTNode : public Expr_ASTNode {}; // This is a virtual class class NewArrayExpr_ASTNode : public Expr_ASTNode { + friend Visitor; bool has_initial_value; std::shared_ptr initial_value; }; -class NewConstructExpr_ASTNode : public Expr_ASTNode {}; -class NewExpr_ASTNode : public Expr_ASTNode {}; +class NewConstructExpr_ASTNode : public Expr_ASTNode { + friend Visitor; +}; +class NewExpr_ASTNode : public Expr_ASTNode { + friend Visitor; +}; class AccessExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr base; IdentifierType member; }; -class MemberVariableAccessExpr_ASTNode : public AccessExpr_ASTNode {}; +class MemberVariableAccessExpr_ASTNode : public AccessExpr_ASTNode { + friend Visitor; +}; class MemberFunctionAccessExpr_ASTNode : public AccessExpr_ASTNode { + friend Visitor; std::vector> arguments; }; class IndexExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr base; std::vector> indices; }; class SuffixExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr base; }; class PrefixExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr base; }; class OppositeExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr base; }; class LNotExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class BNotExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class MDMExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class PMExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class RLExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class GGLLExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class NEExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class BAndExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class BXorExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class BOrExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class LAndExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class LOrExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr left; std::shared_ptr right; }; class TernaryExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr condition; std::shared_ptr src1; std::shared_ptr src2; }; class AssignExpr_ASTNode : public Expr_ASTNode { + friend Visitor; std::shared_ptr dest; std::shared_ptr src; }; -class ThisExpr_ASTNode : public BasicExpr_ASTNode {}; +class ThisExpr_ASTNode : public BasicExpr_ASTNode { + friend Visitor; +}; class ParenExpr_ASTNode : public BasicExpr_ASTNode { + friend Visitor; std::shared_ptr expr; }; class IDExpr_ASTNode : public BasicExpr_ASTNode { + friend Visitor; IdentifierType id; }; class FunctionCallExpr_ASTNode : public BasicExpr_ASTNode { + friend Visitor; IdentifierType func_name; std::vector> arguments; }; class FormattedStringExpr_ASTNode : public BasicExpr_ASTNode { + friend Visitor; using SegmentType = std::variant>; std::vector segments; }; struct NullType {}; using AtomicConstantType = std::variant; struct ArrayConstantType { + friend Visitor; std::vector, NullType, AtomicConstantType>> elements; size_t level; }; class ConstantExpr_ASTNode : public BasicExpr_ASTNode { + friend Visitor; std::variant value; }; diff --git a/include/ast/statement_astnode.h b/include/ast/statement_astnode.h index 858dc7b..0b7e48c 100644 --- a/include/ast/statement_astnode.h +++ b/include/ast/statement_astnode.h @@ -8,34 +8,43 @@ class Statement_ASTNode : public ASTNodeBase { virtual ~Statement_ASTNode() = default; }; -class EmptyStatement_ASTNode : public Statement_ASTNode {}; +class EmptyStatement_ASTNode : public Statement_ASTNode { + friend Visitor; +}; class DefinitionStatement_ASTNode : public Statement_ASTNode { - ExprTypeInfo type; + friend Visitor; + ExprTypeInfo var_type; std::vector>> vars; }; class ExprStatement_ASTNode : public Statement_ASTNode { + friend Visitor; std::shared_ptr expr; }; class IfStatement_ASTNode : public Statement_ASTNode { + friend Visitor; bool has_else_clause; std::shared_ptr condition; std::shared_ptr if_clause; std::shared_ptr else_clause; }; class WhileStatement_ASTNode : public Statement_ASTNode { + friend Visitor; std::shared_ptr condition; std::shared_ptr loop_body; }; class ForStatement_ASTNode : public Statement_ASTNode { + friend Visitor; std::shared_ptr initial; std::shared_ptr condition; std::shared_ptr update; std::shared_ptr loop_body; }; class JmpStatement_ASTNode : public Statement_ASTNode { + friend Visitor; std::shared_ptr return_value; }; class SuiteStatement_ASTNode : public Statement_ASTNode { + friend Visitor; std::vector> statements; }; diff --git a/include/ast/structural_astnode.h b/include/ast/structural_astnode.h index 99d7fb0..5c2d3bf 100644 --- a/include/ast/structural_astnode.h +++ b/include/ast/structural_astnode.h @@ -7,21 +7,35 @@ #include "expr_astnode.h" #include "statement_astnode.h" class FuncDef_ASTNode : public ASTNodeBase { + friend Visitor; + bool is_constructor; IdentifierType func_name; + ExprTypeInfo return_type; std::vector> params; std::shared_ptr func_body; + + public: + FuncDef_ASTNode() = default; }; -class Constructor_ASTNode : public FuncDef_ASTNode {}; -class ClassVariable_ASTNode : public DefinitionStatement_ASTNode {}; class ClassDef_ASTNode : public ASTNodeBase { + friend Visitor; + private: - using ClassElement = std::variant, std::shared_ptr, - std::shared_ptr>; - std::vector elements; + std::string class_name; + std::vector> member_variables; + std::vector> member_functions; + std::shared_ptr constructor; + + public: + ClassDef_ASTNode() = default; }; class Program_ASTNode : public ASTNodeBase { - using ProgramElement = std::variant, std::shared_ptr, - std::shared_ptr>; - std::vector elements; + friend Visitor; + std::vector> global_variables; + std::vector> classes; + std::vector> functions; + + public: + Program_ASTNode() = default; }; #endif \ No newline at end of file diff --git a/include/semantic/semantic.h b/include/semantic/semantic.h index 6c45d2d..0bd3290 100644 --- a/include/semantic/semantic.h +++ b/include/semantic/semantic.h @@ -4,15 +4,6 @@ #include "ast/ast.h" #include "visitor.h" -class SemanticError : public std::exception { - std::string msg; - int error_code; - - public: - SemanticError(const std::string &msg, int error_code) : msg(msg), error_code(error_code) {} - const char *what() const noexcept override { return msg.c_str(); } - int GetErrorCode() const { return error_code; } -}; std::shared_ptr BuildAST(Visitor *visitor, antlr4::tree::ParseTree *tree); void SemanticCheck(std::istream &fin, std::shared_ptr &ast); #endif \ No newline at end of file diff --git a/include/semantic/visitor.h b/include/semantic/visitor.h index e22511e..0ccfe3f 100644 --- a/include/semantic/visitor.h +++ b/include/semantic/visitor.h @@ -1,9 +1,13 @@ #ifndef VISITOR_H #define VISITOR_H +#include +#include #include "MXParserVisitor.h" - +#include "tools.h" class Visitor : public MXParserVisitor { + std::vector nodetype_stk; + public: std::any visitMxprog(MXParser::MxprogContext *context) override; std::any visitFunction_def(MXParser::Function_defContext *context) override; diff --git a/include/tools.h b/include/tools.h new file mode 100644 index 0000000..60ac182 --- /dev/null +++ b/include/tools.h @@ -0,0 +1,61 @@ +#pragma once +#include +enum class ASTNodeType { + // Expression nodes + NewArrayExpr, + NewConstructExpr, + NewExpr, + MemberVariableAccessExpr, + MemberFunctionAccessExpr, + IndexExpr, + SuffixExpr, + PrefixExpr, + OppositeExpr, + LNotExpr, + BNotExpr, + MDMExpr, + PMExpr, + RLExpr, + GGLLExpr, + NEExpr, + BAndExpr, + BXorExpr, + BOrExpr, + LAndExpr, + LOrExpr, + TernaryExpr, + AssignExpr, + ThisExpr, + ParenExpr, + IDExpr, + FunctionCallExpr, + FormattedStringExpr, + ConstantExpr, + + // Statement nodes + EmptyStatement, + DefinitionStatement, + ExprStatement, + IfStatement, + WhileStatement, + ForStatement, + JmpStatement, + SuiteStatement, + + // Structural nodes + FuncDef, + Constructor, + ClassVariable, + ClassDef, + Program +}; + +class SemanticError : public std::exception { + std::string msg; + int error_code; + + public: + SemanticError(const std::string &msg, int error_code) : msg(msg), error_code(error_code) {} + const char *what() const noexcept override { return msg.c_str(); } + int GetErrorCode() const { return error_code; } +}; \ No newline at end of file diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 2513fe2..d901faa 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -1,3 +1,6 @@ include_directories(${CMAKE_SOURCE_DIR}/include/ast) file(GLOB AST_SOURCES "*.cpp") -add_library(ast STATIC ${AST_SOURCES}) \ No newline at end of file +add_library(ast STATIC ${AST_SOURCES}) +target_include_directories(ast PUBLIC /usr/include/antlr4-runtime/) +target_include_directories(ast PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../semantic/antlr-generated) +target_link_libraries(ast PUBLIC antlr4-runtime) \ No newline at end of file diff --git a/src/semantic/semantic.cpp b/src/semantic/semantic.cpp index 9cf2ad0..884f373 100644 --- a/src/semantic/semantic.cpp +++ b/src/semantic/semantic.cpp @@ -13,13 +13,15 @@ class MXErrorListener : public antlr4::BaseErrorListener { MXErrorListener() : no_problem(true) {} void syntaxError(antlr4::Recognizer *recognizer, antlr4::Token *offendingSymbol, size_t line, size_t charPositionInLine, const std::string &msg, std::exception_ptr e) override { - std::cout << "line " << line << ":" << charPositionInLine << " AT " - << offendingSymbol->getText() << ": " << msg << std::endl; + std::cout << "line " << line << ":" << charPositionInLine << " AT " << offendingSymbol->getText() << ": " << msg + << std::endl; no_problem = false; } bool IsOk() { return no_problem; } }; -std::shared_ptr BuildAST(Visitor *visitor, antlr4::tree::ParseTree *tree) { ; } +std::shared_ptr BuildAST(Visitor *visitor, antlr4::tree::ParseTree *tree) { + return std::any_cast>(visitor->visit(tree)); +} std::shared_ptr CheckAndDecorate(std::shared_ptr src) { ; } void SemanticCheck(std::istream &fin, std::shared_ptr &ast_out) { diff --git a/src/semantic/visitor.cpp b/src/semantic/visitor.cpp index 4369d97..40c24f4 100644 --- a/src/semantic/visitor.cpp +++ b/src/semantic/visitor.cpp @@ -1,16 +1,205 @@ #include "visitor.h" #include +#include +#include "MXParser.h" +#include "MXParserVisitor.h" +#include "ast/ast.h" -std::any Visitor::visitMxprog(MXParser::MxprogContext *context) { throw std::runtime_error("Not implemented"); } -std::any Visitor::visitFunction_def(MXParser::Function_defContext *context) { - throw std::runtime_error("Not implemented"); +std::any Visitor::visitMxprog(MXParser::MxprogContext *context) { + auto program = std::make_shared(); + program->type = ASTNodeType::Program; + program->start_line = context->getStart()->getLine(); + program->start_char_pos = context->getStart()->getCharPositionInLine(); + program->end_line = context->getStop()->getLine(); + program->end_char_pos = context->getStop()->getCharPositionInLine(); + nodetype_stk.push_back(ASTNodeType::Program); + + for (auto def : context->children) { + if (auto classDefContext = dynamic_cast(def)) { + auto classNode = std::any_cast>(visit(classDefContext)); + program->classes.push_back(classNode); + } else if (auto defineStmtContext = dynamic_cast(def)) { + auto defineNode = std::any_cast>(visit(defineStmtContext)); + program->global_variables.push_back(defineNode); + } else if (auto funcDefContext = dynamic_cast(def)) { + auto funcNode = std::any_cast>(visit(funcDefContext)); + program->functions.push_back(funcNode); + } else if (auto EOFToken = dynamic_cast(def)) { + if (EOFToken == context->EOF()) break; + throw std::runtime_error("unknown subnode occurred in visitMxprog"); + } else { + throw std::runtime_error("unknown subnode occurred in visitMxprog"); + } + } + nodetype_stk.pop_back(); + + return program; +} + +std::any Visitor::visitFunction_def(MXParser::Function_defContext *context) { + auto func = std::make_shared(); + func->type = ASTNodeType::FuncDef; + func->is_constructor = false; + func->start_line = context->getStart()->getLine(); + func->start_char_pos = context->getStart()->getCharPositionInLine(); + func->end_line = context->getStop()->getLine(); + func->end_char_pos = context->getStop()->getCharPositionInLine(); + nodetype_stk.push_back(ASTNodeType::FuncDef); + std::string return_type_str; + + if (auto type_context = dynamic_cast(context->children[0])) { + // return type + return_type_str = type_context->getText(); + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "return_type_str=" << return_type_str << std::endl; + } else + throw std::runtime_error("unknown subnode occurred in visitFunction_def"); + int return_dimensions = 0; + for (size_t i = 1; i < context->children.size(); i++) { + if (dynamic_cast(context->children[i]) != nullptr && + dynamic_cast(context->children[i])->getSymbol()->getType() == MXParser::ID) { + break; + } + return_dimensions++; + } + return_dimensions >>= 1; + if (return_dimensions == 0) { + func->return_type = return_type_str; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded return type is [none-array]" << return_type_str + << std::endl; + } else { + func->return_type = ArrayType{true, return_type_str, static_cast(return_dimensions)}; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded return type is [array]" << return_type_str + << " with dimensions=" << return_dimensions << std::endl; + } + func->func_name = context->children[1 + 2 * return_dimensions]->getText(); + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "func_name=" << func->func_name << std::endl; + size_t cur = 3 + 2 * return_dimensions; + bool is_first_para = true; + while (true) { + if (cur >= context->children.size()) throw std::runtime_error("something strange happened in visitFunction_def"); + auto ptr = dynamic_cast(context->children[cur]); + if (ptr != nullptr && ptr->getSymbol()->getType() == MXParser::RPAREN) break; + if (is_first_para) goto cancel_first; + if (ptr->getSymbol()->getType() != MXParser::COMMA) + throw std::runtime_error("something strange happened in visitFunction_def"); + cur++; + cancel_first: + is_first_para = false; + std::string cur_para_type_base_str = context->children[cur]->getText(); + int cur_para_dimensions = 0; + while (dynamic_cast(context->children[cur + 1]) != nullptr && + (dynamic_cast(context->children[cur + 1])->getSymbol()->getType() == + MXParser::LBRACKET || + dynamic_cast(context->children[cur + 1])->getSymbol()->getType() == + MXParser::RBRACKET)) { + cur_para_dimensions++; + cur++; + } + cur++; + cur_para_dimensions >>= 1; + std::string cur_para_name = context->children[cur]->getText(); + ExprTypeInfo cur_para_type; + if (cur_para_dimensions == 0) { + cur_para_type = cur_para_type_base_str; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded parameter type is [none-array]" + << cur_para_type_base_str << std::endl; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded parameter name is " << cur_para_name + << std::endl; + } else { + cur_para_type = ArrayType{true, cur_para_type_base_str, static_cast(cur_para_dimensions)}; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded parameter type is [array]" + << cur_para_type_base_str << " with dimensions=" << cur_para_dimensions << std::endl; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded parameter name is " << cur_para_name + << std::endl; + } + func->params.push_back(std::make_pair(cur_para_name, cur_para_type)); + cur++; + } + func->func_body = std::any_cast>(visit(context->suite())); + nodetype_stk.pop_back(); + return func; +} +std::any Visitor::visitClass_def(MXParser::Class_defContext *context) { + auto class_def = std::make_shared(); + class_def->type = ASTNodeType::ClassDef; + class_def->start_line = context->getStart()->getLine(); + class_def->start_char_pos = context->getStart()->getCharPositionInLine(); + class_def->end_line = context->getStop()->getLine(); + class_def->end_char_pos = context->getStop()->getCharPositionInLine(); + nodetype_stk.push_back(ASTNodeType::ClassDef); + + class_def->class_name = context->ID()->getText(); + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "building a class named " << class_def->class_name + << std::endl; + std::vector constructors = context->class_constructor(); + if (constructors.size() > 1) throw SemanticError("Multiple constructor found for class " + class_def->class_name, 2); + if (constructors.size() > 0) + class_def->constructor = std::any_cast>(visit(constructors[0])); + std::vector functions = context->function_def(); + for (auto func : functions) { + auto func_node = std::any_cast>(visit(func)); + class_def->member_functions.push_back(func_node); + } + + nodetype_stk.pop_back(); + return class_def; } -std::any Visitor::visitClass_def(MXParser::Class_defContext *context) { throw std::runtime_error("Not implemented"); } std::any Visitor::visitClass_var_def(MXParser::Class_var_defContext *context) { - throw std::runtime_error("Not implemented"); + auto member_var_def = std::make_shared(); + member_var_def->type = ASTNodeType::ClassVariable; + member_var_def->start_line = context->getStart()->getLine(); + member_var_def->start_char_pos = context->getStart()->getCharPositionInLine(); + member_var_def->end_line = context->getStop()->getLine(); + member_var_def->end_char_pos = context->getStop()->getCharPositionInLine(); + nodetype_stk.push_back(ASTNodeType::ClassVariable); + + std::string define_type_base; + if (auto type_context = dynamic_cast(context->children[0])) { + define_type_base = type_context->getText(); + } else + throw std::runtime_error("unknown subnode occurred in visitClass_var_def"); + int define_dimensions = 0; + for (size_t i = 1; i < context->children.size(); i++) { + if (dynamic_cast(context->children[i]) != nullptr && + dynamic_cast(context->children[i])->getSymbol()->getType() == MXParser::ID) { + break; + } + define_dimensions++; + } + define_dimensions >>= 1; + if (define_dimensions == 0) { + member_var_def->var_type = define_type_base; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded member variable type is [none-array]" + << define_type_base << std::endl; + } else { + member_var_def->var_type = ArrayType{true, define_type_base, static_cast(define_dimensions)}; + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded member variable type is [array]" + << define_type_base << " with dimensions=" << define_dimensions << std::endl; + } + auto identifiers = context->ID(); + for (auto id : identifiers) { + member_var_def->vars.push_back(std::make_pair(id->getText(), nullptr)); + std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "recorded member variable name is " << id->getText() + << std::endl; + } + + nodetype_stk.pop_back(); + return member_var_def; } std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *context) { - throw std::runtime_error("Not implemented"); + auto construct_func = std::make_shared(); + construct_func->type = ASTNodeType::Constructor; + construct_func->is_constructor = false; + construct_func->start_line = context->getStart()->getLine(); + construct_func->start_char_pos = context->getStart()->getCharPositionInLine(); + construct_func->end_line = context->getStop()->getLine(); + construct_func->end_char_pos = context->getStop()->getCharPositionInLine(); + nodetype_stk.push_back(ASTNodeType::Constructor); + + construct_func->func_body = std::any_cast>(visit(context->suite())); + + nodetype_stk.pop_back(); + return construct_func; } std::any Visitor::visitSuite(MXParser::SuiteContext *context) { throw std::runtime_error("Not implemented"); } std::any Visitor::visitEmpty_statement(MXParser::Empty_statementContext *context) {