finish basic design for BPT

This commit is contained in:
2024-04-27 12:52:45 +00:00
parent 460245ff5e
commit 3d7b616dc7
7 changed files with 192 additions and 52 deletions

1
.gitignore vendored
View File

@ -3,4 +3,5 @@
/.github
/.cache
/build
/mbuild
/.clang-format

View File

@ -1,25 +1,99 @@
#ifndef BPT_HPP
#define BPT_HPP
#include <algorithm>
#include <shared_mutex>
#include <vector>
#include "bpt/bpt_page.hpp"
#include "bpt/buffer_pool_manager.h"
#include "bpt/config.h"
/**
* @brief B+ Tree Indexer
* @warning The KeyType must can be stored byte by byte. As this is only the indexer, the type of value is always
* b_plus_tree_value_index_t.
*/
template <typename KeyType, typename KeyComparator>
class BPlusTreeIndexer {
private:
// TODO : insert ?
public:
typedef BPlusTreePage<KeyType> PageType;
typedef std::pair<KeyType, default_numeric_index_t> key_index_pair_t;
typedef std::pair<KeyType, b_plus_tree_value_index_t> value_type;
private:
struct PositionSignType {
std::vector<std::pair<BasicPageGuard, in_page_key_count_t>> path;
bool is_end{false};
};
PositionSignType FindPosition(const KeyType &key) { // Finish Design
if (root_page_id == 0) {
return PositionSignType{.is_end = true};
}
BasicPageGuard current_page_guard(bpm->FetchPageBasic(root_page_id));
static auto comparer_for_key_index_pair = [](const key_index_pair_t &a, const KeyType &b) {
return key_cmp(a.first, b);
};
in_page_key_count_t nxt = std::lower_bound(current_page_guard.As<PageType>()->data.p_data,
current_page_guard.As<PageType>()->data.p_data +
current_page_guard.As<PageType>()->data.key_count,
key, comparer_for_key_index_pair) -
current_page_guard.As<PageType>()->data.p_data;
PositionSignType res;
while (res.path.back().first.template As<PageType>()->data.page_status != 0) {
default_numeric_index_t nxt_page_id;
in_page_key_count_t internal_id = res.path.back().second;
if (internal_id < res.path.back().first.template As<PageType>()->data.key_count)
nxt_page_id = res.path.back().first.template As<PageType>()->data.p_data[internal_id].second;
else
nxt_page_id = res.path.back().first.template As<PageType>()->data.p_n;
BasicPageGuard next_page_guard(bpm->FetchPageBasic(nxt_page_id));
nxt =
std::lower_bound(next_page_guard.As<PageType>()->data.p_data,
next_page_guard.As<PageType>()->data.p_data + next_page_guard.As<PageType>()->data.key_count,
key, comparer_for_key_index_pair) -
next_page_guard.As<PageType>()->data.p_data;
res.path.push_back(std::make_pair(std::move(next_page_guard), nxt));
}
if (nxt == res.path.back().first.template As<PageType>()->data.key_count) res.is_end = true;
return res;
}
void InsertEntryAt(PositionSignType &pos, const KeyType &key, b_plus_tree_value_index_t value) {
// TODO
}
void RemoveEntryAt(PositionSignType &pos) {
// TODO
}
public:
// note that for safety, the iterator is not copyable, and the const_iterator is not copyable
class iterator {
BPlusTreeIndexer *domain;
size_t internal_offset;
bool is_end;
WritePageGuard guard;
const KeyType &GetKey() const {
// TODO
friend class BPlusTreeIndexer;
public:
const KeyType &GetKey() const { return guard.As<PageType>()->data.p_data[internal_offset].first; }
const b_plus_tree_value_index_t &GetValue() { return guard.As<PageType>()->data.p_data[internal_offset].second; }
bool operator==(iterator &that) {
return domain == that.domain && guard.PageId() == that.guard.PageId() &&
internal_offset == that.internal_offset && is_end == that.is_end;
}
const b_plus_tree_value_index_t &GetValue() const {
// TODO
void SetValue(b_plus_tree_value_index_t new_value) {
guard.AsMut<PageType>()->data.p_data[internal_offset].second = new_value;
}
// only support ++it
iterator &operator++() {
if (is_end) return *this;
internal_offset++;
if (internal_offset == guard.As<PageType>()->data.key_count) {
default_numeric_index_t nxt_page_id = guard.As<PageType>()->data.p_n;
if (nxt_page_id == 0) {
is_end = true;
return *this;
}
guard = domain->bpm->FetchPageWrite(nxt_page_id);
internal_offset = 0;
}
return *this;
}
};
class const_iterator {
@ -27,11 +101,29 @@ class BPlusTreeIndexer {
size_t internal_offset;
bool is_end;
ReadPageGuard guard;
const KeyType &GetKey() const {
// TODO
friend class BPlusTreeIndexer;
public:
const KeyType &GetKey() { return guard.As<PageType>()->data.p_data[internal_offset].first; }
const b_plus_tree_value_index_t &GetValue() { return guard.As<PageType>()->data.p_data[internal_offset].second; }
bool operator==(const_iterator &that) {
return domain == that.domain && guard.PageId() == that.guard.PageId() &&
internal_offset == that.internal_offset && is_end == that.is_end;
}
const b_plus_tree_value_index_t &GetValue() const {
// TODO
// only support ++it
const_iterator &operator++() {
if (is_end) return *this;
internal_offset++;
if (internal_offset == guard.As<PageType>()->data.key_count) {
default_numeric_index_t nxt_page_id = guard.As<PageType>()->data.p_n;
if (nxt_page_id == 0) {
is_end = true;
return *this;
}
guard = domain->bpm->FetchPageRead(nxt_page_id);
internal_offset = 0;
}
return *this;
}
};
BPlusTreeIndexer() = delete;
@ -39,61 +131,84 @@ class BPlusTreeIndexer {
BPlusTreeIndexer(BPlusTreeIndexer &&) = delete;
BPlusTreeIndexer &operator=(const BPlusTreeIndexer &) = delete;
BPlusTreeIndexer &operator=(BPlusTreeIndexer &&) = delete;
iterator end() {
// TODO
BPlusTreeIndexer(BufferPoolManager *bpm_) {
bpm = bpm_;
raw_data_memory = bpm->RawDataMemory();
memcpy(&root_page_id, raw_data_memory, sizeof(page_id_t));
memcpy(&siz, raw_data_memory + sizeof(page_id_t), sizeof(bpt_size_t));
}
iterator lower_bound(const KeyType &key) {
iterator end() { // Finish Design
iterator res;
res.domain = this;
res.is_end = true;
return res;
}
const_iterator end_const() { // Finish Design
const_iterator res;
res.domain = this;
res.is_end = true;
return res;
}
iterator lower_bound(const KeyType &key) { // Finish Design
std::shared_lock<std::shared_mutex> guard(latch);
// TODO
PositionSignType pos(std::move(FindPosition(key)));
iterator res;
res.domain = this;
res.guard = bpm->FetchPageWrite(pos.path.back().first.PageId());
res.is_end = pos.is_end;
res.internal_offset = pos.path.back().second;
return res;
}
const_iterator lower_bound_const(const KeyType &key) {
const_iterator lower_bound_const(const KeyType &key) { // Finish Design
std::shared_lock<std::shared_mutex> guard(latch);
// TODO
PositionSignType pos(std::move(FindPosition(key)));
const_iterator res;
res.domain = this;
res.guard = bpm->FetchPageRead(pos.path.back().first.PageId());
res.is_end = pos.is_end;
res.internal_offset = pos.path.back().second;
return res;
}
bool Set(const iterator &iter, b_plus_tree_value_index_t value) {
std::unique_lock<std::shared_mutex> guard(latch);
// TODO
}
bool Erase(const iterator &iter) {
std::unique_lock<std::shared_mutex> guard(latch);
// TODO
}
b_plus_tree_value_index_t Get(const KeyType &key) {
b_plus_tree_value_index_t Get(const KeyType &key) { // Finish Design
auto it = lower_bound_const(key);
if (it == end()) return kInvalidValueIndex;
if (it == end_const()) return kInvalidValueIndex;
if (key_cmp(key, it.GetKey())) return kInvalidValueIndex;
return it->second;
return it.GetValue();
}
bool Put(const KeyType &key, b_plus_tree_value_index_t value) {
auto it = lower_bound(key);
if (it != end() && !key_cmp(key, it.GetKey())) {
Set(it, value);
bool Put(const KeyType &key, b_plus_tree_value_index_t value) { // Finish Design
PositionSignType pos(std::move(FindPosition(key)));
if (!pos.is_end &&
!key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first)) {
pos.path.back().first.template AsMut<PageType>()->data.p_data[pos.path.back().second].second = value;
return false;
}
// TODO Insert it
InsertEntryAt(pos, key, value);
return true;
}
bool Remove(const KeyType &key) {
auto it = lower_bound(key);
if (it == end()) return false;
if (key_cmp(key, it.GetKey())) return false;
Erase(it);
bool Remove(const KeyType &key) { // Finish Design
PositionSignType pos(std::move(FindPosition(key)));
if (pos.is_end) return false;
if (key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first))
return false;
RemoveEntryAt(pos);
return true;
}
size_t Size() { return siz; }
void Flush() {
// TODO: do some recording
size_t Size() { return siz; } // Finish Design
void Flush() { // Finish Design
memcpy(raw_data_memory, &root_page_id, sizeof(page_id_t));
memcpy(raw_data_memory + sizeof(page_id_t), &siz, sizeof(bpt_size_t));
bpm->FlushAllPages();
}
private:
page_id_t root_page_id; // stored in the first 4 (0-3) bytes of RawDatMemory, this directly operates on the buf
// maintained by DiskManager, BufferPoolManager only passes the pointer to it
uint64_t siz; // stored in the next 8 (4-11) bytes of RawDatMemory, this directly operates on the buf
bpt_size_t siz; // stored in the next 8 (4-11) bytes of RawDatMemory, this directly operates on the buf
// maintained by DiskManager, BufferPoolManager only passes the pointer to it
static KeyComparator key_cmp;
std::shared_mutex latch;
BufferPoolManager *bpm;
char *raw_data_memory;
};
template <typename KeyType, typename KeyComparator>
KeyComparator BPlusTreeIndexer<KeyType, KeyComparator>::key_cmp = KeyComparator();

View File

@ -6,12 +6,12 @@ template <typename KeyType, size_t kPageSize = 4096>
struct ActualDataType {
typedef std::pair<KeyType, default_numeric_index_t> value_type;
page_id_t p_n;
page_id_t p_parent;
uint8_t is_leaf;
uint16_t key_count;
page_status_t page_status; // root(2) / internal(1) / leaf(0)
in_page_key_count_t key_count;
const static size_t kMaxKeyCount =
(kPageSize - sizeof(page_id_t) * 2 - sizeof(uint8_t) - sizeof(uint16_t)) / sizeof(value_type);
(kPageSize - sizeof(page_id_t) - sizeof(page_status_t) - sizeof(in_page_key_count_t)) / sizeof(value_type);
value_type p_data[kMaxKeyCount];
static_assert(kMaxKeyCount >= 2, "kMaxKeyCount must be greater than or equal to 2");
};
template <typename KeyType, size_t kPageSize = 4096>
union BPlusTreePage {

View File

@ -8,4 +8,7 @@ typedef default_numeric_index_t page_id_t;
typedef default_numeric_index_t frame_id_t;
typedef default_numeric_index_t b_plus_tree_value_index_t;
extern const b_plus_tree_value_index_t kInvalidValueIndex;
typedef uint8_t page_status_t;
typedef uint16_t in_page_key_count_t;
typedef uint64_t bpt_size_t;
#endif

View File

@ -17,6 +17,12 @@
- 一个火车票系统执行引擎
- 一个直接的命令行交互系统用于OJ测试/单会话模式/快照管理/数据错误检查
- 一个Socket服务端用于对接服务端
## B+树
基本参考:<https://en.wikipedia.org/wiki/B%2B_tree>
- p[i]子树中的所有key K都满足 k[i-1] \< K \<= k[i]且k[i]一定能取到即直接无缝对接lower_bound
- 对外接口提供类似于迭代器的东西但该迭代器只支持向后单项移动、读取value值、修改value值并且迭代器会保留PageGuard因此如果B+树在迭代器之前析构,会出现访问越界。
# UI设计
- 语言Python
- 与内核的交互Socket

View File

@ -1,7 +1,9 @@
#include <gtest/gtest.h>
#include <map>
#include "bpt/bpt.hpp"
#include "bpt/buffer_pool_manager.h"
#include "bpt/config.h"
#include "bpt/disk_manager.h"
namespace bpt_basic_test {
template <size_t length>
class FixLengthString {
@ -9,7 +11,7 @@ class FixLengthString {
char data[length];
};
} // namespace bpt_basic_test
TEST(BasicTest, Compile) {
TEST(BasicTest, Compile) { // This Test only test the compile of the code
// test for long long, int, char, long double
BPlusTreePage<long long> page_long_long;
static_assert(sizeof(page_long_long) == 4096, "BPlusTreePage size mismatch");
@ -37,4 +39,17 @@ TEST(BasicTest, Compile) {
static_assert(sizeof(page_35) == 4096, "BPlusTreePage size mismatch");
BPlusTreePage<bpt_basic_test::FixLengthString<40>> page_40;
static_assert(sizeof(page_40) == 4096, "BPlusTreePage size mismatch");
remove("/tmp/bpt1.db");
DiskManager *dm = new DiskManager("/tmp/bpt1.db");
BufferPoolManager *bpm = new BufferPoolManager(10, 3, dm);
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
auto it = bpt.lower_bound(1);
bpt.Flush();
bpt.Get(1);
it.SetValue(2);
bpt.Put(1, 2);
bpt.Remove(1);
delete bpm;
delete dm;
}

View File

@ -177,7 +177,7 @@ TEST(StoreTest, Test1) {
PageType c;
c.data.p_n = 0x1f2f3f4f;
c.data.key_count = 0x1f2a;
c.data.is_leaf = 0x3e;
c.data.page_status = 0x3e;
c.data.p_data[17].first = 0x8f7f6f5f4f3f2f1f;
c.filler[0] = 0x1f;
*basic_guard.AsMut<PageType>() = c;
@ -263,13 +263,13 @@ TEST(MemoryRiver, T2) {
size_t interal_id_tot = 0;
const unsigned int RndSeed = testing::GTEST_FLAG(random_seed);
std::mt19937 rnd(RndSeed);
remove("/tmp/T2.std");
remove("/tmp/T2.dat");
remove("T2.std");
remove("T2.dat");
const int kInfoLength = 100;
{
sol::MemoryRiver<DataType, kInfoLength> STD("/tmp/T2.std");
MemoryRiver<DataType, kInfoLength> mr("/tmp/T2.dat");
int total_opts = 1000;
int total_opts = 1000000;
while (total_opts-- > 0) {
int opt = rnd() % 6;
switch (opt) {