diff --git a/CMakeLists.txt b/CMakeLists.txt index b00979d..ea44928 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,16 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug") endif() endif() +# 设置 Release 模式下开启优化 +if(CMAKE_BUILD_TYPE STREQUAL "Release") + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + add_compile_options(-Wall -Wextra -Wpedantic -fsanitize=address,undefined -O2) + add_link_options(-fsanitize=address,undefined) + elseif(MSVC) + add_compile_options(/O2) + endif() +endif() + include(FetchContent) FetchContent_Declare( googletest diff --git a/include/IR/IR_basic.h b/include/IR/IR_basic.h index 129eedf..9bf5415 100644 --- a/include/IR/IR_basic.h +++ b/include/IR/IR_basic.h @@ -47,12 +47,10 @@ class TypeDefItem : public LLVMIRItemBase { } }; class GlobalVarDefItem : public LLVMIRItemBase { - friend class IRBuilder; - friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); + public: LLVMType type; std::string name_raw; - public: GlobalVarDefItem() = default; void RecursivePrint(std::ostream &os) const { std::string name_full = "@.var.global." + name_raw + ".addrkp"; @@ -449,9 +447,7 @@ class FunctionDeclareItem : public LLVMIRItemBase { } }; class ConstStrItem : public LLVMIRItemBase { - friend std::shared_ptr BuildIR(std::shared_ptr src); - friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); - friend class IRBuilder; + public: std::string string_raw; size_t const_str_id; static std::string Escape(const std::string &src) { @@ -474,7 +470,6 @@ class ConstStrItem : public LLVMIRItemBase { return ss.str(); } - public: ConstStrItem() = default; void RecursivePrint(std::ostream &os) const { os << "@.str." << const_str_id << " = private unnamed_addr constant [" << string_raw.size() + 1 << " x i8] c\"" diff --git a/include/naivebackend/naivebackend.h b/include/naivebackend/naivebackend.h index 3723554..b14a91c 100644 --- a/include/naivebackend/naivebackend.h +++ b/include/naivebackend/naivebackend.h @@ -54,8 +54,6 @@ class RISCVGlobalVarItem : public RISCVAsmItemBase { } }; class RISCVFuncItem : public RISCVAsmItemBase { - friend void ::GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); - public: std::string full_label; std::vector code_lines; @@ -93,20 +91,12 @@ class RISCVProgItem : public RISCVAsmItemBase { } }; class FuncLayout { - friend void ::GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); - friend void GenerateReadAccess(std::string val, size_t bytes, std::string output_reg, FuncLayout &layout, - std::vector &code_lines); - friend void GenerateWriteAccess(std::string val, size_t bytes, std::string data_reg, FuncLayout &layout, - std::vector &code_lines); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::unordered_map local_items; std::unordered_map arg_offset; size_t cur_pos; size_t total_frame_size; // should align to 16 bytes - public: + FuncLayout() : cur_pos(8), total_frame_size(16) {} void AllocateItem(const std::string &name, size_t sz, size_t num = 1) { if (local_items.find(name) != local_items.end()) throw std::runtime_error("Local item already exists"); diff --git a/include/opt/cfg.h b/include/opt/cfg.h index c142057..965e0d5 100644 --- a/include/opt/cfg.h +++ b/include/opt/cfg.h @@ -173,6 +173,6 @@ const static std::vector allocating_regs = {"x3", "x4", "x9", "x "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17"}; inline bool VRegCheck(const std::string &s) { - if (s[0] != '%') return false; + if (s[0] != '%' && s[0] != '$' && s[0] != '#') return false; return true; } \ No newline at end of file diff --git a/include/opt/gen.h b/include/opt/gen.h new file mode 100644 index 0000000..f7ccd52 --- /dev/null +++ b/include/opt/gen.h @@ -0,0 +1,461 @@ +#pragma once +#include +#include +#include +#include +#include "IR/IR_basic.h" +#include "cfg.h" +#include "liveanalysis.h" + +namespace OptBackend { +class RISCVAsmItemBase { + public: + RISCVAsmItemBase() = default; + virtual ~RISCVAsmItemBase() = default; + virtual void RecursivePrint(std::ostream &os) const = 0; +}; +class RISCVConstStrItem : public RISCVAsmItemBase { + public: + std::string full_label; + std::string content; + RISCVConstStrItem() = default; + ~RISCVConstStrItem() = default; + void RecursivePrint(std::ostream &os) const override { + os << full_label << ":\n"; + os << " .asciz \""; + for (auto c : content) { + if (c == '\n') { + os << "\\n"; + } else if (c == '\t') { + os << "\\t"; + } else if (c == '\"') { + os << "\\\""; + } else if (c == '\\') { + os << "\\\\"; + } else { + os << c; + } + } + os << "\"\n"; + } +}; +class RISCVGlobalVarItem : public RISCVAsmItemBase { + public: + std::string full_label; + RISCVGlobalVarItem() = default; + ~RISCVGlobalVarItem() = default; + void RecursivePrint(std::ostream &os) const override { + os << ".globl " << full_label << "\n"; + os << ".p2align 2, 0x0\n"; + os << full_label << ":\n"; + os << " .word 0\n"; + } +}; +class RISCVFuncItem : public RISCVAsmItemBase { + public: + std::string full_label; + std::vector code_lines; + RISCVFuncItem() = default; + ~RISCVFuncItem() = default; + void RecursivePrint(std::ostream &os) const override { + os << ".globl " << full_label << "\n"; + os << ".p2align 2, 0x0\n"; + os << full_label << ":\n"; + for (auto &line : code_lines) { + os << line << "\n"; + } + } +}; +class RISCVProgItem : public RISCVAsmItemBase { + public: + std::vector> const_strs; + std::vector> global_vars; + std::vector> funcs; + RISCVProgItem() = default; + ~RISCVProgItem() = default; + void RecursivePrint(std::ostream &os) const override { + os << ".section .rodata\n"; + for (auto &item : const_strs) { + item->RecursivePrint(os); + } + os << ".section .sbss\n"; + for (auto &item : global_vars) { + item->RecursivePrint(os); + } + os << ".section .text\n"; + for (auto &item : funcs) { + item->RecursivePrint(os); + } + } +}; +class FuncLayout { + public: + std::unordered_map local_items; + std::unordered_map arg_offset; + size_t cur_pos; + size_t total_frame_size; // should align to 16 bytes + + FuncLayout() : cur_pos(12), total_frame_size(16) {} + void AllocateItem(const std::string &name, size_t sz, size_t num = 1) { + if (local_items.find(name) != local_items.end()) throw std::runtime_error("Local item already exists"); + if (cur_pos % sz != 0) { + cur_pos += sz - cur_pos % sz; + } + cur_pos += sz * num; + local_items[name] = cur_pos; + total_frame_size = ((cur_pos + 15) / 16) * 16; + std::cerr << "allocating stack memory for " << name << " at " << cur_pos << std::endl; + } + size_t QueryOffeset(const std::string &name) { + if (local_items.find(name) == local_items.end()) throw std::runtime_error("Local item not found"); + return local_items[name]; + } + size_t QueryFrameSize() const { return total_frame_size; } +}; + +inline void StoreImmToReg(int imm, std::string reg, std::vector &code_lines) { + code_lines.push_back("li " + reg + ", " + std::to_string(imm)); +} + +void GenerateOptASM(std::ostream &os, std::shared_ptr prog); + +extern std::string cur_block_label_for_phi; + +inline std::string AllocateTmpReg(std::vector &available_tmp_regs) { + if (available_tmp_regs.size() == 0) throw std::runtime_error("No available tmp register"); + std::string res; + res = available_tmp_regs.back(); + available_tmp_regs.pop_back(); + return res; +} + +inline std::string ExtractRegName(const std::string &raw) { + if (raw[0] != '$') throw std::runtime_error("Not a register"); + size_t reg_id = std::stoull(raw.substr(5)); + return "x" + std::to_string(reg_id); +} + +inline void FetchValueToReg(std::string original_val, std::string &out_reg, FuncLayout &layout, + std::vector &code_lines, std::vector &available_tmp_regs) { + if (original_val[0] == '$') { + // already assigned to a register, such as `$reg.10` + out_reg = ExtractRegName(original_val); + } else if (original_val[0] == '#') { + // spilled variable, we need find it in the layout + size_t offset = layout.QueryOffeset(original_val); + out_reg = AllocateTmpReg(available_tmp_regs); + if (offset < 2048) { + code_lines.push_back("lw " + out_reg + ", -" + std::to_string(offset) + "(s0)"); + } else { + code_lines.push_back("li " + out_reg + ", -" + std::to_string(offset)); + code_lines.push_back("add " + out_reg + ", s0, " + out_reg); + code_lines.push_back("lw " + out_reg + ", 0(" + out_reg + ")"); + } + } else if (original_val[0] == '@') { + // global variable address keeper + out_reg = AllocateTmpReg(available_tmp_regs); + std::string label_in_asm = original_val.substr(1, original_val.size() - 1); + code_lines.push_back("la " + out_reg + ", " + label_in_asm); + } else if (original_val[0] == '-' || std::isdigit(original_val[0])) { + // immediate value + out_reg = AllocateTmpReg(available_tmp_regs); + StoreImmToReg(std::stoi(original_val), out_reg, code_lines); + } else { + throw std::runtime_error("Unknown value type"); + } +} + +inline void WriteToSpilledVar(std::string val, std::string reg, FuncLayout &layout, + std::vector &code_lines, std::vector &available_tmp_regs) { + if (val[0] != '#') throw std::runtime_error("Not a spilled variable"); + size_t offset = layout.QueryOffeset(val); + if (offset < 2048) { + code_lines.push_back("sw " + reg + ", -" + std::to_string(offset) + "(s0)"); + } else { + std::string tmp_reg = AllocateTmpReg(available_tmp_regs); + code_lines.push_back("li " + tmp_reg + ", -" + std::to_string(offset)); + code_lines.push_back("add " + tmp_reg + ", s0, " + tmp_reg); + code_lines.push_back("sw " + reg + ", 0(" + tmp_reg + ")"); + } +} +inline size_t CalcSize(const LLVMType &tp) { + if (std::holds_alternative(tp)) { + auto &int_tp = std::get(tp); + return (int_tp.bits + 7) / 8; + } else if (std::holds_alternative(tp)) { + return 4; + } else if (std::holds_alternative(tp)) { + throw std::runtime_error("Cannot calculate size of void type"); + return 0; + } else if (std::holds_alternative(tp)) { + throw std::runtime_error("Cannot calculate size of class type"); + } else + throw std::runtime_error("Unknown type"); +} +inline void GenerateASM(std::shared_ptr act, std::vector &code_lines, FuncLayout &layout, + const std::unordered_map &low_level_class_info) { + std::vector available_tmp_regs = held_tmp_regs; + if (auto br_act = std::dynamic_pointer_cast(act)) { + std::string cond_reg; + FetchValueToReg(br_act->cond, cond_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("bnez " + cond_reg + ", .entrylabel." + br_act->true_label_full); + code_lines.push_back("j .entrylabel." + br_act->false_label_full); + } else if (auto jmp_act = std::dynamic_pointer_cast(act)) { + code_lines.push_back("j .entrylabel." + jmp_act->label_full); + } else if (auto ret_act = std::dynamic_pointer_cast(act)) { + code_lines.push_back("lw ra, -4(s0)"); + code_lines.push_back("lw sp, -12(s0)"); + code_lines.push_back("lw s0, -8(s0)"); + code_lines.push_back("ret"); + } else if (auto binary_act = std::dynamic_pointer_cast(act)) { + // size_t sz = CalcSize(binary_act->type); + // IRVar2RISCVReg(binary_act->operand1_full, sz, "t0", layout, code_lines); + // IRVar2RISCVReg(binary_act->operand2_full, sz, "t1", layout, code_lines); + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); + std::string res_reg; + bool need_extra_store = false; + if (binary_act->result_full[0] == '$') { + res_reg = ExtractRegName(binary_act->result_full); + } else if (binary_act->result_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else { + throw std::runtime_error("Unknown result type"); + } + if (binary_act->op == "add") { + code_lines.push_back("add " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "sub") { + code_lines.push_back("sub " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "mul") { + code_lines.push_back("mul " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "sdiv") { + code_lines.push_back("div " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "srem") { + code_lines.push_back("rem " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "and") { + code_lines.push_back("and " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "or") { + code_lines.push_back("or " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "xor") { + code_lines.push_back("xor " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "shl") { + code_lines.push_back("sll " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (binary_act->op == "ashr") { + code_lines.push_back("sra " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else { + throw std::runtime_error("Unknown binary operation"); + } + if (need_extra_store) { + WriteToSpilledVar(binary_act->result_full, res_reg, layout, code_lines, available_tmp_regs); + } + } else if (auto alloca_act = std::dynamic_pointer_cast(act)) { + std::string res_reg; + bool need_extra_store = false; + if (alloca_act->name_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else if (alloca_act->name_full[0] == '$') { + res_reg = ExtractRegName(alloca_act->name_full); + } else { + throw std::runtime_error("Unknown result type"); + } + size_t sz = CalcSize(alloca_act->type) * alloca_act->num; + sz = (sz + 15) / 16 * 16; + code_lines.push_back("addi sp, sp, -" + std::to_string(sz)); + if (!need_extra_store) { + code_lines.push_back("mv " + res_reg + ", sp"); + } else { + WriteToSpilledVar(alloca_act->name_full, "sp", layout, code_lines, available_tmp_regs); + } + } else if (auto load_act = std::dynamic_pointer_cast(act)) { + std::string res_reg; + bool need_extra_store = false; + if (load_act->result_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else if (load_act->result_full[0] == '$') { + res_reg = ExtractRegName(load_act->result_full); + } else { + throw std::runtime_error("Unknown result type"); + } + std::string ptr_reg; + FetchValueToReg(load_act->ptr_full, ptr_reg, layout, code_lines, available_tmp_regs); + if (CalcSize(load_act->ty) == 4) { + code_lines.push_back("lw " + res_reg + ", 0(" + ptr_reg + ")"); + } else if (CalcSize(load_act->ty) == 1) { + code_lines.push_back("lb " + res_reg + ", 0(" + ptr_reg + ")"); + } else { + throw std::runtime_error("Unknown bytes"); + } + if (need_extra_store) { + WriteToSpilledVar(load_act->result_full, res_reg, layout, code_lines, available_tmp_regs); + } + } else if (auto store_act = std::dynamic_pointer_cast(act)) { + std::string val_reg; + std::string ptr_reg; + FetchValueToReg(store_act->value_full, val_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(store_act->ptr_full, ptr_reg, layout, code_lines, available_tmp_regs); + if (CalcSize(store_act->ty) == 4) { + code_lines.push_back("sw " + val_reg + ", 0(" + ptr_reg + ")"); + } else if (CalcSize(store_act->ty) == 1) { + code_lines.push_back("sb " + val_reg + ", 0(" + ptr_reg + ")"); + } else { + throw std::runtime_error("Unknown bytes"); + } + } else if (auto get_element_act = std::dynamic_pointer_cast(act)) { + if (get_element_act->indices.size() == 1) { + // array access + std::string res_reg; + bool need_extra_store = false; + if (get_element_act->result_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else if (get_element_act->result_full[0] == '$') { + res_reg = ExtractRegName(get_element_act->result_full); + } else { + throw std::runtime_error("Unknown result type"); + } + std::string ptr_reg; + std::string idx_reg; + FetchValueToReg(get_element_act->ptr_full, ptr_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(get_element_act->indices[0], idx_reg, layout, code_lines, available_tmp_regs); + std::string tmp_reg = AllocateTmpReg(available_tmp_regs); + size_t element_sz = CalcSize(get_element_act->ty); + code_lines.push_back("slli " + tmp_reg + ", " + idx_reg + ", " + std::to_string(std::countr_zero(element_sz))); + code_lines.push_back("add " + res_reg + ", " + ptr_reg + ", " + tmp_reg); + if (need_extra_store) { + WriteToSpilledVar(get_element_act->result_full, res_reg, layout, code_lines, available_tmp_regs); + } + } else if (get_element_act->indices.size() == 2) { + // // struct access + if (get_element_act->indices[0] != "0") { + throw std::runtime_error("struct access with non-zero offset is not supported"); + } + size_t element_idx = std::stoull(get_element_act->indices[1]); + auto class_ty = std::get(get_element_act->ty); + const IRClassInfo &class_info = low_level_class_info.at(class_ty.class_name_full); + size_t offset = class_info.member_var_pos_after_align[element_idx]; + std::string res_reg; + bool need_extra_store = false; + if (get_element_act->result_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else if (get_element_act->result_full[0] == '$') { + res_reg = ExtractRegName(get_element_act->result_full); + } else { + throw std::runtime_error("Unknown result type"); + } + std::string base_ptr_reg; + FetchValueToReg(get_element_act->ptr_full, base_ptr_reg, layout, code_lines, available_tmp_regs); + if (offset < 2048) { + code_lines.push_back("addi " + res_reg + ", " + base_ptr_reg + ", " + std::to_string(offset)); + } else { + std::string tmp_reg = AllocateTmpReg(available_tmp_regs); + code_lines.push_back("li " + tmp_reg + ", " + std::to_string(offset)); + code_lines.push_back("add " + res_reg + ", " + base_ptr_reg + ", " + tmp_reg); + } + if (need_extra_store) { + WriteToSpilledVar(get_element_act->result_full, res_reg, layout, code_lines, available_tmp_regs); + } + } else { + throw std::runtime_error("Unknown getelementptr indices size"); + } + } else if (auto icmp_act = std::dynamic_pointer_cast(act)) { + std::string operand1_reg, operand2_reg; + FetchValueToReg(icmp_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(icmp_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); + std::string res_reg; + bool need_extra_store = false; + if (icmp_act->result_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else if (icmp_act->result_full[0] == '$') { + res_reg = ExtractRegName(icmp_act->result_full); + } else { + throw std::runtime_error("Unknown result type"); + } + std::string tmp_reg = AllocateTmpReg(available_tmp_regs); + if (icmp_act->op == "eq") { + // code_lines.push_back("xor t2, t0, t1"); + // code_lines.push_back("seqz t2, t2"); + code_lines.push_back("xor " + tmp_reg + ", " + operand1_reg + ", " + operand2_reg); + code_lines.push_back("seqz " + res_reg + ", " + tmp_reg); + } else if (icmp_act->op == "ne") { + // code_lines.push_back("xor t2, t0, t1"); + // code_lines.push_back("snez t2, t2"); + code_lines.push_back("xor " + tmp_reg + ", " + operand1_reg + ", " + operand2_reg); + code_lines.push_back("snez " + res_reg + ", " + tmp_reg); + } else if (icmp_act->op == "slt") { + // code_lines.push_back("slt t2, t0, t1"); + code_lines.push_back("slt " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } else if (icmp_act->op == "sle") { + // code_lines.push_back("slt t2, t1, t0"); + // code_lines.push_back("xori t2, t2, 1"); + code_lines.push_back("slt " + res_reg + ", " + operand2_reg + ", " + operand1_reg); + code_lines.push_back("xori " + res_reg + ", " + res_reg + ", 1"); + } else if (icmp_act->op == "sgt") { + // code_lines.push_back("slt t2, t1, t0"); + code_lines.push_back("slt " + res_reg + ", " + operand2_reg + ", " + operand1_reg); + } else if (icmp_act->op == "sge") { + // code_lines.push_back("slt t2, t0, t1"); + // code_lines.push_back("xori t2, t2, 1"); + code_lines.push_back("slt " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + code_lines.push_back("xori " + res_reg + ", " + res_reg + ", 1"); + } else { + throw std::runtime_error("Unknown icmp operation"); + } + if (need_extra_store) { + WriteToSpilledVar(icmp_act->result_full, res_reg, layout, code_lines, available_tmp_regs); + } + } else if (auto call_act = std::dynamic_pointer_cast(act)) { + // no need to to further process, as callling convention is handled in reg alloc + code_lines.push_back("call " + call_act->func_name_raw); + } else if (auto phi_act = std::dynamic_pointer_cast(act)) { + throw std::runtime_error("Phi should not be in the layout"); + } else if (auto select_act = std::dynamic_pointer_cast(act)) { + std::string res_reg; + bool need_extra_store = false; + if (select_act->result_full[0] == '#') { + need_extra_store = true; + res_reg = AllocateTmpReg(available_tmp_regs); + } else if (select_act->result_full[0] == '$') { + res_reg = ExtractRegName(select_act->result_full); + } else { + throw std::runtime_error("Unknown result type"); + } + std::string operand1_reg, operand2_reg, cond_reg; + FetchValueToReg(select_act->cond_full, cond_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(select_act->true_val_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(select_act->false_val_full, operand2_reg, layout, code_lines, available_tmp_regs); + std::string tmp1_reg = AllocateTmpReg(available_tmp_regs); + std::string tmp2_reg = AllocateTmpReg(available_tmp_regs); + code_lines.push_back("slli " + tmp1_reg + ", " + cond_reg + ", 31"); + code_lines.push_back("srai " + tmp1_reg + ", " + tmp1_reg + ", 31"); + code_lines.push_back("xor " + tmp2_reg + ", " + operand1_reg + ", " + operand2_reg); + code_lines.push_back("and " + tmp2_reg + ", " + tmp2_reg + ", " + tmp1_reg); + code_lines.push_back("xor " + res_reg + ", " + tmp2_reg + ", " + operand2_reg); + if (need_extra_store) { + WriteToSpilledVar(select_act->result_full, res_reg, layout, code_lines, available_tmp_regs); + } + } else if (auto load_spilled_args_act = std::dynamic_pointer_cast(act)) { + throw std::runtime_error("Not implemented"); + } else if (auto store_spilled_args_act = std::dynamic_pointer_cast(act)) { + throw std::runtime_error("Not implemented"); + } else if (auto move_act = std::dynamic_pointer_cast(act)) { + std::string src_reg; + FetchValueToReg(move_act->src_full, src_reg, layout, code_lines, available_tmp_regs); + if (move_act->dest_full[0] == '$') { + std::string dest_reg = ExtractRegName(move_act->dest_full); + code_lines.push_back("mv " + dest_reg + ", " + src_reg); + } else if (move_act->dest_full[0] == '#') { + WriteToSpilledVar(move_act->dest_full, src_reg, layout, code_lines, available_tmp_regs); + } else { + throw std::runtime_error("Unknown dest type"); + } + } else { + throw std::runtime_error("Unknown action type"); + } +} +} // namespace OptBackend \ No newline at end of file diff --git a/include/opt/opt.h b/include/opt/opt.h index 503f478..512f418 100644 --- a/include/opt/opt.h +++ b/include/opt/opt.h @@ -1,3 +1,4 @@ #include "cfg.h" +#include "gen.h" #include "mem2reg.h" #include "regalloc.h" \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index d070a68..7c24bc2 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,10 +46,11 @@ int main(int argc, char **argv) { GenerateNaiveASM(fout, IR); } else { auto IR_with_out_allocas = Mem2Reg(IR); - IR_with_out_allocas->RecursivePrint(fout); + // IR_with_out_allocas->RecursivePrint(fout); auto IR_with_out_phis = PhiEliminate(IR_with_out_allocas); // IR_with_out_phis->RecursivePrint(fout); auto alloced_code = RegAlloc(IR_with_out_phis); + OptBackend::GenerateOptASM(fout, alloced_code); } } catch (const SemanticError &err) { std::cout << err.what() << std::endl; diff --git a/src/opt/optbackend.cpp b/src/opt/optbackend.cpp new file mode 100644 index 0000000..e7fea4f --- /dev/null +++ b/src/opt/optbackend.cpp @@ -0,0 +1,95 @@ +#include "opt/gen.h" +namespace OptBackend { +std::string cur_block_label_for_phi; +void GenerateOptASM(std::ostream &os, std::shared_ptr prog) { + auto riscv = std::make_shared(); + + for (auto conststr : prog->const_strs) { + auto asm_item = std::make_shared(); + riscv->const_strs.push_back(asm_item); + asm_item->content = conststr->string_raw; + asm_item->full_label = ".str." + std::to_string(conststr->const_str_id); + } + for (auto global_var : prog->global_var_defs) { + auto asm_item = std::make_shared(); + riscv->global_vars.push_back(asm_item); + asm_item->full_label = ".var.global." + global_var->name_raw + ".addrkp"; + } + std::unordered_map func_layouts; + for (auto func_def : prog->function_defs) { + // if (func_def->init_block) { + // for (auto act : func_def->init_block->actions) { + // ScanForVar(func_layouts[func_def->func_name_raw], act, prog->low_level_class_info); + // } + // } + // for (auto block : func_def->basic_blocks) { + // for (auto act : block->actions) { + // ScanForVar(func_layouts[func_def->func_name_raw], act, prog->low_level_class_info); + // } + // } + FuncLayout &layout = func_layouts[func_def->func_name_raw]; + // for (size_t i = 0; i < func_def->args_full_name.size(); i++) { + // layout.arg_offset[func_def->args_full_name[i]] = i; + // } + for (size_t i = 0; i < func_def->spilled_vars; i++) { + layout.AllocateItem("#" + std::to_string(i), 4, 1); + } + // debug: + + // std::cerr << "layout info of function " << func_def->func_name_raw << std::endl; + // std::cerr << "\tcur_pos=" << layout.cur_pos << std::endl; + // std::cerr << "\ttotal_frame_size=" << layout.total_frame_size << std::endl; + // for (const auto &item : layout.local_items) { + // std::cerr << "\t" << item.first << " " << item.second << std::endl; + // } + } + + for (auto func_def : prog->function_defs) { + std::cerr << "generating asm for function " << func_def->func_name_raw << std::endl; + auto func_asm = std::make_shared(); + riscv->funcs.push_back(func_asm); + func_asm->full_label = func_def->func_name_raw; + FuncLayout &layout = func_layouts[func_def->func_name_raw]; + if (layout.total_frame_size < 2048) { + func_asm->code_lines.push_back("addi sp, sp, -" + std::to_string(layout.total_frame_size)); + func_asm->code_lines.push_back("sw ra, " + std::to_string(layout.total_frame_size - 4) + "(sp)"); + func_asm->code_lines.push_back("sw s0, " + std::to_string(layout.total_frame_size - 8) + "(sp)"); + func_asm->code_lines.push_back("addi s0, sp, " + std::to_string(layout.total_frame_size)); + func_asm->code_lines.push_back("sw s0, " + std::to_string(layout.total_frame_size - 12) + "(sp)"); + } else { + func_asm->code_lines.push_back("li x31, " + std::to_string(layout.total_frame_size)); + func_asm->code_lines.push_back("sub sp, sp, x31"); + func_asm->code_lines.push_back("add x31, x31, sp"); + func_asm->code_lines.push_back("sw ra, -4(x31)"); + func_asm->code_lines.push_back("sw s0, -8(x31)"); + func_asm->code_lines.push_back("sw x31, -12(x31)"); + func_asm->code_lines.push_back("mv s0, t0"); + } + if (func_def->init_block) { + func_asm->code_lines.push_back(".entrylabel." + func_def->init_block->label_full + ":"); + for (auto act : func_def->init_block->actions) { + OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw], + prog->low_level_class_info); + } + if (func_def->init_block->exit_action->corresponding_phi) { + OptBackend::cur_block_label_for_phi = func_def->init_block->label_full; + OptBackend::GenerateASM(func_def->init_block->exit_action->corresponding_phi, func_asm->code_lines, + func_layouts[func_def->func_name_raw], prog->low_level_class_info); + } + OptBackend::GenerateASM(func_def->init_block->exit_action, func_asm->code_lines, + func_layouts[func_def->func_name_raw], prog->low_level_class_info); + } + for (auto block : func_def->basic_blocks) { + func_asm->code_lines.push_back(".entrylabel." + block->label_full + ":"); + for (auto act : block->actions) { + OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw], + prog->low_level_class_info); + } + OptBackend::GenerateASM(block->exit_action, func_asm->code_lines, func_layouts[func_def->func_name_raw], + prog->low_level_class_info); + } + } + + riscv->RecursivePrint(os); +} +} // namespace OptBackend \ No newline at end of file diff --git a/src/opt/regalloc.cpp b/src/opt/regalloc.cpp index b03ae45..a9ebb10 100644 --- a/src/opt/regalloc.cpp +++ b/src/opt/regalloc.cpp @@ -154,12 +154,12 @@ void TranslateColorResult(std::shared_ptr func, CFGType &cfg, C var = "$reg." + std::to_string(confnode->color); } }; + func->spilled_vars = 0; for (auto node : cfg.nodes) { auto block = node->corresponding_block; std::vector cur_node_use; std::vector cur_node_def; bool use_def_init = false; - func->spilled_vars = 0; for (auto act : block->actions) { if (auto br_act = std::dynamic_pointer_cast(act)) { if (var_to_id.find(br_act->cond) != var_to_id.end()) { @@ -284,11 +284,12 @@ void TranslateColorResult(std::shared_ptr func, CFGType &cfg, C if (move_act->src_full == move_act->dest_full) { need_remove = true; } - } else if (auto force_def_act = std::dynamic_pointer_cast(*act_it)) { - need_remove = true; - } else if (auto force_use_act = std::dynamic_pointer_cast(*act_it)) { - need_remove = true; } + // else if (auto force_def_act = std::dynamic_pointer_cast(*act_it)) { + // need_remove = true; + // } else if (auto force_use_act = std::dynamic_pointer_cast(*act_it)) { + // need_remove = true; + // } if (need_remove) { auto it_next = act_it; ++it_next; @@ -327,6 +328,22 @@ void PairMoveEliminate(std::shared_ptr func, CFGType &cfg, Conf } } } + +void RemoveCallingConventionKeeper(std::shared_ptr func, CFGType &cfg, ConfGraph &confgraph) { + for (auto node : cfg.nodes) { + auto block = node->corresponding_block; + std::vector>::iterator> act_to_move; + for (auto it = block->actions.begin(); it != block->actions.end(); ++it) { + if (std::dynamic_pointer_cast(*it) || std::dynamic_pointer_cast(*it)) { + act_to_move.push_back(it); + } + } + for (auto it : act_to_move) { + block->actions.erase(it); + } + } +} + void ConductRegAllocForFunction(std::shared_ptr func) { std::cerr << "processing function " << func->func_name_raw << std::endl; CFGType cfg; @@ -342,7 +359,7 @@ void ConductRegAllocForFunction(std::shared_ptr func) { confgraph = BuildConfGraph(cfg); } while (ConductColoring(func, cfg, confgraph)); TranslateColorResult(func, cfg, confgraph); - // PairMoveEliminate(func, cfg, confgraph); + RemoveCallingConventionKeeper(func, cfg, confgraph); func->RecursivePrint(std::cerr); }