can rounghly codegen

This commit is contained in:
2024-10-22 12:09:09 +00:00
parent 4c9d010ff7
commit 17816f2bf0
9 changed files with 597 additions and 27 deletions

View File

@ -25,6 +25,16 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
endif()
endif()
# 设置 Release 模式下开启优化
if(CMAKE_BUILD_TYPE STREQUAL "Release")
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
add_compile_options(-Wall -Wextra -Wpedantic -fsanitize=address,undefined -O2)
add_link_options(-fsanitize=address,undefined)
elseif(MSVC)
add_compile_options(/O2)
endif()
endif()
include(FetchContent)
FetchContent_Declare(
googletest

View File

@ -47,12 +47,10 @@ class TypeDefItem : public LLVMIRItemBase {
}
};
class GlobalVarDefItem : public LLVMIRItemBase {
friend class IRBuilder;
friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr<ModuleItem> prog);
public:
LLVMType type;
std::string name_raw;
public:
GlobalVarDefItem() = default;
void RecursivePrint(std::ostream &os) const {
std::string name_full = "@.var.global." + name_raw + ".addrkp";
@ -449,9 +447,7 @@ class FunctionDeclareItem : public LLVMIRItemBase {
}
};
class ConstStrItem : public LLVMIRItemBase {
friend std::shared_ptr<ModuleItem> BuildIR(std::shared_ptr<Program_ASTNode> src);
friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr<ModuleItem> prog);
friend class IRBuilder;
public:
std::string string_raw;
size_t const_str_id;
static std::string Escape(const std::string &src) {
@ -474,7 +470,6 @@ class ConstStrItem : public LLVMIRItemBase {
return ss.str();
}
public:
ConstStrItem() = default;
void RecursivePrint(std::ostream &os) const {
os << "@.str." << const_str_id << " = private unnamed_addr constant [" << string_raw.size() + 1 << " x i8] c\""

View File

@ -54,8 +54,6 @@ class RISCVGlobalVarItem : public RISCVAsmItemBase {
}
};
class RISCVFuncItem : public RISCVAsmItemBase {
friend void ::GenerateNaiveASM(std::ostream &os, std::shared_ptr<ModuleItem> prog);
public:
std::string full_label;
std::vector<std::string> code_lines;
@ -93,20 +91,12 @@ class RISCVProgItem : public RISCVAsmItemBase {
}
};
class FuncLayout {
friend void ::GenerateNaiveASM(std::ostream &os, std::shared_ptr<ModuleItem> prog);
friend void GenerateReadAccess(std::string val, size_t bytes, std::string output_reg, FuncLayout &layout,
std::vector<std::string> &code_lines);
friend void GenerateWriteAccess(std::string val, size_t bytes, std::string data_reg, FuncLayout &layout,
std::vector<std::string> &code_lines);
friend void NaiveBackend::GenerateASM(std::shared_ptr<ActionItem> act, std::vector<std::string> &code_lines,
FuncLayout &layout,
const std::unordered_map<std::string, IRClassInfo> &low_level_class_info,
bool process_phi);
public:
std::unordered_map<std::string, size_t> local_items;
std::unordered_map<std::string, size_t> arg_offset;
size_t cur_pos;
size_t total_frame_size; // should align to 16 bytes
public:
FuncLayout() : cur_pos(8), total_frame_size(16) {}
void AllocateItem(const std::string &name, size_t sz, size_t num = 1) {
if (local_items.find(name) != local_items.end()) throw std::runtime_error("Local item already exists");

View File

@ -173,6 +173,6 @@ const static std::vector<std::string> allocating_regs = {"x3", "x4", "x9", "x
"x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17"};
inline bool VRegCheck(const std::string &s) {
if (s[0] != '%') return false;
if (s[0] != '%' && s[0] != '$' && s[0] != '#') return false;
return true;
}

461
include/opt/gen.h Normal file
View File

@ -0,0 +1,461 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "IR/IR_basic.h"
#include "cfg.h"
#include "liveanalysis.h"
namespace OptBackend {
class RISCVAsmItemBase {
public:
RISCVAsmItemBase() = default;
virtual ~RISCVAsmItemBase() = default;
virtual void RecursivePrint(std::ostream &os) const = 0;
};
class RISCVConstStrItem : public RISCVAsmItemBase {
public:
std::string full_label;
std::string content;
RISCVConstStrItem() = default;
~RISCVConstStrItem() = default;
void RecursivePrint(std::ostream &os) const override {
os << full_label << ":\n";
os << " .asciz \"";
for (auto c : content) {
if (c == '\n') {
os << "\\n";
} else if (c == '\t') {
os << "\\t";
} else if (c == '\"') {
os << "\\\"";
} else if (c == '\\') {
os << "\\\\";
} else {
os << c;
}
}
os << "\"\n";
}
};
class RISCVGlobalVarItem : public RISCVAsmItemBase {
public:
std::string full_label;
RISCVGlobalVarItem() = default;
~RISCVGlobalVarItem() = default;
void RecursivePrint(std::ostream &os) const override {
os << ".globl " << full_label << "\n";
os << ".p2align 2, 0x0\n";
os << full_label << ":\n";
os << " .word 0\n";
}
};
class RISCVFuncItem : public RISCVAsmItemBase {
public:
std::string full_label;
std::vector<std::string> code_lines;
RISCVFuncItem() = default;
~RISCVFuncItem() = default;
void RecursivePrint(std::ostream &os) const override {
os << ".globl " << full_label << "\n";
os << ".p2align 2, 0x0\n";
os << full_label << ":\n";
for (auto &line : code_lines) {
os << line << "\n";
}
}
};
class RISCVProgItem : public RISCVAsmItemBase {
public:
std::vector<std::shared_ptr<RISCVConstStrItem>> const_strs;
std::vector<std::shared_ptr<RISCVGlobalVarItem>> global_vars;
std::vector<std::shared_ptr<RISCVFuncItem>> funcs;
RISCVProgItem() = default;
~RISCVProgItem() = default;
void RecursivePrint(std::ostream &os) const override {
os << ".section .rodata\n";
for (auto &item : const_strs) {
item->RecursivePrint(os);
}
os << ".section .sbss\n";
for (auto &item : global_vars) {
item->RecursivePrint(os);
}
os << ".section .text\n";
for (auto &item : funcs) {
item->RecursivePrint(os);
}
}
};
class FuncLayout {
public:
std::unordered_map<std::string, size_t> local_items;
std::unordered_map<std::string, size_t> arg_offset;
size_t cur_pos;
size_t total_frame_size; // should align to 16 bytes
FuncLayout() : cur_pos(12), total_frame_size(16) {}
void AllocateItem(const std::string &name, size_t sz, size_t num = 1) {
if (local_items.find(name) != local_items.end()) throw std::runtime_error("Local item already exists");
if (cur_pos % sz != 0) {
cur_pos += sz - cur_pos % sz;
}
cur_pos += sz * num;
local_items[name] = cur_pos;
total_frame_size = ((cur_pos + 15) / 16) * 16;
std::cerr << "allocating stack memory for " << name << " at " << cur_pos << std::endl;
}
size_t QueryOffeset(const std::string &name) {
if (local_items.find(name) == local_items.end()) throw std::runtime_error("Local item not found");
return local_items[name];
}
size_t QueryFrameSize() const { return total_frame_size; }
};
inline void StoreImmToReg(int imm, std::string reg, std::vector<std::string> &code_lines) {
code_lines.push_back("li " + reg + ", " + std::to_string(imm));
}
void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog);
extern std::string cur_block_label_for_phi;
inline std::string AllocateTmpReg(std::vector<std::string> &available_tmp_regs) {
if (available_tmp_regs.size() == 0) throw std::runtime_error("No available tmp register");
std::string res;
res = available_tmp_regs.back();
available_tmp_regs.pop_back();
return res;
}
inline std::string ExtractRegName(const std::string &raw) {
if (raw[0] != '$') throw std::runtime_error("Not a register");
size_t reg_id = std::stoull(raw.substr(5));
return "x" + std::to_string(reg_id);
}
inline void FetchValueToReg(std::string original_val, std::string &out_reg, FuncLayout &layout,
std::vector<std::string> &code_lines, std::vector<std::string> &available_tmp_regs) {
if (original_val[0] == '$') {
// already assigned to a register, such as `$reg.10`
out_reg = ExtractRegName(original_val);
} else if (original_val[0] == '#') {
// spilled variable, we need find it in the layout
size_t offset = layout.QueryOffeset(original_val);
out_reg = AllocateTmpReg(available_tmp_regs);
if (offset < 2048) {
code_lines.push_back("lw " + out_reg + ", -" + std::to_string(offset) + "(s0)");
} else {
code_lines.push_back("li " + out_reg + ", -" + std::to_string(offset));
code_lines.push_back("add " + out_reg + ", s0, " + out_reg);
code_lines.push_back("lw " + out_reg + ", 0(" + out_reg + ")");
}
} else if (original_val[0] == '@') {
// global variable address keeper
out_reg = AllocateTmpReg(available_tmp_regs);
std::string label_in_asm = original_val.substr(1, original_val.size() - 1);
code_lines.push_back("la " + out_reg + ", " + label_in_asm);
} else if (original_val[0] == '-' || std::isdigit(original_val[0])) {
// immediate value
out_reg = AllocateTmpReg(available_tmp_regs);
StoreImmToReg(std::stoi(original_val), out_reg, code_lines);
} else {
throw std::runtime_error("Unknown value type");
}
}
inline void WriteToSpilledVar(std::string val, std::string reg, FuncLayout &layout,
std::vector<std::string> &code_lines, std::vector<std::string> &available_tmp_regs) {
if (val[0] != '#') throw std::runtime_error("Not a spilled variable");
size_t offset = layout.QueryOffeset(val);
if (offset < 2048) {
code_lines.push_back("sw " + reg + ", -" + std::to_string(offset) + "(s0)");
} else {
std::string tmp_reg = AllocateTmpReg(available_tmp_regs);
code_lines.push_back("li " + tmp_reg + ", -" + std::to_string(offset));
code_lines.push_back("add " + tmp_reg + ", s0, " + tmp_reg);
code_lines.push_back("sw " + reg + ", 0(" + tmp_reg + ")");
}
}
inline size_t CalcSize(const LLVMType &tp) {
if (std::holds_alternative<LLVMIRIntType>(tp)) {
auto &int_tp = std::get<LLVMIRIntType>(tp);
return (int_tp.bits + 7) / 8;
} else if (std::holds_alternative<LLVMIRPTRType>(tp)) {
return 4;
} else if (std::holds_alternative<LLVMVOIDType>(tp)) {
throw std::runtime_error("Cannot calculate size of void type");
return 0;
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(tp)) {
throw std::runtime_error("Cannot calculate size of class type");
} else
throw std::runtime_error("Unknown type");
}
inline void GenerateASM(std::shared_ptr<ActionItem> act, std::vector<std::string> &code_lines, FuncLayout &layout,
const std::unordered_map<std::string, IRClassInfo> &low_level_class_info) {
std::vector<std::string> available_tmp_regs = held_tmp_regs;
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
std::string cond_reg;
FetchValueToReg(br_act->cond, cond_reg, layout, code_lines, available_tmp_regs);
code_lines.push_back("bnez " + cond_reg + ", .entrylabel." + br_act->true_label_full);
code_lines.push_back("j .entrylabel." + br_act->false_label_full);
} else if (auto jmp_act = std::dynamic_pointer_cast<UNConditionJMPAction>(act)) {
code_lines.push_back("j .entrylabel." + jmp_act->label_full);
} else if (auto ret_act = std::dynamic_pointer_cast<RETAction>(act)) {
code_lines.push_back("lw ra, -4(s0)");
code_lines.push_back("lw sp, -12(s0)");
code_lines.push_back("lw s0, -8(s0)");
code_lines.push_back("ret");
} else if (auto binary_act = std::dynamic_pointer_cast<BinaryOperationAction>(act)) {
// size_t sz = CalcSize(binary_act->type);
// IRVar2RISCVReg(binary_act->operand1_full, sz, "t0", layout, code_lines);
// IRVar2RISCVReg(binary_act->operand2_full, sz, "t1", layout, code_lines);
std::string operand1_reg, operand2_reg;
FetchValueToReg(binary_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(binary_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs);
std::string res_reg;
bool need_extra_store = false;
if (binary_act->result_full[0] == '$') {
res_reg = ExtractRegName(binary_act->result_full);
} else if (binary_act->result_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else {
throw std::runtime_error("Unknown result type");
}
if (binary_act->op == "add") {
code_lines.push_back("add " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "sub") {
code_lines.push_back("sub " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "mul") {
code_lines.push_back("mul " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "sdiv") {
code_lines.push_back("div " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "srem") {
code_lines.push_back("rem " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "and") {
code_lines.push_back("and " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "or") {
code_lines.push_back("or " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "xor") {
code_lines.push_back("xor " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "shl") {
code_lines.push_back("sll " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (binary_act->op == "ashr") {
code_lines.push_back("sra " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else {
throw std::runtime_error("Unknown binary operation");
}
if (need_extra_store) {
WriteToSpilledVar(binary_act->result_full, res_reg, layout, code_lines, available_tmp_regs);
}
} else if (auto alloca_act = std::dynamic_pointer_cast<AllocaAction>(act)) {
std::string res_reg;
bool need_extra_store = false;
if (alloca_act->name_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else if (alloca_act->name_full[0] == '$') {
res_reg = ExtractRegName(alloca_act->name_full);
} else {
throw std::runtime_error("Unknown result type");
}
size_t sz = CalcSize(alloca_act->type) * alloca_act->num;
sz = (sz + 15) / 16 * 16;
code_lines.push_back("addi sp, sp, -" + std::to_string(sz));
if (!need_extra_store) {
code_lines.push_back("mv " + res_reg + ", sp");
} else {
WriteToSpilledVar(alloca_act->name_full, "sp", layout, code_lines, available_tmp_regs);
}
} else if (auto load_act = std::dynamic_pointer_cast<LoadAction>(act)) {
std::string res_reg;
bool need_extra_store = false;
if (load_act->result_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else if (load_act->result_full[0] == '$') {
res_reg = ExtractRegName(load_act->result_full);
} else {
throw std::runtime_error("Unknown result type");
}
std::string ptr_reg;
FetchValueToReg(load_act->ptr_full, ptr_reg, layout, code_lines, available_tmp_regs);
if (CalcSize(load_act->ty) == 4) {
code_lines.push_back("lw " + res_reg + ", 0(" + ptr_reg + ")");
} else if (CalcSize(load_act->ty) == 1) {
code_lines.push_back("lb " + res_reg + ", 0(" + ptr_reg + ")");
} else {
throw std::runtime_error("Unknown bytes");
}
if (need_extra_store) {
WriteToSpilledVar(load_act->result_full, res_reg, layout, code_lines, available_tmp_regs);
}
} else if (auto store_act = std::dynamic_pointer_cast<StoreAction>(act)) {
std::string val_reg;
std::string ptr_reg;
FetchValueToReg(store_act->value_full, val_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(store_act->ptr_full, ptr_reg, layout, code_lines, available_tmp_regs);
if (CalcSize(store_act->ty) == 4) {
code_lines.push_back("sw " + val_reg + ", 0(" + ptr_reg + ")");
} else if (CalcSize(store_act->ty) == 1) {
code_lines.push_back("sb " + val_reg + ", 0(" + ptr_reg + ")");
} else {
throw std::runtime_error("Unknown bytes");
}
} else if (auto get_element_act = std::dynamic_pointer_cast<GetElementPtrAction>(act)) {
if (get_element_act->indices.size() == 1) {
// array access
std::string res_reg;
bool need_extra_store = false;
if (get_element_act->result_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else if (get_element_act->result_full[0] == '$') {
res_reg = ExtractRegName(get_element_act->result_full);
} else {
throw std::runtime_error("Unknown result type");
}
std::string ptr_reg;
std::string idx_reg;
FetchValueToReg(get_element_act->ptr_full, ptr_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(get_element_act->indices[0], idx_reg, layout, code_lines, available_tmp_regs);
std::string tmp_reg = AllocateTmpReg(available_tmp_regs);
size_t element_sz = CalcSize(get_element_act->ty);
code_lines.push_back("slli " + tmp_reg + ", " + idx_reg + ", " + std::to_string(std::countr_zero(element_sz)));
code_lines.push_back("add " + res_reg + ", " + ptr_reg + ", " + tmp_reg);
if (need_extra_store) {
WriteToSpilledVar(get_element_act->result_full, res_reg, layout, code_lines, available_tmp_regs);
}
} else if (get_element_act->indices.size() == 2) {
// // struct access
if (get_element_act->indices[0] != "0") {
throw std::runtime_error("struct access with non-zero offset is not supported");
}
size_t element_idx = std::stoull(get_element_act->indices[1]);
auto class_ty = std::get<LLVMIRCLASSTYPE>(get_element_act->ty);
const IRClassInfo &class_info = low_level_class_info.at(class_ty.class_name_full);
size_t offset = class_info.member_var_pos_after_align[element_idx];
std::string res_reg;
bool need_extra_store = false;
if (get_element_act->result_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else if (get_element_act->result_full[0] == '$') {
res_reg = ExtractRegName(get_element_act->result_full);
} else {
throw std::runtime_error("Unknown result type");
}
std::string base_ptr_reg;
FetchValueToReg(get_element_act->ptr_full, base_ptr_reg, layout, code_lines, available_tmp_regs);
if (offset < 2048) {
code_lines.push_back("addi " + res_reg + ", " + base_ptr_reg + ", " + std::to_string(offset));
} else {
std::string tmp_reg = AllocateTmpReg(available_tmp_regs);
code_lines.push_back("li " + tmp_reg + ", " + std::to_string(offset));
code_lines.push_back("add " + res_reg + ", " + base_ptr_reg + ", " + tmp_reg);
}
if (need_extra_store) {
WriteToSpilledVar(get_element_act->result_full, res_reg, layout, code_lines, available_tmp_regs);
}
} else {
throw std::runtime_error("Unknown getelementptr indices size");
}
} else if (auto icmp_act = std::dynamic_pointer_cast<ICMPAction>(act)) {
std::string operand1_reg, operand2_reg;
FetchValueToReg(icmp_act->operand1_full, operand1_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(icmp_act->operand2_full, operand2_reg, layout, code_lines, available_tmp_regs);
std::string res_reg;
bool need_extra_store = false;
if (icmp_act->result_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else if (icmp_act->result_full[0] == '$') {
res_reg = ExtractRegName(icmp_act->result_full);
} else {
throw std::runtime_error("Unknown result type");
}
std::string tmp_reg = AllocateTmpReg(available_tmp_regs);
if (icmp_act->op == "eq") {
// code_lines.push_back("xor t2, t0, t1");
// code_lines.push_back("seqz t2, t2");
code_lines.push_back("xor " + tmp_reg + ", " + operand1_reg + ", " + operand2_reg);
code_lines.push_back("seqz " + res_reg + ", " + tmp_reg);
} else if (icmp_act->op == "ne") {
// code_lines.push_back("xor t2, t0, t1");
// code_lines.push_back("snez t2, t2");
code_lines.push_back("xor " + tmp_reg + ", " + operand1_reg + ", " + operand2_reg);
code_lines.push_back("snez " + res_reg + ", " + tmp_reg);
} else if (icmp_act->op == "slt") {
// code_lines.push_back("slt t2, t0, t1");
code_lines.push_back("slt " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
} else if (icmp_act->op == "sle") {
// code_lines.push_back("slt t2, t1, t0");
// code_lines.push_back("xori t2, t2, 1");
code_lines.push_back("slt " + res_reg + ", " + operand2_reg + ", " + operand1_reg);
code_lines.push_back("xori " + res_reg + ", " + res_reg + ", 1");
} else if (icmp_act->op == "sgt") {
// code_lines.push_back("slt t2, t1, t0");
code_lines.push_back("slt " + res_reg + ", " + operand2_reg + ", " + operand1_reg);
} else if (icmp_act->op == "sge") {
// code_lines.push_back("slt t2, t0, t1");
// code_lines.push_back("xori t2, t2, 1");
code_lines.push_back("slt " + res_reg + ", " + operand1_reg + ", " + operand2_reg);
code_lines.push_back("xori " + res_reg + ", " + res_reg + ", 1");
} else {
throw std::runtime_error("Unknown icmp operation");
}
if (need_extra_store) {
WriteToSpilledVar(icmp_act->result_full, res_reg, layout, code_lines, available_tmp_regs);
}
} else if (auto call_act = std::dynamic_pointer_cast<CallItem>(act)) {
// no need to to further process, as callling convention is handled in reg alloc
code_lines.push_back("call " + call_act->func_name_raw);
} else if (auto phi_act = std::dynamic_pointer_cast<PhiItem>(act)) {
throw std::runtime_error("Phi should not be in the layout");
} else if (auto select_act = std::dynamic_pointer_cast<SelectItem>(act)) {
std::string res_reg;
bool need_extra_store = false;
if (select_act->result_full[0] == '#') {
need_extra_store = true;
res_reg = AllocateTmpReg(available_tmp_regs);
} else if (select_act->result_full[0] == '$') {
res_reg = ExtractRegName(select_act->result_full);
} else {
throw std::runtime_error("Unknown result type");
}
std::string operand1_reg, operand2_reg, cond_reg;
FetchValueToReg(select_act->cond_full, cond_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(select_act->true_val_full, operand1_reg, layout, code_lines, available_tmp_regs);
FetchValueToReg(select_act->false_val_full, operand2_reg, layout, code_lines, available_tmp_regs);
std::string tmp1_reg = AllocateTmpReg(available_tmp_regs);
std::string tmp2_reg = AllocateTmpReg(available_tmp_regs);
code_lines.push_back("slli " + tmp1_reg + ", " + cond_reg + ", 31");
code_lines.push_back("srai " + tmp1_reg + ", " + tmp1_reg + ", 31");
code_lines.push_back("xor " + tmp2_reg + ", " + operand1_reg + ", " + operand2_reg);
code_lines.push_back("and " + tmp2_reg + ", " + tmp2_reg + ", " + tmp1_reg);
code_lines.push_back("xor " + res_reg + ", " + tmp2_reg + ", " + operand2_reg);
if (need_extra_store) {
WriteToSpilledVar(select_act->result_full, res_reg, layout, code_lines, available_tmp_regs);
}
} else if (auto load_spilled_args_act = std::dynamic_pointer_cast<opt::LoadSpilledArgs>(act)) {
throw std::runtime_error("Not implemented");
} else if (auto store_spilled_args_act = std::dynamic_pointer_cast<opt::StoreSpilledArgs>(act)) {
throw std::runtime_error("Not implemented");
} else if (auto move_act = std::dynamic_pointer_cast<opt::MoveInstruct>(act)) {
std::string src_reg;
FetchValueToReg(move_act->src_full, src_reg, layout, code_lines, available_tmp_regs);
if (move_act->dest_full[0] == '$') {
std::string dest_reg = ExtractRegName(move_act->dest_full);
code_lines.push_back("mv " + dest_reg + ", " + src_reg);
} else if (move_act->dest_full[0] == '#') {
WriteToSpilledVar(move_act->dest_full, src_reg, layout, code_lines, available_tmp_regs);
} else {
throw std::runtime_error("Unknown dest type");
}
} else {
throw std::runtime_error("Unknown action type");
}
}
} // namespace OptBackend

View File

@ -1,3 +1,4 @@
#include "cfg.h"
#include "gen.h"
#include "mem2reg.h"
#include "regalloc.h"

View File

@ -46,10 +46,11 @@ int main(int argc, char **argv) {
GenerateNaiveASM(fout, IR);
} else {
auto IR_with_out_allocas = Mem2Reg(IR);
IR_with_out_allocas->RecursivePrint(fout);
// IR_with_out_allocas->RecursivePrint(fout);
auto IR_with_out_phis = PhiEliminate(IR_with_out_allocas);
// IR_with_out_phis->RecursivePrint(fout);
auto alloced_code = RegAlloc(IR_with_out_phis);
OptBackend::GenerateOptASM(fout, alloced_code);
}
} catch (const SemanticError &err) {
std::cout << err.what() << std::endl;

95
src/opt/optbackend.cpp Normal file
View File

@ -0,0 +1,95 @@
#include "opt/gen.h"
namespace OptBackend {
std::string cur_block_label_for_phi;
void GenerateOptASM(std::ostream &os, std::shared_ptr<ModuleItem> prog) {
auto riscv = std::make_shared<RISCVProgItem>();
for (auto conststr : prog->const_strs) {
auto asm_item = std::make_shared<RISCVConstStrItem>();
riscv->const_strs.push_back(asm_item);
asm_item->content = conststr->string_raw;
asm_item->full_label = ".str." + std::to_string(conststr->const_str_id);
}
for (auto global_var : prog->global_var_defs) {
auto asm_item = std::make_shared<RISCVGlobalVarItem>();
riscv->global_vars.push_back(asm_item);
asm_item->full_label = ".var.global." + global_var->name_raw + ".addrkp";
}
std::unordered_map<std::string, FuncLayout> func_layouts;
for (auto func_def : prog->function_defs) {
// if (func_def->init_block) {
// for (auto act : func_def->init_block->actions) {
// ScanForVar(func_layouts[func_def->func_name_raw], act, prog->low_level_class_info);
// }
// }
// for (auto block : func_def->basic_blocks) {
// for (auto act : block->actions) {
// ScanForVar(func_layouts[func_def->func_name_raw], act, prog->low_level_class_info);
// }
// }
FuncLayout &layout = func_layouts[func_def->func_name_raw];
// for (size_t i = 0; i < func_def->args_full_name.size(); i++) {
// layout.arg_offset[func_def->args_full_name[i]] = i;
// }
for (size_t i = 0; i < func_def->spilled_vars; i++) {
layout.AllocateItem("#" + std::to_string(i), 4, 1);
}
// debug:
// std::cerr << "layout info of function " << func_def->func_name_raw << std::endl;
// std::cerr << "\tcur_pos=" << layout.cur_pos << std::endl;
// std::cerr << "\ttotal_frame_size=" << layout.total_frame_size << std::endl;
// for (const auto &item : layout.local_items) {
// std::cerr << "\t" << item.first << " " << item.second << std::endl;
// }
}
for (auto func_def : prog->function_defs) {
std::cerr << "generating asm for function " << func_def->func_name_raw << std::endl;
auto func_asm = std::make_shared<RISCVFuncItem>();
riscv->funcs.push_back(func_asm);
func_asm->full_label = func_def->func_name_raw;
FuncLayout &layout = func_layouts[func_def->func_name_raw];
if (layout.total_frame_size < 2048) {
func_asm->code_lines.push_back("addi sp, sp, -" + std::to_string(layout.total_frame_size));
func_asm->code_lines.push_back("sw ra, " + std::to_string(layout.total_frame_size - 4) + "(sp)");
func_asm->code_lines.push_back("sw s0, " + std::to_string(layout.total_frame_size - 8) + "(sp)");
func_asm->code_lines.push_back("addi s0, sp, " + std::to_string(layout.total_frame_size));
func_asm->code_lines.push_back("sw s0, " + std::to_string(layout.total_frame_size - 12) + "(sp)");
} else {
func_asm->code_lines.push_back("li x31, " + std::to_string(layout.total_frame_size));
func_asm->code_lines.push_back("sub sp, sp, x31");
func_asm->code_lines.push_back("add x31, x31, sp");
func_asm->code_lines.push_back("sw ra, -4(x31)");
func_asm->code_lines.push_back("sw s0, -8(x31)");
func_asm->code_lines.push_back("sw x31, -12(x31)");
func_asm->code_lines.push_back("mv s0, t0");
}
if (func_def->init_block) {
func_asm->code_lines.push_back(".entrylabel." + func_def->init_block->label_full + ":");
for (auto act : func_def->init_block->actions) {
OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw],
prog->low_level_class_info);
}
if (func_def->init_block->exit_action->corresponding_phi) {
OptBackend::cur_block_label_for_phi = func_def->init_block->label_full;
OptBackend::GenerateASM(func_def->init_block->exit_action->corresponding_phi, func_asm->code_lines,
func_layouts[func_def->func_name_raw], prog->low_level_class_info);
}
OptBackend::GenerateASM(func_def->init_block->exit_action, func_asm->code_lines,
func_layouts[func_def->func_name_raw], prog->low_level_class_info);
}
for (auto block : func_def->basic_blocks) {
func_asm->code_lines.push_back(".entrylabel." + block->label_full + ":");
for (auto act : block->actions) {
OptBackend::GenerateASM(act, func_asm->code_lines, func_layouts[func_def->func_name_raw],
prog->low_level_class_info);
}
OptBackend::GenerateASM(block->exit_action, func_asm->code_lines, func_layouts[func_def->func_name_raw],
prog->low_level_class_info);
}
}
riscv->RecursivePrint(os);
}
} // namespace OptBackend

View File

@ -154,12 +154,12 @@ void TranslateColorResult(std::shared_ptr<FunctionDefItem> func, CFGType &cfg, C
var = "$reg." + std::to_string(confnode->color);
}
};
func->spilled_vars = 0;
for (auto node : cfg.nodes) {
auto block = node->corresponding_block;
std::vector<size_t> cur_node_use;
std::vector<size_t> cur_node_def;
bool use_def_init = false;
func->spilled_vars = 0;
for (auto act : block->actions) {
if (auto br_act = std::dynamic_pointer_cast<BRAction>(act)) {
if (var_to_id.find(br_act->cond) != var_to_id.end()) {
@ -284,11 +284,12 @@ void TranslateColorResult(std::shared_ptr<FunctionDefItem> func, CFGType &cfg, C
if (move_act->src_full == move_act->dest_full) {
need_remove = true;
}
} else if (auto force_def_act = std::dynamic_pointer_cast<ForceDef>(*act_it)) {
need_remove = true;
} else if (auto force_use_act = std::dynamic_pointer_cast<ForceUse>(*act_it)) {
need_remove = true;
}
// else if (auto force_def_act = std::dynamic_pointer_cast<ForceDef>(*act_it)) {
// need_remove = true;
// } else if (auto force_use_act = std::dynamic_pointer_cast<ForceUse>(*act_it)) {
// need_remove = true;
// }
if (need_remove) {
auto it_next = act_it;
++it_next;
@ -327,6 +328,22 @@ void PairMoveEliminate(std::shared_ptr<FunctionDefItem> func, CFGType &cfg, Conf
}
}
}
void RemoveCallingConventionKeeper(std::shared_ptr<FunctionDefItem> func, CFGType &cfg, ConfGraph &confgraph) {
for (auto node : cfg.nodes) {
auto block = node->corresponding_block;
std::vector<std::list<std::shared_ptr<ActionItem>>::iterator> act_to_move;
for (auto it = block->actions.begin(); it != block->actions.end(); ++it) {
if (std::dynamic_pointer_cast<opt::ForceDef>(*it) || std::dynamic_pointer_cast<opt::ForceUse>(*it)) {
act_to_move.push_back(it);
}
}
for (auto it : act_to_move) {
block->actions.erase(it);
}
}
}
void ConductRegAllocForFunction(std::shared_ptr<FunctionDefItem> func) {
std::cerr << "processing function " << func->func_name_raw << std::endl;
CFGType cfg;
@ -342,7 +359,7 @@ void ConductRegAllocForFunction(std::shared_ptr<FunctionDefItem> func) {
confgraph = BuildConfGraph(cfg);
} while (ConductColoring(func, cfg, confgraph));
TranslateColorResult(func, cfg, confgraph);
// PairMoveEliminate(func, cfg, confgraph);
RemoveCallingConventionKeeper(func, cfg, confgraph);
func->RecursivePrint(std::cerr);
}