diff --git a/include/IR/IR_basic.h b/include/IR/IR_basic.h index 9bfcd26..3edfa0a 100644 --- a/include/IR/IR_basic.h +++ b/include/IR/IR_basic.h @@ -336,6 +336,7 @@ class PhiItem : public ActionItem { } }; class SelectItem : public ActionItem { + friend class IRBuilder; std::string result_full; std::string cond_full; std::string true_val_full; @@ -449,6 +450,7 @@ class FunctionDeclareItem : public LLVMIRItemBase { } }; class ConstStrItem : public LLVMIRItemBase { + friend std::shared_ptr BuildIR(std::shared_ptr src); friend class IRBuilder; std::string string_raw; size_t const_str_id; diff --git a/include/tools.h b/include/tools.h index ea3f379..df51a49 100644 --- a/include/tools.h +++ b/include/tools.h @@ -181,7 +181,7 @@ inline LLVMType Type_AST2LLVM(const ExprTypeInfo &src) { return LLVMIRPTRType(); } -inline std::string StringLiteralDeEscape(const std::string src) { +inline std::string StringLiteralDeEscape(const std::string &src) { std::stringstream ss; for (size_t i = 1; i < src.size() - 1; i++) { if (src[i] != '\\') @@ -207,4 +207,37 @@ inline std::string StringLiteralDeEscape(const std::string src) { } } return ss.str(); +} + +inline std::string FmtStrLiteralDeEscape(const std::string &src) { + std::stringstream ss; + for (size_t i = 0; i < src.size(); i++) { + if (src[i] == '\\') { + i++; + if (src[i] == 'n') + ss << '\n'; + else if (src[i] == 'r') + ss << '\r'; + else if (src[i] == 't') + ss << '\t'; + else if (src[i] == '\\') + ss << '\\'; + else if (src[i] == '\'') + ss << '\''; + else if (src[i] == '\"') + ss << '\"'; + else if (src[i] == '0') + ss << '\0'; + else + throw std::runtime_error("Invalid escape character"); + } else if (src[i] == '$') { + i++; + if (src[i] == '$') + ss << '$'; + else + throw std::runtime_error("Invalid escape character"); + } else + ss << src[i]; + } + return ss.str(); } \ No newline at end of file diff --git a/src/IR/IRBuilder.cpp b/src/IR/IRBuilder.cpp index 721fabc..2e4e1a2 100644 --- a/src/IR/IRBuilder.cpp +++ b/src/IR/IRBuilder.cpp @@ -798,6 +798,37 @@ void IRBuilder::ActuralVisit(RLExpr_ASTNode *node) { void IRBuilder::ActuralVisit(GGLLExpr_ASTNode *node) { node->left->accept(this); node->right->accept(this); + ExprTypeInfo string_std = IdentifierType("string"); + if (node->left->expr_type_info == string_std) { + auto strcmp_call = std::make_shared(); + cur_block->actions.push_back(strcmp_call); + strcmp_call->func_name_raw = "strcmp"; + strcmp_call->return_type = LLVMIRIntType(32); + strcmp_call->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + strcmp_call->args_ty.push_back(LLVMIRPTRType()); + strcmp_call->args_val_full.push_back(node->left->IR_result_full); + strcmp_call->args_ty.push_back(LLVMIRPTRType()); + strcmp_call->args_val_full.push_back(node->right->IR_result_full); + auto cmp_act = std::make_shared(); + cur_block->actions.push_back(cmp_act); + if (node->op == ">=") { + cmp_act->op = "sge"; + } else if (node->op == ">") { + cmp_act->op = "sgt"; + } else if (node->op == "<=") { + cmp_act->op = "sle"; + } else if (node->op == "<") { + cmp_act->op = "slt"; + } else { + throw std::runtime_error("unknown GGLL operator"); + } + cmp_act->operand1_full = strcmp_call->result_full; + cmp_act->operand2_full = "0"; + cmp_act->type = LLVMIRIntType(32); + cmp_act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = cmp_act->result_full; + return; + } auto act = std::make_shared(); cur_block->actions.push_back(act); if (node->op == ">=") { @@ -816,12 +847,36 @@ void IRBuilder::ActuralVisit(GGLLExpr_ASTNode *node) { act->type = Type_AST2LLVM(node->left->expr_type_info); act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); node->IR_result_full = act->result_full; - // TODO: string comparison } void IRBuilder::ActuralVisit(NEExpr_ASTNode *node) { node->left->accept(this); node->right->accept(this); + ExprTypeInfo string_std = IdentifierType("string"); + if (node->left->expr_type_info == string_std) { + auto strcmp_call = std::make_shared(); + cur_block->actions.push_back(strcmp_call); + strcmp_call->func_name_raw = "strcmp"; + strcmp_call->return_type = LLVMIRIntType(32); + strcmp_call->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + strcmp_call->args_ty.push_back(LLVMIRPTRType()); + strcmp_call->args_val_full.push_back(node->left->IR_result_full); + strcmp_call->args_ty.push_back(LLVMIRPTRType()); + strcmp_call->args_val_full.push_back(node->right->IR_result_full); + auto cmp_act = std::make_shared(); + cur_block->actions.push_back(cmp_act); + if (node->op == "!=") { + cmp_act->op = "ne"; + } else { + cmp_act->op = "eq"; + } + cmp_act->operand1_full = strcmp_call->result_full; + cmp_act->operand2_full = "0"; + cmp_act->type = LLVMIRIntType(32); + cmp_act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + node->IR_result_full = cmp_act->result_full; + return; + } auto act = std::make_shared(); cur_block->actions.push_back(act); if (node->op == "!=") { @@ -836,7 +891,6 @@ void IRBuilder::ActuralVisit(NEExpr_ASTNode *node) { act->type = Type_AST2LLVM(node->left->expr_type_info); act->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); node->IR_result_full = act->result_full; - // TODO: string comparison } void IRBuilder::ActuralVisit(BAndExpr_ASTNode *node) { @@ -1123,8 +1177,81 @@ void IRBuilder::ActuralVisit(FunctionCallExpr_ASTNode *node) { } void IRBuilder::ActuralVisit(FormattedStringExpr_ASTNode *node) { - // TODO: Implement function body - throw std::runtime_error("formatted string not supported"); + for (size_t i = 0; i < node->literals.size(); i++) { + std::string str; + if (i == 0) + str = node->literals[i].substr(2, node->literals[i].size() - 3); + else + str = node->literals[i].substr(1, node->literals[i].size() - 2); + str = FmtStrLiteralDeEscape(str); + if (const_str_dict.find(str) == const_str_dict.end()) { + const_str_dict[str] = const_str_counter++; + auto const_str_item = std::make_shared(); + prog->const_strs.push_back(const_str_item); + const_str_item->string_raw = str; + const_str_item->const_str_id = const_str_dict[str]; + } + } + ExprTypeInfo string_std = IdentifierType("string"); + ExprTypeInfo int_std = IdentifierType("int"); + ExprTypeInfo bool_std = IdentifierType("bool"); + std::string res = + "@.str." + + std::to_string(const_str_dict[FmtStrLiteralDeEscape(node->literals[0].substr(2, node->literals[0].size() - 3))]); + if (node->exprs.size() + 1 != node->literals.size()) throw std::runtime_error("formatted string error"); + for (size_t i = 0; i < node->exprs.size(); i++) { + node->exprs[i]->accept(this); + std::string str_res = node->exprs[i]->IR_result_full; + if (node->exprs[i]->expr_type_info == string_std) { + ; // just do nothing + } else if (node->exprs[i]->expr_type_info == int_std) { + auto call = std::make_shared(); + cur_block->actions.push_back(call); + call->func_name_raw = "toString"; + call->return_type = LLVMIRPTRType(); + call->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + call->args_ty.push_back(LLVMIRIntType(32)); + call->args_val_full.push_back(str_res); + str_res = call->result_full; + } else if (node->exprs[i]->expr_type_info == bool_std) { + auto select = std::make_shared(); + cur_block->actions.push_back(select); + select->result_full = "%.var.tmp." + std::to_string(tmp_var_counter++); + select->ty = LLVMIRPTRType(); + select->cond_full = str_res; + select->true_val_full = "@.str." + std::to_string(const_str_dict["true"]); + select->false_val_full = "@.str." + std::to_string(const_str_dict["false"]); + str_res = select->result_full; + } else { + throw std::runtime_error("formatted string error"); + } + std::string tmp = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto strcat1 = std::make_shared(); + cur_block->actions.push_back(strcat1); + strcat1->func_name_raw = ".builtin.strcat"; + strcat1->return_type = LLVMIRPTRType(); + strcat1->result_full = tmp; + strcat1->args_ty.push_back(LLVMIRPTRType()); + strcat1->args_val_full.push_back(res); + strcat1->args_ty.push_back(LLVMIRPTRType()); + strcat1->args_val_full.push_back(str_res); + res = tmp; + tmp = "%.var.tmp." + std::to_string(tmp_var_counter++); + auto strcat2 = std::make_shared(); + cur_block->actions.push_back(strcat2); + strcat2->func_name_raw = ".builtin.strcat"; + strcat2->return_type = LLVMIRPTRType(); + strcat2->result_full = tmp; + strcat2->args_ty.push_back(LLVMIRPTRType()); + strcat2->args_val_full.push_back(res); + strcat2->args_ty.push_back(LLVMIRPTRType()); + strcat2->args_val_full.push_back( + "@.str." + + std::to_string( + const_str_dict[FmtStrLiteralDeEscape(node->literals[i + 1].substr(1, node->literals[i + 1].size() - 2))])); + res = tmp; + } + node->IR_result_full = res; } void IRBuilder::ActuralVisit(ConstantExpr_ASTNode *node) { @@ -1273,6 +1400,25 @@ std::shared_ptr BuildIR(std::shared_ptr src) { tmp->args.push_back(LLVMIRPTRType()); visitor.prog->function_declares.push_back(tmp); + tmp = std::make_shared(); // int strcmp(const char *str1, const char *str2); + tmp->func_name_raw = "strcmp"; + tmp->return_type = LLVMIRIntType(32); + tmp->args.push_back(LLVMIRPTRType()); + tmp->args.push_back(LLVMIRPTRType()); + visitor.prog->function_declares.push_back(tmp); + + visitor.const_str_dict["true"] = visitor.const_str_counter++; + auto const_str_item = std::make_shared(); + visitor.prog->const_strs.push_back(const_str_item); + const_str_item->string_raw = "true"; + const_str_item->const_str_id = visitor.const_str_dict["true"]; + + visitor.const_str_dict["false"] = visitor.const_str_counter++; + const_str_item = std::make_shared(); + visitor.prog->const_strs.push_back(const_str_item); + const_str_item->string_raw = "false"; + const_str_item->const_str_id = visitor.const_str_dict["false"]; + visitor.global_scope = std::dynamic_pointer_cast(src->current_scope); if (!(visitor.global_scope)) throw std::runtime_error("global scope not found"); visitor.visit(src.get());