write basic functions for LLVM IR

This commit is contained in:
2024-08-22 06:39:38 +00:00
parent ed1ba4b59a
commit 4f4113f16a
13 changed files with 1195 additions and 155 deletions

View File

@ -1,10 +1,31 @@
#pragma once
#include <memory>
#include "IR_basic.h"
#include "ast/astnode_visitor.h"
class IRBuilder : public ASTNodeVirturalVisitor {
friend std::shared_ptr<ModuleItem> BuildIR(std::shared_ptr<Program_ASTNode> src);
std::shared_ptr<ModuleItem> prog;
std::shared_ptr<TypeDefItem> cur_class;
std::shared_ptr<FunctionDefItem> cur_func;
std::shared_ptr<BlockItem> cur_block;
std::string cur_class_name;
bool is_in_class_def;
bool is_in_func_def;
size_t tmp_var_counter;
size_t block_counter;
std::string cur_break_target;
std::string cur_continue_target;
bool just_encountered_jmp;
std::shared_ptr<GlobalScope> global_scope;
public:
IRBuilder() {
tmp_var_counter = 0;
block_counter = 0;
is_in_class_def = false;
is_in_func_def = false;
just_encountered_jmp = false;
}
// Structural AST Nodes
void ActuralVisit(FuncDef_ASTNode *node) override;
void ActuralVisit(ClassDef_ASTNode *node) override;

View File

@ -1,132 +1,458 @@
#pragma once
#include <cstddef>
#include <ios>
#include <list>
#include <memory>
#include <stdexcept>
#include <string>
#include <variant>
#include <vector>
#include "ast/astnode.h"
struct LLVMIRIntType {
size_t bits;
};
struct LLVMIRPTRType {};
struct LLVMIRCLASSTYPE {};
using LLVMType = std::variant<LLVMIRIntType, LLVMIRPTRType, LLVMIRCLASSTYPE>;
#include "tools.h"
class LLVMIRItemBase {
public:
LLVMIRItemBase() = default;
virtual ~LLVMIRItemBase() = default;
virtual void RecursivePrint(std::ostream &os) const;
virtual void RecursivePrint(std::ostream &os) const = 0;
};
class TypeDefItem : public LLVMIRItemBase {
friend class IRBuilder;
std::string class_name_raw;
std::vector<LLVMType> elements;
public:
void RecursivePrint(std::ostream &os) const { ; }
TypeDefItem() = default;
void RecursivePrint(std::ostream &os) const {
os << "%.class." << class_name_raw;
os << " = type {";
for (size_t i = 0; i < elements.size(); i++) {
if (std::holds_alternative<LLVMIRIntType>(elements[i])) {
os << "i" << std::get<LLVMIRIntType>(elements[i]).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(elements[i])) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(elements[i])) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(elements[i])) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
if (i != elements.size() - 1) {
os << ",";
}
}
os << "}\n";
}
};
class GlobalVarDefItem : public LLVMIRItemBase {
friend class IRBuilder;
LLVMType type;
std::string name;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class ActionItem : public LLVMIRItemBase {
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class JMPActionItem : public ActionItem {
std::string label;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class BRAction: public JMPActionItem {
std::string cond;
std::string true_label;
std::string false_label;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class UNConditionJMPAction: public JMPActionItem {
std::string label;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class RETAction : public JMPActionItem {
std::string value;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class BinaryOperationAction : public ActionItem {
std::string op;
std::string lhs;
std::string rhs;
std::string result;
LLVMType type;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class AllocaAction : public ActionItem {
std::string name;
LLVMType type;
size_t num;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class LoadAction : public ActionItem {
std::string result;
LLVMType ty;
std::string ptr;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class StoreAction : public ActionItem {
LLVMType ty;
std::string value;
std::string ptr;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class GetElementPtrAction : public ActionItem {
std::string result;
LLVMType ty;
std::string ptr;
std::vector<std::string> indices;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class ICMPAction : public ActionItem {
std::string op;
std::string lhs;
std::string rhs;
std::string result;
LLVMType ty;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class BlockItem : public LLVMIRItemBase {
std::string label;
std::vector<std::shared_ptr<ActionItem>> actions;
std::shared_ptr<JMPActionItem> exit_action;
public:
void RecursivePrint(std::ostream &os) const { ; }
};
class FunctionDefItem : public LLVMIRItemBase {
std::vector<std::shared_ptr<BlockItem>> basic_blocks;
std::string name_raw;
public:
GlobalVarDefItem() = default;
void RecursivePrint(std::ostream &os) const {
for (auto &item : basic_blocks) {
item->RecursivePrint(os);
os << '\n';
std::string name_full = "@.var.global." + name_raw + ".addrkp";
os << name_full << " = global ";
if (std::holds_alternative<LLVMIRIntType>(type)) {
os << "i" << std::get<LLVMIRIntType>(type).bits << " 0\n";
} else if (std::holds_alternative<LLVMIRPTRType>(type)) {
os << "ptr null\n";
} else {
throw std::runtime_error("something strange happened");
}
}
};
class ActionItem : public LLVMIRItemBase {};
class JMPActionItem : public ActionItem {};
class BRAction : public JMPActionItem {
friend class IRBuilder;
std::string cond;
std::string true_label_full;
std::string false_label_full;
public:
BRAction() = default;
void RecursivePrint(std::ostream &os) const {
os << "br i1 " << cond << ", label %" << true_label_full << ", label %" << false_label_full << "\n";
}
};
class UNConditionJMPAction : public JMPActionItem {
friend class IRBuilder;
std::string label_full;
public:
UNConditionJMPAction() = default;
void RecursivePrint(std::ostream &os) const { os << "br label %" << label_full << "\n"; }
};
class RETAction : public JMPActionItem {
friend class IRBuilder;
LLVMType type;
std::string value;
public:
RETAction() = default;
void RecursivePrint(std::ostream &os) const {
if (std::holds_alternative<LLVMVOIDType>(type)) {
os << "ret void\n";
} else if (std::holds_alternative<LLVMIRIntType>(type)) {
os << "ret i" << std::get<LLVMIRIntType>(type).bits << " " << value << "\n";
} else if (std::holds_alternative<LLVMIRPTRType>(type)) {
os << "ret ptr " << value << "\n";
} else {
throw std::runtime_error("something strange happened");
}
}
};
class BinaryOperationAction : public ActionItem {
friend class IRBuilder;
std::string op;
std::string operand1_full;
std::string operand2_full;
std::string result_full;
LLVMType type;
public:
BinaryOperationAction() = default;
void RecursivePrint(std::ostream &os) const {
os << result_full << " = " << op << " ";
if (std::holds_alternative<LLVMIRIntType>(type)) {
os << "i" << std::get<LLVMIRIntType>(type).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(type)) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(type)) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(type)) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
os << " " << operand1_full << ", " << operand2_full << "\n";
}
};
class AllocaAction : public ActionItem {
friend class IRBuilder;
std::string name_full;
LLVMType type;
size_t num;
public:
AllocaAction() : num(1){};
void RecursivePrint(std::ostream &os) const {
os << name_full << " = alloca ";
if (std::holds_alternative<LLVMIRIntType>(type)) {
os << "i" << std::get<LLVMIRIntType>(type).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(type)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
if (num > 1) {
os << ", i32 " << num;
}
os << "\n";
}
};
class LoadAction : public ActionItem {
friend class IRBuilder;
std::string result_full;
LLVMType ty;
std::string ptr_full;
public:
LoadAction() = default;
void RecursivePrint(std::ostream &os) const {
os << result_full << " = load ";
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
os << ", ptr " << ptr_full << '\n';
}
};
class StoreAction : public ActionItem {
friend class IRBuilder;
LLVMType ty;
std::string value_full;
std::string ptr_full;
public:
StoreAction() = default;
void RecursivePrint(std::ostream &os) const {
os << "store ";
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
os << ' ' << value_full << ", ptr " << ptr_full << '\n';
}
};
class GetElementPtrAction : public ActionItem {
std::string result_full;
LLVMType ty;
std::string ptr_full;
std::vector<std::string> indices;
public:
GetElementPtrAction() = default;
void RecursivePrint(std::ostream &os) const {
os << result_full << " = getelementptr ";
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(ty)) {
os << std::get<LLVMIRCLASSTYPE>(ty).class_name_full;
} else {
throw std::runtime_error("something strange happened");
}
os << ", ptr " << ptr_full;
for (auto &index : indices) {
os << ", i32 " << index;
}
os << '\n';
}
};
class ICMPAction : public ActionItem {
friend class IRBuilder;
std::string op;
std::string operand1_full;
std::string operand2_full;
std::string result_full;
LLVMType type;
public:
ICMPAction() = default;
void RecursivePrint(std::ostream &os) const {
os << result_full << " = icmp " << op << " ";
if (std::holds_alternative<LLVMIRIntType>(type)) {
os << "i" << std::get<LLVMIRIntType>(type).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(type)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
os << ' ' << operand1_full << ", " << operand2_full << '\n';
}
};
class BlockItem : public LLVMIRItemBase {
friend class IRBuilder;
std::string label_full;
std::vector<std::shared_ptr<ActionItem>> actions;
std::shared_ptr<JMPActionItem> exit_action;
public:
BlockItem() = default;
void RecursivePrint(std::ostream &os) const {
os << label_full << ":\n";
for (auto &action : actions) {
action->RecursivePrint(os);
}
if (exit_action) exit_action->RecursivePrint(os);
}
};
class CallItem : public ActionItem {
friend class IRBuilder;
std::string result_full;
LLVMType return_type;
std::string func_name_raw;
std::vector<LLVMType> args_ty;
std::vector<std::string> args_val_full;
public:
CallItem() = default;
void RecursivePrint(std::ostream &os) const {
if (std::holds_alternative<LLVMVOIDType>(return_type)) {
os << "call ";
} else {
os << result_full << " = call ";
}
if (std::holds_alternative<LLVMIRIntType>(return_type)) {
os << "i" << std::get<LLVMIRIntType>(return_type).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(return_type)) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(return_type)) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(return_type)) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
os << " @" << func_name_raw << "(";
for (size_t i = 0; i < args_val_full.size(); i++) {
auto &ty = args_ty[i];
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(ty)) {
throw std::runtime_error("void type is not allowed in function call");
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(ty)) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
} else {
throw std::runtime_error("something strange happened");
}
os << ' ' << args_val_full[i];
if (i != args_val_full.size() - 1) {
os << ", ";
}
}
os << ")\n";
}
};
class PhiItem : public ActionItem {
std::string result_full;
LLVMType ty;
std::vector<std::pair<std::string, std::string>> values; // (val_i_full, label_i_full)
public:
PhiItem() = default;
void RecursivePrint(std::ostream &os) const {
os << result_full << " = phi ";
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
os << " ";
for (size_t i = 0; i < values.size(); i++) {
os << " [" << values[i].first << ", " << values[i].second << "]";
if (i != values.size() - 1) {
os << ", ";
}
}
os << "\n";
}
};
class SelectItem : public ActionItem {
std::string result_full;
std::string cond_full;
std::string true_val_full;
std::string false_val_full;
LLVMType ty;
public:
SelectItem() = default;
void RecursivePrint(std::ostream &os) const {
os << result_full << " = select i1 " << cond_full << ", ";
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
os << " " << true_val_full << ", ";
if (std::holds_alternative<LLVMIRIntType>(ty)) {
os << "i" << std::get<LLVMIRIntType>(ty).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(ty)) {
os << "ptr";
} else {
throw std::runtime_error("something strange happened");
}
os << false_val_full << "\n";
}
};
class FunctionDefItem : public LLVMIRItemBase {
friend class IRBuilder;
LLVMType return_type;
std::string func_name_raw;
std::vector<LLVMType> args;
std::vector<std::string> args_full_name;
std::vector<std::shared_ptr<BlockItem>> basic_blocks;
public:
FunctionDefItem() = default;
void RecursivePrint(std::ostream &os) const {
os << "define ";
if (std::holds_alternative<LLVMIRIntType>(return_type)) {
os << "i" << std::get<LLVMIRIntType>(return_type).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(return_type)) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(return_type)) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(return_type)) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
os << " @" << func_name_raw << "(";
for (size_t i = 0; i < args.size(); i++) {
if (std::holds_alternative<LLVMIRIntType>(args[i])) {
os << "i" << std::get<LLVMIRIntType>(args[i]).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(args[i])) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(args[i])) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(args[i])) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
os << ' ' << args_full_name[i];
if (i != args.size() - 1) {
os << ",";
}
}
os << ")\n{\n";
for (auto &item : basic_blocks) {
item->RecursivePrint(os);
}
os << "}\n";
}
};
class FunctionDeclareItem : public LLVMIRItemBase {
friend class IRBuilder;
friend std::shared_ptr<class ModuleItem> BuildIR(std::shared_ptr<Program_ASTNode> src);
LLVMType return_type;
std::string func_name_raw;
std::vector<LLVMType> args;
public:
FunctionDeclareItem() = default;
void RecursivePrint(std::ostream &os) const {
os << "declare ";
if (std::holds_alternative<LLVMIRIntType>(return_type)) {
os << "i" << std::get<LLVMIRIntType>(return_type).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(return_type)) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(return_type)) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(return_type)) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
os << " @" << func_name_raw << "(";
for (size_t i = 0; i < args.size(); i++) {
if (std::holds_alternative<LLVMIRIntType>(args[i])) {
os << "i" << std::get<LLVMIRIntType>(args[i]).bits;
} else if (std::holds_alternative<LLVMIRPTRType>(args[i])) {
os << "ptr";
} else if (std::holds_alternative<LLVMVOIDType>(args[i])) {
os << "void";
} else if (std::holds_alternative<LLVMIRCLASSTYPE>(args[i])) {
throw std::runtime_error("In MX* language, class types are referenced by pointers");
}
if (i != args.size() - 1) {
os << ",";
}
}
os << ")\n";
}
};
class ModuleItem : public LLVMIRItemBase {
friend class IRBuilder;
friend std::shared_ptr<ModuleItem> BuildIR(std::shared_ptr<Program_ASTNode> src);
std::vector<std::shared_ptr<FunctionDeclareItem>> function_declares;
std::vector<std::shared_ptr<TypeDefItem>> type_defs;
std::vector<std::shared_ptr<GlobalVarDefItem>> global_var_defs;
std::vector<std::shared_ptr<FunctionDefItem>> function_defs;
public:
ModuleItem() = default;
void RecursivePrint(std::ostream &os) const {
for (auto &item : function_declares) {
item->RecursivePrint(os);
}
for (auto &item : type_defs) {
item->RecursivePrint(os);
os << '\n';

View File

@ -10,6 +10,7 @@
class ASTNodeBase {
friend Visitor;
friend std::shared_ptr<Program_ASTNode> CheckAndDecorate(std::shared_ptr<Program_ASTNode> src);
friend std::shared_ptr<class ModuleItem> BuildIR(std::shared_ptr<Program_ASTNode> src);
protected:
std::shared_ptr<ScopeBase> current_scope;

View File

@ -8,11 +8,14 @@
class Expr_ASTNode : public ASTNodeBase {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
ExprTypeInfo expr_type_info;
bool assignable;
std::string IR_result_full;
bool is_requiring_lvalue;
public:
Expr_ASTNode() : assignable(false){};
Expr_ASTNode() : assignable(false), is_requiring_lvalue(false){};
virtual ~Expr_ASTNode() = default;
};
@ -21,6 +24,7 @@ class BasicExpr_ASTNode : public Expr_ASTNode {}; // This is a virtual class
class NewArrayExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
bool has_initial_value;
std::vector<std::shared_ptr<Expr_ASTNode>> dim_size;
std::shared_ptr<class ConstantExpr_ASTNode> initial_value;
@ -33,6 +37,7 @@ class NewArrayExpr_ASTNode : public Expr_ASTNode {
class NewConstructExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
public:
NewConstructExpr_ASTNode() = default;
@ -42,6 +47,7 @@ class NewConstructExpr_ASTNode : public Expr_ASTNode {
class NewExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
public:
NewExpr_ASTNode() = default;
@ -51,6 +57,7 @@ class NewExpr_ASTNode : public Expr_ASTNode {
class AccessExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> base;
IdentifierType member;
bool is_function;
@ -64,6 +71,7 @@ class AccessExpr_ASTNode : public Expr_ASTNode {
class IndexExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> base;
std::vector<std::shared_ptr<Expr_ASTNode>> indices;
@ -75,6 +83,7 @@ class IndexExpr_ASTNode : public Expr_ASTNode {
class SuffixExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> base;
@ -86,6 +95,7 @@ class SuffixExpr_ASTNode : public Expr_ASTNode {
class PrefixExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> base;
@ -97,6 +107,7 @@ class PrefixExpr_ASTNode : public Expr_ASTNode {
class OppositeExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> base;
public:
@ -107,6 +118,7 @@ class OppositeExpr_ASTNode : public Expr_ASTNode {
class LNotExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> base;
public:
@ -117,6 +129,7 @@ class LNotExpr_ASTNode : public Expr_ASTNode {
class BNotExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> base;
public:
@ -127,6 +140,7 @@ class BNotExpr_ASTNode : public Expr_ASTNode {
class MDMExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -139,6 +153,7 @@ class MDMExpr_ASTNode : public Expr_ASTNode {
class PMExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -151,6 +166,7 @@ class PMExpr_ASTNode : public Expr_ASTNode {
class RLExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -163,6 +179,7 @@ class RLExpr_ASTNode : public Expr_ASTNode {
class GGLLExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -175,6 +192,7 @@ class GGLLExpr_ASTNode : public Expr_ASTNode {
class NEExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -187,6 +205,7 @@ class NEExpr_ASTNode : public Expr_ASTNode {
class BAndExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -199,6 +218,7 @@ class BAndExpr_ASTNode : public Expr_ASTNode {
class BXorExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -211,6 +231,7 @@ class BXorExpr_ASTNode : public Expr_ASTNode {
class BOrExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -223,6 +244,7 @@ class BOrExpr_ASTNode : public Expr_ASTNode {
class LAndExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -235,6 +257,7 @@ class LAndExpr_ASTNode : public Expr_ASTNode {
class LOrExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> left;
std::shared_ptr<Expr_ASTNode> right;
@ -247,6 +270,7 @@ class LOrExpr_ASTNode : public Expr_ASTNode {
class TernaryExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> condition;
std::shared_ptr<Expr_ASTNode> src1;
std::shared_ptr<Expr_ASTNode> src2;
@ -259,6 +283,7 @@ class TernaryExpr_ASTNode : public Expr_ASTNode {
class AssignExpr_ASTNode : public Expr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::string op;
std::shared_ptr<Expr_ASTNode> dest;
std::shared_ptr<Expr_ASTNode> src;
@ -271,6 +296,7 @@ class AssignExpr_ASTNode : public Expr_ASTNode {
class ThisExpr_ASTNode : public BasicExpr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
public:
ThisExpr_ASTNode() = default;
@ -280,6 +306,7 @@ class ThisExpr_ASTNode : public BasicExpr_ASTNode {
class ParenExpr_ASTNode : public BasicExpr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> expr;
public:
@ -290,6 +317,7 @@ class ParenExpr_ASTNode : public BasicExpr_ASTNode {
class IDExpr_ASTNode : public BasicExpr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
IdentifierType id;
public:
@ -300,6 +328,7 @@ class IDExpr_ASTNode : public BasicExpr_ASTNode {
class FunctionCallExpr_ASTNode : public BasicExpr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
IdentifierType func_name;
std::vector<std::shared_ptr<Expr_ASTNode>> arguments;
@ -311,6 +340,7 @@ class FunctionCallExpr_ASTNode : public BasicExpr_ASTNode {
class FormattedStringExpr_ASTNode : public BasicExpr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::vector<std::string> literals;
std::vector<std::shared_ptr<Expr_ASTNode>> exprs;
@ -326,6 +356,7 @@ using AtomicConstantType = std::variant<uint32_t, bool, std::string, NullType>;
class ConstantExpr_ASTNode : public BasicExpr_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
size_t level;
std::variant<AtomicConstantType, std::vector<std::shared_ptr<ConstantExpr_ASTNode>>> value;

View File

@ -14,6 +14,7 @@ class ScopeBase {
friend class ClassDefScope;
friend class GlobalScope;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
protected:
ScopeBase *parent; // cannot use std::shared_ptr<ScopeBase> because of circular dependency
@ -21,15 +22,17 @@ class ScopeBase {
virtual bool VariableNameAvailable(const std::string &name, int ttl) = 0;
virtual bool add_variable(const std::string &name, const ExprTypeInfo &type) = 0;
virtual ExprTypeInfo fetch_varaible(const std::string &name) = 0;
virtual IRVariableInfo fetch_variable_for_IR(const std::string &name) = 0;
static inline bool IsKeyWord(const std::string &name) {
static const std::unordered_set<std::string> keywords = {"void", "bool", "int", "string", "new", "class",
"null", "true", "false", "this", "if", "else",
"for", "while", "break", "continue", "return"};
return keywords.find(name) != keywords.end();
}
public:
public:
ScopeBase() {
static size_t scope_counter=0;
static size_t scope_counter = 0;
scope_id = scope_counter++;
}
};
@ -56,6 +59,17 @@ class LocalScope : public ScopeBase {
}
return parent->fetch_varaible(name);
}
IRVariableInfo fetch_variable_for_IR(const std::string &name) override {
if (local_variables.find(name) != local_variables.end()) {
IRVariableInfo res;
res.variable_name_raw = name;
res.scope_id = scope_id;
res.variable_type = 1;
res.ty = Type_AST2LLVM(local_variables[name]);
return res;
}
return parent->fetch_variable_for_IR(name);
}
bool VariableNameAvailable(const std::string &name, int ttl) override {
if (ttl == 0 && IsKeyWord(name)) {
return false;
@ -90,6 +104,19 @@ class FunctionScope : public ScopeBase {
}
return parent->fetch_varaible(name);
}
IRVariableInfo fetch_variable_for_IR(const std::string &name) override {
for (const auto &arg : schema.arguments) {
if (arg.second == name) {
IRVariableInfo res;
res.variable_name_raw = name;
res.scope_id = scope_id;
res.variable_type = 3;
res.ty = Type_AST2LLVM(arg.first);
return res;
}
}
return parent->fetch_variable_for_IR(name);
}
bool VariableNameAvailable(const std::string &name, int ttl) override {
if (ttl == 0 && IsKeyWord(name)) {
return false;
@ -131,6 +158,17 @@ class ClassDefScope : public ScopeBase {
}
return parent->fetch_varaible(name);
}
IRVariableInfo fetch_variable_for_IR(const std::string &name) override {
if (member_variables.find(name) != member_variables.end()) {
IRVariableInfo res;
res.variable_name_raw = name;
res.scope_id = scope_id;
res.variable_type = 2;
res.ty = Type_AST2LLVM(member_variables[name]);
return res;
}
return parent->fetch_variable_for_IR(name);
}
bool add_function(const std::string &name, std::shared_ptr<FunctionScope> ptr) {
if (IsKeyWord(name)) return false;
if (member_variables.find(name) != member_variables.end()) {
@ -162,6 +200,7 @@ class ClassDefScope : public ScopeBase {
class GlobalScope : public ScopeBase {
friend class Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
friend std::shared_ptr<class Program_ASTNode> CheckAndDecorate(std::shared_ptr<class Program_ASTNode> src);
std::unordered_map<std::string, ExprTypeInfo> global_variables;
std::unordered_map<std::string, std::shared_ptr<FunctionScope>> global_functions;
@ -257,6 +296,17 @@ class GlobalScope : public ScopeBase {
}
throw SemanticError("Undefined Identifier", 1);
}
IRVariableInfo fetch_variable_for_IR(const std::string &name) override {
if (global_variables.find(name) != global_variables.end()) {
IRVariableInfo res;
res.variable_name_raw = name;
res.scope_id = scope_id;
res.variable_type = 0;
res.ty = Type_AST2LLVM(global_variables[name]);
return res;
}
return parent->fetch_variable_for_IR(name);
}
public:
GlobalScope() { parent = nullptr; }

View File

@ -11,6 +11,7 @@ class Statement_ASTNode : public ASTNodeBase {
class EmptyStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
public:
EmptyStatement_ASTNode() = default;
@ -19,6 +20,7 @@ class EmptyStatement_ASTNode : public Statement_ASTNode {
class DefinitionStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
ExprTypeInfo var_type;
std::vector<std::pair<IdentifierType, std::shared_ptr<Expr_ASTNode>>> vars;
@ -29,6 +31,7 @@ class DefinitionStatement_ASTNode : public Statement_ASTNode {
class ExprStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> expr;
public:
@ -38,6 +41,7 @@ class ExprStatement_ASTNode : public Statement_ASTNode {
class IfStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
bool has_else_clause;
std::shared_ptr<Expr_ASTNode> condition;
std::shared_ptr<Statement_ASTNode> if_clause;
@ -50,6 +54,7 @@ class IfStatement_ASTNode : public Statement_ASTNode {
class WhileStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Expr_ASTNode> condition;
std::shared_ptr<Statement_ASTNode> loop_body;
@ -60,6 +65,7 @@ class WhileStatement_ASTNode : public Statement_ASTNode {
class ForStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::shared_ptr<Statement_ASTNode> initial;
std::shared_ptr<Expr_ASTNode> condition;
std::shared_ptr<Statement_ASTNode> update;
@ -72,6 +78,7 @@ class ForStatement_ASTNode : public Statement_ASTNode {
class JmpStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
uint8_t jmp_type; // 0: return, 1: break, 2: continue
std::shared_ptr<Expr_ASTNode> return_value;
@ -82,6 +89,7 @@ class JmpStatement_ASTNode : public Statement_ASTNode {
class SuiteStatement_ASTNode : public Statement_ASTNode {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::vector<std::shared_ptr<Statement_ASTNode>> statements;
public:

View File

@ -10,6 +10,7 @@
class FuncDef_ASTNode : public ASTNodeBase {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
bool is_constructor;
IdentifierType func_name;
ExprTypeInfo return_type;
@ -23,6 +24,7 @@ class FuncDef_ASTNode : public ASTNodeBase {
class ClassDef_ASTNode : public ASTNodeBase {
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
private:
std::string class_name;
@ -36,8 +38,10 @@ class ClassDef_ASTNode : public ASTNodeBase {
virtual void accept(class ASTNodeVisitorBase *visitor) override;
};
class Program_ASTNode : public ASTNodeBase {
friend std::shared_ptr<class ModuleItem> BuildIR(std::shared_ptr<Program_ASTNode> src);
friend Visitor;
friend class ASTSemanticCheckVisitor;
friend class IRBuilder;
std::vector<std::shared_ptr<DefinitionStatement_ASTNode>> global_variables;
std::vector<std::shared_ptr<ClassDef_ASTNode>> classes;
std::vector<std::shared_ptr<FuncDef_ASTNode>> functions;

View File

@ -1,4 +1,5 @@
#pragma once
#include <bits/types/struct_sched_param.h>
#include <stdexcept>
#include <unordered_map>
#include <variant>
@ -101,10 +102,21 @@ inline bool operator==(const ExprTypeInfo &l, const ExprTypeInfo &r) {
throw std::runtime_error("something strange happened");
}
inline bool operator!=(const ExprTypeInfo &l, const ExprTypeInfo &r) { return !(l == r); }
class LLVMIRIntType {
public:
size_t bits;
LLVMIRIntType() = default;
LLVMIRIntType(size_t bits) : bits(bits) {}
};
struct LLVMIRPTRType {};
struct LLVMVOIDType {};
struct LLVMIRCLASSTYPE {
std::string class_name_full;
};
using LLVMType = std::variant<LLVMIRIntType, LLVMIRPTRType, LLVMVOIDType, LLVMIRCLASSTYPE>;
class IRClassInfo {
public:
std::string class_name; // This data must be provided by user
std::string class_name_raw; // This data must be provided by user
std::vector<size_t> member_var_size; // This data must be provided by user. Each of them is the size of a member
// variable, which must be in [1,4]
std::unordered_map<std::string, size_t> member_var_offset; // This data must be provided by user
@ -128,10 +140,36 @@ class IRClassInfo {
class_size_after_align = cur_pos;
}
}
std::string GenerateFullName() { return "%.class." + class_name_raw; }
};
class IRVariableInfo {
public:
enum class VariableType { global_variable, local_variable, member_variable };
std::string class_name;
std::string variable_name;
};
std::string variable_name_raw;
size_t scope_id;
uint8_t variable_type; // 0: global, 1: local, 2: member, 3: argument
LLVMType ty;
std::string GenerateFullName() {
if (variable_type == 2) {
throw std::runtime_error("Member variable should not be used in this function");
} else if (variable_type == 0) {
return "@.var.global." + variable_name_raw + ".addrkp";
} else if (variable_type == 1) {
return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".addrkp";
} else if (variable_type == 3) {
return "%.var.local." + std::to_string(scope_id) + "." + variable_name_raw + ".val";
} else {
throw std::runtime_error("Invalid scope id");
}
}
};
inline LLVMType Type_AST2LLVM(const ExprTypeInfo &src) {
if (std::holds_alternative<ArrayType>(src)) return LLVMIRPTRType();
std::string tname = std::get<IdentifierType>(src);
if (tname == "bool") return LLVMIRIntType(1);
if (tname == "int") return LLVMIRIntType(32);
if (tname == "void") return LLVMVOIDType();
return LLVMIRPTRType();
}