diff --git a/include/ast/astnode.h b/include/ast/astnode.h index 2411067..8d40368 100644 --- a/include/ast/astnode.h +++ b/include/ast/astnode.h @@ -16,6 +16,7 @@ class ASTNodeVisitorBase { class ASTNodeBase { friend Visitor; + friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); protected: std::shared_ptr current_scope; diff --git a/include/ast/scope.hpp b/include/ast/scope.hpp index daaca55..0777291 100644 --- a/include/ast/scope.hpp +++ b/include/ast/scope.hpp @@ -32,6 +32,13 @@ class LocalScope : public ScopeBase { if (!VariableNameAvailable(name, 0)) { return false; } + if (std::holds_alternative(type) && std::get(type) == "void") { + throw SemanticError("Variable cannot be void", 1); + } + if (std::holds_alternative(type) && std::get(type).has_base_type && + std::get(type).basetype == "void") { + throw SemanticError("Variable cannot be void", 1); + } local_variables[name] = type; return true; } @@ -53,6 +60,7 @@ struct FunctionSchema { std::vector> arguments; }; class FunctionScope : public ScopeBase { + friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); friend class Visitor; FunctionSchema schema; bool add_variable([[maybe_unused]] const std::string &name, [[maybe_unused]] const ExprTypeInfo &type) override { @@ -80,6 +88,13 @@ class ClassDefScope : public ScopeBase { if (!VariableNameAvailable(name, 0)) { return false; } + if (std::holds_alternative(type) && std::get(type) == "void") { + throw SemanticError("Variable cannot be void", 1); + } + if (std::holds_alternative(type) && std::get(type).has_base_type && + std::get(type).basetype == "void") { + throw SemanticError("Variable cannot be void", 1); + } member_variables[name] = type; return true; } @@ -111,6 +126,7 @@ class ClassDefScope : public ScopeBase { }; class GlobalScope : public ScopeBase { friend class Visitor; + friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); std::unordered_map global_variables; std::unordered_map> global_functions; std::unordered_map> classes; @@ -146,6 +162,13 @@ class GlobalScope : public ScopeBase { if (!VariableNameAvailable(name, 0)) { return false; } + if (std::holds_alternative(type) && std::get(type) == "void") { + throw SemanticError("Variable cannot be void", 1); + } + if (std::holds_alternative(type) && std::get(type).has_base_type && + std::get(type).basetype == "void") { + throw SemanticError("Variable cannot be void", 1); + } global_variables[name] = type; return true; } diff --git a/src/semantic/semantic.cpp b/src/semantic/semantic.cpp index 6ddcdaf..21ab397 100644 --- a/src/semantic/semantic.cpp +++ b/src/semantic/semantic.cpp @@ -22,7 +22,19 @@ class MXErrorListener : public antlr4::BaseErrorListener { std::shared_ptr BuildAST(Visitor *visitor, antlr4::tree::ParseTree *tree) { return std::any_cast>(visitor->visit(tree)); } -std::shared_ptr CheckAndDecorate(std::shared_ptr src) { return nullptr; } +std::shared_ptr CheckAndDecorate(std::shared_ptr src) { + auto global_scope = std::dynamic_pointer_cast(src->current_scope); + if (global_scope->global_functions.find("main") == global_scope->global_functions.end()) { + throw SemanticError("No main() function", 1); + } else { + const auto &main_schema = global_scope->global_functions["main"]->schema; + if ((!std::holds_alternative(main_schema.return_type)) || + std::get(main_schema.return_type) != "int" || main_schema.arguments.size() != 0) { + throw SemanticError("main() function should be int main()", 1); + } + } + return src; +} void SemanticCheck(std::istream &fin, std::shared_ptr &ast_out) { antlr4::ANTLRInputStream input(fin); diff --git a/src/semantic/visitor.cpp b/src/semantic/visitor.cpp index 31802f2..7cd1501 100644 --- a/src/semantic/visitor.cpp +++ b/src/semantic/visitor.cpp @@ -629,15 +629,25 @@ std::any Visitor::visitNew_array_expression(MXParser::New_array_expressionContex std::string base_type = context->type()->getText(); new_array->expr_type_info = ArrayType{true, base_type, total_dimensions}; size_t total_dim_count = 0; + bool dim_size_specified = false; + bool dim_with_size_end = false; new_array->dim_size.resize(total_dimensions); for (size_t i = 3; i < context->children.size() && total_dim_count < total_dimensions; i++) { if (dynamic_cast(context->children[i]) != nullptr && dynamic_cast(context->children[i])->getSymbol()->getType() == MXParser::RBRACKET) { total_dim_count++; + if (!dim_size_specified) { + dim_with_size_end = true; + } + dim_size_specified = false; } else if (dynamic_cast(context->children[i]) != nullptr) { new_array->dim_size[total_dim_count] = std::any_cast>(visit(context->children[i])); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "dim " << total_dim_count << " has size " << std::endl; + dim_size_specified = true; + if (dim_with_size_end) { + throw SemanticError("The shape of multidimensional array must be specified from left to right.", 1); + } } } if (context->constant() != nullptr) {