diff --git a/include/IR/IRBuilder.h b/include/IR/IRBuilder.h index 4001b9f..fa0816d 100644 --- a/include/IR/IRBuilder.h +++ b/include/IR/IRBuilder.h @@ -1,10 +1,31 @@ +#pragma once #include #include "IR_basic.h" #include "ast/astnode_visitor.h" class IRBuilder : public ASTNodeVirturalVisitor { + friend std::shared_ptr BuildIR(std::shared_ptr src); std::shared_ptr prog; + std::shared_ptr cur_class; + std::shared_ptr cur_func; + std::shared_ptr cur_block; + std::string cur_class_name; + bool is_in_class_def; + bool is_in_func_def; + size_t tmp_var_counter; + size_t block_counter; + std::string cur_break_target; + std::string cur_continue_target; + bool just_encountered_jmp; + std::shared_ptr global_scope; public: + IRBuilder() { + tmp_var_counter = 0; + block_counter = 0; + is_in_class_def = false; + is_in_func_def = false; + just_encountered_jmp = false; + } // Structural AST Nodes void ActuralVisit(FuncDef_ASTNode *node) override; void ActuralVisit(ClassDef_ASTNode *node) override; diff --git a/include/IR/IR_basic.h b/include/IR/IR_basic.h index d9f6fe5..bf5b330 100644 --- a/include/IR/IR_basic.h +++ b/include/IR/IR_basic.h @@ -1,132 +1,458 @@ #pragma once +#include #include +#include #include +#include #include #include #include #include "ast/astnode.h" -struct LLVMIRIntType { - size_t bits; -}; -struct LLVMIRPTRType {}; -struct LLVMIRCLASSTYPE {}; -using LLVMType = std::variant; +#include "tools.h" + class LLVMIRItemBase { public: LLVMIRItemBase() = default; virtual ~LLVMIRItemBase() = default; - virtual void RecursivePrint(std::ostream &os) const; + virtual void RecursivePrint(std::ostream &os) const = 0; }; class TypeDefItem : public LLVMIRItemBase { + friend class IRBuilder; + std::string class_name_raw; + std::vector elements; + public: - void RecursivePrint(std::ostream &os) const { ; } + TypeDefItem() = default; + void RecursivePrint(std::ostream &os) const { + os << "%.class." << class_name_raw; + os << " = type {"; + for (size_t i = 0; i < elements.size(); i++) { + if (std::holds_alternative(elements[i])) { + os << "i" << std::get(elements[i]).bits; + } else if (std::holds_alternative(elements[i])) { + os << "ptr"; + } else if (std::holds_alternative(elements[i])) { + os << "void"; + } else if (std::holds_alternative(elements[i])) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + if (i != elements.size() - 1) { + os << ","; + } + } + os << "}\n"; + } }; class GlobalVarDefItem : public LLVMIRItemBase { + friend class IRBuilder; LLVMType type; - std::string name; - - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class ActionItem : public LLVMIRItemBase { - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class JMPActionItem : public ActionItem { - std::string label; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class BRAction: public JMPActionItem { - std::string cond; - std::string true_label; - std::string false_label; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class UNConditionJMPAction: public JMPActionItem { - std::string label; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class RETAction : public JMPActionItem { - std::string value; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class BinaryOperationAction : public ActionItem { - std::string op; - std::string lhs; - std::string rhs; - std::string result; - LLVMType type; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class AllocaAction : public ActionItem { - std::string name; - LLVMType type; - size_t num; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class LoadAction : public ActionItem { - std::string result; - LLVMType ty; - std::string ptr; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class StoreAction : public ActionItem { - LLVMType ty; - std::string value; - std::string ptr; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class GetElementPtrAction : public ActionItem { - std::string result; - LLVMType ty; - std::string ptr; - std::vector indices; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class ICMPAction : public ActionItem { - std::string op; - std::string lhs; - std::string rhs; - std::string result; - LLVMType ty; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class BlockItem : public LLVMIRItemBase { - std::string label; - std::vector> actions; - std::shared_ptr exit_action; - public: - void RecursivePrint(std::ostream &os) const { ; } -}; -class FunctionDefItem : public LLVMIRItemBase { - std::vector> basic_blocks; + std::string name_raw; public: + GlobalVarDefItem() = default; void RecursivePrint(std::ostream &os) const { - for (auto &item : basic_blocks) { - item->RecursivePrint(os); - os << '\n'; + std::string name_full = "@.var.global." + name_raw + ".addrkp"; + os << name_full << " = global "; + if (std::holds_alternative(type)) { + os << "i" << std::get(type).bits << " 0\n"; + } else if (std::holds_alternative(type)) { + os << "ptr null\n"; + } else { + throw std::runtime_error("something strange happened"); } } }; +class ActionItem : public LLVMIRItemBase {}; +class JMPActionItem : public ActionItem {}; +class BRAction : public JMPActionItem { + friend class IRBuilder; + std::string cond; + std::string true_label_full; + std::string false_label_full; + + public: + BRAction() = default; + void RecursivePrint(std::ostream &os) const { + os << "br i1 " << cond << ", label %" << true_label_full << ", label %" << false_label_full << "\n"; + } +}; +class UNConditionJMPAction : public JMPActionItem { + friend class IRBuilder; + std::string label_full; + + public: + UNConditionJMPAction() = default; + void RecursivePrint(std::ostream &os) const { os << "br label %" << label_full << "\n"; } +}; +class RETAction : public JMPActionItem { + friend class IRBuilder; + LLVMType type; + std::string value; + + public: + RETAction() = default; + void RecursivePrint(std::ostream &os) const { + if (std::holds_alternative(type)) { + os << "ret void\n"; + } else if (std::holds_alternative(type)) { + os << "ret i" << std::get(type).bits << " " << value << "\n"; + } else if (std::holds_alternative(type)) { + os << "ret ptr " << value << "\n"; + } else { + throw std::runtime_error("something strange happened"); + } + } +}; +class BinaryOperationAction : public ActionItem { + friend class IRBuilder; + std::string op; + std::string operand1_full; + std::string operand2_full; + std::string result_full; + LLVMType type; + + public: + BinaryOperationAction() = default; + void RecursivePrint(std::ostream &os) const { + os << result_full << " = " << op << " "; + if (std::holds_alternative(type)) { + os << "i" << std::get(type).bits; + } else if (std::holds_alternative(type)) { + os << "ptr"; + } else if (std::holds_alternative(type)) { + os << "void"; + } else if (std::holds_alternative(type)) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + os << " " << operand1_full << ", " << operand2_full << "\n"; + } +}; +class AllocaAction : public ActionItem { + friend class IRBuilder; + std::string name_full; + LLVMType type; + size_t num; + + public: + AllocaAction() : num(1){}; + void RecursivePrint(std::ostream &os) const { + os << name_full << " = alloca "; + if (std::holds_alternative(type)) { + os << "i" << std::get(type).bits; + } else if (std::holds_alternative(type)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + if (num > 1) { + os << ", i32 " << num; + } + os << "\n"; + } +}; +class LoadAction : public ActionItem { + friend class IRBuilder; + std::string result_full; + LLVMType ty; + std::string ptr_full; + + public: + LoadAction() = default; + void RecursivePrint(std::ostream &os) const { + os << result_full << " = load "; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + os << ", ptr " << ptr_full << '\n'; + } +}; +class StoreAction : public ActionItem { + friend class IRBuilder; + LLVMType ty; + std::string value_full; + std::string ptr_full; + + public: + StoreAction() = default; + void RecursivePrint(std::ostream &os) const { + os << "store "; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + os << ' ' << value_full << ", ptr " << ptr_full << '\n'; + } +}; +class GetElementPtrAction : public ActionItem { + std::string result_full; + LLVMType ty; + std::string ptr_full; + std::vector indices; + + public: + GetElementPtrAction() = default; + void RecursivePrint(std::ostream &os) const { + os << result_full << " = getelementptr "; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else if (std::holds_alternative(ty)) { + os << std::get(ty).class_name_full; + } else { + throw std::runtime_error("something strange happened"); + } + os << ", ptr " << ptr_full; + for (auto &index : indices) { + os << ", i32 " << index; + } + os << '\n'; + } +}; +class ICMPAction : public ActionItem { + friend class IRBuilder; + std::string op; + std::string operand1_full; + std::string operand2_full; + std::string result_full; + LLVMType type; + + public: + ICMPAction() = default; + void RecursivePrint(std::ostream &os) const { + os << result_full << " = icmp " << op << " "; + if (std::holds_alternative(type)) { + os << "i" << std::get(type).bits; + } else if (std::holds_alternative(type)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + os << ' ' << operand1_full << ", " << operand2_full << '\n'; + } +}; +class BlockItem : public LLVMIRItemBase { + friend class IRBuilder; + std::string label_full; + std::vector> actions; + std::shared_ptr exit_action; + + public: + BlockItem() = default; + void RecursivePrint(std::ostream &os) const { + os << label_full << ":\n"; + for (auto &action : actions) { + action->RecursivePrint(os); + } + if (exit_action) exit_action->RecursivePrint(os); + } +}; +class CallItem : public ActionItem { + friend class IRBuilder; + std::string result_full; + LLVMType return_type; + std::string func_name_raw; + std::vector args_ty; + std::vector args_val_full; + + public: + CallItem() = default; + void RecursivePrint(std::ostream &os) const { + if (std::holds_alternative(return_type)) { + os << "call "; + } else { + os << result_full << " = call "; + } + if (std::holds_alternative(return_type)) { + os << "i" << std::get(return_type).bits; + } else if (std::holds_alternative(return_type)) { + os << "ptr"; + } else if (std::holds_alternative(return_type)) { + os << "void"; + } else if (std::holds_alternative(return_type)) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + os << " @" << func_name_raw << "("; + for (size_t i = 0; i < args_val_full.size(); i++) { + auto &ty = args_ty[i]; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else if (std::holds_alternative(ty)) { + throw std::runtime_error("void type is not allowed in function call"); + } else if (std::holds_alternative(ty)) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } else { + throw std::runtime_error("something strange happened"); + } + os << ' ' << args_val_full[i]; + if (i != args_val_full.size() - 1) { + os << ", "; + } + } + os << ")\n"; + } +}; + +class PhiItem : public ActionItem { + std::string result_full; + LLVMType ty; + std::vector> values; // (val_i_full, label_i_full) + public: + PhiItem() = default; + void RecursivePrint(std::ostream &os) const { + os << result_full << " = phi "; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + os << " "; + for (size_t i = 0; i < values.size(); i++) { + os << " [" << values[i].first << ", " << values[i].second << "]"; + if (i != values.size() - 1) { + os << ", "; + } + } + os << "\n"; + } +}; +class SelectItem : public ActionItem { + std::string result_full; + std::string cond_full; + std::string true_val_full; + std::string false_val_full; + LLVMType ty; + + public: + SelectItem() = default; + void RecursivePrint(std::ostream &os) const { + os << result_full << " = select i1 " << cond_full << ", "; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + os << " " << true_val_full << ", "; + if (std::holds_alternative(ty)) { + os << "i" << std::get(ty).bits; + } else if (std::holds_alternative(ty)) { + os << "ptr"; + } else { + throw std::runtime_error("something strange happened"); + } + os << false_val_full << "\n"; + } +}; +class FunctionDefItem : public LLVMIRItemBase { + friend class IRBuilder; + LLVMType return_type; + std::string func_name_raw; + std::vector args; + std::vector args_full_name; + std::vector> basic_blocks; + + public: + FunctionDefItem() = default; + void RecursivePrint(std::ostream &os) const { + os << "define "; + if (std::holds_alternative(return_type)) { + os << "i" << std::get(return_type).bits; + } else if (std::holds_alternative(return_type)) { + os << "ptr"; + } else if (std::holds_alternative(return_type)) { + os << "void"; + } else if (std::holds_alternative(return_type)) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + os << " @" << func_name_raw << "("; + for (size_t i = 0; i < args.size(); i++) { + if (std::holds_alternative(args[i])) { + os << "i" << std::get(args[i]).bits; + } else if (std::holds_alternative(args[i])) { + os << "ptr"; + } else if (std::holds_alternative(args[i])) { + os << "void"; + } else if (std::holds_alternative(args[i])) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + os << ' ' << args_full_name[i]; + if (i != args.size() - 1) { + os << ","; + } + } + os << ")\n{\n"; + for (auto &item : basic_blocks) { + item->RecursivePrint(os); + } + os << "}\n"; + } +}; +class FunctionDeclareItem : public LLVMIRItemBase { + friend class IRBuilder; + friend std::shared_ptr BuildIR(std::shared_ptr src); + LLVMType return_type; + std::string func_name_raw; + std::vector args; + + public: + FunctionDeclareItem() = default; + void RecursivePrint(std::ostream &os) const { + os << "declare "; + if (std::holds_alternative(return_type)) { + os << "i" << std::get(return_type).bits; + } else if (std::holds_alternative(return_type)) { + os << "ptr"; + } else if (std::holds_alternative(return_type)) { + os << "void"; + } else if (std::holds_alternative(return_type)) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + os << " @" << func_name_raw << "("; + for (size_t i = 0; i < args.size(); i++) { + if (std::holds_alternative(args[i])) { + os << "i" << std::get(args[i]).bits; + } else if (std::holds_alternative(args[i])) { + os << "ptr"; + } else if (std::holds_alternative(args[i])) { + os << "void"; + } else if (std::holds_alternative(args[i])) { + throw std::runtime_error("In MX* language, class types are referenced by pointers"); + } + if (i != args.size() - 1) { + os << ","; + } + } + os << ")\n"; + } +}; class ModuleItem : public LLVMIRItemBase { + friend class IRBuilder; + friend std::shared_ptr BuildIR(std::shared_ptr src); + std::vector> function_declares; std::vector> type_defs; std::vector> global_var_defs; std::vector> function_defs; public: + ModuleItem() = default; void RecursivePrint(std::ostream &os) const { + for (auto &item : function_declares) { + item->RecursivePrint(os); + } for (auto &item : type_defs) { item->RecursivePrint(os); os << '\n'; diff --git a/include/ast/astnode.h b/include/ast/astnode.h index 546f106..e90edfe 100644 --- a/include/ast/astnode.h +++ b/include/ast/astnode.h @@ -10,6 +10,7 @@ class ASTNodeBase { friend Visitor; friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); + friend std::shared_ptr BuildIR(std::shared_ptr src); protected: std::shared_ptr current_scope; diff --git a/include/ast/expr_astnode.h b/include/ast/expr_astnode.h index 9b193a8..1f5473c 100644 --- a/include/ast/expr_astnode.h +++ b/include/ast/expr_astnode.h @@ -8,11 +8,14 @@ class Expr_ASTNode : public ASTNodeBase { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; ExprTypeInfo expr_type_info; bool assignable; + std::string IR_result_full; + bool is_requiring_lvalue; public: - Expr_ASTNode() : assignable(false){}; + Expr_ASTNode() : assignable(false), is_requiring_lvalue(false){}; virtual ~Expr_ASTNode() = default; }; @@ -21,6 +24,7 @@ class BasicExpr_ASTNode : public Expr_ASTNode {}; // This is a virtual class class NewArrayExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; bool has_initial_value; std::vector> dim_size; std::shared_ptr initial_value; @@ -33,6 +37,7 @@ class NewArrayExpr_ASTNode : public Expr_ASTNode { class NewConstructExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; public: NewConstructExpr_ASTNode() = default; @@ -42,6 +47,7 @@ class NewConstructExpr_ASTNode : public Expr_ASTNode { class NewExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; public: NewExpr_ASTNode() = default; @@ -51,6 +57,7 @@ class NewExpr_ASTNode : public Expr_ASTNode { class AccessExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr base; IdentifierType member; bool is_function; @@ -64,6 +71,7 @@ class AccessExpr_ASTNode : public Expr_ASTNode { class IndexExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr base; std::vector> indices; @@ -75,6 +83,7 @@ class IndexExpr_ASTNode : public Expr_ASTNode { class SuffixExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr base; @@ -86,6 +95,7 @@ class SuffixExpr_ASTNode : public Expr_ASTNode { class PrefixExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr base; @@ -97,6 +107,7 @@ class PrefixExpr_ASTNode : public Expr_ASTNode { class OppositeExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr base; public: @@ -107,6 +118,7 @@ class OppositeExpr_ASTNode : public Expr_ASTNode { class LNotExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr base; public: @@ -117,6 +129,7 @@ class LNotExpr_ASTNode : public Expr_ASTNode { class BNotExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr base; public: @@ -127,6 +140,7 @@ class BNotExpr_ASTNode : public Expr_ASTNode { class MDMExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -139,6 +153,7 @@ class MDMExpr_ASTNode : public Expr_ASTNode { class PMExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -151,6 +166,7 @@ class PMExpr_ASTNode : public Expr_ASTNode { class RLExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -163,6 +179,7 @@ class RLExpr_ASTNode : public Expr_ASTNode { class GGLLExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -175,6 +192,7 @@ class GGLLExpr_ASTNode : public Expr_ASTNode { class NEExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -187,6 +205,7 @@ class NEExpr_ASTNode : public Expr_ASTNode { class BAndExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -199,6 +218,7 @@ class BAndExpr_ASTNode : public Expr_ASTNode { class BXorExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -211,6 +231,7 @@ class BXorExpr_ASTNode : public Expr_ASTNode { class BOrExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -223,6 +244,7 @@ class BOrExpr_ASTNode : public Expr_ASTNode { class LAndExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -235,6 +257,7 @@ class LAndExpr_ASTNode : public Expr_ASTNode { class LOrExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr left; std::shared_ptr right; @@ -247,6 +270,7 @@ class LOrExpr_ASTNode : public Expr_ASTNode { class TernaryExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr condition; std::shared_ptr src1; std::shared_ptr src2; @@ -259,6 +283,7 @@ class TernaryExpr_ASTNode : public Expr_ASTNode { class AssignExpr_ASTNode : public Expr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::string op; std::shared_ptr dest; std::shared_ptr src; @@ -271,6 +296,7 @@ class AssignExpr_ASTNode : public Expr_ASTNode { class ThisExpr_ASTNode : public BasicExpr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; public: ThisExpr_ASTNode() = default; @@ -280,6 +306,7 @@ class ThisExpr_ASTNode : public BasicExpr_ASTNode { class ParenExpr_ASTNode : public BasicExpr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr expr; public: @@ -290,6 +317,7 @@ class ParenExpr_ASTNode : public BasicExpr_ASTNode { class IDExpr_ASTNode : public BasicExpr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; IdentifierType id; public: @@ -300,6 +328,7 @@ class IDExpr_ASTNode : public BasicExpr_ASTNode { class FunctionCallExpr_ASTNode : public BasicExpr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; IdentifierType func_name; std::vector> arguments; @@ -311,6 +340,7 @@ class FunctionCallExpr_ASTNode : public BasicExpr_ASTNode { class FormattedStringExpr_ASTNode : public BasicExpr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::vector literals; std::vector> exprs; @@ -326,6 +356,7 @@ using AtomicConstantType = std::variant; class ConstantExpr_ASTNode : public BasicExpr_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; size_t level; std::variant>> value; diff --git a/include/ast/scope.hpp b/include/ast/scope.hpp index 4899078..5db8ab9 100644 --- a/include/ast/scope.hpp +++ b/include/ast/scope.hpp @@ -14,6 +14,7 @@ class ScopeBase { friend class ClassDefScope; friend class GlobalScope; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; protected: ScopeBase *parent; // cannot use std::shared_ptr because of circular dependency @@ -21,15 +22,17 @@ class ScopeBase { virtual bool VariableNameAvailable(const std::string &name, int ttl) = 0; virtual bool add_variable(const std::string &name, const ExprTypeInfo &type) = 0; virtual ExprTypeInfo fetch_varaible(const std::string &name) = 0; + virtual IRVariableInfo fetch_variable_for_IR(const std::string &name) = 0; static inline bool IsKeyWord(const std::string &name) { static const std::unordered_set keywords = {"void", "bool", "int", "string", "new", "class", "null", "true", "false", "this", "if", "else", "for", "while", "break", "continue", "return"}; return keywords.find(name) != keywords.end(); } - public: + + public: ScopeBase() { - static size_t scope_counter=0; + static size_t scope_counter = 0; scope_id = scope_counter++; } }; @@ -56,6 +59,17 @@ class LocalScope : public ScopeBase { } return parent->fetch_varaible(name); } + IRVariableInfo fetch_variable_for_IR(const std::string &name) override { + if (local_variables.find(name) != local_variables.end()) { + IRVariableInfo res; + res.variable_name_raw = name; + res.scope_id = scope_id; + res.variable_type = 1; + res.ty = Type_AST2LLVM(local_variables[name]); + return res; + } + return parent->fetch_variable_for_IR(name); + } bool VariableNameAvailable(const std::string &name, int ttl) override { if (ttl == 0 && IsKeyWord(name)) { return false; @@ -90,6 +104,19 @@ class FunctionScope : public ScopeBase { } return parent->fetch_varaible(name); } + IRVariableInfo fetch_variable_for_IR(const std::string &name) override { + for (const auto &arg : schema.arguments) { + if (arg.second == name) { + IRVariableInfo res; + res.variable_name_raw = name; + res.scope_id = scope_id; + res.variable_type = 3; + res.ty = Type_AST2LLVM(arg.first); + return res; + } + } + return parent->fetch_variable_for_IR(name); + } bool VariableNameAvailable(const std::string &name, int ttl) override { if (ttl == 0 && IsKeyWord(name)) { return false; @@ -131,6 +158,17 @@ class ClassDefScope : public ScopeBase { } return parent->fetch_varaible(name); } + IRVariableInfo fetch_variable_for_IR(const std::string &name) override { + if (member_variables.find(name) != member_variables.end()) { + IRVariableInfo res; + res.variable_name_raw = name; + res.scope_id = scope_id; + res.variable_type = 2; + res.ty = Type_AST2LLVM(member_variables[name]); + return res; + } + return parent->fetch_variable_for_IR(name); + } bool add_function(const std::string &name, std::shared_ptr ptr) { if (IsKeyWord(name)) return false; if (member_variables.find(name) != member_variables.end()) { @@ -162,6 +200,7 @@ class ClassDefScope : public ScopeBase { class GlobalScope : public ScopeBase { friend class Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; friend std::shared_ptr CheckAndDecorate(std::shared_ptr src); std::unordered_map global_variables; std::unordered_map> global_functions; @@ -257,6 +296,17 @@ class GlobalScope : public ScopeBase { } throw SemanticError("Undefined Identifier", 1); } + IRVariableInfo fetch_variable_for_IR(const std::string &name) override { + if (global_variables.find(name) != global_variables.end()) { + IRVariableInfo res; + res.variable_name_raw = name; + res.scope_id = scope_id; + res.variable_type = 0; + res.ty = Type_AST2LLVM(global_variables[name]); + return res; + } + return parent->fetch_variable_for_IR(name); + } public: GlobalScope() { parent = nullptr; } diff --git a/include/ast/statement_astnode.h b/include/ast/statement_astnode.h index 74325d8..d273ba3 100644 --- a/include/ast/statement_astnode.h +++ b/include/ast/statement_astnode.h @@ -11,6 +11,7 @@ class Statement_ASTNode : public ASTNodeBase { class EmptyStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; public: EmptyStatement_ASTNode() = default; @@ -19,6 +20,7 @@ class EmptyStatement_ASTNode : public Statement_ASTNode { class DefinitionStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; ExprTypeInfo var_type; std::vector>> vars; @@ -29,6 +31,7 @@ class DefinitionStatement_ASTNode : public Statement_ASTNode { class ExprStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr expr; public: @@ -38,6 +41,7 @@ class ExprStatement_ASTNode : public Statement_ASTNode { class IfStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; bool has_else_clause; std::shared_ptr condition; std::shared_ptr if_clause; @@ -50,6 +54,7 @@ class IfStatement_ASTNode : public Statement_ASTNode { class WhileStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr condition; std::shared_ptr loop_body; @@ -60,6 +65,7 @@ class WhileStatement_ASTNode : public Statement_ASTNode { class ForStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::shared_ptr initial; std::shared_ptr condition; std::shared_ptr update; @@ -72,6 +78,7 @@ class ForStatement_ASTNode : public Statement_ASTNode { class JmpStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; uint8_t jmp_type; // 0: return, 1: break, 2: continue std::shared_ptr return_value; @@ -82,6 +89,7 @@ class JmpStatement_ASTNode : public Statement_ASTNode { class SuiteStatement_ASTNode : public Statement_ASTNode { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::vector> statements; public: diff --git a/include/ast/structural_astnode.h b/include/ast/structural_astnode.h index 51a7908..35d250e 100644 --- a/include/ast/structural_astnode.h +++ b/include/ast/structural_astnode.h @@ -10,6 +10,7 @@ class FuncDef_ASTNode : public ASTNodeBase { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; bool is_constructor; IdentifierType func_name; ExprTypeInfo return_type; @@ -23,6 +24,7 @@ class FuncDef_ASTNode : public ASTNodeBase { class ClassDef_ASTNode : public ASTNodeBase { friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; private: std::string class_name; @@ -36,8 +38,10 @@ class ClassDef_ASTNode : public ASTNodeBase { virtual void accept(class ASTNodeVisitorBase *visitor) override; }; class Program_ASTNode : public ASTNodeBase { + friend std::shared_ptr BuildIR(std::shared_ptr src); friend Visitor; friend class ASTSemanticCheckVisitor; + friend class IRBuilder; std::vector> global_variables; std::vector> classes; std::vector> functions; diff --git a/include/tools.h b/include/tools.h index 4491ae2..8ee8d4b 100644 --- a/include/tools.h +++ b/include/tools.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -101,10 +102,21 @@ inline bool operator==(const ExprTypeInfo &l, const ExprTypeInfo &r) { throw std::runtime_error("something strange happened"); } inline bool operator!=(const ExprTypeInfo &l, const ExprTypeInfo &r) { return !(l == r); } - +class LLVMIRIntType { + public: + size_t bits; + LLVMIRIntType() = default; + LLVMIRIntType(size_t bits) : bits(bits) {} +}; +struct LLVMIRPTRType {}; +struct LLVMVOIDType {}; +struct LLVMIRCLASSTYPE { + std::string class_name_full; +}; +using LLVMType = std::variant; class IRClassInfo { public: - std::string class_name; // This data must be provided by user + std::string class_name_raw; // This data must be provided by user std::vector member_var_size; // This data must be provided by user. Each of them is the size of a member // variable, which must be in [1,4] std::unordered_map member_var_offset; // This data must be provided by user @@ -128,10 +140,36 @@ class IRClassInfo { class_size_after_align = cur_pos; } } + std::string GenerateFullName() { return "%.class." + class_name_raw; } }; class IRVariableInfo { public: enum class VariableType { global_variable, local_variable, member_variable }; std::string class_name; - std::string variable_name; -}; \ No newline at end of file + std::string variable_name_raw; + size_t scope_id; + uint8_t variable_type; // 0: global, 1: local, 2: member, 3: argument + LLVMType ty; + std::string GenerateFullName() { + if (variable_type == 2) { + throw std::runtime_error("Member variable should not be used in this function"); + } else if (variable_type == 0) { + return "@.var.global." + variable_name_raw + ".addrkp"; + } else if (variable_type == 1) { + return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".addrkp"; + } else if (variable_type == 3) { + return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".val"; + } else { + throw std::runtime_error("Invalid scope id"); + } + } +}; + +inline LLVMType Type_AST2LLVM(const ExprTypeInfo &src) { + if (std::holds_alternative(src)) return LLVMIRPTRType(); + std::string tname = std::get(src); + if (tname == "bool") return LLVMIRIntType(1); + if (tname == "int") return LLVMIRIntType(32); + if (tname == "void") return LLVMVOIDType(); + return LLVMIRPTRType(); +} \ No newline at end of file diff --git a/src/IR/IRBuilder.cpp b/src/IR/IRBuilder.cpp index 9f37587..0853c3d 100644 --- a/src/IR/IRBuilder.cpp +++ b/src/IR/IRBuilder.cpp @@ -1,51 +1,310 @@ #include "IRBuilder.h" #include +#include "IR.h" +#include "IR_basic.h" +#include "tools.h" // Structural AST Nodes void IRBuilder::ActuralVisit(FuncDef_ASTNode *node) { - // TODO: Implement function body + auto declare = std::make_shared(); + auto func_def = std::make_shared(); + // prog->function_declares.push_back(declare); + prog->function_defs.push_back(func_def); + if (is_in_class_def) { + declare->func_name_raw = cur_class_name + "." + node->func_name; + declare->args.push_back(LLVMIRPTRType()); // The `self` pointer + } else { + declare->func_name_raw = node->func_name; + } + declare->return_type = Type_AST2LLVM(node->return_type); + for (auto &arg : node->params) { + declare->args.push_back(Type_AST2LLVM(arg.second)); + } + func_def->func_name_raw = declare->func_name_raw; + func_def->return_type = declare->return_type; + func_def->args = declare->args; + size_t scope_id = node->current_scope->scope_id; + if (is_in_class_def) { + std::string arg_name_raw = "this"; + std::string full_name = "%.var.local." + std::to_string(scope_id) + "." + arg_name_raw + ".val"; + func_def->args_full_name.push_back(full_name); + } + for (auto &arg : node->params) { + std::string &arg_name_raw = arg.first; + std::string full_name = "%.var.local." + std::to_string(scope_id) + "." + arg_name_raw + ".val"; + func_def->args_full_name.push_back(full_name); + } + if (func_def->args.size() != func_def->args_full_name.size()) throw std::runtime_error("args size not match"); + + cur_func = func_def; + auto current_block = std::make_shared(); + cur_block = current_block; + cur_func->basic_blocks.push_back(current_block); + size_t block_id = block_counter++; + current_block->label_full = "label_" + std::to_string(block_id); + + is_in_func_def = true; + node->func_body->accept(this); + is_in_func_def = false; + + if (!(cur_block->exit_action)) { + auto default_ret = std::make_shared(); + cur_block->exit_action = default_ret; + default_ret->type = cur_func->return_type; + if (std::holds_alternative(default_ret->type)) { + default_ret->value = "0"; + } else { + default_ret->value = "null"; + } + } } void IRBuilder::ActuralVisit(ClassDef_ASTNode *node) { - // TODO: Implement function body + is_in_class_def = true; + cur_class_name = node->class_name; + auto tpdf = std::make_shared(); + tpdf->class_name_raw = node->class_name; + cur_class = tpdf; + prog->type_defs.push_back(tpdf); + for (auto ch : node->sorted_children) { + ch->accept(this); + } + is_in_class_def = false; } void IRBuilder::ActuralVisit(Program_ASTNode *node) { - // TODO: Implement function body - throw std::runtime_error("IRBuilder not implemented"); + for (auto ch : node->sorted_children) { + ch->accept(this); + } } // Statement AST Nodes void IRBuilder::ActuralVisit(EmptyStatement_ASTNode *node) { - // TODO: Implement function body + // do nothing } void IRBuilder::ActuralVisit(DefinitionStatement_ASTNode *node) { - // TODO: Implement function body + if (is_in_class_def) { + cur_class->elements.push_back(Type_AST2LLVM(node->var_type)); + } else if (!is_in_func_def) { + for (const auto &var : node->vars) { + auto var_def = std::make_shared(); + prog->global_var_defs.push_back(var_def); + var_def->type = Type_AST2LLVM(node->var_type); + var_def->name_raw = var.first; + // TODO: initial value + } + } else { + for (const auto &var : node->vars) { + auto var_def = std::make_shared(); + cur_block->actions.push_back(var_def); + var_def->num = 1; + var_def->type = Type_AST2LLVM(node->var_type); + var_def->name_full = "%.var.local." + std::to_string(node->current_scope->scope_id) + "." + var.first + ".addrkp"; + if (var.second) { + var.second->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->ty = var_def->type; + act->ptr_full = var_def->name_full; + act->value_full = var.second->IR_result_full; + } + } + } + just_encountered_jmp = false; } void IRBuilder::ActuralVisit(ExprStatement_ASTNode *node) { - // TODO: Implement function body + // just visit it + node->expr->accept(this); + just_encountered_jmp = false; } void IRBuilder::ActuralVisit(IfStatement_ASTNode *node) { - // TODO: Implement function body + node->condition->accept(this); + cur_block->exit_action = std::make_shared(); + if (node->else_clause) { + size_t if_block_start_id = block_counter++; + size_t else_block_start_id = block_counter++; + size_t following_block_id = block_counter++; + auto first_if_block = std::make_shared(); + first_if_block->label_full = "label_" + std::to_string(if_block_start_id); + auto first_else_block = std::make_shared(); + first_else_block->label_full = "label_" + std::to_string(else_block_start_id); + auto following_block = std::make_shared(); + following_block->label_full = "label_" + std::to_string(following_block_id); + std::dynamic_pointer_cast(cur_block->exit_action)->cond = node->condition->IR_result_full; + std::dynamic_pointer_cast(cur_block->exit_action)->true_label_full = first_if_block->label_full; + std::dynamic_pointer_cast(cur_block->exit_action)->false_label_full = first_else_block->label_full; + cur_block = first_if_block; + cur_func->basic_blocks.push_back(first_if_block); + node->if_clause->accept(this); + if (!(cur_block->exit_action)) { + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = following_block->label_full; + } + cur_block = first_else_block; + cur_func->basic_blocks.push_back(first_else_block); + node->else_clause->accept(this); + if (!(cur_block->exit_action)) { + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = following_block->label_full; + } + cur_block = following_block; + cur_func->basic_blocks.push_back(following_block); + } else { + size_t if_block_start_id = block_counter++; + size_t following_block_id = block_counter++; + auto first_if_block = std::make_shared(); + first_if_block->label_full = "label_" + std::to_string(if_block_start_id); + auto following_block = std::make_shared(); + following_block->label_full = "label_" + std::to_string(following_block_id); + std::dynamic_pointer_cast(cur_block->exit_action)->cond = node->condition->IR_result_full; + std::dynamic_pointer_cast(cur_block->exit_action)->true_label_full = first_if_block->label_full; + std::dynamic_pointer_cast(cur_block->exit_action)->false_label_full = following_block->label_full; + cur_block = first_if_block; + cur_func->basic_blocks.push_back(first_if_block); + node->if_clause->accept(this); + if (!(cur_block->exit_action)) { + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = following_block->label_full; + } + cur_block = following_block; + cur_func->basic_blocks.push_back(following_block); + } + just_encountered_jmp = false; } void IRBuilder::ActuralVisit(WhileStatement_ASTNode *node) { - // TODO: Implement function body + std::string break_target_backup = cur_break_target; + std::string continue_target_backup = cur_continue_target; + + size_t checker_block_id = block_counter++; + size_t first_loop_body_block_id = block_counter++; + size_t following_block_id = block_counter++; + auto checker_block = std::make_shared(); + auto first_loop_body_block = std::make_shared(); + auto following_block = std::make_shared(); + checker_block->label_full = "label_" + std::to_string(checker_block_id); + first_loop_body_block->label_full = "label_" + std::to_string(first_loop_body_block_id); + following_block->label_full = "label_" + std::to_string(following_block_id); + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = checker_block->label_full; + cur_block = checker_block; + cur_func->basic_blocks.push_back(checker_block); + node->condition->accept(this); + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->cond = node->condition->IR_result_full; + std::dynamic_pointer_cast(cur_block->exit_action)->true_label_full = first_loop_body_block->label_full; + std::dynamic_pointer_cast(cur_block->exit_action)->false_label_full = following_block->label_full; + + cur_block = first_loop_body_block; + cur_func->basic_blocks.push_back(first_loop_body_block); + cur_break_target = following_block->label_full; + cur_continue_target = checker_block->label_full; + node->loop_body->accept(this); + if (!(cur_block->exit_action)) { + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = checker_block->label_full; + } + + cur_block = following_block; + cur_func->basic_blocks.push_back(following_block); + + cur_break_target = break_target_backup; + cur_continue_target = continue_target_backup; + just_encountered_jmp = false; } void IRBuilder::ActuralVisit(ForStatement_ASTNode *node) { - // TODO: Implement function body + std::string break_target_backup = cur_break_target; + std::string continue_target_backup = cur_continue_target; + size_t checker_block_id = block_counter++; + size_t first_loop_body_block_id = block_counter++; + size_t step_block_id = block_counter++; + size_t following_block_id = block_counter++; + auto checker_block = std::make_shared(); + auto first_loop_body_block = std::make_shared(); + auto step_block = std::make_shared(); + auto following_block = std::make_shared(); + checker_block->label_full = "label_" + std::to_string(checker_block_id); + first_loop_body_block->label_full = "label_" + std::to_string(first_loop_body_block_id); + step_block->label_full = "label_" + std::to_string(step_block_id); + following_block->label_full = "label_" + std::to_string(following_block_id); + if (node->initial) { + node->initial->accept(this); // just finish initialization worker in current block + } + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = checker_block->label_full; + cur_block = checker_block; + cur_func->basic_blocks.push_back(checker_block); + if (node->condition) { + node->condition->accept(this); + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->cond = node->condition->IR_result_full; + std::dynamic_pointer_cast(cur_block->exit_action)->true_label_full = first_loop_body_block->label_full; + std::dynamic_pointer_cast(cur_block->exit_action)->false_label_full = following_block->label_full; + } else { + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = + first_loop_body_block->label_full; + } + cur_block = first_loop_body_block; + cur_func->basic_blocks.push_back(first_loop_body_block); + cur_break_target = following_block->label_full; + cur_continue_target = step_block->label_full; + node->loop_body->accept(this); + if (!(cur_block->exit_action)) { + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = step_block->label_full; + } + cur_block = step_block; + cur_func->basic_blocks.push_back(step_block); + if (node->update) { + node->update->accept(this); + } + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = checker_block->label_full; + cur_block = following_block; + cur_func->basic_blocks.push_back(following_block); + cur_break_target = break_target_backup; + cur_continue_target = continue_target_backup; + just_encountered_jmp = false; } void IRBuilder::ActuralVisit(JmpStatement_ASTNode *node) { - // TODO: Implement function body + if (node->jmp_type == 1) { + // break + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = cur_break_target; + just_encountered_jmp = true; + } else if (node->jmp_type == 2) { + // continue + cur_block->exit_action = std::make_shared(); + std::dynamic_pointer_cast(cur_block->exit_action)->label_full = cur_continue_target; + just_encountered_jmp = true; + } else if (node->jmp_type == 0) { + // return + cur_block->exit_action = std::make_shared(); + if (node->return_value) { + node->return_value->accept(this); + std::dynamic_pointer_cast(cur_block->exit_action)->value = node->return_value->IR_result_full; + } + std::dynamic_pointer_cast(cur_block->exit_action)->type = cur_func->return_type; + just_encountered_jmp = true; + } else { + throw std::runtime_error("unknown jmp type"); + } } void IRBuilder::ActuralVisit(SuiteStatement_ASTNode *node) { - // TODO: Implement function body + for (auto &stmt : node->statements) { + stmt->accept(this); + if (just_encountered_jmp) { + just_encountered_jmp = false; + break; // no need to continue + } + } } // Expression AST Nodes @@ -70,71 +329,286 @@ void IRBuilder::ActuralVisit(IndexExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(SuffixExpr_ASTNode *node) { - // TODO: Implement function body + node->base->is_requiring_lvalue = true; + node->base->accept(this); + std::string val_backup = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto backup_act = std::make_shared(); + cur_block->actions.push_back(backup_act); + backup_act->ty = Type_AST2LLVM(node->base->expr_type_info); + backup_act->ptr_full = node->base->IR_result_full; + backup_act->result_full = val_backup; + auto op_act = std::make_shared(); + cur_block->actions.push_back(op_act); + if (node->op == "++") { + op_act->op = "add"; + } else { + op_act->op = "sub"; + } + op_act->type = LLVMIRIntType(32); + op_act->operand1_full = val_backup; + op_act->operand2_full = "1"; + op_act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto store_act = std::make_shared(); + cur_block->actions.push_back(store_act); + store_act->ty = Type_AST2LLVM(node->base->expr_type_info); + store_act->ptr_full = node->base->IR_result_full; + store_act->value_full = op_act->result_full; + node->IR_result_full = val_backup; } void IRBuilder::ActuralVisit(PrefixExpr_ASTNode *node) { - // TODO: Implement function body + node->base->is_requiring_lvalue = true; + node->base->accept(this); + std::string val_backup = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto backup_act = std::make_shared(); + cur_block->actions.push_back(backup_act); + backup_act->ty = Type_AST2LLVM(node->base->expr_type_info); + backup_act->ptr_full = node->base->IR_result_full; + backup_act->result_full = val_backup; + auto op_act = std::make_shared(); + cur_block->actions.push_back(op_act); + if (node->op == "++") { + op_act->op = "add"; + } else { + op_act->op = "sub"; + } + op_act->type = LLVMIRIntType(32); + op_act->operand1_full = val_backup; + op_act->operand2_full = "1"; + op_act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto store_act = std::make_shared(); + cur_block->actions.push_back(store_act); + store_act->ty = Type_AST2LLVM(node->base->expr_type_info); + store_act->ptr_full = node->base->IR_result_full; + store_act->value_full = op_act->result_full; + if (node->is_requiring_lvalue) { + node->IR_result_full = node->base->IR_result_full; + } else { + node->IR_result_full = op_act->result_full; + } } void IRBuilder::ActuralVisit(OppositeExpr_ASTNode *node) { - // TODO: Implement function body + node->base->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "sub"; + act->operand1_full = "0"; + act->operand2_full = node->base->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(LNotExpr_ASTNode *node) { - // TODO: Implement function body + node->base->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "xor"; + act->operand1_full = node->base->IR_result_full; + act->operand2_full = "1"; + act->type = LLVMIRIntType(1); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(BNotExpr_ASTNode *node) { - // TODO: Implement function body + node->base->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "xor"; + act->operand1_full = node->base->IR_result_full; + act->operand2_full = "-1"; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(MDMExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + if (node->op == "*") { + act->op = "mul"; + } else if (node->op == "/") { + act->op = "sdiv"; + } else if (node->op == "%") { + act->op = "srem"; + } else { + throw std::runtime_error("unknown MDM operator"); + } + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(PMExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + if (node->op == "+") { + act->op = "add"; + } else if (node->op == "-") { + act->op = "sub"; + } else { + throw std::runtime_error("unknown PM operator"); + } + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; + // TODO: support for string concatenation } void IRBuilder::ActuralVisit(RLExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + if (node->op == ">>") { + act->op = "ashr"; + } else if (node->op == "<<") { + act->op = "shl"; + } else { + throw std::runtime_error("unknown RL operator"); + } + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(GGLLExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + if (node->op == ">=") { + act->op = "sge"; + } else if (node->op == ">") { + act->op = "sgt"; + } else if (node->op == "<=") { + act->op = "sle"; + } else if (node->op == "<") { + act->op = "slt"; + } else { + throw std::runtime_error("unknown GGLL operator"); + } + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = Type_AST2LLVM(node->left->expr_type_info); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; + // TODO: string comparison } void IRBuilder::ActuralVisit(NEExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + if (node->op == "!=") { + act->op = "ne"; + } else if (node->op == "==") { + act->op = "eq"; + } else { + throw std::runtime_error("unknown NE operator"); + } + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = Type_AST2LLVM(node->left->expr_type_info); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; + // TODO: string comparison } void IRBuilder::ActuralVisit(BAndExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "and"; + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(BXorExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "xor"; + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(BOrExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "or"; + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(32); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; } void IRBuilder::ActuralVisit(LAndExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "and"; + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(1); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; + // TODO: short-circuit } void IRBuilder::ActuralVisit(LOrExpr_ASTNode *node) { - // TODO: Implement function body + node->left->accept(this); + node->right->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->op = "or"; + act->operand1_full = node->left->IR_result_full; + act->operand2_full = node->right->IR_result_full; + act->type = LLVMIRIntType(1); + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; + // TODO: short-circuit } void IRBuilder::ActuralVisit(TernaryExpr_ASTNode *node) { // TODO: Implement function body + throw std::runtime_error("ternary operator not supported"); } void IRBuilder::ActuralVisit(AssignExpr_ASTNode *node) { - // TODO: Implement function body + node->dest->is_requiring_lvalue = true; + node->dest->accept(this); + node->src->accept(this); + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->ptr_full = node->dest->IR_result_full; + act->value_full = node->src->IR_result_full; + act->ty = Type_AST2LLVM(node->src->expr_type_info); } void IRBuilder::ActuralVisit(ThisExpr_ASTNode *node) { @@ -142,15 +616,62 @@ void IRBuilder::ActuralVisit(ThisExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(ParenExpr_ASTNode *node) { - // TODO: Implement function body + node->expr->accept(this); // just visit it } void IRBuilder::ActuralVisit(IDExpr_ASTNode *node) { - // TODO: Implement function body + IRVariableInfo var_info = node->current_scope->fetch_variable_for_IR(node->id); + if (var_info.variable_type == 0 || var_info.variable_type == 1) { + if (node->is_requiring_lvalue) { + node->IR_result_full = var_info.GenerateFullName(); + return; + } + auto act = std::make_shared(); + cur_block->actions.push_back(act); + act->ptr_full = var_info.GenerateFullName(); + act->ty = var_info.ty; + act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = act->result_full; + } else if (var_info.variable_type == 2) { + throw std::runtime_error("not support for class member access"); + } else if (var_info.variable_type == 3) { + if (node->is_requiring_lvalue) { + throw std::runtime_error("for argument, lvalue support need additional work"); + } + node->IR_result_full = var_info.GenerateFullName(); + return; + } else { + throw std::runtime_error("unknown variable type"); + } } void IRBuilder::ActuralVisit(FunctionCallExpr_ASTNode *node) { - // TODO: Implement function body + bool is_member_func = false; + if (is_in_class_def) { + try { + auto schema = global_scope->FetchClassMemberFunction(cur_class_name, node->func_name); + is_member_func = true; + } catch (...) { + } + } + if (is_member_func) { + // TODO: member function call + throw std::runtime_error("not support for class member function call"); + } else { + auto call = std::make_shared(); + call->return_type = Type_AST2LLVM(node->expr_type_info); + call->func_name_raw = node->func_name; + for (auto &arg : node->arguments) { + arg->accept(this); + call->args_val_full.push_back(arg->IR_result_full); + call->args_ty.push_back(Type_AST2LLVM(arg->expr_type_info)); + } + cur_block->actions.push_back(call); + if (!std::holds_alternative(call->return_type)) { + call->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = call->result_full; + } + } } void IRBuilder::ActuralVisit(FormattedStringExpr_ASTNode *node) { @@ -158,10 +679,49 @@ void IRBuilder::ActuralVisit(FormattedStringExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(ConstantExpr_ASTNode *node) { - // TODO: Implement function body + if ((node->level == 0) != std::holds_alternative(node->value)) { + throw std::runtime_error("ConstantExpr_ASTNode level not match"); + } + if (std::holds_alternative(node->value)) { + auto &val = std::get(node->value); + if (std::holds_alternative(val)) { + node->IR_result_full = std::to_string(std::get(val)); + } else if (std::holds_alternative(val)) { + node->IR_result_full = std::to_string(int(std::get(val))); + } else if (std::holds_alternative(val)) { + // TODO: string constant + throw std::runtime_error("String constant not supported"); + } else if (std::holds_alternative(val)) { + node->IR_result_full = "null"; + } else { + throw std::runtime_error("unknown constant type"); + } + } else { + // TODO: array constant + throw std::runtime_error("Array constant not supported"); + } } std::shared_ptr BuildIR(std::shared_ptr src) { IRBuilder visitor; + visitor.prog = std::make_shared(); + auto tmp = std::make_shared(); + tmp->func_name_raw = "malloc"; + tmp->return_type = LLVMIRPTRType(); + tmp->args.push_back(LLVMIRIntType(32)); + visitor.prog->function_declares.push_back(tmp); + tmp = std::make_shared(); + tmp->func_name_raw = "printInt"; + tmp->return_type = LLVMVOIDType(); + tmp->args.push_back(LLVMIRIntType(32)); + visitor.prog->function_declares.push_back(tmp); + tmp = std::make_shared(); + tmp->func_name_raw = "print"; + tmp->return_type = LLVMVOIDType(); + tmp->args.push_back(LLVMIRPTRType()); + visitor.prog->function_declares.push_back(tmp); + visitor.global_scope = std::dynamic_pointer_cast(src->current_scope); + if (!(visitor.global_scope)) throw std::runtime_error("global scope not found"); visitor.visit(src.get()); + return visitor.prog; } \ No newline at end of file diff --git a/src/IR/build.sh b/src/IR/build.sh index e02155f..65b84b8 100755 --- a/src/IR/build.sh +++ b/src/IR/build.sh @@ -7,4 +7,4 @@ clang-18 -S -emit-llvm --target=riscv32-unknown-elf -O2 -fno-builtin-printf -fno builtin.c -o builtin_intermediate.ll sed 's/_builtin_/.builtin./g;s/string_/string./g;s/array_/array./g' builtin_intermediate.ll > builtin.ll rm builtin_intermediate.ll -llc-18 -march=riscv32 builtin.ll -o builtin.s -O2 \ No newline at end of file +llc-18 -march=riscv32 -mattr=+m builtin.ll -o builtin.s -O2 \ No newline at end of file diff --git a/src/IR/builtin.s b/src/IR/builtin.s index 2324839..7a27c65 100644 --- a/src/IR/builtin.s +++ b/src/IR/builtin.s @@ -3,7 +3,7 @@ .attribute 5, "rv32i2p1_m2p0_a2p1_c2p0" .file "builtin.c" .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl string.length # -- Begin function string.length .p2align 1 .type string.length,@function @@ -23,7 +23,7 @@ string.length: # @string.length # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl string.substring # -- Begin function string.substring .p2align 1 .type string.substring,@function @@ -63,7 +63,7 @@ string.substring: # @string.substring # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl string.parseInt # -- Begin function string.parseInt .p2align 1 .type string.parseInt,@function @@ -84,7 +84,7 @@ string.parseInt: # @string.parseInt # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl string.ord # -- Begin function string.ord .p2align 1 .type string.ord,@function @@ -98,7 +98,7 @@ string.ord: # @string.ord # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl print # -- Begin function print .p2align 1 .type print,@function @@ -115,7 +115,7 @@ print: # @print # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl println # -- Begin function println .p2align 1 .type println,@function @@ -132,7 +132,7 @@ println: # @println # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl printInt # -- Begin function printInt .p2align 1 .type printInt,@function @@ -149,7 +149,7 @@ printInt: # @printInt # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl printlnInt # -- Begin function printlnInt .p2align 1 .type printlnInt,@function @@ -166,7 +166,7 @@ printlnInt: # @printlnInt # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl toString # -- Begin function toString .p2align 1 .type toString,@function @@ -196,7 +196,7 @@ toString: # @toString # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl getString # -- Begin function getString .p2align 1 .type getString,@function @@ -300,7 +300,7 @@ getString: # @getString # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl getInt # -- Begin function getInt .p2align 1 .type getInt,@function @@ -321,7 +321,7 @@ getInt: # @getInt # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl .builtin.AllocateClassBody # -- Begin function .builtin.AllocateClassBody .p2align 1 .type .builtin.AllocateClassBody,@function @@ -333,7 +333,7 @@ getInt: # @getInt # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl .builtin.GetArrayLength # -- Begin function .builtin.GetArrayLength .p2align 1 .type .builtin.GetArrayLength,@function @@ -355,7 +355,7 @@ getInt: # @getInt # -- End function .option pop .option push - .option arch, +a, +c, +m + .option arch, +a, +c .globl .builtin.RecursiveAllocateArray # -- Begin function .builtin.RecursiveAllocateArray .p2align 1 .type .builtin.RecursiveAllocateArray,@function diff --git a/src/main.cpp b/src/main.cpp index 4cd13d6..69bf237 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -26,6 +26,7 @@ int main(int argc, char **argv) { try { SemanticCheck(fin, ast); auto IR = BuildIR(ast); + IR->RecursivePrint(std::cout); } catch (const SemanticError &err) { std::cout << err.what() << std::endl; return err.GetErrorCode(); diff --git a/src/semantic/visitor.cpp b/src/semantic/visitor.cpp index 8bffd84..a727261 100644 --- a/src/semantic/visitor.cpp +++ b/src/semantic/visitor.cpp @@ -249,7 +249,7 @@ std::any Visitor::visitClass_var_def(MXParser::Class_var_defContext *context) { std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *context) { auto construct_func = std::make_shared(); construct_func->type = ASTNodeType::Constructor; - construct_func->is_constructor = false; + construct_func->is_constructor = true; construct_func->start_line = context->getStart()->getLine(); construct_func->start_char_pos = context->getStart()->getCharPositionInLine(); construct_func->end_line = context->getStop()->getLine(); @@ -259,6 +259,7 @@ std::any Visitor::visitClass_constructor(MXParser::Class_constructorContext *con construct_func->current_scope = cur_scope; construct_func->func_name = context->ID()->getText(); cur_scope->schema.return_type = "void"; + construct_func->return_type = "void"; nodetype_stk.push_back({ASTNodeType::Constructor, construct_func->current_scope}); construct_func->func_body = std::dynamic_pointer_cast( @@ -476,8 +477,7 @@ std::any Visitor::visitDefine_statement(MXParser::Define_statementContext *conte assert(nodetype_stk.size() > 0); def_stmt->current_scope = nodetype_stk.back().second; if (nodetype_stk.size() > 0 && (nodetype_stk.back().first == ASTNodeType::IfStatement || - nodetype_stk.back().first == ASTNodeType::WhileStatement || - nodetype_stk.back().first == ASTNodeType::ForStatement)) { + nodetype_stk.back().first == ASTNodeType::WhileStatement)) { def_stmt->current_scope = std::make_shared(); def_stmt->current_scope->parent = nodetype_stk.back().second.get(); }