diff --git a/src/opt/optbackend.cpp b/src/opt/optbackend.cpp index 3919cfc..a7e1be9 100644 --- a/src/opt/optbackend.cpp +++ b/src/opt/optbackend.cpp @@ -1,3 +1,5 @@ +#include +#include #include "IR/IR_basic.h" #include "opt/gen.h" namespace OptBackend { @@ -99,6 +101,7 @@ void GenerateOptASM(std::ostream &os, std::shared_ptr prog) { // } } + size_t tmp_label_counter = 0; for (auto func_def : prog->function_defs) { std::cerr << "generating asm for function " << func_def->func_name_raw << std::endl; RemoveJumpOnlyBlock(func_def.get()); @@ -106,42 +109,114 @@ void GenerateOptASM(std::ostream &os, std::shared_ptr prog) { riscv->funcs.push_back(func_asm); func_asm->full_label = func_def->func_name_raw; FuncLayout &layout = func_layouts[func_def->func_name_raw]; + std::vector code_line_tmp; if (layout.total_frame_size < 2048) { - func_asm->code_lines.push_back("addi sp, sp, -" + std::to_string(layout.total_frame_size)); - func_asm->code_lines.push_back("sw ra, " + std::to_string(layout.total_frame_size - 4) + "(sp)"); - func_asm->code_lines.push_back("sw s0, " + std::to_string(layout.total_frame_size - 8) + "(sp)"); - func_asm->code_lines.push_back("addi s0, sp, " + std::to_string(layout.total_frame_size)); - func_asm->code_lines.push_back("sw s0, " + std::to_string(layout.total_frame_size - 12) + "(sp)"); + code_line_tmp.push_back("addi sp, sp, -" + std::to_string(layout.total_frame_size)); + code_line_tmp.push_back("sw ra, " + std::to_string(layout.total_frame_size - 4) + "(sp)"); + code_line_tmp.push_back("sw s0, " + std::to_string(layout.total_frame_size - 8) + "(sp)"); + code_line_tmp.push_back("addi s0, sp, " + std::to_string(layout.total_frame_size)); + code_line_tmp.push_back("sw s0, " + std::to_string(layout.total_frame_size - 12) + "(sp)"); } else { - func_asm->code_lines.push_back("li x31, " + std::to_string(layout.total_frame_size)); - func_asm->code_lines.push_back("sub sp, sp, x31"); - func_asm->code_lines.push_back("add x31, x31, sp"); - func_asm->code_lines.push_back("sw ra, -4(x31)"); - func_asm->code_lines.push_back("sw s0, -8(x31)"); - func_asm->code_lines.push_back("sw x31, -12(x31)"); - func_asm->code_lines.push_back("mv s0, x31"); + code_line_tmp.push_back("li x31, " + std::to_string(layout.total_frame_size)); + code_line_tmp.push_back("sub sp, sp, x31"); + code_line_tmp.push_back("add x31, x31, sp"); + code_line_tmp.push_back("sw ra, -4(x31)"); + code_line_tmp.push_back("sw s0, -8(x31)"); + code_line_tmp.push_back("sw x31, -12(x31)"); + code_line_tmp.push_back("mv s0, x31"); } if (func_def->init_block) { - func_asm->code_lines.push_back(".entrylabel." + func_def->init_block->label_full + ":"); + code_line_tmp.push_back(".entrylabel." + func_def->init_block->label_full + ":"); for (auto act : func_def->init_block->actions) { - OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw], - prog->low_level_class_info); + OptBackend::GenerateASM(act, code_line_tmp, func_layouts[func_def->func_name_raw], prog->low_level_class_info); } - OptBackend::GenerateASM(func_def->init_block->exit_action, func_asm->code_lines, - func_layouts[func_def->func_name_raw], prog->low_level_class_info); + OptBackend::GenerateASM(func_def->init_block->exit_action, code_line_tmp, func_layouts[func_def->func_name_raw], + prog->low_level_class_info); } for (auto block : func_def->basic_blocks) { - if (func_asm->code_lines.size() > 0 && func_asm->code_lines.back() == "j .entrylabel." + block->label_full) { - func_asm->code_lines.pop_back(); // remove redundant jump + if (code_line_tmp.size() > 0 && code_line_tmp.back() == "j .entrylabel." + block->label_full) { + code_line_tmp.pop_back(); // remove redundant jump } - func_asm->code_lines.push_back(".entrylabel." + block->label_full + ":"); + code_line_tmp.push_back(".entrylabel." + block->label_full + ":"); for (auto act : block->actions) { - OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw], - prog->low_level_class_info); + OptBackend::GenerateASM(act, code_line_tmp, func_layouts[func_def->func_name_raw], prog->low_level_class_info); } - OptBackend::GenerateASM(block->exit_action, func_asm->code_lines, func_layouts[func_def->func_name_raw], + OptBackend::GenerateASM(block->exit_action, code_line_tmp, func_layouts[func_def->func_name_raw], prog->low_level_class_info); } + std::unordered_map label_to_line; + std::vector code_line_tmp2; + bool branch_replaced = false; + std::unordered_set is_branch; + is_branch.insert("beq"); + is_branch.insert("bne"); + is_branch.insert("blt"); + is_branch.insert("bge"); + is_branch.insert("bnez"); + is_branch.insert("beqz"); + do { + branch_replaced = false; + code_line_tmp2.clear(); + label_to_line.clear(); + for (size_t i = 0; i < code_line_tmp.size(); i++) { + if (code_line_tmp[i].substr(0, 12) == ".entrylabel.") { + label_to_line[code_line_tmp[i].substr(0, code_line_tmp[i].size() - 1)] = i; + } + } + for (size_t i = 0; i < code_line_tmp.size(); i++) { + std::stringstream ss(code_line_tmp[i]); + std::string tmp; + std::string label; + std::vector tokens; + while (ss >> tmp) { + label = tmp; + tokens.push_back(tmp); + } + if (is_branch.find(tokens[0]) != is_branch.end()) { + if (label_to_line.find(label) == label_to_line.end()) { + goto write; + throw std::runtime_error("label " + label + " not found"); + } + int64_t delta = label_to_line[label] - static_cast(i); + if (delta < -2000 || delta > 2000) { + std::string tmp_label_name = ".entrylabel.tmp_label." + std::to_string(tmp_label_counter++); + ss.clear(); + if (tokens[0] == "beq") { + tokens[0] = "bne"; + } else if (tokens[0] == "bne") { + tokens[0] = "beq"; + } else if (tokens[0] == "blt") { + tokens[0] = "bge"; + } else if (tokens[0] == "bge") { + tokens[0] = "blt"; + } else if (tokens[0] == "bnez") { + tokens[0] = "beqz"; + } else if (tokens[0] == "beqz") { + tokens[0] = "bnez"; + } else { + throw std::runtime_error("unknown branch type"); + } + for (size_t j = 0; j < tokens.size() - 1; j++) { + ss << tokens[j] << " "; + } + ss << tmp_label_name; + code_line_tmp2.push_back(ss.str()); + code_line_tmp2.push_back("j " + label); + code_line_tmp2.push_back(tmp_label_name + ":"); + branch_replaced = true; + } else { + code_line_tmp2.push_back(code_line_tmp[i]); + } + } else { + code_line_tmp2.push_back(code_line_tmp[i]); + } + } + code_line_tmp = code_line_tmp2; + } while (branch_replaced); + write: + for (const auto &code : code_line_tmp) { + func_asm->code_lines.push_back(code); + } } riscv->RecursivePrint(os);