From cb56e7a184dd01428c262ba515ebddb49eb09711 Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Fri, 23 Aug 2024 02:20:27 +0000 Subject: [PATCH] add explicit member func call, ready to merge new testcases --- include/IR/IR_basic.h | 1 - include/ast/scope.hpp | 21 +++++++ include/tools.h | 11 +++- src/IR/IRBuilder.cpp | 129 ++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 152 insertions(+), 10 deletions(-) diff --git a/include/IR/IR_basic.h b/include/IR/IR_basic.h index 090dc2e..a0b276d 100644 --- a/include/IR/IR_basic.h +++ b/include/IR/IR_basic.h @@ -244,7 +244,6 @@ class ICMPAction : public ActionItem { }; class BlockItem : public LLVMIRItemBase { friend class IRBuilder; - friend void ArrangeConstArr(BlockItem &blk, class ConstantExpr_ASTNode *node, size_t &tmp_var_counter); std::string label_full; std::vector> actions; std::shared_ptr exit_action; diff --git a/include/ast/scope.hpp b/include/ast/scope.hpp index 5db8ab9..36134a9 100644 --- a/include/ast/scope.hpp +++ b/include/ast/scope.hpp @@ -138,6 +138,7 @@ class ClassDefScope : public ScopeBase { std::unordered_map member_variables; std::unordered_map> member_functions; IRClassInfo llvm_class_info; + size_t arrange_counter; bool add_variable(const std::string &name, const ExprTypeInfo &type) override { if (!VariableNameAvailable(name, 0)) { return false; @@ -150,6 +151,14 @@ class ClassDefScope : public ScopeBase { throw SemanticError("Invalid Type", 1); } member_variables[name] = type; + llvm_class_info.member_var_offset[name] = arrange_counter++; + size_t cur_element_size = 4; + ExprTypeInfo bool_std = IdentifierType("bool"); + if (type == bool_std) { + cur_element_size = 1; + } + llvm_class_info.member_var_size.push_back(cur_element_size); + llvm_class_info.member_var_type.push_back(Type_AST2LLVM(type)); return true; } virtual ExprTypeInfo fetch_varaible(const std::string &name) override { @@ -196,6 +205,9 @@ class ClassDefScope : public ScopeBase { } return parent->VariableNameAvailable(name, ttl + 1); } + + public: + ClassDefScope() : arrange_counter(0) {} }; class GlobalScope : public ScopeBase { friend class Visitor; @@ -307,6 +319,15 @@ class GlobalScope : public ScopeBase { } return parent->fetch_variable_for_IR(name); } + IRClassInfo fetch_class_info(const std::string &name) { + if (classes.find(name) == classes.end()) { + throw SemanticError("Undefined Identifier", 1); + } + auto &tmp = classes[name]->llvm_class_info; + tmp.class_name_raw = name; + tmp.ArrangeSpace(); + return tmp; + } public: GlobalScope() { parent = nullptr; } diff --git a/include/tools.h b/include/tools.h index 98bd277..9b3c968 100644 --- a/include/tools.h +++ b/include/tools.h @@ -117,13 +117,17 @@ struct LLVMIRCLASSTYPE { using LLVMType = std::variant; class IRClassInfo { public: - std::string class_name_raw; // This data must be provided by user - std::vector member_var_size; // This data must be provided by user. Each of them is the size of a member - // variable, which must be in [1,4] + std::string class_name_raw; // This data must be provided by user + std::vector member_var_size; // This data must be provided by user. Each of them is the size of a member + // variable, which must be in [1,4] + std::vector member_var_type; // This data must be provided by user std::unordered_map member_var_offset; // This data must be provided by user std::vector member_var_pos_after_align; size_t class_size_after_align; + bool alread_arranged; void ArrangeSpace() { + if (alread_arranged) return; + alread_arranged = true; size_t cur_pos = 0; size_t align_size = 1; for (size_t cur_size : member_var_size) { @@ -142,6 +146,7 @@ class IRClassInfo { } } std::string GenerateFullName() { return "%.class." + class_name_raw; } + IRClassInfo() : alread_arranged(false) {} }; class IRVariableInfo { public: diff --git a/src/IR/IRBuilder.cpp b/src/IR/IRBuilder.cpp index 8fcd478..1b42227 100644 --- a/src/IR/IRBuilder.cpp +++ b/src/IR/IRBuilder.cpp @@ -397,7 +397,56 @@ void IRBuilder::ActuralVisit(NewExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(AccessExpr_ASTNode *node) { - // TODO: Implement function body + if (!node->is_function) { + node->base->accept(this); + std::string type_of_base = std::get(node->base->expr_type_info); + IRClassInfo class_info = global_scope->fetch_class_info(type_of_base); + std::string base_ptr = node->base->IR_result_full; + size_t idx = class_info.member_var_offset[node->member]; + auto member_addr_cal = std::make_shared(); + cur_block->actions.push_back(member_addr_cal); + member_addr_cal->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + member_addr_cal->ty = LLVMIRCLASSTYPE{"%.class." + type_of_base}; + member_addr_cal->ptr_full = base_ptr; + member_addr_cal->indices.push_back("0"); + member_addr_cal->indices.push_back(std::to_string(idx)); + if (node->is_requiring_lvalue) { + node->IR_result_full = member_addr_cal->result_full; + } else { + auto member_val_load = std::make_shared(); + cur_block->actions.push_back(member_val_load); + member_val_load->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + member_val_load->ty = class_info.member_var_type[idx]; + member_val_load->ptr_full = member_addr_cal->result_full; + node->IR_result_full = member_val_load->result_full; + } + } else { + node->base->accept(this); + std::string type_of_base = std::get(node->base->expr_type_info); + std::string base_ptr = node->base->IR_result_full; + std::string func_name = type_of_base + "." + node->member; + std::vector arg_val; + for (size_t i = 0; i < node->arguments.size(); i++) { + node->arguments[i]->accept(this); + arg_val.push_back(node->arguments[i]->IR_result_full); + } + auto call_act = std::make_shared(); + cur_block->actions.push_back(call_act); + call_act->return_type = Type_AST2LLVM(node->expr_type_info); + if (!std::holds_alternative(call_act->return_type)) { + call_act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + } + call_act->func_name_raw = func_name; + call_act->args_ty.push_back(LLVMIRPTRType()); + call_act->args_val_full.push_back(base_ptr); + for (size_t i = 0; i < arg_val.size(); i++) { + call_act->args_ty.push_back(Type_AST2LLVM(node->arguments[i]->expr_type_info)); + call_act->args_val_full.push_back(arg_val[i]); + } + if (!std::holds_alternative(call_act->return_type)) { + node->IR_result_full = call_act->result_full; + } + } } void IRBuilder::ActuralVisit(IndexExpr_ASTNode *node) { @@ -740,6 +789,7 @@ void IRBuilder::ActuralVisit(ThisExpr_ASTNode *node) { void IRBuilder::ActuralVisit(ParenExpr_ASTNode *node) { node->expr->accept(this); // just visit it + node->IR_result_full = node->expr->IR_result_full; } void IRBuilder::ActuralVisit(IDExpr_ASTNode *node) { @@ -846,34 +896,101 @@ void IRBuilder::ActuralVisit(ConstantExpr_ASTNode *node) { std::shared_ptr BuildIR(std::shared_ptr src) { IRBuilder visitor; visitor.prog = std::make_shared(); - auto tmp = std::make_shared(); + auto tmp = std::make_shared(); // void* malloc(unsigned int size) tmp->func_name_raw = "malloc"; tmp->return_type = LLVMIRPTRType(); tmp->args.push_back(LLVMIRIntType(32)); visitor.prog->function_declares.push_back(tmp); - tmp = std::make_shared(); + + tmp = std::make_shared(); // void printInt(int n) tmp->func_name_raw = "printInt"; tmp->return_type = LLVMVOIDType(); tmp->args.push_back(LLVMIRIntType(32)); visitor.prog->function_declares.push_back(tmp); - tmp = std::make_shared(); + + tmp = std::make_shared(); // void print(char *str) tmp->func_name_raw = "print"; tmp->return_type = LLVMVOIDType(); tmp->args.push_back(LLVMIRPTRType()); visitor.prog->function_declares.push_back(tmp); - tmp = std::make_shared(); + + tmp = std::make_shared(); // void* _builtin_AllocateArray(int element_size, int element_num) tmp->func_name_raw = ".builtin.AllocateArray"; tmp->return_type = LLVMIRPTRType(); tmp->args.push_back(LLVMIRIntType(32)); tmp->args.push_back(LLVMIRIntType(32)); visitor.prog->function_declares.push_back(tmp); - tmp = std::make_shared(); + + tmp = std::make_shared(); // void* _builtin_RecursiveAllocateArray(int dims_with_size, + // int element_size, int* dim_size) tmp->func_name_raw = ".builtin.RecursiveAllocateArray"; tmp->return_type = LLVMIRPTRType(); tmp->args.push_back(LLVMIRIntType(32)); tmp->args.push_back(LLVMIRIntType(32)); tmp->args.push_back(LLVMIRPTRType()); visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // int _builtin_GetArrayLength(void* array) + tmp->func_name_raw = ".builtin.GetArrayLength"; + tmp->return_type = LLVMIRIntType(32); + tmp->args.push_back(LLVMIRPTRType()); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // int getInt() + tmp->func_name_raw = "getInt"; + tmp->return_type = LLVMIRIntType(32); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // char* getString() + tmp->func_name_raw = "getString"; + tmp->return_type = LLVMIRPTRType(); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // char* toString(int n) + tmp->func_name_raw = "toString"; + tmp->return_type = LLVMIRPTRType(); + tmp->args.push_back(LLVMIRIntType(32)); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // void printlnInt(int n) + tmp->func_name_raw = "printlnInt"; + tmp->return_type = LLVMVOIDType(); + tmp->args.push_back(LLVMIRIntType(32)); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // void println(char *str) + tmp->func_name_raw = "println"; + tmp->return_type = LLVMVOIDType(); + tmp->args.push_back(LLVMIRPTRType()); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // int string_length(char *self) + tmp->func_name_raw = "string.length"; + tmp->return_type = LLVMIRIntType(32); + tmp->args.push_back(LLVMIRPTRType()); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // char* string_substring(char *self,int left, int right) + tmp->func_name_raw = "string.substring"; + tmp->return_type = LLVMIRPTRType(); + tmp->args.push_back(LLVMIRPTRType()); + tmp->args.push_back(LLVMIRIntType(32)); + tmp->args.push_back(LLVMIRIntType(32)); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // int string_parseInt(char *self) + tmp->func_name_raw = "string.parseInt"; + tmp->return_type = LLVMIRIntType(32); + tmp->args.push_back(LLVMIRPTRType()); + visitor.prog->function_declares.push_back(tmp); + + tmp = std::make_shared(); // int string_ord(char *self, int index) + tmp->func_name_raw = "string.ord"; + tmp->return_type = LLVMIRIntType(32); + tmp->args.push_back(LLVMIRPTRType()); + tmp->args.push_back(LLVMIRIntType(32)); + visitor.prog->function_declares.push_back(tmp); + visitor.global_scope = std::dynamic_pointer_cast(src->current_scope); if (!(visitor.global_scope)) throw std::runtime_error("global scope not found"); visitor.visit(src.get());