From c715ae3f60cbbff2f36813bec868f815abdd317a Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Wed, 23 Oct 2024 16:18:43 +0000 Subject: [PATCH] add imm support for mv, add and sub --- src/opt/gen.cpp | 94 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 84 insertions(+), 10 deletions(-) diff --git a/src/opt/gen.cpp b/src/opt/gen.cpp index 928f872..af8bb17 100644 --- a/src/opt/gen.cpp +++ b/src/opt/gen.cpp @@ -1,4 +1,5 @@ #include "gen.h" +#include namespace OptBackend { void FetchValueToReg(std::string original_val, std::string &out_reg, FuncLayout &layout, @@ -61,9 +62,13 @@ size_t CalcSize(const LLVMType &tp) { throw std::runtime_error("Unknown type"); } - +bool CanActivateImmSupport(const std::string &val, int64_t min_val, int64_t max_val) { + if (val[0] != '-' && !std::isdigit(val[0])) return false; + int64_t num = std::stoll(val); + return num >= min_val && num <= max_val; +} void GenerateASM(std::shared_ptr act, std::vector &code_lines, FuncLayout &layout, - const std::unordered_map &low_level_class_info) { + const std::unordered_map &low_level_class_info) { std::vector available_tmp_regs = held_tmp_regs; if (auto br_act = std::dynamic_pointer_cast(act)) { std::string cond_reg; @@ -81,9 +86,8 @@ void GenerateASM(std::shared_ptr act, std::vector &code // size_t sz = CalcSize(binary_act->type); // IRVar2RISCVReg(binary_act->operand1_full, sz, "t0", layout, code_lines); // IRVar2RISCVReg(binary_act->operand2_full, sz, "t1", layout, code_lines); - std::string operand1_reg, operand2_reg; - FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); - FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); + if (binary_act->operand1_full == "null") binary_act->operand1_full = "0"; + if (binary_act->operand2_full == "null") binary_act->operand2_full = "0"; std::string res_reg; bool need_extra_store = false; if (binary_act->result_full[0] == '$') { @@ -95,24 +99,87 @@ void GenerateASM(std::shared_ptr act, std::vector &code throw std::runtime_error("Unknown result type"); } if (binary_act->op == "add") { - code_lines.push_back("add " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + bool ope1_is_imm = CanActivateImmSupport(binary_act->operand1_full, -2048, 2047); + bool ope2_is_imm = CanActivateImmSupport(binary_act->operand2_full, -2048, 2047); + bool is_constant = CanActivateImmSupport(binary_act->operand1_full, std::numeric_limits::min(), + std::numeric_limits::max()) && + CanActivateImmSupport(binary_act->operand2_full, std::numeric_limits::min(), + std::numeric_limits::max()); + if (is_constant) { + int32_t result = std::stoi(binary_act->operand1_full) + std::stoi(binary_act->operand2_full); + StoreImmToReg(result, res_reg, code_lines); + } else if (ope1_is_imm || ope2_is_imm) { + if (!ope1_is_imm) { + std::swap(binary_act->operand1_full, binary_act->operand2_full); + } + std::string operand2_reg; + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("addi " + res_reg + ", " + operand2_reg + ", " + binary_act->operand1_full); + } else { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("add " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } } else if (binary_act->op == "sub") { - code_lines.push_back("sub " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + bool ope2_is_imm = CanActivateImmSupport(binary_act->operand2_full, -2048, 2047); + bool is_constant = CanActivateImmSupport(binary_act->operand1_full, std::numeric_limits::min(), + std::numeric_limits::max()) && + CanActivateImmSupport(binary_act->operand2_full, std::numeric_limits::min(), + std::numeric_limits::max()); + if (is_constant) { + int32_t result = std::stoi(binary_act->operand1_full) - std::stoi(binary_act->operand2_full); + StoreImmToReg(result, res_reg, code_lines); + } else if (ope2_is_imm) { + std::string operand1_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("addi " + res_reg + ", " + operand1_reg + ", " + + std::to_string(-std::stoi(binary_act->operand2_full))); + } else { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("sub " + res_reg + ", " + operand1_reg + ", " + operand2_reg); + } } else if (binary_act->op == "mul") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("mul " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "sdiv") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("div " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "srem") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("rem " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "and") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("and " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "or") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("or " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "xor") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("xor " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "shl") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("sll " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else if (binary_act->op == "ashr") { + std::string operand1_reg, operand2_reg; + FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs); code_lines.push_back("sra " + res_reg + ", " + operand1_reg + ", " + operand2_reg); } else { throw std::runtime_error("Unknown binary operation"); @@ -339,12 +406,19 @@ void GenerateASM(std::shared_ptr act, std::vector &code size_t offset = 4 * (store_spilled_args_act->arg_id - 8); code_lines.push_back("sw " + val_reg + ", " + std::to_string(offset) + "(sp)"); } else if (auto move_act = std::dynamic_pointer_cast(act)) { + if (move_act->src_full == "null") move_act->src_full = "0"; std::string src_reg; - FetchValueToReg(move_act->src_full, src_reg, layout, code_lines, available_tmp_regs); if (move_act->dest_full[0] == '$') { - std::string dest_reg = ExtractRegName(move_act->dest_full); - code_lines.push_back("mv " + dest_reg + ", " + src_reg); + if (move_act->src_full[0] == '-' || std::isdigit(move_act->src_full[0])) { + std::string dest_reg = ExtractRegName(move_act->dest_full); + StoreImmToReg(std::stoi(move_act->src_full), dest_reg, code_lines); + } else { + FetchValueToReg(move_act->src_full, src_reg, layout, code_lines, available_tmp_regs); + std::string dest_reg = ExtractRegName(move_act->dest_full); + code_lines.push_back("mv " + dest_reg + ", " + src_reg); + } } else if (move_act->dest_full[0] == '#') { + FetchValueToReg(move_act->src_full, src_reg, layout, code_lines, available_tmp_regs); WriteToSpilledVar(move_act->dest_full, src_reg, layout, code_lines, available_tmp_regs); } else { throw std::runtime_error("Unknown dest type");