From 1b75ca5e7580211d7ff129a294debbde6df8e0df Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Tue, 29 Oct 2024 15:25:15 +0000 Subject: [PATCH] use better BR --- include/opt/betterbr.h | 4 ++ include/opt/cfg.h | 39 ++++++++++++++++++- include/opt/opt.h | 3 +- src/main.cpp | 1 + src/opt/betterbr.cpp | 81 ++++++++++++++++++++++++++++++++++++++++ src/opt/gen.cpp | 34 +++++++++++++++-- src/opt/liveanalysis.cpp | 61 ++++++++++++++++++++++++++---- src/opt/regalloc.cpp | 42 +++++++++++++++++---- 8 files changed, 243 insertions(+), 22 deletions(-) create mode 100644 include/opt/betterbr.h create mode 100644 src/opt/betterbr.cpp diff --git a/include/opt/betterbr.h b/include/opt/betterbr.h new file mode 100644 index 0000000..b091a51 --- /dev/null +++ b/include/opt/betterbr.h @@ -0,0 +1,4 @@ +#pragma once +#include "cfg.h" + +std::shared_ptr GenerateBetterBR(std::shared_ptr src); \ No newline at end of file diff --git a/include/opt/cfg.h b/include/opt/cfg.h index f49807c..59081ea 100644 --- a/include/opt/cfg.h +++ b/include/opt/cfg.h @@ -5,6 +5,7 @@ #include #include #include "IR/IR_basic.h" +#include "ast/expr_astnode.h" using CFGNodeCollection = std::list; class CFGNodeType { public: @@ -175,4 +176,40 @@ const static std::vector allocating_regs = {"x3", "x4", "x9", "x inline bool VRegCheck(const std::string &s) { if (s[0] != '%' && s[0] != '$' && s[0] != '#') return false; return true; -} \ No newline at end of file +} + +class BEQAction : public BRAction { + public: + std::string rs1, rs2; + void RecursivePrint(std::ostream &os) const { + os << "beq " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %" + << false_label_full << "\n"; + } +}; + +class BNEAction : public BRAction { + public: + std::string rs1, rs2; + void RecursivePrint(std::ostream &os) const { + os << "bne " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %" + << false_label_full << "\n"; + } +}; + +class BLTAction : public BRAction { + public: + std::string rs1, rs2; + void RecursivePrint(std::ostream &os) const { + os << "blt " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %" + << false_label_full << "\n"; + } +}; + +class BGEAction : public BRAction { + public: + std::string rs1, rs2; + void RecursivePrint(std::ostream &os) const { + os << "bge " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %" + << false_label_full << "\n"; + } +}; \ No newline at end of file diff --git a/include/opt/opt.h b/include/opt/opt.h index 23353c6..c969236 100644 --- a/include/opt/opt.h +++ b/include/opt/opt.h @@ -3,4 +3,5 @@ #include "global_var_cache.h" #include "mem2reg.h" #include "regalloc.h" -#include "dce.h" \ No newline at end of file +#include "dce.h" +#include "betterbr.h" \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 1c2b87a..2d6af18 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -49,6 +49,7 @@ int main(int argc, char **argv) { auto IR_with_out_allocas = Mem2Reg(IR); // IR_with_out_allocas->RecursivePrint(std::cerr); IR_with_out_allocas = GloabalVarCache(IR_with_out_allocas); + IR_with_out_allocas = GenerateBetterBR(IR_with_out_allocas); IR_with_out_allocas = DCE(IR_with_out_allocas); // IR_with_out_allocas->RecursivePrint(std::cerr); auto IR_with_out_phis = PhiEliminate(IR_with_out_allocas); diff --git a/src/opt/betterbr.cpp b/src/opt/betterbr.cpp new file mode 100644 index 0000000..7e390c0 --- /dev/null +++ b/src/opt/betterbr.cpp @@ -0,0 +1,81 @@ +#include "betterbr.h" +#include "IR/IR_basic.h" +#include "cfg.h" + +void GenerateBetterBRForFunction(std::shared_ptr func); + +void GenerateBetterBRForBlock(std::shared_ptr block); + +void GenerateBetterBRForBlock(std::shared_ptr block) { + std::unordered_map> icmp_map; + for (auto act : block->actions) { + if (auto icmp_act = std::dynamic_pointer_cast(act)) { + icmp_map[icmp_act->result_full] = icmp_act; + } + } + if (auto br_act = std::dynamic_pointer_cast(block->exit_action)) { + if (icmp_map.find(br_act->cond) != icmp_map.end()) { + auto icmp_act = icmp_map[br_act->cond]; + if (icmp_act->op == "eq") { + auto new_br = std::make_shared(); + new_br->true_label_full = br_act->true_label_full; + new_br->false_label_full = br_act->false_label_full; + new_br->rs1 = icmp_act->operand1_full; + new_br->rs2 = icmp_act->operand2_full; + block->exit_action = new_br; + } else if (icmp_act->op == "ne") { + auto new_br = std::make_shared(); + new_br->true_label_full = br_act->true_label_full; + new_br->false_label_full = br_act->false_label_full; + new_br->rs1 = icmp_act->operand1_full; + new_br->rs2 = icmp_act->operand2_full; + block->exit_action = new_br; + } else if (icmp_act->op == "slt") { + auto new_br = std::make_shared(); + new_br->true_label_full = br_act->true_label_full; + new_br->false_label_full = br_act->false_label_full; + new_br->rs1 = icmp_act->operand1_full; + new_br->rs2 = icmp_act->operand2_full; + block->exit_action = new_br; + } else if (icmp_act->op == "sle") { + auto new_br = std::make_shared(); + new_br->true_label_full = br_act->true_label_full; + new_br->false_label_full = br_act->false_label_full; + new_br->rs1 = icmp_act->operand2_full; + new_br->rs2 = icmp_act->operand1_full; + block->exit_action = new_br; + } else if (icmp_act->op == "sgt") { + auto new_br = std::make_shared(); + new_br->true_label_full = br_act->true_label_full; + new_br->false_label_full = br_act->false_label_full; + new_br->rs1 = icmp_act->operand2_full; + new_br->rs2 = icmp_act->operand1_full; + block->exit_action = new_br; + } else if (icmp_act->op == "sge") { + auto new_br = std::make_shared(); + new_br->true_label_full = br_act->true_label_full; + new_br->false_label_full = br_act->false_label_full; + new_br->rs1 = icmp_act->operand1_full; + new_br->rs2 = icmp_act->operand2_full; + block->exit_action = new_br; + } else { + throw std::runtime_error("Unknown icmp op"); + } + } + } +} +void GenerateBetterBRForFunction(std::shared_ptr func) { + if (func->init_block) { + GenerateBetterBRForBlock(func->init_block); + } + for (auto block : func->basic_blocks) { + GenerateBetterBRForBlock(block); + } +} + +std::shared_ptr GenerateBetterBR(std::shared_ptr src) { + for (auto func : src->function_defs) { + GenerateBetterBRForFunction(func); + } + return src; +} \ No newline at end of file diff --git a/src/opt/gen.cpp b/src/opt/gen.cpp index af8bb17..5d30add 100644 --- a/src/opt/gen.cpp +++ b/src/opt/gen.cpp @@ -71,10 +71,36 @@ void GenerateASM(std::shared_ptr act, std::vector &code const std::unordered_map &low_level_class_info) { std::vector available_tmp_regs = held_tmp_regs; if (auto br_act = std::dynamic_pointer_cast(act)) { - std::string cond_reg; - FetchValueToReg(br_act->cond, cond_reg, layout, code_lines, available_tmp_regs); - code_lines.push_back("bnez " + cond_reg + ", .entrylabel." + br_act->true_label_full); - code_lines.push_back("j .entrylabel." + br_act->false_label_full); + if (auto beq_act = std::dynamic_pointer_cast(br_act)) { + std::string rs1_reg, rs2_reg; + FetchValueToReg(beq_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(beq_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("beq " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + beq_act->true_label_full); + code_lines.push_back("j .entrylabel." + beq_act->false_label_full); + } else if (auto bne_act = std::dynamic_pointer_cast(br_act)) { + std::string rs1_reg, rs2_reg; + FetchValueToReg(bne_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(bne_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("bne " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + bne_act->true_label_full); + code_lines.push_back("j .entrylabel." + bne_act->false_label_full); + } else if (auto blt_act = std::dynamic_pointer_cast(br_act)) { + std::string rs1_reg, rs2_reg; + FetchValueToReg(blt_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(blt_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("blt " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + blt_act->true_label_full); + code_lines.push_back("j .entrylabel." + blt_act->false_label_full); + } else if (auto bge_act = std::dynamic_pointer_cast(br_act)) { + std::string rs1_reg, rs2_reg; + FetchValueToReg(bge_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs); + FetchValueToReg(bge_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("bge " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + bge_act->true_label_full); + code_lines.push_back("j .entrylabel." + bge_act->false_label_full); + } else { + std::string cond_reg; + FetchValueToReg(br_act->cond, cond_reg, layout, code_lines, available_tmp_regs); + code_lines.push_back("bnez " + cond_reg + ", .entrylabel." + br_act->true_label_full); + code_lines.push_back("j .entrylabel." + br_act->false_label_full); + } } else if (auto jmp_act = std::dynamic_pointer_cast(act)) { code_lines.push_back("j .entrylabel." + jmp_act->label_full); } else if (auto ret_act = std::dynamic_pointer_cast(act)) { diff --git a/src/opt/liveanalysis.cpp b/src/opt/liveanalysis.cpp index 508a8e1..f1c8852 100644 --- a/src/opt/liveanalysis.cpp +++ b/src/opt/liveanalysis.cpp @@ -113,13 +113,9 @@ void UseDefCollect(CFGType &cfg, [[maybe_unused]] std::vector &id_t std::vector &cur_act_use = node->action_use_vars[act.get()]; std::vector &cur_act_def = node->action_def_vars[act.get()]; if (auto br_act = std::dynamic_pointer_cast(act)) { - if (var_to_id.find(br_act->cond) != var_to_id.end()) { - cur_act_use.push_back(var_to_id[br_act->cond]); - } + throw std::runtime_error("BRAction should not appear in action list"); } else if (auto ret_act = std::dynamic_pointer_cast(act)) { - if (!std::holds_alternative(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { - cur_act_use.push_back(var_to_id[ret_act->value]); - } + throw std::runtime_error("RETAction should not appear in action list"); } else if (auto bin_act = std::dynamic_pointer_cast(act)) { if (var_to_id.find(bin_act->operand1_full) != var_to_id.end()) { cur_act_use.push_back(var_to_id[bin_act->operand1_full]); @@ -273,14 +269,46 @@ void UseDefCollect(CFGType &cfg, [[maybe_unused]] std::vector &id_t std::vector &cur_act_use = node->action_use_vars[act.get()]; std::vector &cur_act_def = node->action_def_vars[act.get()]; if (auto br_act = std::dynamic_pointer_cast(act)) { - if (var_to_id.find(br_act->cond) != var_to_id.end()) { - cur_act_use.push_back(var_to_id[br_act->cond]); + if (auto beq_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(beq_act->rs1) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[beq_act->rs1]); + } + if (beq_act->rs1 != beq_act->rs2 && var_to_id.find(beq_act->rs2) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[beq_act->rs2]); + } + } else if (auto bne_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(bne_act->rs1) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[bne_act->rs1]); + } + if (bne_act->rs1 != bne_act->rs2 && var_to_id.find(bne_act->rs2) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[bne_act->rs2]); + } + } else if (auto blt_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(blt_act->rs1) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[blt_act->rs1]); + } + if (blt_act->rs1 != blt_act->rs2 && var_to_id.find(blt_act->rs2) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[blt_act->rs2]); + } + } else if (auto bge_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(bge_act->rs1) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[bge_act->rs1]); + } + if (bge_act->rs1 != bge_act->rs2 && var_to_id.find(bge_act->rs2) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[bge_act->rs2]); + } + } else { + if (var_to_id.find(br_act->cond) != var_to_id.end()) { + cur_act_use.push_back(var_to_id[br_act->cond]); + } } } else if (auto ret_act = std::dynamic_pointer_cast(act)) { if (!std::holds_alternative(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { cur_act_use.push_back(var_to_id[ret_act->value]); } } + std::sort(cur_act_use.begin(), cur_act_use.end()); + std::sort(cur_act_def.begin(), cur_act_def.end()); if (!use_def_init) { use_def_init = true; cur_node_use = cur_act_use; @@ -377,7 +405,24 @@ void ActionLevelTracking(CFGType &cfg, CFGNodeType *node) { void LiveAnalysis(CFGType &cfg) { VarCollect(cfg, cfg.id_to_var, cfg.var_to_id); + // std::cerr << "all vars collected" << std::endl; + // for (auto var : cfg.id_to_var) { + // std::cerr << "\tid=" << cfg.var_to_id[var] << " var=" << var << std::endl; + // } UseDefCollect(cfg, cfg.id_to_var, cfg.var_to_id); + // for (auto node : cfg.nodes) { + // std::cerr << "block " << node->corresponding_block->label_full << std::endl; + // std::cerr << "\tblock_use_vars="; + // for (auto i : node->block_use_vars) { + // std::cerr << i << " "; + // } + // std::cerr << std::endl; + // std::cerr << "\tblock_def_vars="; + // for (auto i : node->block_def_vars) { + // std::cerr << i << " "; + // } + // std::cerr << std::endl; + // } BlockLevelTracking(cfg, cfg.id_to_var, cfg.var_to_id); for (auto node : cfg.nodes) { ActionLevelTracking(cfg, node.get()); diff --git a/src/opt/regalloc.cpp b/src/opt/regalloc.cpp index d6af305..b9b43d5 100644 --- a/src/opt/regalloc.cpp +++ b/src/opt/regalloc.cpp @@ -174,13 +174,9 @@ void TranslateColorResult(std::shared_ptr func, CFGType &cfg, C bool use_def_init = false; for (auto act : block->actions) { if (auto br_act = std::dynamic_pointer_cast(act)) { - if (var_to_id.find(br_act->cond) != var_to_id.end()) { - ApplyColoringResult(br_act->cond); - } + throw std::runtime_error("BRAction should not exist in the middle of a block"); } else if (auto ret_act = std::dynamic_pointer_cast(act)) { - if (!std::holds_alternative(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { - ApplyColoringResult(ret_act->value); - } + throw std::runtime_error("RETAction should not exist in the middle of a block"); } else if (auto bin_act = std::dynamic_pointer_cast(act)) { if (var_to_id.find(bin_act->operand1_full) != var_to_id.end()) { ApplyColoringResult(bin_act->operand1_full); @@ -281,8 +277,38 @@ void TranslateColorResult(std::shared_ptr func, CFGType &cfg, C { auto act = block->exit_action; if (auto br_act = std::dynamic_pointer_cast(act)) { - if (var_to_id.find(br_act->cond) != var_to_id.end()) { - ApplyColoringResult(br_act->cond); + if (auto beq_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(beq_act->rs1) != var_to_id.end()) { + ApplyColoringResult(beq_act->rs1); + } + if (var_to_id.find(beq_act->rs2) != var_to_id.end()) { + ApplyColoringResult(beq_act->rs2); + } + } else if (auto bne_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(bne_act->rs1) != var_to_id.end()) { + ApplyColoringResult(bne_act->rs1); + } + if (var_to_id.find(bne_act->rs2) != var_to_id.end()) { + ApplyColoringResult(bne_act->rs2); + } + } else if (auto blt_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(blt_act->rs1) != var_to_id.end()) { + ApplyColoringResult(blt_act->rs1); + } + if (var_to_id.find(blt_act->rs2) != var_to_id.end()) { + ApplyColoringResult(blt_act->rs2); + } + } else if (auto bge_act = std::dynamic_pointer_cast(br_act)) { + if (var_to_id.find(bge_act->rs1) != var_to_id.end()) { + ApplyColoringResult(bge_act->rs1); + } + if (var_to_id.find(bge_act->rs2) != var_to_id.end()) { + ApplyColoringResult(bge_act->rs2); + } + } else { + if (var_to_id.find(br_act->cond) != var_to_id.end()) { + ApplyColoringResult(br_act->cond); + } } } else if (auto ret_act = std::dynamic_pointer_cast(act)) { if (!std::holds_alternative(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) {