From 1f9e7acadec51f3b74a8df07123c35e9e3617a66 Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Fri, 23 Aug 2024 04:10:35 +0000 Subject: [PATCH] support argument as lvalue --- include/IR/IRBuilder.h | 39 +++++---- include/IR/IR_basic.h | 5 ++ include/ast/scope.hpp | 3 + include/tools.h | 5 +- src/IR/IRBuilder.cpp | 177 +++++++++++++++++++++++++++++++++++---- src/semantic/visitor.cpp | 4 + 6 files changed, 200 insertions(+), 33 deletions(-) diff --git a/include/IR/IRBuilder.h b/include/IR/IRBuilder.h index 74a7c74..fe5efd5 100644 --- a/include/IR/IRBuilder.h +++ b/include/IR/IRBuilder.h @@ -1,8 +1,10 @@ #pragma once #include #include +#include #include "IR_basic.h" #include "ast/astnode_visitor.h" +#include "ast/scope.hpp" #include "tools.h" class IRBuilder : public ASTNodeVirturalVisitor { friend std::shared_ptr BuildIR(std::shared_ptr src); @@ -10,6 +12,7 @@ class IRBuilder : public ASTNodeVirturalVisitor { std::shared_ptr cur_class; std::shared_ptr cur_func; std::shared_ptr cur_block; + std::shared_ptr main_init_block; std::string cur_class_name; bool is_in_class_def; bool is_in_func_def; @@ -19,9 +22,11 @@ class IRBuilder : public ASTNodeVirturalVisitor { std::string cur_continue_target; bool just_encountered_jmp; std::shared_ptr global_scope; + std::shared_ptr cur_func_scope; size_t const_str_counter; std::unordered_map const_str_dict; size_t const_arr_counter; + std::unordered_set already_set_constructor; public: IRBuilder() { @@ -32,6 +37,8 @@ class IRBuilder : public ASTNodeVirturalVisitor { just_encountered_jmp = false; const_str_counter = 0; const_arr_counter = 0; + main_init_block = std::make_shared(); + main_init_block->label_full = "main_init"; } // Structural AST Nodes void ActuralVisit(FuncDef_ASTNode *node) override; @@ -98,31 +105,31 @@ class IRBuilder : public ASTNodeVirturalVisitor { ty = LLVMIRPTRType(); elem_size = 4; } - auto& sub_nodes=std::get>>(node->value); + auto &sub_nodes = std::get>>(node->value); std::string array_head = "%.var.tmp." + std::to_string(tmp_var_counter++); - auto allocate_action=std::make_shared(); + auto allocate_action = std::make_shared(); blk.actions.push_back(allocate_action); - allocate_action->func_name_raw=".builtin.AllocateArray"; - allocate_action->result_full=array_head; - allocate_action->return_type=LLVMIRPTRType(); + allocate_action->func_name_raw = ".builtin.AllocateArray"; + allocate_action->result_full = array_head; + allocate_action->return_type = LLVMIRPTRType(); allocate_action->args_ty.push_back(LLVMIRIntType(32)); allocate_action->args_val_full.push_back(std::to_string(elem_size)); allocate_action->args_ty.push_back(LLVMIRIntType(32)); allocate_action->args_val_full.push_back(std::to_string(sub_nodes.size())); - for(size_t i=0;i(); + for (size_t i = 0; i < sub_nodes.size(); i++) { + std::string ret = ArrangeConstArrDfs(blk, sub_nodes[i].get(), depth + 1, total_level, basetype); + std::string addr = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto ptr_cal = std::make_shared(); blk.actions.push_back(ptr_cal); - ptr_cal->result_full=addr; - ptr_cal->ty=ty; - ptr_cal->ptr_full=array_head; + ptr_cal->result_full = addr; + ptr_cal->ty = ty; + ptr_cal->ptr_full = array_head; ptr_cal->indices.push_back(std::to_string(i)); - auto store_action=std::make_shared(); + auto store_action = std::make_shared(); blk.actions.push_back(store_action); - store_action->ty=ty; - store_action->ptr_full=addr; - store_action->value_full=ret; + store_action->ty = ty; + store_action->ptr_full = addr; + store_action->value_full = ret; } return allocate_action->result_full; diff --git a/include/IR/IR_basic.h b/include/IR/IR_basic.h index a0b276d..f0ba5a0 100644 --- a/include/IR/IR_basic.h +++ b/include/IR/IR_basic.h @@ -81,6 +81,7 @@ class BRAction : public JMPActionItem { }; class UNConditionJMPAction : public JMPActionItem { friend class IRBuilder; + friend class FunctionDefItem; std::string label_full; public: @@ -89,6 +90,7 @@ class UNConditionJMPAction : public JMPActionItem { }; class RETAction : public JMPActionItem { friend class IRBuilder; + friend class FunctionDefItem; LLVMType type; std::string value; @@ -244,6 +246,7 @@ class ICMPAction : public ActionItem { }; class BlockItem : public LLVMIRItemBase { friend class IRBuilder; + friend class FunctionDefItem; std::string label_full; std::vector> actions; std::shared_ptr exit_action; @@ -366,6 +369,7 @@ class FunctionDefItem : public LLVMIRItemBase { std::string func_name_raw; std::vector args; std::vector args_full_name; + std::shared_ptr init_block; std::vector> basic_blocks; public: @@ -398,6 +402,7 @@ class FunctionDefItem : public LLVMIRItemBase { } } os << ")\n{\n"; + if (init_block) init_block->RecursivePrint(os); for (auto &item : basic_blocks) { item->RecursivePrint(os); } diff --git a/include/ast/scope.hpp b/include/ast/scope.hpp index 36134a9..bc34b5f 100644 --- a/include/ast/scope.hpp +++ b/include/ast/scope.hpp @@ -326,6 +326,9 @@ class GlobalScope : public ScopeBase { auto &tmp = classes[name]->llvm_class_info; tmp.class_name_raw = name; tmp.ArrangeSpace(); + if (classes[name]->member_functions.find(name) != classes[name]->member_functions.end()) { + tmp.has_user_specified_constructor = true; + } return tmp; } diff --git a/include/tools.h b/include/tools.h index 9b3c968..ea3f379 100644 --- a/include/tools.h +++ b/include/tools.h @@ -125,6 +125,7 @@ class IRClassInfo { std::vector member_var_pos_after_align; size_t class_size_after_align; bool alread_arranged; + bool has_user_specified_constructor; void ArrangeSpace() { if (alread_arranged) return; alread_arranged = true; @@ -146,7 +147,7 @@ class IRClassInfo { } } std::string GenerateFullName() { return "%.class." + class_name_raw; } - IRClassInfo() : alread_arranged(false) {} + IRClassInfo() : alread_arranged(false), has_user_specified_constructor(false) {} }; class IRVariableInfo { public: @@ -164,7 +165,7 @@ class IRVariableInfo { } else if (variable_type == 1) { return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".addrkp"; } else if (variable_type == 3) { - return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".val"; + return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".addrkp"; } else { throw std::runtime_error("Invalid scope id"); } diff --git a/src/IR/IRBuilder.cpp b/src/IR/IRBuilder.cpp index 1b42227..8ebc7d3 100644 --- a/src/IR/IRBuilder.cpp +++ b/src/IR/IRBuilder.cpp @@ -1,7 +1,9 @@ #include "IRBuilder.h" #include +#include #include "IR.h" #include "IR_basic.h" +#include "ast/scope.hpp" #include "tools.h" // Structural AST Nodes @@ -37,12 +39,38 @@ void IRBuilder::ActuralVisit(FuncDef_ASTNode *node) { if (func_def->args.size() != func_def->args_full_name.size()) throw std::runtime_error("args size not match"); cur_func = func_def; + cur_func_scope = std::dynamic_pointer_cast(node->current_scope); + if (!(cur_func_scope)) throw std::runtime_error("Function scope not found"); auto current_block = std::make_shared(); cur_block = current_block; cur_func->basic_blocks.push_back(current_block); size_t block_id = block_counter++; current_block->label_full = "label_" + std::to_string(block_id); + if (node->func_name == "main") { + func_def->init_block = main_init_block; + } else { + func_def->init_block = std::make_shared(); + func_def->init_block->label_full = "label_init_" + node->func_name; + for (auto &arg : node->params) { + std::string &arg_name_raw = arg.first; + std::string rvalue_full_name = "%.var.local." + std::to_string(scope_id) + "." + arg_name_raw + ".val"; + std::string lvalue_full_name = "%.var.local." + std::to_string(scope_id) + "." + arg_name_raw + ".addrkp"; + auto allocact = std::make_shared(); + func_def->init_block->actions.push_back(allocact); + allocact->num = 1; + allocact->type = Type_AST2LLVM(arg.second); + allocact->name_full = lvalue_full_name; + auto storeact = std::make_shared(); + func_def->init_block->actions.push_back(storeact); + storeact->ty = Type_AST2LLVM(arg.second); + storeact->ptr_full = lvalue_full_name; + storeact->value_full = rvalue_full_name; + } + } + cur_func->init_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_func->init_block->exit_action)->label_full = + current_block->label_full; is_in_func_def = true; node->func_body->accept(this); is_in_func_def = false; @@ -84,15 +112,31 @@ void IRBuilder::ActuralVisit(EmptyStatement_ASTNode *node) { } void IRBuilder::ActuralVisit(DefinitionStatement_ASTNode *node) { - if (is_in_class_def) { - cur_class->elements.push_back(Type_AST2LLVM(node->var_type)); + if (is_in_class_def && !is_in_func_def) { + for (size_t i = 0; i < node->vars.size(); i++) cur_class->elements.push_back(Type_AST2LLVM(node->var_type)); } else if (!is_in_func_def) { for (const auto &var : node->vars) { auto var_def = std::make_shared(); prog->global_var_defs.push_back(var_def); var_def->type = Type_AST2LLVM(node->var_type); var_def->name_raw = var.first; - // TODO: initial value + if (var.second) { + var.second->accept(this); + std::string init_var = var.second->IR_result_full; + if (init_var[0] == '#') { + init_var = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto const_array_construct_call = std::make_shared(); + main_init_block->actions.push_back(const_array_construct_call); + const_array_construct_call->result_full = init_var; + const_array_construct_call->return_type = LLVMIRPTRType(); + const_array_construct_call->func_name_raw = var.second->IR_result_full.substr(1); + } + auto act = std::make_shared(); + main_init_block->actions.push_back(act); + act->ty = var_def->type; + act->ptr_full = "@.var.global." + var_def->name_raw + ".addrkp"; + act->value_full = init_var; + } } } else { for (const auto &var : node->vars) { @@ -389,11 +433,81 @@ void IRBuilder::ActuralVisit(NewArrayExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(NewConstructExpr_ASTNode *node) { - // TODO: Implement function body + std::string class_name = std::get(node->expr_type_info); + if (already_set_constructor.find(class_name) != already_set_constructor.end()) { + node->IR_result_full = "#.classconstruct." + class_name; + return; + } + already_set_constructor.insert(class_name); + std::string construct_func_name = ".classconstruct." + class_name; + auto construct_func = std::make_shared(); + prog->function_defs.push_back(construct_func); + construct_func->return_type = LLVMIRPTRType(); + construct_func->func_name_raw = construct_func_name; + auto block = std::make_shared(); + construct_func->basic_blocks.push_back(block); + block->label_full = "label_construct_" + class_name; + IRClassInfo class_info = global_scope->fetch_class_info(class_name); + std::string allocated_addr = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto malloc_call = std::make_shared(); + block->actions.push_back(malloc_call); + malloc_call->result_full = allocated_addr; + malloc_call->return_type = LLVMIRPTRType(); + malloc_call->func_name_raw = "malloc"; // just use libc malloc + malloc_call->args_ty.push_back(LLVMIRIntType(32)); + malloc_call->args_val_full.push_back(std::to_string(class_info.class_size_after_align)); + if (class_info.has_user_specified_constructor) { + auto call_act = std::make_shared(); + block->actions.push_back(call_act); + call_act->return_type = LLVMVOIDType(); + call_act->func_name_raw = class_name + "." + class_name; + call_act->args_ty.push_back(LLVMIRPTRType()); + call_act->args_val_full.push_back(allocated_addr); + } + auto ret_act = std::make_shared(); + block->exit_action = ret_act; + ret_act->type = LLVMIRPTRType(); + ret_act->value = allocated_addr; + node->IR_result_full = "#" + construct_func_name; } void IRBuilder::ActuralVisit(NewExpr_ASTNode *node) { - // TODO: Implement function body + std::string class_name = std::get(node->expr_type_info); + if (already_set_constructor.find(class_name) != already_set_constructor.end()) { + node->IR_result_full = "#.classconstruct." + class_name; + return; + } + already_set_constructor.insert(class_name); + std::string construct_func_name = ".classconstruct." + class_name; + auto construct_func = std::make_shared(); + prog->function_defs.push_back(construct_func); + construct_func->return_type = LLVMIRPTRType(); + construct_func->func_name_raw = construct_func_name; + auto block = std::make_shared(); + construct_func->basic_blocks.push_back(block); + block->label_full = "label_construct_" + class_name; + IRClassInfo class_info = global_scope->fetch_class_info(class_name); + std::string allocated_addr = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto malloc_call = std::make_shared(); + block->actions.push_back(malloc_call); + malloc_call->result_full = allocated_addr; + malloc_call->return_type = LLVMIRPTRType(); + malloc_call->func_name_raw = "malloc"; // just use libc malloc + malloc_call->args_ty.push_back(LLVMIRIntType(32)); + malloc_call->args_val_full.push_back(std::to_string(class_info.class_size_after_align)); + if (class_info.has_user_specified_constructor) { + auto call_act = std::make_shared(); + block->actions.push_back(call_act); + call_act->return_type = LLVMVOIDType(); + call_act->func_name_raw = class_name + "." + class_name; + call_act->args_ty.push_back(LLVMIRPTRType()); + call_act->args_val_full.push_back(allocated_addr); + } + auto ret_act = std::make_shared(); + block->exit_action = ret_act; + ret_act->type = LLVMIRPTRType(); + ret_act->value = allocated_addr; + node->IR_result_full = "#" + construct_func_name; } void IRBuilder::ActuralVisit(AccessExpr_ASTNode *node) { @@ -783,8 +897,8 @@ void IRBuilder::ActuralVisit(AssignExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(ThisExpr_ASTNode *node) { - // TODO: Implement function body - throw std::runtime_error("this not supported"); + size_t scope_id = cur_func_scope->scope_id; + node->IR_result_full = "%.var.local." + std::to_string(scope_id) + ".this.val"; } void IRBuilder::ActuralVisit(ParenExpr_ASTNode *node) { @@ -794,7 +908,7 @@ void IRBuilder::ActuralVisit(ParenExpr_ASTNode *node) { void IRBuilder::ActuralVisit(IDExpr_ASTNode *node) { IRVariableInfo var_info = node->current_scope->fetch_variable_for_IR(node->id); - if (var_info.variable_type == 0 || var_info.variable_type == 1) { + if (var_info.variable_type == 0 || var_info.variable_type == 1 || var_info.variable_type == 3) { if (node->is_requiring_lvalue) { node->IR_result_full = var_info.GenerateFullName(); return; @@ -806,13 +920,29 @@ void IRBuilder::ActuralVisit(IDExpr_ASTNode *node) { act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); node->IR_result_full = act->result_full; } else if (var_info.variable_type == 2) { - throw std::runtime_error("not support for class member access"); - } else if (var_info.variable_type == 3) { + size_t scope_id = cur_func_scope->scope_id; + std::string this_ptr = "%.var.local." + std::to_string(scope_id) + ".this.val"; + std::string class_name = cur_class_name; + std::string member_var_name = var_info.variable_name_raw; + IRClassInfo class_info = global_scope->fetch_class_info(class_name); + size_t idx = class_info.member_var_offset[member_var_name]; + auto member_addr_cal = std::make_shared(); + cur_block->actions.push_back(member_addr_cal); + member_addr_cal->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + member_addr_cal->ty = LLVMIRCLASSTYPE{"%.class." + class_name}; + member_addr_cal->ptr_full = this_ptr; + member_addr_cal->indices.push_back("0"); + member_addr_cal->indices.push_back(std::to_string(idx)); if (node->is_requiring_lvalue) { - throw std::runtime_error("for argument, lvalue support need additional work"); + node->IR_result_full = member_addr_cal->result_full; + } else { + auto member_val_load = std::make_shared(); + cur_block->actions.push_back(member_val_load); + member_val_load->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + member_val_load->ty = class_info.member_var_type[idx]; + member_val_load->ptr_full = member_addr_cal->result_full; + node->IR_result_full = member_val_load->result_full; } - node->IR_result_full = var_info.GenerateFullName(); - return; } else { throw std::runtime_error("unknown variable type"); } @@ -828,8 +958,25 @@ void IRBuilder::ActuralVisit(FunctionCallExpr_ASTNode *node) { } } if (is_member_func) { - // TODO: member function call - throw std::runtime_error("not support for class member function call"); + size_t scope_id = cur_func_scope->scope_id; + std::string this_ptr = "%.var.local." + std::to_string(scope_id) + ".this.val"; + std::string class_name = cur_class_name; + std::string full_func_name = class_name + "." + node->func_name; + auto call = std::make_shared(); + call->return_type = Type_AST2LLVM(node->expr_type_info); + call->func_name_raw = full_func_name; + call->args_val_full.push_back(this_ptr); + call->args_ty.push_back(LLVMIRPTRType()); + for (auto &arg : node->arguments) { + arg->accept(this); + call->args_val_full.push_back(arg->IR_result_full); + call->args_ty.push_back(Type_AST2LLVM(arg->expr_type_info)); + } + cur_block->actions.push_back(call); + if (!std::holds_alternative(call->return_type)) { + call->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = call->result_full; + } } else { auto call = std::make_shared(); call->return_type = Type_AST2LLVM(node->expr_type_info); diff --git a/src/semantic/visitor.cpp b/src/semantic/visitor.cpp index a727261..8ef27e5 100644 --- a/src/semantic/visitor.cpp +++ b/src/semantic/visitor.cpp @@ -265,6 +265,10 @@ std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *con construct_func->func_body = std::dynamic_pointer_cast( std::any_cast>(visit(context->suite()))); std::cerr << std::string(nodetype_stk.size() * 2, ' ') << "constructor body visited" << std::endl; + ClassDefScope *cparent = dynamic_cast(cur_scope->parent); + if (!cparent->add_function(construct_func->func_name, cur_scope)) { + throw SemanticError("Multiple Definitions", 1); + } nodetype_stk.pop_back(); return construct_func;