passed local tests

This commit is contained in:
2024-08-14 14:41:02 +00:00
parent f11ff0693b
commit 746e5bae07
4 changed files with 54 additions and 18 deletions

View File

@ -55,9 +55,11 @@ class ASTNodeVirturalVisitor : public ASTNodeVisitorBase {
}; };
class ASTSemanticCheckVisitor : public ASTNodeVirturalVisitor { class ASTSemanticCheckVisitor : public ASTNodeVirturalVisitor {
bool is_in_func; bool is_in_func_def;
bool has_return;
FunctionSchema cur_func_schema; FunctionSchema cur_func_schema;
std::string cur_class_name; std::string cur_class_name;
bool is_in_class_def;
size_t loop_level; size_t loop_level;
std::shared_ptr<GlobalScope> global_scope; std::shared_ptr<GlobalScope> global_scope;
friend std::shared_ptr<Program_ASTNode> CheckAndDecorate(std::shared_ptr<Program_ASTNode> src); friend std::shared_ptr<Program_ASTNode> CheckAndDecorate(std::shared_ptr<Program_ASTNode> src);
@ -68,7 +70,7 @@ class ASTSemanticCheckVisitor : public ASTNodeVirturalVisitor {
} }
public: public:
ASTSemanticCheckVisitor() : is_in_func(false), loop_level(0) {} ASTSemanticCheckVisitor() : is_in_func_def(false), loop_level(0) {}
// Structural AST Nodes // Structural AST Nodes
void ActuralVisit(FuncDef_ASTNode *node) override; void ActuralVisit(FuncDef_ASTNode *node) override;
void ActuralVisit(ClassDef_ASTNode *node) override; void ActuralVisit(ClassDef_ASTNode *node) override;

View File

@ -84,6 +84,9 @@ inline bool operator==(const ExprTypeInfo &l, const ExprTypeInfo &r) {
return true; return true;
} }
if (std::holds_alternative<IdentifierType>(l)) { if (std::holds_alternative<IdentifierType>(l)) {
bool x = std::holds_alternative<IdentifierType>(r);
std::string a = std::get<IdentifierType>(l);
std::string b = std::get<IdentifierType>(r);
return std::holds_alternative<IdentifierType>(r) && std::get<IdentifierType>(l) == std::get<IdentifierType>(r); return std::holds_alternative<IdentifierType>(r) && std::get<IdentifierType>(l) == std::get<IdentifierType>(r);
} }
if (std::holds_alternative<ArrayType>(l)) { if (std::holds_alternative<ArrayType>(l)) {

View File

@ -6,12 +6,17 @@
// Structural AST Nodes // Structural AST Nodes
void ASTSemanticCheckVisitor::ActuralVisit(FuncDef_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(FuncDef_ASTNode *node) {
is_in_func = true; is_in_func_def = true;
cur_func_schema = std::dynamic_pointer_cast<FunctionScope>(node->current_scope)->schema; cur_func_schema = std::dynamic_pointer_cast<FunctionScope>(node->current_scope)->schema;
std::cerr << "enter function " << node->func_name << std::endl; std::cerr << "enter function " << node->func_name << std::endl;
has_return = false;
node->func_body->accept(this); node->func_body->accept(this);
if (!has_return && std::holds_alternative<IdentifierType>(cur_func_schema.return_type) &&
std::get<IdentifierType>(cur_func_schema.return_type) != "void" && node->func_name != "main") {
throw SemanticError("Non-void function must have a return value", 1);
}
std::cerr << "leave function " << node->func_name << std::endl; std::cerr << "leave function " << node->func_name << std::endl;
is_in_func = false; is_in_func_def = false;
} }
void ASTSemanticCheckVisitor::ActuralVisit(ClassDef_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(ClassDef_ASTNode *node) {
@ -19,11 +24,23 @@ void ASTSemanticCheckVisitor::ActuralVisit(ClassDef_ASTNode *node) {
// var->accept(this); // var->accept(this);
// } // }
cur_class_name = node->class_name; cur_class_name = node->class_name;
is_in_class_def = true;
if (node->constructor) {
if (node->constructor->func_name != node->class_name) {
throw SemanticError("Constructor name mismatch", 1);
}
}
for (auto func : node->member_functions) {
if (func->func_name == node->class_name) {
throw SemanticError("Constructor Type Error", 1);
}
}
for (auto ch : node->sorted_children) { for (auto ch : node->sorted_children) {
if (std::dynamic_pointer_cast<DefinitionStatement_ASTNode>(ch) == nullptr) { if (std::dynamic_pointer_cast<DefinitionStatement_ASTNode>(ch) == nullptr) {
ch->accept(this); ch->accept(this);
} }
} }
is_in_class_def = false;
} }
void ASTSemanticCheckVisitor::ActuralVisit(Program_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(Program_ASTNode *node) {
@ -109,8 +126,9 @@ void ASTSemanticCheckVisitor::ActuralVisit(JmpStatement_ASTNode *node) {
if (loop_level == 0 && node->jmp_type > 0) throw SemanticError("Jump statement outside loop", 1); if (loop_level == 0 && node->jmp_type > 0) throw SemanticError("Jump statement outside loop", 1);
if (node->jmp_type == 0) { if (node->jmp_type == 0) {
if (node->return_value) { if (node->return_value) {
has_return = true;
node->return_value->accept(this); node->return_value->accept(this);
if (node->return_value->expr_type_info != cur_func_schema.return_type) { if (cur_func_schema.return_type != node->return_value->expr_type_info) {
throw SemanticError("Return type mismatch", 1); throw SemanticError("Return type mismatch", 1);
} }
} else { } else {
@ -189,11 +207,11 @@ void ASTSemanticCheckVisitor::ActuralVisit(AccessExpr_ASTNode *node) {
} }
} }
node->expr_type_info = schema.return_type; node->expr_type_info = schema.return_type;
node->assignable = true; // node->assignable = true;
if (std::holds_alternative<IdentifierType>(node->expr_type_info)) { // if (std::holds_alternative<IdentifierType>(node->expr_type_info)) {
std::string type = std::get<IdentifierType>(node->expr_type_info); // std::string type = std::get<IdentifierType>(node->expr_type_info);
if (type == "int" || type == "bool" || type == "void") node->assignable = false; // if (type == "int" || type == "bool" || type == "void") node->assignable = false;
} // }
} else { } else {
node->expr_type_info = global_scope->FetchClassMemberVariable(base_type, node->member); node->expr_type_info = global_scope->FetchClassMemberVariable(base_type, node->member);
node->assignable = true; node->assignable = true;
@ -405,7 +423,8 @@ void ASTSemanticCheckVisitor::ActuralVisit(TernaryExpr_ASTNode *node) {
const static ExprTypeInfo standard = "bool"; const static ExprTypeInfo standard = "bool";
node->src1->accept(this); node->src1->accept(this);
node->src2->accept(this); node->src2->accept(this);
if (node->src1->expr_type_info != node->src2->expr_type_info) { if (node->src1->expr_type_info != node->src2->expr_type_info &&
node->src2->expr_type_info != node->src1->expr_type_info) {
throw SemanticError("Ternary operation on different type", 1); throw SemanticError("Ternary operation on different type", 1);
} }
node->expr_type_info = node->src1->expr_type_info; node->expr_type_info = node->src1->expr_type_info;
@ -444,7 +463,18 @@ void ASTSemanticCheckVisitor::ActuralVisit(IDExpr_ASTNode *node) {
void ASTSemanticCheckVisitor::ActuralVisit(FunctionCallExpr_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(FunctionCallExpr_ASTNode *node) {
// TODO: Implement this method // TODO: Implement this method
// TODO: check function existence and arg number // TODO: check function existence and arg number
auto schema = global_scope->FetchFunction(node->func_name); FunctionSchema schema;
bool schema_ready = false;
if (is_in_class_def) {
try {
schema = global_scope->FetchClassMemberFunction(cur_class_name, node->func_name);
schema_ready = true;
} catch (...) {
}
}
if (!schema_ready) {
schema = global_scope->FetchFunction(node->func_name);
}
std::cerr << "function to call is " << node->func_name << std::endl; std::cerr << "function to call is " << node->func_name << std::endl;
if (schema.arguments.size() != node->arguments.size()) { if (schema.arguments.size() != node->arguments.size()) {
throw SemanticError("Argument number mismatch", 1); throw SemanticError("Argument number mismatch", 1);
@ -452,16 +482,16 @@ void ASTSemanticCheckVisitor::ActuralVisit(FunctionCallExpr_ASTNode *node) {
for (auto &arg : node->arguments) { for (auto &arg : node->arguments) {
arg->accept(this); arg->accept(this);
int idx = &arg - &node->arguments[0]; // for debug; int idx = &arg - &node->arguments[0]; // for debug;
if (arg->expr_type_info != schema.arguments[&arg - &node->arguments[0]].first) { if (schema.arguments[&arg - &node->arguments[0]].first != arg->expr_type_info) {
throw SemanticError("Argument type mismatch", 1); throw SemanticError("Argument type mismatch", 1);
} }
} }
node->expr_type_info = schema.return_type; node->expr_type_info = schema.return_type;
node->assignable = true; // node->assignable = true;
if (std::holds_alternative<IdentifierType>(node->expr_type_info)) { // if (std::holds_alternative<IdentifierType>(node->expr_type_info)) {
std::string type = std::get<IdentifierType>(node->expr_type_info); // std::string type = std::get<IdentifierType>(node->expr_type_info);
if (type == "int" || type == "bool" || type == "void") node->assignable = false; // if (type == "int" || type == "bool" || type == "void") node->assignable = false;
} // }
} }
void ASTSemanticCheckVisitor::ActuralVisit(FormattedStringExpr_ASTNode *node) { void ASTSemanticCheckVisitor::ActuralVisit(FormattedStringExpr_ASTNode *node) {

View File

@ -257,6 +257,7 @@ std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *con
auto cur_scope = std::make_shared<FunctionScope>(); auto cur_scope = std::make_shared<FunctionScope>();
cur_scope->parent = nodetype_stk.back().second.get(); cur_scope->parent = nodetype_stk.back().second.get();
construct_func->current_scope = cur_scope; construct_func->current_scope = cur_scope;
construct_func->func_name = context->ID()->getText();
cur_scope->schema.return_type = "void"; cur_scope->schema.return_type = "void";
nodetype_stk.push_back({ASTNodeType::Constructor, construct_func->current_scope}); nodetype_stk.push_back({ASTNodeType::Constructor, construct_func->current_scope});