remove stupid jumps

This commit is contained in:
2024-10-24 02:02:07 +00:00
parent c715ae3f60
commit 3349e934c5
2 changed files with 60 additions and 8 deletions

View File

@ -119,8 +119,6 @@ inline void StoreImmToReg(int imm, std::string reg, std::vector<std::string> &co
void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog); void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog);
extern std::string cur_block_label_for_phi;
std::string AllocateTmpReg(std::vector<std::string> &available_tmp_regs); std::string AllocateTmpReg(std::vector<std::string> &available_tmp_regs);
std::string ExtractRegName(const std::string &raw); std::string ExtractRegName(const std::string &raw);

View File

@ -1,6 +1,61 @@
#include "IR/IR_basic.h"
#include "opt/gen.h" #include "opt/gen.h"
namespace OptBackend { namespace OptBackend {
std::string cur_block_label_for_phi; void RemoveJumpOnlyBlock(FunctionDefItem *func);
std::string FindFinalDestination(std::string cur, std::unordered_map<std::string, std::string> &rerouted_to,
std::unordered_set<std::string> &visited) {
if (visited.find(cur) != visited.end()) return cur;
if (rerouted_to.find(cur) == rerouted_to.end()) throw std::runtime_error("reroute not found");
std::string next = rerouted_to[cur];
if (rerouted_to.find(next) == rerouted_to.end()) return next;
visited.insert(cur);
return rerouted_to[cur] = FindFinalDestination(next, rerouted_to, visited);
}
void RemoveJumpOnlyBlock(FunctionDefItem *func) {
std::unordered_map<std::string, std::vector<std::string *>> reverse_lookup;
if (auto block = func->init_block) {
if (auto br_act = std::dynamic_pointer_cast<BRAction>(block->exit_action)) {
reverse_lookup[br_act->true_label_full].push_back(&br_act->true_label_full);
reverse_lookup[br_act->false_label_full].push_back(&br_act->false_label_full);
} else if (auto uncond = std::dynamic_pointer_cast<UNConditionJMPAction>(block->exit_action)) {
reverse_lookup[uncond->label_full].push_back(&uncond->label_full);
}
}
for (auto block : func->basic_blocks) {
if (auto br_act = std::dynamic_pointer_cast<BRAction>(block->exit_action)) {
reverse_lookup[br_act->true_label_full].push_back(&br_act->true_label_full);
reverse_lookup[br_act->false_label_full].push_back(&br_act->false_label_full);
} else if (auto uncond = std::dynamic_pointer_cast<UNConditionJMPAction>(block->exit_action)) {
reverse_lookup[uncond->label_full].push_back(&uncond->label_full);
}
}
std::unordered_set<std::string> block_need_remove;
std::unordered_map<std::string, std::string> rerouted_to;
for (auto block : func->basic_blocks) {
if (block->actions.size() == 0 && std::dynamic_pointer_cast<UNConditionJMPAction>(block->exit_action)) {
auto uncond = std::dynamic_pointer_cast<UNConditionJMPAction>(block->exit_action);
// for (auto label : reverse_lookup[uncond->label_full]) {
// *label = uncond->label_full;
// }
rerouted_to[block->label_full] = uncond->label_full;
block_need_remove.insert(block->label_full);
}
}
for (auto [src, _] : rerouted_to) {
for (auto label : reverse_lookup[src]) {
std::unordered_set<std::string> tmp;
*label = FindFinalDestination(*label, rerouted_to, tmp);
}
}
auto tmp = func->basic_blocks;
func->basic_blocks.clear();
for (auto block : tmp) {
if (block_need_remove.find(block->label_full) == block_need_remove.end()) {
func->basic_blocks.push_back(block);
}
}
}
void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog) { void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog) {
auto riscv = std::make_shared<RISCVProgItem>(); auto riscv = std::make_shared<RISCVProgItem>();
@ -46,6 +101,7 @@ void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog) {
for (auto func_def : prog->function_defs) { for (auto func_def : prog->function_defs) {
std::cerr << "generating asm for function " << func_def->func_name_raw << std::endl; std::cerr << "generating asm for function " << func_def->func_name_raw << std::endl;
RemoveJumpOnlyBlock(func_def.get());
auto func_asm = std::make_shared<RISCVFuncItem>(); auto func_asm = std::make_shared<RISCVFuncItem>();
riscv->funcs.push_back(func_asm); riscv->funcs.push_back(func_asm);
func_asm->full_label = func_def->func_name_raw; func_asm->full_label = func_def->func_name_raw;
@ -71,15 +127,13 @@ void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog) {
OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw], OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw],
prog->low_level_class_info); prog->low_level_class_info);
} }
if (func_def->init_block->exit_action->corresponding_phi) {
OptBackend::cur_block_label_for_phi = func_def->init_block->label_full;
OptBackend::GenerateASM(func_def->init_block->exit_action->corresponding_phi, func_asm->code_lines,
func_layouts[func_def->func_name_raw], prog->low_level_class_info);
}
OptBackend::GenerateASM(func_def->init_block->exit_action, func_asm->code_lines, OptBackend::GenerateASM(func_def->init_block->exit_action, func_asm->code_lines,
func_layouts[func_def->func_name_raw], prog->low_level_class_info); func_layouts[func_def->func_name_raw], prog->low_level_class_info);
} }
for (auto block : func_def->basic_blocks) { for (auto block : func_def->basic_blocks) {
if (func_asm->code_lines.size() > 0 && func_asm->code_lines.back() == "j .entrylabel." + block->label_full) {
func_asm->code_lines.pop_back(); // remove redundant jump
}
func_asm->code_lines.push_back(".entrylabel." + block->label_full + ":"); func_asm->code_lines.push_back(".entrylabel." + block->label_full + ":");
for (auto act : block->actions) { for (auto act : block->actions) {
OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw], OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw],