use better BR

This commit is contained in:
2024-10-29 15:25:15 +00:00
parent 5634022bd9
commit 1b75ca5e75
8 changed files with 243 additions and 22 deletions

4
include/opt/betterbr.h Normal file
View File

@ -0,0 +1,4 @@
#pragma once
#include "cfg.h"
std::shared_ptr<ModuleItem> GenerateBetterBR(std::shared_ptr<ModuleItem> src);

View File

@ -5,6 +5,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "IR/IR_basic.h" #include "IR/IR_basic.h"
#include "ast/expr_astnode.h"
using CFGNodeCollection = std::list<class CFGNodeType *>; using CFGNodeCollection = std::list<class CFGNodeType *>;
class CFGNodeType { class CFGNodeType {
public: public:
@ -175,4 +176,40 @@ const static std::vector<std::string> allocating_regs = {"x3", "x4", "x9", "x
inline bool VRegCheck(const std::string &s) { inline bool VRegCheck(const std::string &s) {
if (s[0] != '%' && s[0] != '$' && s[0] != '#') return false; if (s[0] != '%' && s[0] != '$' && s[0] != '#') return false;
return true; return true;
} }
class BEQAction : public BRAction {
public:
std::string rs1, rs2;
void RecursivePrint(std::ostream &os) const {
os << "beq " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %"
<< false_label_full << "\n";
}
};
class BNEAction : public BRAction {
public:
std::string rs1, rs2;
void RecursivePrint(std::ostream &os) const {
os << "bne " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %"
<< false_label_full << "\n";
}
};
class BLTAction : public BRAction {
public:
std::string rs1, rs2;
void RecursivePrint(std::ostream &os) const {
os << "blt " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %"
<< false_label_full << "\n";
}
};
class BGEAction : public BRAction {
public:
std::string rs1, rs2;
void RecursivePrint(std::ostream &os) const {
os << "bge " << rs1 << ' ' << rs2 << ' ' << cond << ", label %" << true_label_full << ", label %"
<< false_label_full << "\n";
}
};

View File

@ -3,4 +3,5 @@
#include "global_var_cache.h" #include "global_var_cache.h"
#include "mem2reg.h" #include "mem2reg.h"
#include "regalloc.h" #include "regalloc.h"
#include "dce.h" #include "dce.h"
#include "betterbr.h"

View File

@ -49,6 +49,7 @@ int main(int argc, char **argv) {
auto IR_with_out_allocas = Mem2Reg(IR); auto IR_with_out_allocas = Mem2Reg(IR);
// IR_with_out_allocas->RecursivePrint(std::cerr); // IR_with_out_allocas->RecursivePrint(std::cerr);
IR_with_out_allocas = GloabalVarCache(IR_with_out_allocas); IR_with_out_allocas = GloabalVarCache(IR_with_out_allocas);
IR_with_out_allocas = GenerateBetterBR(IR_with_out_allocas);
IR_with_out_allocas = DCE(IR_with_out_allocas); IR_with_out_allocas = DCE(IR_with_out_allocas);
// IR_with_out_allocas->RecursivePrint(std::cerr); // IR_with_out_allocas->RecursivePrint(std::cerr);
auto IR_with_out_phis = PhiEliminate(IR_with_out_allocas); auto IR_with_out_phis = PhiEliminate(IR_with_out_allocas);

81
src/opt/betterbr.cpp Normal file
View File

@ -0,0 +1,81 @@
#include "betterbr.h"
#include "IR/IR_basic.h"
#include "cfg.h"
void GenerateBetterBRForFunction(std::shared_ptr<FunctionDefItem> func);
void GenerateBetterBRForBlock(std::shared_ptr<BlockItem> block);
void GenerateBetterBRForBlock(std::shared_ptr<BlockItem> block) {
std::unordered_map<std::string, std::shared_ptr<ICMPAction>> icmp_map;
for (auto act : block->actions) {
if (auto icmp_act = std::dynamic_pointer_cast<ICMPAction>(act)) {
icmp_map[icmp_act->result_full] = icmp_act;
}
}
if (auto br_act = std::dynamic_pointer_cast<BRAction>(block->exit_action)) {
if (icmp_map.find(br_act->cond) != icmp_map.end()) {
auto icmp_act = icmp_map[br_act->cond];
if (icmp_act->op == "eq") {
auto new_br = std::make_shared<BEQAction>();
new_br->true_label_full = br_act->true_label_full;
new_br->false_label_full = br_act->false_label_full;
new_br->rs1 = icmp_act->operand1_full;
new_br->rs2 = icmp_act->operand2_full;
block->exit_action = new_br;
} else if (icmp_act->op == "ne") {
auto new_br = std::make_shared<BNEAction>();
new_br->true_label_full = br_act->true_label_full;
new_br->false_label_full = br_act->false_label_full;
new_br->rs1 = icmp_act->operand1_full;
new_br->rs2 = icmp_act->operand2_full;
block->exit_action = new_br;
} else if (icmp_act->op == "slt") {
auto new_br = std::make_shared<BLTAction>();
new_br->true_label_full = br_act->true_label_full;
new_br->false_label_full = br_act->false_label_full;
new_br->rs1 = icmp_act->operand1_full;
new_br->rs2 = icmp_act->operand2_full;
block->exit_action = new_br;
} else if (icmp_act->op == "sle") {
auto new_br = std::make_shared<BGEAction>();
new_br->true_label_full = br_act->true_label_full;
new_br->false_label_full = br_act->false_label_full;
new_br->rs1 = icmp_act->operand2_full;
new_br->rs2 = icmp_act->operand1_full;
block->exit_action = new_br;
} else if (icmp_act->op == "sgt") {
auto new_br = std::make_shared<BLTAction>();
new_br->true_label_full = br_act->true_label_full;
new_br->false_label_full = br_act->false_label_full;
new_br->rs1 = icmp_act->operand2_full;
new_br->rs2 = icmp_act->operand1_full;
block->exit_action = new_br;
} else if (icmp_act->op == "sge") {
auto new_br = std::make_shared<BGEAction>();
new_br->true_label_full = br_act->true_label_full;
new_br->false_label_full = br_act->false_label_full;
new_br->rs1 = icmp_act->operand1_full;
new_br->rs2 = icmp_act->operand2_full;
block->exit_action = new_br;
} else {
throw std::runtime_error("Unknown icmp op");
}
}
}
}
void GenerateBetterBRForFunction(std::shared_ptr<FunctionDefItem> func) {
if (func->init_block) {
GenerateBetterBRForBlock(func->init_block);
}
for (auto block : func->basic_blocks) {
GenerateBetterBRForBlock(block);
}
}
std::shared_ptr<ModuleItem> GenerateBetterBR(std::shared_ptr<ModuleItem> src) {
for (auto func : src->function_defs) {
GenerateBetterBRForFunction(func);
}
return src;
}

View File

@ -71,10 +71,36 @@ void GenerateASM(std::shared_ptr<ActionItem> act, std::vector<std::string> &code
const std::unordered_map<std::string, IRClassInfo> &low_level_class_info) { const std::unordered_map<std::string, IRClassInfo> &low_level_class_info) {
std::vector<std::string> available_tmp_regs = held_tmp_regs; std::vector<std::string> available_tmp_regs = held_tmp_regs;
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) { if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
std::string cond_reg; if (auto beq_act = std::dynamic_pointer_cast<BEQAction>(br_act)) {
FetchValueToReg(br_act->cond, cond_reg, layout, code_lines, available_tmp_regs); std::string rs1_reg, rs2_reg;
code_lines.push_back("bnez " + cond_reg + ", .entrylabel." + br_act->true_label_full); FetchValueToReg(beq_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("j .entrylabel." + br_act->false_label_full); FetchValueToReg(beq_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("beq " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + beq_act->true_label_full);
code_lines.push_back("j .entrylabel." + beq_act->false_label_full);
} else if (auto bne_act = std::dynamic_pointer_cast<BNEAction>(br_act)) {
std::string rs1_reg, rs2_reg;
FetchValueToReg(bne_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(bne_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("bne " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + bne_act->true_label_full);
code_lines.push_back("j .entrylabel." + bne_act->false_label_full);
} else if (auto blt_act = std::dynamic_pointer_cast<BLTAction>(br_act)) {
std::string rs1_reg, rs2_reg;
FetchValueToReg(blt_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(blt_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("blt " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + blt_act->true_label_full);
code_lines.push_back("j .entrylabel." + blt_act->false_label_full);
} else if (auto bge_act = std::dynamic_pointer_cast<BGEAction>(br_act)) {
std::string rs1_reg, rs2_reg;
FetchValueToReg(bge_act->rs1, rs1_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(bge_act->rs2, rs2_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("bge " + rs1_reg + ", " + rs2_reg + ", .entrylabel." + bge_act->true_label_full);
code_lines.push_back("j .entrylabel." + bge_act->false_label_full);
} else {
std::string cond_reg;
FetchValueToReg(br_act->cond, cond_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("bnez " + cond_reg + ", .entrylabel." + br_act->true_label_full);
code_lines.push_back("j .entrylabel." + br_act->false_label_full);
}
} else if (auto jmp_act = std::dynamic_pointer_cast<UNConditionJMPAction>(act)) { } else if (auto jmp_act = std::dynamic_pointer_cast<UNConditionJMPAction>(act)) {
code_lines.push_back("j .entrylabel." + jmp_act->label_full); code_lines.push_back("j .entrylabel." + jmp_act->label_full);
} else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) { } else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) {

View File

@ -113,13 +113,9 @@ void UseDefCollect(CFGType &cfg, [[maybe_unused]] std::vector<std::string> &id_t
std::vector<size_t> &cur_act_use = node->action_use_vars[act.get()]; std::vector<size_t> &cur_act_use = node->action_use_vars[act.get()];
std::vector<size_t> &cur_act_def = node->action_def_vars[act.get()]; std::vector<size_t> &cur_act_def = node->action_def_vars[act.get()];
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) { if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
if (var_to_id.find(br_act->cond) != var_to_id.end()) { throw std::runtime_error("BRAction should not appear in action list");
cur_act_use.push_back(var_to_id[br_act->cond]);
}
} else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) { } else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) {
if (!std::holds_alternative<LLVMVOIDType>(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { throw std::runtime_error("RETAction should not appear in action list");
cur_act_use.push_back(var_to_id[ret_act->value]);
}
} else if (auto bin_act = std::dynamic_pointer_cast<BinaryOperationAction>(act)) { } else if (auto bin_act = std::dynamic_pointer_cast<BinaryOperationAction>(act)) {
if (var_to_id.find(bin_act->operand1_full) != var_to_id.end()) { if (var_to_id.find(bin_act->operand1_full) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[bin_act->operand1_full]); cur_act_use.push_back(var_to_id[bin_act->operand1_full]);
@ -273,14 +269,46 @@ void UseDefCollect(CFGType &cfg, [[maybe_unused]] std::vector<std::string> &id_t
std::vector<size_t> &cur_act_use = node->action_use_vars[act.get()]; std::vector<size_t> &cur_act_use = node->action_use_vars[act.get()];
std::vector<size_t> &cur_act_def = node->action_def_vars[act.get()]; std::vector<size_t> &cur_act_def = node->action_def_vars[act.get()];
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) { if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
if (var_to_id.find(br_act->cond) != var_to_id.end()) { if (auto beq_act = std::dynamic_pointer_cast<BEQAction>(br_act)) {
cur_act_use.push_back(var_to_id[br_act->cond]); if (var_to_id.find(beq_act->rs1) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[beq_act->rs1]);
}
if (beq_act->rs1 != beq_act->rs2 && var_to_id.find(beq_act->rs2) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[beq_act->rs2]);
}
} else if (auto bne_act = std::dynamic_pointer_cast<BNEAction>(br_act)) {
if (var_to_id.find(bne_act->rs1) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[bne_act->rs1]);
}
if (bne_act->rs1 != bne_act->rs2 && var_to_id.find(bne_act->rs2) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[bne_act->rs2]);
}
} else if (auto blt_act = std::dynamic_pointer_cast<BLTAction>(br_act)) {
if (var_to_id.find(blt_act->rs1) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[blt_act->rs1]);
}
if (blt_act->rs1 != blt_act->rs2 && var_to_id.find(blt_act->rs2) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[blt_act->rs2]);
}
} else if (auto bge_act = std::dynamic_pointer_cast<BGEAction>(br_act)) {
if (var_to_id.find(bge_act->rs1) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[bge_act->rs1]);
}
if (bge_act->rs1 != bge_act->rs2 && var_to_id.find(bge_act->rs2) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[bge_act->rs2]);
}
} else {
if (var_to_id.find(br_act->cond) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[br_act->cond]);
}
} }
} else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) { } else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) {
if (!std::holds_alternative<LLVMVOIDType>(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { if (!std::holds_alternative<LLVMVOIDType>(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) {
cur_act_use.push_back(var_to_id[ret_act->value]); cur_act_use.push_back(var_to_id[ret_act->value]);
} }
} }
std::sort(cur_act_use.begin(), cur_act_use.end());
std::sort(cur_act_def.begin(), cur_act_def.end());
if (!use_def_init) { if (!use_def_init) {
use_def_init = true; use_def_init = true;
cur_node_use = cur_act_use; cur_node_use = cur_act_use;
@ -377,7 +405,24 @@ void ActionLevelTracking(CFGType &cfg, CFGNodeType *node) {
void LiveAnalysis(CFGType &cfg) { void LiveAnalysis(CFGType &cfg) {
VarCollect(cfg, cfg.id_to_var, cfg.var_to_id); VarCollect(cfg, cfg.id_to_var, cfg.var_to_id);
// std::cerr << "all vars collected" << std::endl;
// for (auto var : cfg.id_to_var) {
// std::cerr << "\tid=" << cfg.var_to_id[var] << " var=" << var << std::endl;
// }
UseDefCollect(cfg, cfg.id_to_var, cfg.var_to_id); UseDefCollect(cfg, cfg.id_to_var, cfg.var_to_id);
// for (auto node : cfg.nodes) {
// std::cerr << "block " << node->corresponding_block->label_full << std::endl;
// std::cerr << "\tblock_use_vars=";
// for (auto i : node->block_use_vars) {
// std::cerr << i << " ";
// }
// std::cerr << std::endl;
// std::cerr << "\tblock_def_vars=";
// for (auto i : node->block_def_vars) {
// std::cerr << i << " ";
// }
// std::cerr << std::endl;
// }
BlockLevelTracking(cfg, cfg.id_to_var, cfg.var_to_id); BlockLevelTracking(cfg, cfg.id_to_var, cfg.var_to_id);
for (auto node : cfg.nodes) { for (auto node : cfg.nodes) {
ActionLevelTracking(cfg, node.get()); ActionLevelTracking(cfg, node.get());

View File

@ -174,13 +174,9 @@ void TranslateColorResult(std::shared_ptr<FunctionDefItem> func, CFGType &cfg, C
bool use_def_init = false; bool use_def_init = false;
for (auto act : block->actions) { for (auto act : block->actions) {
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) { if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
if (var_to_id.find(br_act->cond) != var_to_id.end()) { throw std::runtime_error("BRAction should not exist in the middle of a block");
ApplyColoringResult(br_act->cond);
}
} else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) { } else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) {
if (!std::holds_alternative<LLVMVOIDType>(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { throw std::runtime_error("RETAction should not exist in the middle of a block");
ApplyColoringResult(ret_act->value);
}
} else if (auto bin_act = std::dynamic_pointer_cast<BinaryOperationAction>(act)) { } else if (auto bin_act = std::dynamic_pointer_cast<BinaryOperationAction>(act)) {
if (var_to_id.find(bin_act->operand1_full) != var_to_id.end()) { if (var_to_id.find(bin_act->operand1_full) != var_to_id.end()) {
ApplyColoringResult(bin_act->operand1_full); ApplyColoringResult(bin_act->operand1_full);
@ -281,8 +277,38 @@ void TranslateColorResult(std::shared_ptr<FunctionDefItem> func, CFGType &cfg, C
{ {
auto act = block->exit_action; auto act = block->exit_action;
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) { if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
if (var_to_id.find(br_act->cond) != var_to_id.end()) { if (auto beq_act = std::dynamic_pointer_cast<BEQAction>(br_act)) {
ApplyColoringResult(br_act->cond); if (var_to_id.find(beq_act->rs1) != var_to_id.end()) {
ApplyColoringResult(beq_act->rs1);
}
if (var_to_id.find(beq_act->rs2) != var_to_id.end()) {
ApplyColoringResult(beq_act->rs2);
}
} else if (auto bne_act = std::dynamic_pointer_cast<BNEAction>(br_act)) {
if (var_to_id.find(bne_act->rs1) != var_to_id.end()) {
ApplyColoringResult(bne_act->rs1);
}
if (var_to_id.find(bne_act->rs2) != var_to_id.end()) {
ApplyColoringResult(bne_act->rs2);
}
} else if (auto blt_act = std::dynamic_pointer_cast<BLTAction>(br_act)) {
if (var_to_id.find(blt_act->rs1) != var_to_id.end()) {
ApplyColoringResult(blt_act->rs1);
}
if (var_to_id.find(blt_act->rs2) != var_to_id.end()) {
ApplyColoringResult(blt_act->rs2);
}
} else if (auto bge_act = std::dynamic_pointer_cast<BGEAction>(br_act)) {
if (var_to_id.find(bge_act->rs1) != var_to_id.end()) {
ApplyColoringResult(bge_act->rs1);
}
if (var_to_id.find(bge_act->rs2) != var_to_id.end()) {
ApplyColoringResult(bge_act->rs2);
}
} else {
if (var_to_id.find(br_act->cond) != var_to_id.end()) {
ApplyColoringResult(br_act->cond);
}
} }
} else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) { } else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) {
if (!std::holds_alternative<LLVMVOIDType>(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) { if (!std::holds_alternative<LLVMVOIDType>(ret_act->type) && var_to_id.find(ret_act->value) != var_to_id.end()) {