diff --git a/demo/alu.cpp b/demo/alu.cpp index 08c4f00..563962d 100644 --- a/demo/alu.cpp +++ b/demo/alu.cpp @@ -1,131 +1,111 @@ -#include "../include/tools.h" +#include "tools.h" #include - -bool reset; -bool ready; +#include // RISC-V enum class Opcode : dark::max_size_t { - ADD, - SUB, - SLL, - SRL, - SRA, - AND, - OR, - XOR, - SLT, - SLTU, - SGE, - SGEU, - SEQ, - SNEQ + ADD, + SUB, + SLL, + SRL, + SRA, + AND, + OR, + XOR, + SLT, + SLTU, + SGE, + SGEU, + SEQ, + SNEQ }; - // Normally, only wire can be used in the input. struct AluInput { - Wire <8> opcode; - Wire <1> issue; - Wire <32> rs1; - Wire <32> rs2; + Wire<8> opcode; + Wire<1> issue; + Wire<32> rs1; + Wire<32> rs2; }; struct AluOutput { - Register <32> out; - Register <1> done; + Register<32> out; + Register<1> done; }; -struct AluModule : AluInput, AluOutput { - using Tags = SyncTags; - - void work() { - if (reset) { - done <= 0; - } else if(ready && issue) { - switch (static_cast (static_cast (opcode))) { - using enum Opcode; - case ADD: out <= (rs1 + rs2); break; - case SUB: out <= (rs1 - rs2); break; - case SLL: out <= (rs1 << rs2); break; - case SRL: out <= (rs1 >> rs2); break; - case SRA: out <= (to_signed(rs1) >> to_unsigned(rs2)); - case AND: out <= (rs1 & rs2); break; - case OR: out <= (rs1 | rs2); break; - case XOR: out <= (rs1 ^ rs2); break; - case SLT: out <= (to_signed(rs1) < to_signed(rs2)); break; - case SLTU: out <= (rs1 < rs2); break; - case SGE: out <= (to_signed(rs1) >= to_signed(rs2)); break; - case SGEU: out <= (rs1 >= rs2); break; - case SEQ: out <= (rs1 == rs2); break; - case SNEQ: out <= (rs1 != rs2); break; - default: dark::debug::assert(false, "Invalid opcode"); - } - done <= 1; - } else { - done <= 0; - } - } - +struct AluModule : dark::Module { + void work() override { + if (issue) { + switch (static_cast(static_cast(opcode))) { + using enum Opcode; + case ADD: out <= (rs1 + rs2); break; + case SUB: out <= (rs1 - rs2); break; + case SLL: out <= (rs1 << rs2); break; + case SRL: out <= (rs1 >> rs2); break; + case SRA: out <= (to_signed(rs1) >> to_unsigned(rs2)); + case AND: out <= (rs1 & rs2); break; + case OR: out <= (rs1 | rs2); break; + case XOR: out <= (rs1 ^ rs2); break; + case SLT: out <= (to_signed(rs1) < to_signed(rs2)); break; + case SLTU: out <= (rs1 < rs2); break; + case SGE: out <= (to_signed(rs1) >= to_signed(rs2)); break; + case SGEU: out <= (rs1 >= rs2); break; + case SEQ: out <= (rs1 == rs2); break; + case SNEQ: out <= (rs1 != rs2); break; + default: dark::debug::assert(false, "Invalid opcode"); + } + done <= 1; + } + else { + done <= 0; + } + } }; -signed main() { - AluModule alu; +int main() { + std::string opstring; - std::string opstring; + max_size_t opcode; + max_size_t issue; + max_size_t rs1; + max_size_t rs2; - max_size_t opcode; - max_size_t issue; - max_size_t rs1; - max_size_t rs2; - - ready = 1; - - alu.opcode = [&]() { return opcode; }; - alu.issue = [&]() { return issue; }; - alu.rs1 = [&]() { return rs1; }; - alu.rs2 = [&]() { return rs2; }; - - while (std::cin >> opstring) { - issue = 1; - std::cin >> rs1 >> rs2; - if (opstring == "add") { - opcode = static_cast (Opcode::ADD); - } else if (opstring == "sub") { - opcode = static_cast (Opcode::SUB); - } else if (opstring == "sll") { - opcode = static_cast (Opcode::SLL); - } else if (opstring == "srl") { - opcode = static_cast (Opcode::SRL); - } else if (opstring == "sra") { - opcode = static_cast (Opcode::SRA); - } else if (opstring == "and") { - opcode = static_cast (Opcode::AND); - } else if (opstring == "or") { - opcode = static_cast (Opcode::OR); - } else if (opstring == "xor") { - opcode = static_cast (Opcode::XOR); - } else if (opstring == "slt") { - opcode = static_cast (Opcode::SLT); - } else if (opstring == "sltu") { - opcode = static_cast (Opcode::SLTU); - } else if (opstring == "sge") { - opcode = static_cast (Opcode::SGE); - } else if (opstring == "sgeu") { - opcode = static_cast (Opcode::SGEU); - } else if (opstring == "seq") { - opcode = static_cast (Opcode::SEQ); - } else if (opstring == "sneq") { - opcode = static_cast (Opcode::SNEQ); - } else { - std::cout << "Invalid opcode" << std::endl; - issue = 0; - } - - alu.work(); - - std::cout << to_unsigned(alu.out) << std::endl; - sync_member(alu); - } + dark::CPU cpu; + AluModule alu; + alu.opcode = [&]() { return opcode; }; + alu.issue = [&]() { return issue; }; + alu.rs1 = [&]() { return rs1; }; + alu.rs2 = [&]() { return rs2; }; + cpu.add_module(&alu); + std::unordered_map cmd2op = { + {"add", Opcode::ADD}, + {"sub", Opcode::SUB}, + {"sll", Opcode::SLL}, + {"src", Opcode::SRL}, + {"sra", Opcode::SRA}, + {"and", Opcode::AND}, + {"or", Opcode::OR}, + {"xor", Opcode::XOR}, + {"slt", Opcode::SLT}, + {"sltu", Opcode::SLTU}, + {"sge", Opcode::SGE}, + {"sgeu", Opcode::SGEU}, + {"seq", Opcode::SEQ}, + {"sneq", Opcode::SNEQ}}; + while (std::cin >> opstring) { + if (cmd2op.find(opstring) == cmd2op.end()) { + std::cout << "Invalid opcode" << std::endl; + issue = 0; + } + else { + issue = 1; + std::cin >> rs1 >> rs2; + } + opcode = static_cast(cmd2op[opstring]); + cpu.run_once(); + std::cout << "out: " << static_cast(alu.out) << std::endl; + std::cout << "done: " << static_cast(alu.done) << std::endl; + } + return 0; } \ No newline at end of file diff --git a/include/concept.h b/include/concept.h index 6e1ef93..149cdb4 100644 --- a/include/concept.h +++ b/include/concept.h @@ -43,7 +43,9 @@ concept int_type = !has_length<_Tp> && implicit_convertible_to<_Tp, max_size_t>; template concept bit_match = - (bit_type<_Lhs> && bit_type<_Rhs> && _Lhs::_Bit_Len == _Rhs::_Bit_Len) || (int_type<_Lhs> || int_type<_Rhs>); + (bit_type<_Lhs> && bit_type<_Rhs> && _Lhs::_Bit_Len == _Rhs::_Bit_Len) // prevent format + || (int_type<_Lhs> && bit_type<_Rhs>) // + || (bit_type<_Lhs> && int_type<_Rhs>); template concept bit_convertible = diff --git a/include/cpu.h b/include/cpu.h index 7d63df6..732d4b8 100644 --- a/include/cpu.h +++ b/include/cpu.h @@ -1,4 +1,6 @@ +#pragma once #include "module.h" +#include #include #include #include @@ -7,7 +9,8 @@ namespace dark { class CPU { private: - std::vector> modules; + std::vector> mod_owned; + std::vector modules; public: unsigned long long cycles = 0; @@ -19,9 +22,21 @@ private: } public: - void add_module(std::unique_ptr module) { - modules.push_back(std::move(module)); + /// @attention the pointer will be moved. you SHOULD NOT use it after calling this function. + template + requires std::derived_from<_Tp, ModuleBase> + void add_module(std::unique_ptr<_Tp> &module) { + modules.push_back(module.get()); + mod_owned.emplace_back(std::move(module)); } + void add_module(std::unique_ptr module) { + modules.push_back(module.get()); + mod_owned.emplace_back(std::move(module)); + } + void add_module(ModuleBase *module) { + modules.push_back(module); + } + void run_once() { ++cycles; for (auto &module: modules) @@ -30,10 +45,7 @@ public: } void run_once_shuffle() { static std::default_random_engine engine; - std::vector shuffled; - shuffled.reserve(modules.size()); - for (auto &module: modules) - shuffled.push_back(module.get()); + std::vector shuffled = modules; std::shuffle(shuffled.begin(), shuffled.end(), engine); ++cycles; diff --git a/include/module.h b/include/module.h index a9a7a46..6ed7ea5 100644 --- a/include/module.h +++ b/include/module.h @@ -1,13 +1,21 @@ +#pragma once #include "synchronize.h" namespace dark { +namespace details { + struct empty_class { + void sync() { /* do nothing */ } + }; +} // namespace details + struct ModuleBase { virtual void work() = 0; virtual void sync() = 0; virtual ~ModuleBase() = default; }; -template +template + requires std::is_aggregate_v<_Tinput> && std::is_aggregate_v<_Toutput> && std::is_aggregate_v<_Tprivate> struct Module : public ModuleBase, public _Tinput, public _Toutput, protected _Tprivate { void sync() override final { sync_member(static_cast<_Tinput &>(*this)); diff --git a/include/register.h b/include/register.h index 136bf31..aa27df5 100644 --- a/include/register.h +++ b/include/register.h @@ -41,6 +41,7 @@ public: } explicit operator max_size_t() const { return this->_M_old; } + explicit operator bool() const { return this->_M_old; } }; } // namespace dark diff --git a/include/tools.h b/include/tools.h index a757337..3c10d63 100644 --- a/include/tools.h +++ b/include/tools.h @@ -5,6 +5,8 @@ #include "register.h" #include "synchronize.h" #include "wire.h" +#include "module.h" +#include "cpu.h" using dark::Bit; using dark::sign_extend; diff --git a/include/wire.h b/include/wire.h index 92dc7cc..5367df4 100644 --- a/include/wire.h +++ b/include/wire.h @@ -107,6 +107,10 @@ public: this->_M_func.reset(_M_new_func(std::forward<_Fn>(fn))); this->sync(); } + + explicit operator bool() const { + return static_cast(*this); + } };