#pragma once #include #include #include #include #include #include #include #include #include #include #include "ast/astnode.h" #include "tools.h" class LLVMIRItemBase { public: LLVMIRItemBase() = default; virtual ~LLVMIRItemBase() = default; virtual void RecursivePrint(std::ostream &os) const = 0; }; class TypeDefItem : public LLVMIRItemBase { friend class IRBuilder; std::string class_name_raw; std::vector elements; public: TypeDefItem() = default; void RecursivePrint(std::ostream &os) const { os << "%.class." << class_name_raw; os << " = type {"; for (size_t i = 0; i < elements.size(); i++) { if (std::holds_alternative(elements[i])) { os << "i" << std::get(elements[i]).bits; } else if (std::holds_alternative(elements[i])) { os << "ptr"; } else if (std::holds_alternative(elements[i])) { os << "void"; } else if (std::holds_alternative(elements[i])) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } if (i != elements.size() - 1) { os << ","; } } os << "}\n"; } }; class GlobalVarDefItem : public LLVMIRItemBase { friend class IRBuilder; LLVMType type; std::string name_raw; public: GlobalVarDefItem() = default; void RecursivePrint(std::ostream &os) const { std::string name_full = "@.var.global." + name_raw + ".addrkp"; os << name_full << " = global "; if (std::holds_alternative(type)) { os << "i" << std::get(type).bits << " 0\n"; } else if (std::holds_alternative(type)) { os << "ptr null\n"; } else { throw std::runtime_error("something strange happened"); } } }; class ActionItem : public LLVMIRItemBase {}; class JMPActionItem : public ActionItem {}; class BRAction : public JMPActionItem { friend class IRBuilder; std::string cond; std::string true_label_full; std::string false_label_full; public: BRAction() = default; void RecursivePrint(std::ostream &os) const { os << "br i1 " << cond << ", label %" << true_label_full << ", label %" << false_label_full << "\n"; } }; class UNConditionJMPAction : public JMPActionItem { friend class IRBuilder; std::string label_full; public: UNConditionJMPAction() = default; void RecursivePrint(std::ostream &os) const { os << "br label %" << label_full << "\n"; } }; class RETAction : public JMPActionItem { friend class IRBuilder; LLVMType type; std::string value; public: RETAction() = default; void RecursivePrint(std::ostream &os) const { if (std::holds_alternative(type)) { os << "ret void\n"; } else if (std::holds_alternative(type)) { os << "ret i" << std::get(type).bits << " " << value << "\n"; } else if (std::holds_alternative(type)) { os << "ret ptr " << value << "\n"; } else { throw std::runtime_error("something strange happened"); } } }; class BinaryOperationAction : public ActionItem { friend class IRBuilder; std::string op; std::string operand1_full; std::string operand2_full; std::string result_full; LLVMType type; public: BinaryOperationAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = " << op << " "; if (std::holds_alternative(type)) { os << "i" << std::get(type).bits; } else if (std::holds_alternative(type)) { os << "ptr"; } else if (std::holds_alternative(type)) { os << "void"; } else if (std::holds_alternative(type)) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } os << " " << operand1_full << ", " << operand2_full << "\n"; } }; class AllocaAction : public ActionItem { friend class IRBuilder; std::string name_full; LLVMType type; size_t num; public: AllocaAction() : num(1){}; void RecursivePrint(std::ostream &os) const { os << name_full << " = alloca "; if (std::holds_alternative(type)) { os << "i" << std::get(type).bits; } else if (std::holds_alternative(type)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } if (num > 1) { os << ", i32 " << num; } os << "\n"; } }; class LoadAction : public ActionItem { friend class IRBuilder; std::string result_full; LLVMType ty; std::string ptr_full; public: LoadAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = load "; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } os << ", ptr " << ptr_full << '\n'; } }; class StoreAction : public ActionItem { friend class IRBuilder; LLVMType ty; std::string value_full; std::string ptr_full; public: StoreAction() = default; void RecursivePrint(std::ostream &os) const { os << "store "; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } os << ' ' << value_full << ", ptr " << ptr_full << '\n'; } }; class GetElementPtrAction : public ActionItem { friend class IRBuilder; std::string result_full; LLVMType ty; std::string ptr_full; std::vector indices; public: GetElementPtrAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = getelementptr "; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else if (std::holds_alternative(ty)) { os << std::get(ty).class_name_full; } else { throw std::runtime_error("something strange happened"); } os << ", ptr " << ptr_full; for (auto &index : indices) { os << ", i32 " << index; } os << '\n'; } }; class ICMPAction : public ActionItem { friend class IRBuilder; std::string op; std::string operand1_full; std::string operand2_full; std::string result_full; LLVMType type; public: ICMPAction() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = icmp " << op << " "; if (std::holds_alternative(type)) { os << "i" << std::get(type).bits; } else if (std::holds_alternative(type)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } os << ' ' << operand1_full << ", " << operand2_full << '\n'; } }; class BlockItem : public LLVMIRItemBase { friend class IRBuilder; std::string label_full; std::vector> actions; std::shared_ptr exit_action; public: BlockItem() = default; void RecursivePrint(std::ostream &os) const { os << label_full << ":\n"; for (auto &action : actions) { action->RecursivePrint(os); } if (exit_action) exit_action->RecursivePrint(os); } }; class CallItem : public ActionItem { friend class IRBuilder; std::string result_full; LLVMType return_type; std::string func_name_raw; std::vector args_ty; std::vector args_val_full; public: CallItem() = default; void RecursivePrint(std::ostream &os) const { if (std::holds_alternative(return_type)) { os << "call "; } else { os << result_full << " = call "; } if (std::holds_alternative(return_type)) { os << "i" << std::get(return_type).bits; } else if (std::holds_alternative(return_type)) { os << "ptr"; } else if (std::holds_alternative(return_type)) { os << "void"; } else if (std::holds_alternative(return_type)) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } os << " @" << func_name_raw << "("; for (size_t i = 0; i < args_val_full.size(); i++) { auto &ty = args_ty[i]; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else if (std::holds_alternative(ty)) { throw std::runtime_error("void type is not allowed in function call"); } else if (std::holds_alternative(ty)) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } else { throw std::runtime_error("something strange happened"); } os << ' ' << args_val_full[i]; if (i != args_val_full.size() - 1) { os << ", "; } } os << ")\n"; } }; class PhiItem : public ActionItem { std::string result_full; LLVMType ty; std::vector> values; // (val_i_full, label_i_full) public: PhiItem() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = phi "; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } os << " "; for (size_t i = 0; i < values.size(); i++) { os << " [" << values[i].first << ", " << values[i].second << "]"; if (i != values.size() - 1) { os << ", "; } } os << "\n"; } }; class SelectItem : public ActionItem { std::string result_full; std::string cond_full; std::string true_val_full; std::string false_val_full; LLVMType ty; public: SelectItem() = default; void RecursivePrint(std::ostream &os) const { os << result_full << " = select i1 " << cond_full << ", "; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } os << " " << true_val_full << ", "; if (std::holds_alternative(ty)) { os << "i" << std::get(ty).bits; } else if (std::holds_alternative(ty)) { os << "ptr"; } else { throw std::runtime_error("something strange happened"); } os << false_val_full << "\n"; } }; class FunctionDefItem : public LLVMIRItemBase { friend class IRBuilder; LLVMType return_type; std::string func_name_raw; std::vector args; std::vector args_full_name; std::vector> basic_blocks; public: FunctionDefItem() = default; void RecursivePrint(std::ostream &os) const { os << "define "; if (std::holds_alternative(return_type)) { os << "i" << std::get(return_type).bits; } else if (std::holds_alternative(return_type)) { os << "ptr"; } else if (std::holds_alternative(return_type)) { os << "void"; } else if (std::holds_alternative(return_type)) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } os << " @" << func_name_raw << "("; for (size_t i = 0; i < args.size(); i++) { if (std::holds_alternative(args[i])) { os << "i" << std::get(args[i]).bits; } else if (std::holds_alternative(args[i])) { os << "ptr"; } else if (std::holds_alternative(args[i])) { os << "void"; } else if (std::holds_alternative(args[i])) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } os << ' ' << args_full_name[i]; if (i != args.size() - 1) { os << ","; } } os << ")\n{\n"; for (auto &item : basic_blocks) { item->RecursivePrint(os); } os << "}\n"; } }; class FunctionDeclareItem : public LLVMIRItemBase { friend class IRBuilder; friend std::shared_ptr BuildIR(std::shared_ptr src); LLVMType return_type; std::string func_name_raw; std::vector args; public: FunctionDeclareItem() = default; void RecursivePrint(std::ostream &os) const { os << "declare "; if (std::holds_alternative(return_type)) { os << "i" << std::get(return_type).bits; } else if (std::holds_alternative(return_type)) { os << "ptr"; } else if (std::holds_alternative(return_type)) { os << "void"; } else if (std::holds_alternative(return_type)) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } os << " @" << func_name_raw << "("; for (size_t i = 0; i < args.size(); i++) { if (std::holds_alternative(args[i])) { os << "i" << std::get(args[i]).bits; } else if (std::holds_alternative(args[i])) { os << "ptr"; } else if (std::holds_alternative(args[i])) { os << "void"; } else if (std::holds_alternative(args[i])) { throw std::runtime_error("In MX* language, class types are referenced by pointers"); } if (i != args.size() - 1) { os << ","; } } os << ")\n"; } }; class ConstStrItem : public LLVMIRItemBase { friend class IRBuilder; std::string string_raw; size_t const_str_id; static std::string Escape(const std::string &src) { std::stringstream ss; for (auto &ch : src) { if (ch == '\n') { ss << "\\0A"; } else if (ch == '\t') { ss << "\\09"; } else if (ch == '\r') { ss << "\\0D"; } else if (ch == '\"') { ss << "\\22"; } else if (ch == '\\') { ss << "\\5C"; } else { ss << ch; } } return ss.str(); } public: ConstStrItem() = default; void RecursivePrint(std::ostream &os) const { os << "@.str." << const_str_id << " = private unnamed_addr constant [" << string_raw.size() + 1 << " x i8] c\"" << Escape(string_raw) << "\\00\"\n"; } }; class ModuleItem : public LLVMIRItemBase { friend class IRBuilder; friend std::shared_ptr BuildIR(std::shared_ptr src); std::vector> const_strs; std::vector> function_declares; std::vector> type_defs; std::vector> global_var_defs; std::vector> function_defs; public: ModuleItem() = default; void RecursivePrint(std::ostream &os) const { for (auto &item : const_strs) { item->RecursivePrint(os); } for (auto &item : function_declares) { item->RecursivePrint(os); } for (auto &item : type_defs) { item->RecursivePrint(os); os << '\n'; } for (auto &item : global_var_defs) { item->RecursivePrint(os); os << '\n'; } for (auto &item : function_defs) { item->RecursivePrint(os); os << '\n'; } } };