From c4ca99d7b138e07b0ce9395f58a165f136e0859d Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Fri, 18 Oct 2024 10:14:04 +0000 Subject: [PATCH] basically write mem2reg --- include/IR/IR_basic.h | 107 ++++----------------- src/IR/IRBuilder.cpp | 2 +- src/opt/cfg.cpp | 11 ++- src/opt/mem2reg.cpp | 214 +++++++++++++++++++++++++++++++++++++++++- 4 files changed, 233 insertions(+), 101 deletions(-) diff --git a/include/IR/IR_basic.h b/include/IR/IR_basic.h index 5c52d78..98848a1 100644 --- a/include/IR/IR_basic.h +++ b/include/IR/IR_basic.h @@ -78,18 +78,11 @@ void GenerateASM(std::shared_ptr act, std::vector &code const std::unordered_map &low_level_class_info, bool process_phi); } // namespace NaiveBackend class BRAction : public JMPActionItem { - friend class IRBuilder; - friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); - friend class CFGType BuildCFGForFunction(const std::shared_ptr &func); + public: std::string cond; std::string true_label_full; std::string false_label_full; - public: BRAction() = default; void RecursivePrint(std::ostream &os) const { os << "br i1 " << cond << ", label %" << true_label_full << ", label %" << false_label_full << "\n"; @@ -111,17 +104,10 @@ class UNConditionJMPAction : public JMPActionItem { void RecursivePrint(std::ostream &os) const { os << "br label %" << label_full << "\n"; } }; class RETAction : public JMPActionItem { - friend class IRBuilder; - friend class FunctionDefItem; - friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: LLVMType type; std::string value; - public: RETAction() = default; void RecursivePrint(std::ostream &os) const { if (std::holds_alternative(type)) { @@ -137,20 +123,13 @@ class RETAction : public JMPActionItem { }; class BinaryOperationAction : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string op; std::string operand1_full; std::string operand2_full; std::string result_full; LLVMType type; - public: BinaryOperationAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = " << op << " "; @@ -167,14 +146,11 @@ class BinaryOperationAction : public ActionItem { } }; class AllocaAction : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); + public: std::string name_full; LLVMType type; size_t num; - public: AllocaAction() : num(1){}; void RecursivePrint(std::ostream &os) const { os << name_full << " = alloca "; @@ -192,18 +168,11 @@ class AllocaAction : public ActionItem { } }; class LoadAction : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string result_full; LLVMType ty; std::string ptr_full; - public: LoadAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = load "; @@ -218,18 +187,11 @@ class LoadAction : public ActionItem { } }; class StoreAction : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: LLVMType ty; std::string value_full; std::string ptr_full; - public: StoreAction() = default; void RecursivePrint(std::ostream &os) const { os << "store "; @@ -244,19 +206,12 @@ class StoreAction : public ActionItem { } }; class GetElementPtrAction : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string result_full; LLVMType ty; std::string ptr_full; std::vector indices; - public: GetElementPtrAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = getelementptr "; @@ -277,20 +232,13 @@ class GetElementPtrAction : public ActionItem { } }; class ICMPAction : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string op; std::string operand1_full; std::string operand2_full; std::string result_full; LLVMType type; - public: ICMPAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = icmp " << op << " "; @@ -305,20 +253,18 @@ class ICMPAction : public ActionItem { } }; class BlockItem : public LLVMIRItemBase { - friend class IRBuilder; - friend class FunctionDefItem; - friend void GenerateNaiveASM(std::ostream &os, std::shared_ptr prog); - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend class CFGType BuildCFGForFunction(const std::shared_ptr &func); + public: std::string label_full; + std::unordered_map> phi_map; // this is used to store phi items when optimizing std::vector> actions; std::shared_ptr exit_action; - public: BlockItem() = default; void RecursivePrint(std::ostream &os) const { os << label_full << ":\n"; + for (auto &kv : phi_map) { + std::static_pointer_cast(kv.second)->RecursivePrint(os); + } for (auto &action : actions) { action->RecursivePrint(os); } @@ -326,20 +272,13 @@ class BlockItem : public LLVMIRItemBase { } }; class CallItem : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string result_full; LLVMType return_type; std::string func_name_raw; std::vector args_ty; std::vector args_val_full; - public: CallItem() = default; void RecursivePrint(std::ostream &os) const { if (std::holds_alternative(return_type)) { @@ -380,17 +319,10 @@ class CallItem : public ActionItem { }; class PhiItem : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string result_full; LLVMType ty; std::vector> values; // (val_i_full, label_i_full) - public: PhiItem() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = phi "; @@ -412,20 +344,13 @@ class PhiItem : public ActionItem { } }; class SelectItem : public ActionItem { - friend class IRBuilder; - friend void NaiveBackend::ScanForVar(class NaiveBackend::FuncLayout &layout, std::shared_ptr action, - const std::unordered_map &low_level_class_info); - friend void NaiveBackend::GenerateASM(std::shared_ptr act, std::vector &code_lines, - NaiveBackend::FuncLayout &layout, - const std::unordered_map &low_level_class_info, - bool process_phi); + public: std::string result_full; std::string cond_full; std::string true_val_full; std::string false_val_full; LLVMType ty; - public: SelectItem() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = select i1 " << cond_full << ", "; diff --git a/src/IR/IRBuilder.cpp b/src/IR/IRBuilder.cpp index 724c16c..3d69f87 100644 --- a/src/IR/IRBuilder.cpp +++ b/src/IR/IRBuilder.cpp @@ -430,7 +430,7 @@ void IRBuilder::ActuralVisit(NewArrayExpr_ASTNode *node) { } auto dim_info = std::make_shared(); - std::string dim_info_var = "%.var.local.tmp." + std::to_string(tmp_var_counter++); + std::string dim_info_var = "%.var.tmp." + std::to_string(tmp_var_counter++); cur_alloca_block->actions.push_back(dim_info); dim_info->num = dims_with_size; dim_info->name_full = dim_info_var; diff --git a/src/opt/cfg.cpp b/src/opt/cfg.cpp index 0c16ef9..21e1e46 100644 --- a/src/opt/cfg.cpp +++ b/src/opt/cfg.cpp @@ -86,14 +86,17 @@ bool CFGNodeCollectionIsSame(const CFGNodeCollection &a, const CFGNodeCollection CFGType BuildCFGForFunction(const std::shared_ptr &func) { CFGType res; + auto init_block=func->init_block; if (!func->init_block) { - throw std::runtime_error("Function does not have an init block"); + // throw std::runtime_error("Function does not have an init block"); + if(func->basic_blocks.size()==0) throw std::runtime_error("Function does not have any block"); + init_block = func->basic_blocks[0]; } - res.label_to_block[func->init_block->label_full] = func->init_block.get(); + res.label_to_block[init_block->label_full] = init_block.get(); res.nodes.push_back(std::make_shared()); res.entry = res.nodes.back().get(); - res.entry->corresponding_block = func->init_block.get(); - res.block_to_node[func->init_block.get()] = res.entry; + res.entry->corresponding_block = init_block.get(); + res.block_to_node[init_block.get()] = res.entry; for (auto block_ptr : func->basic_blocks) { res.label_to_block[block_ptr->label_full] = block_ptr.get(); res.nodes.push_back(std::make_shared()); diff --git a/src/opt/mem2reg.cpp b/src/opt/mem2reg.cpp index 21f3afe..2d03eea 100644 --- a/src/opt/mem2reg.cpp +++ b/src/opt/mem2reg.cpp @@ -1,8 +1,10 @@ #include "mem2reg.h" +#include #include +#include "IR/IR_basic.h" #include "cfg.h" -void ConductMem2RegForFunction(const std::shared_ptr &func, const CFGType &cfg) { +void BuildDomForFunction(const std::shared_ptr &func, const CFGType &cfg) { bool all_dom_unchanged; CFGNodeCollection all_nodes; for (auto &node : cfg.nodes) { @@ -43,10 +45,10 @@ void ConductMem2RegForFunction(const std::shared_ptr &func, con cur->dom = new_dom; } } - } while (all_dom_unchanged); + } while (!all_dom_unchanged); for (auto node : cfg.nodes) { if (node.get() == cfg.entry) continue; - for (auto potential_predecessor : node->predecessors) { + for (auto potential_predecessor : node->dom) { if (potential_predecessor->dom.size() + 1 == node->dom.size()) { node->idom = potential_predecessor; node->idom->successors_in_dom_tree.push_back(node.get()); @@ -66,11 +68,213 @@ void ConductMem2RegForFunction(const std::shared_ptr &func, con frontier_node->dom_frontier.push_back(node.get()); } } + // debug + for (auto &node : cfg.nodes) { + std::cerr << node->corresponding_block->label_full << ":\n"; + std::cerr << "\tdom:"; + for (auto &dom_node : node->dom) { + std::cerr << ' ' << dom_node->corresponding_block->label_full; + } + if (node->idom) std::cerr << "\n\tidom: " << node->idom->corresponding_block->label_full; + std::cerr << "\n\tdom_frontier:"; + for (auto &frontier_node : node->dom_frontier) { + std::cerr << ' ' << frontier_node->corresponding_block->label_full; + } + std::cerr << "\n\tcfg pred:"; + for (auto &pred : node->predecessors) { + std::cerr << ' ' << pred->corresponding_block->label_full; + } + std::cerr << "\n\tsuccessors_in_dom_tree:"; + for (auto &succ : node->successors_in_dom_tree) { + std::cerr << ' ' << succ->corresponding_block->label_full; + } + std::cerr << std::endl; + } +} + +size_t InNodeReplace(CFGNodeType *cur_node, std::string origin_var_name, size_t &cur_version, + std::stack &name_stk) { + size_t versions_pushed = 0; + BlockItem *cur_block = cur_node->corresponding_block; + std::vector> new_actions; + std::unordered_set is_an_alias_generated_by_load; + if (cur_block->phi_map.find(origin_var_name) != cur_block->phi_map.end()) { + name_stk.push(cur_block->phi_map[origin_var_name]->result_full); + versions_pushed++; + } + for (auto act : cur_block->actions) { + if (std::dynamic_pointer_cast(act)) + throw std::runtime_error("JMPActionItem should not appear in actions"); + if (auto alloca_act = std::dynamic_pointer_cast(act)) { + if (alloca_act->name_full == origin_var_name) { + // do nothing, just erase it + } else { + new_actions.push_back(alloca_act); + } + } else if (auto bin_act = std::dynamic_pointer_cast(act)) { + if (is_an_alias_generated_by_load.find(bin_act->operand1_full) != is_an_alias_generated_by_load.end()) { + bin_act->operand1_full = name_stk.top(); + } + if (is_an_alias_generated_by_load.find(bin_act->operand2_full) != is_an_alias_generated_by_load.end()) { + bin_act->operand2_full = name_stk.top(); + } + new_actions.push_back(bin_act); + } else if (auto load_act = std::dynamic_pointer_cast(act)) { + if (load_act->ptr_full == origin_var_name) { + // remove it + is_an_alias_generated_by_load.insert(load_act->result_full); + } else { + new_actions.push_back(load_act); + } + } else if (auto store_act = std::dynamic_pointer_cast(act)) { + if (is_an_alias_generated_by_load.find(store_act->value_full) != is_an_alias_generated_by_load.end()) { + store_act->value_full = name_stk.top(); + } + if (store_act->ptr_full == origin_var_name) { + // remove it + name_stk.push(store_act->value_full); + versions_pushed++; + } else { + new_actions.push_back(store_act); + } + } else if (auto get_act = std::dynamic_pointer_cast(act)) { + new_actions.push_back(get_act); + if (is_an_alias_generated_by_load.find(get_act->ptr_full) != is_an_alias_generated_by_load.end()) { + get_act->ptr_full = name_stk.top(); + } + for (auto &idx : get_act->indices) { + if (is_an_alias_generated_by_load.find(idx) != is_an_alias_generated_by_load.end()) { + idx = name_stk.top(); + } + } + } else if (auto icmp_act = std::dynamic_pointer_cast(act)) { + if (is_an_alias_generated_by_load.find(icmp_act->operand1_full) != is_an_alias_generated_by_load.end()) { + icmp_act->operand1_full = name_stk.top(); + } + if (is_an_alias_generated_by_load.find(icmp_act->operand2_full) != is_an_alias_generated_by_load.end()) { + icmp_act->operand2_full = name_stk.top(); + } + new_actions.push_back(icmp_act); + } else if (auto phi_act = std::dynamic_pointer_cast(act)) { + new_actions.push_back(phi_act); + for (auto &val : phi_act->values) { + if (is_an_alias_generated_by_load.find(val.first) != is_an_alias_generated_by_load.end()) { + val.first = name_stk.top(); + } + } + } else if (auto call_act = std::dynamic_pointer_cast(act)) { + for (size_t i = 0; i < call_act->args_val_full.size(); i++) { + if (is_an_alias_generated_by_load.find(call_act->args_val_full[i]) != is_an_alias_generated_by_load.end()) { + call_act->args_val_full[i] = name_stk.top(); + } + } + new_actions.push_back(call_act); + } else if (auto select_act = std::dynamic_pointer_cast(act)) { + if (is_an_alias_generated_by_load.find(select_act->cond_full) != is_an_alias_generated_by_load.end()) { + select_act->cond_full = name_stk.top(); + } + if (is_an_alias_generated_by_load.find(select_act->true_val_full) != is_an_alias_generated_by_load.end()) { + select_act->true_val_full = name_stk.top(); + } + if (is_an_alias_generated_by_load.find(select_act->false_val_full) != is_an_alias_generated_by_load.end()) { + select_act->false_val_full = name_stk.top(); + } + new_actions.push_back(select_act); + } else { + throw std::runtime_error("Unknown action type"); + } + } + if (auto br_act = std::dynamic_pointer_cast(cur_block->exit_action)) { + if (is_an_alias_generated_by_load.find(br_act->cond) != is_an_alias_generated_by_load.end()) { + br_act->cond = name_stk.top(); + } + } else if (auto ret_act = std::dynamic_pointer_cast(cur_block->exit_action)) { + if (is_an_alias_generated_by_load.find(ret_act->value) != is_an_alias_generated_by_load.end()) { + ret_act->value = name_stk.top(); + } + } + cur_block->actions = new_actions; + return versions_pushed; +} +void DFSReplace(CFGNodeType *cur_node, std::string origin_var_name, size_t &cur_version, + std::stack &name_stk) { + // std::cerr << "DFSReplace: " << cur_node->corresponding_block->label_full << std::endl; + size_t versions_pushed = 0; + // step 1: process current node + versions_pushed = InNodeReplace(cur_node, origin_var_name, cur_version, name_stk); + // step 2: process the phi commands in the successors in cfg + for (auto succ : cur_node->successors) { + if (succ->corresponding_block->phi_map.find(origin_var_name) != succ->corresponding_block->phi_map.end()) { + auto phi = succ->corresponding_block->phi_map[origin_var_name]; + phi->values.push_back(std::make_pair(name_stk.top(), cur_node->corresponding_block->label_full)); + } + } + // step 3: process the successors in dom tree + for (auto succ : cur_node->successors_in_dom_tree) { + DFSReplace(succ, origin_var_name, cur_version, name_stk); + } + // step 4: restore the stack + for (size_t i = 0; i < versions_pushed; i++) { + name_stk.pop(); + } +} +void ConductMem2RegForFunction(const std::shared_ptr &func, const CFGType &cfg) { + BuildDomForFunction(func, cfg); + std::vector all_local_vars; + std::unordered_map var_to_version; + std::unordered_map var_to_type; + for (auto act : cfg.entry->corresponding_block->actions) { + if (auto alloca_act = std::dynamic_pointer_cast(act)) { + if (alloca_act->num == 1 && alloca_act->name_full.substr(0, 12) == "%.var.local.") { + all_local_vars.push_back(alloca_act->name_full); + var_to_version[alloca_act->name_full] = 0; + var_to_type[alloca_act->name_full] = alloca_act->type; + } + } + } + std::unordered_map> var_to_def_sites; + for (auto node : cfg.nodes) { + for (auto act : node->corresponding_block->actions) { + if (auto store_act = std::dynamic_pointer_cast(act)) { + if (var_to_version.find(store_act->ptr_full) != var_to_version.end()) { + var_to_def_sites[store_act->ptr_full].push_back(node.get()); + break; + } + } + } + } + for (const auto &var : all_local_vars) { + size_t &cur_version = var_to_version[var]; + std::queue Q; + for (auto def_site : var_to_def_sites[var]) { + Q.push(def_site); + } + while (Q.size() > 0) { + CFGNodeType *cur_node = Q.front(); + Q.pop(); + for (auto frontier : cur_node->dom_frontier) { + if (frontier->corresponding_block->phi_map.find(var) != frontier->corresponding_block->phi_map.end()) continue; + auto new_phi = std::make_shared(); + frontier->corresponding_block->phi_map[var] = new_phi; + new_phi->result_full = var + ".v" + std::to_string(++cur_version); + new_phi->ty = var_to_type[var]; + Q.push(frontier); + } + } + } + for (const auto &var : all_local_vars) { + size_t &cur_version = var_to_version[var]; + std::stack name_stk; + name_stk.push("0"); + // std::cerr << "processing " << var << std::endl; + DFSReplace(cfg.entry, var, cur_version, name_stk); + } } std::shared_ptr Mem2Reg(std::shared_ptr src) { - auto res = std::make_shared(*src); + // auto res = std::make_shared(*src); + auto res = src; for (auto &func : res->function_defs) { - func = std::make_shared(*func); + // func = std::make_shared(*func); auto cfg = BuildCFGForFunction(func); ConductMem2RegForFunction(func, cfg); }