finish simple case

This commit is contained in:
2024-04-27 13:46:33 +00:00
parent 3d7b616dc7
commit 87e19fb96c
4 changed files with 174 additions and 18 deletions

View File

@ -1,6 +1,7 @@
#ifndef BPT_HPP #ifndef BPT_HPP
#define BPT_HPP #define BPT_HPP
#include <algorithm> #include <algorithm>
#include <cstring>
#include <shared_mutex> #include <shared_mutex>
#include <vector> #include <vector>
#include "bpt/bpt_page.hpp" #include "bpt/bpt_page.hpp"
@ -9,11 +10,13 @@
/** /**
* @brief B+ Tree Indexer * @brief B+ Tree Indexer
* @warning The KeyType must can be stored byte by byte. As this is only the indexer, the type of value is always * @warning The KeyType must can be stored byte by byte. As this is only the indexer, the type of value is always
* b_plus_tree_value_index_t. * b_plus_tree_value_index_t. And also, this is only the indexer, the value is not stored in the indexer, the value is
* stored in the value file, and the BPlusTreeIndexer should not be used directly.
*/ */
template <typename KeyType, typename KeyComparator> template <typename KeyType, typename KeyComparator>
class BPlusTreeIndexer { class BPlusTreeIndexer {
typedef BPlusTreePage<KeyType> PageType; typedef BPlusTreePage<KeyType> PageType;
typedef ActualDataType<KeyType> _ActualDataType;
typedef std::pair<KeyType, default_numeric_index_t> key_index_pair_t; typedef std::pair<KeyType, default_numeric_index_t> key_index_pair_t;
typedef std::pair<KeyType, b_plus_tree_value_index_t> value_type; typedef std::pair<KeyType, b_plus_tree_value_index_t> value_type;
@ -24,6 +27,7 @@ class BPlusTreeIndexer {
}; };
PositionSignType FindPosition(const KeyType &key) { // Finish Design PositionSignType FindPosition(const KeyType &key) { // Finish Design
if (root_page_id == 0) { if (root_page_id == 0) {
// special case for the empty tree
return PositionSignType{.is_end = true}; return PositionSignType{.is_end = true};
} }
BasicPageGuard current_page_guard(bpm->FetchPageBasic(root_page_id)); BasicPageGuard current_page_guard(bpm->FetchPageBasic(root_page_id));
@ -36,7 +40,8 @@ class BPlusTreeIndexer {
key, comparer_for_key_index_pair) - key, comparer_for_key_index_pair) -
current_page_guard.As<PageType>()->data.p_data; current_page_guard.As<PageType>()->data.p_data;
PositionSignType res; PositionSignType res;
while (res.path.back().first.template As<PageType>()->data.page_status != 0) { res.path.push_back(std::make_pair(std::move(current_page_guard), nxt));
while ((res.path.back().first.template As<PageType>()->data.page_status & PageStatusType::LEAF) == 0) {
default_numeric_index_t nxt_page_id; default_numeric_index_t nxt_page_id;
in_page_key_count_t internal_id = res.path.back().second; in_page_key_count_t internal_id = res.path.back().second;
if (internal_id < res.path.back().first.template As<PageType>()->data.key_count) if (internal_id < res.path.back().first.template As<PageType>()->data.key_count)
@ -55,9 +60,50 @@ class BPlusTreeIndexer {
return res; return res;
} }
void InsertEntryAt(PositionSignType &pos, const KeyType &key, b_plus_tree_value_index_t value) { void InsertEntryAt(PositionSignType &pos, const KeyType &key, b_plus_tree_value_index_t value) {
if (siz == 0) {
// special case for the first entry
BasicPageGuard new_page_guard = bpm->NewPageGuarded(&root_page_id);
new_page_guard.AsMut<PageType>()->data.page_status = PageStatusType::ROOT | PageStatusType::LEAF;
new_page_guard.AsMut<PageType>()->data.key_count = 1;
new_page_guard.AsMut<PageType>()->data.p_data[0] = std::make_pair(key, value);
new_page_guard.AsMut<PageType>()->data.p_n = 0;
++siz;
return;
}
auto &page_guard = pos.path.back().first;
if (page_guard.template As<PageType>()->data.key_count < _ActualDataType::kMaxKeyCount) {
// case 1: the page has enough space
memmove(page_guard.template AsMut<PageType>()->data.p_data + pos.path.back().second + 1,
page_guard.template As<PageType>()->data.p_data + pos.path.back().second,
(page_guard.template As<PageType>()->data.key_count - pos.path.back().second) * sizeof(value_type));
page_guard.template AsMut<PageType>()->data.p_data[pos.path.back().second] = std::make_pair(key, value);
page_guard.template AsMut<PageType>()->data.key_count++;
++siz;
return;
}
throw std::runtime_error("Not implemented yet: InsertEntryAt");
// TODO // TODO
} }
void RemoveEntryAt(PositionSignType &pos) { void RemoveEntryAt(PositionSignType &pos) {
if (siz == 1) {
// special case for the last entry
bpm->DeletePage(root_page_id);
root_page_id = 0;
--siz;
return;
}
auto &page_guard = pos.path.back().first;
if (page_guard.template As<PageType>()->data.key_count > _ActualDataType::kMinNumberOfKeysForLeaf ||
(page_guard.template As<PageType>()->data.page_status & PageStatusType::ROOT) != 0) {
// case 1: the page has enough keys
memmove(page_guard.template AsMut<PageType>()->data.p_data + pos.path.back().second,
page_guard.template As<PageType>()->data.p_data + pos.path.back().second + 1,
(page_guard.template As<PageType>()->data.key_count - pos.path.back().second - 1) * sizeof(value_type));
page_guard.template AsMut<PageType>()->data.key_count--;
--siz;
return;
}
throw std::runtime_error("Not implemented yet: RemoveEntryAt");
// TODO // TODO
} }
@ -71,13 +117,20 @@ class BPlusTreeIndexer {
friend class BPlusTreeIndexer; friend class BPlusTreeIndexer;
public: public:
const KeyType &GetKey() const { return guard.As<PageType>()->data.p_data[internal_offset].first; } const KeyType &GetKey() const {
const b_plus_tree_value_index_t &GetValue() { return guard.As<PageType>()->data.p_data[internal_offset].second; } std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].first;
}
const b_plus_tree_value_index_t &GetValue() {
std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].second;
}
bool operator==(iterator &that) { bool operator==(iterator &that) {
return domain == that.domain && guard.PageId() == that.guard.PageId() && return domain == that.domain && is_end == that.is_end &&
internal_offset == that.internal_offset && is_end == that.is_end; (is_end || (guard.PageId() == that.guard.PageId() && internal_offset == that.internal_offset));
} }
void SetValue(b_plus_tree_value_index_t new_value) { void SetValue(b_plus_tree_value_index_t new_value) {
std::unique_lock<std::shared_mutex> lock_guard(domain->latch);
guard.AsMut<PageType>()->data.p_data[internal_offset].second = new_value; guard.AsMut<PageType>()->data.p_data[internal_offset].second = new_value;
} }
// only support ++it // only support ++it
@ -104,11 +157,17 @@ class BPlusTreeIndexer {
friend class BPlusTreeIndexer; friend class BPlusTreeIndexer;
public: public:
const KeyType &GetKey() { return guard.As<PageType>()->data.p_data[internal_offset].first; } const KeyType &GetKey() {
const b_plus_tree_value_index_t &GetValue() { return guard.As<PageType>()->data.p_data[internal_offset].second; } std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].first;
}
const b_plus_tree_value_index_t &GetValue() {
std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].second;
}
bool operator==(const_iterator &that) { bool operator==(const_iterator &that) {
return domain == that.domain && guard.PageId() == that.guard.PageId() && return domain == that.domain && is_end == that.is_end &&
internal_offset == that.internal_offset && is_end == that.is_end; (is_end || (guard.PageId() == that.guard.PageId() && internal_offset == that.internal_offset));
} }
// only support ++it // only support ++it
const_iterator &operator++() { const_iterator &operator++() {
@ -137,6 +196,7 @@ class BPlusTreeIndexer {
memcpy(&root_page_id, raw_data_memory, sizeof(page_id_t)); memcpy(&root_page_id, raw_data_memory, sizeof(page_id_t));
memcpy(&siz, raw_data_memory + sizeof(page_id_t), sizeof(bpt_size_t)); memcpy(&siz, raw_data_memory + sizeof(page_id_t), sizeof(bpt_size_t));
} }
~BPlusTreeIndexer() { Flush(); }
iterator end() { // Finish Design iterator end() { // Finish Design
iterator res; iterator res;
res.domain = this; res.domain = this;
@ -154,8 +214,9 @@ class BPlusTreeIndexer {
PositionSignType pos(std::move(FindPosition(key))); PositionSignType pos(std::move(FindPosition(key)));
iterator res; iterator res;
res.domain = this; res.domain = this;
res.guard = bpm->FetchPageWrite(pos.path.back().first.PageId());
res.is_end = pos.is_end; res.is_end = pos.is_end;
if (res.is_end) return res;
res.guard = bpm->FetchPageWrite(pos.path.back().first.PageId());
res.internal_offset = pos.path.back().second; res.internal_offset = pos.path.back().second;
return res; return res;
} }
@ -164,18 +225,21 @@ class BPlusTreeIndexer {
PositionSignType pos(std::move(FindPosition(key))); PositionSignType pos(std::move(FindPosition(key)));
const_iterator res; const_iterator res;
res.domain = this; res.domain = this;
res.guard = bpm->FetchPageRead(pos.path.back().first.PageId());
res.is_end = pos.is_end; res.is_end = pos.is_end;
if (res.is_end) return res;
res.guard = bpm->FetchPageRead(pos.path.back().first.PageId());
res.internal_offset = pos.path.back().second; res.internal_offset = pos.path.back().second;
return res; return res;
} }
b_plus_tree_value_index_t Get(const KeyType &key) { // Finish Design b_plus_tree_value_index_t Get(const KeyType &key) { // Finish Design
std::shared_lock<std::shared_mutex> guard(latch);
auto it = lower_bound_const(key); auto it = lower_bound_const(key);
if (it == end_const()) return kInvalidValueIndex; if (it == end_const()) return kInvalidValueIndex;
if (key_cmp(key, it.GetKey())) return kInvalidValueIndex; if (key_cmp(key, it.GetKey())) return kInvalidValueIndex;
return it.GetValue(); return it.GetValue();
} }
bool Put(const KeyType &key, b_plus_tree_value_index_t value) { // Finish Design bool Put(const KeyType &key, b_plus_tree_value_index_t value) { // Finish Design
std::unique_lock<std::shared_mutex> guard(latch);
PositionSignType pos(std::move(FindPosition(key))); PositionSignType pos(std::move(FindPosition(key)));
if (!pos.is_end && if (!pos.is_end &&
!key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first)) { !key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first)) {
@ -186,6 +250,7 @@ class BPlusTreeIndexer {
return true; return true;
} }
bool Remove(const KeyType &key) { // Finish Design bool Remove(const KeyType &key) { // Finish Design
std::unique_lock<std::shared_mutex> guard(latch);
PositionSignType pos(std::move(FindPosition(key))); PositionSignType pos(std::move(FindPosition(key)));
if (pos.is_end) return false; if (pos.is_end) return false;
if (key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first)) if (key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first))
@ -195,6 +260,7 @@ class BPlusTreeIndexer {
} }
size_t Size() { return siz; } // Finish Design size_t Size() { return siz; } // Finish Design
void Flush() { // Finish Design void Flush() { // Finish Design
std::unique_lock<std::shared_mutex> guard(latch);
memcpy(raw_data_memory, &root_page_id, sizeof(page_id_t)); memcpy(raw_data_memory, &root_page_id, sizeof(page_id_t));
memcpy(raw_data_memory + sizeof(page_id_t), &siz, sizeof(bpt_size_t)); memcpy(raw_data_memory + sizeof(page_id_t), &siz, sizeof(bpt_size_t));
bpm->FlushAllPages(); bpm->FlushAllPages();

View File

@ -6,10 +6,12 @@ template <typename KeyType, size_t kPageSize = 4096>
struct ActualDataType { struct ActualDataType {
typedef std::pair<KeyType, default_numeric_index_t> value_type; typedef std::pair<KeyType, default_numeric_index_t> value_type;
page_id_t p_n; page_id_t p_n;
page_status_t page_status; // root(2) / internal(1) / leaf(0) page_status_t page_status; // root(4) / internal(2) / leaf(1)
in_page_key_count_t key_count; in_page_key_count_t key_count;
const static size_t kMaxKeyCount = const static size_t kMaxKeyCount =
(kPageSize - sizeof(page_id_t) - sizeof(page_status_t) - sizeof(in_page_key_count_t)) / sizeof(value_type); (kPageSize - sizeof(page_id_t) - sizeof(page_status_t) - sizeof(in_page_key_count_t)) / sizeof(value_type);
const static size_t kMinNumberOfKeysForInternal = (kMaxKeyCount) / 2;
const static size_t kMinNumberOfKeysForLeaf = (kMaxKeyCount + 1) / 2;
value_type p_data[kMaxKeyCount]; value_type p_data[kMaxKeyCount];
static_assert(kMaxKeyCount >= 2, "kMaxKeyCount must be greater than or equal to 2"); static_assert(kMaxKeyCount >= 2, "kMaxKeyCount must be greater than or equal to 2");
}; };

View File

@ -11,4 +11,9 @@ extern const b_plus_tree_value_index_t kInvalidValueIndex;
typedef uint8_t page_status_t; typedef uint8_t page_status_t;
typedef uint16_t in_page_key_count_t; typedef uint16_t in_page_key_count_t;
typedef uint64_t bpt_size_t; typedef uint64_t bpt_size_t;
enum PageStatusType {
LEAF = 1,
INTERNAL = 2,
ROOT = 4,
};
#endif #endif

View File

@ -12,6 +12,7 @@ class FixLengthString {
}; };
} // namespace bpt_basic_test } // namespace bpt_basic_test
TEST(BasicTest, Compile) { // This Test only test the compile of the code TEST(BasicTest, Compile) { // This Test only test the compile of the code
return;
// test for long long, int, char, long double // test for long long, int, char, long double
BPlusTreePage<long long> page_long_long; BPlusTreePage<long long> page_long_long;
static_assert(sizeof(page_long_long) == 4096, "BPlusTreePage size mismatch"); static_assert(sizeof(page_long_long) == 4096, "BPlusTreePage size mismatch");
@ -53,3 +54,85 @@ TEST(BasicTest, Compile) { // This Test only test the compile of the code
delete bpm; delete bpm;
delete dm; delete dm;
} }
TEST(BasicTest, Put_and_Get) {
remove("/tmp/bpt2.db");
DiskManager *dm = new DiskManager("/tmp/bpt2.db");
BufferPoolManager *bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
bpt.Put(1, 2);
ASSERT_EQ(bpt.Get(1), 2);
bpt.Put(2, 5);
ASSERT_EQ(bpt.Get(2), 5);
bpt.Put(3, 7);
ASSERT_EQ(bpt.Get(3), 7);
bpt.Put(4, 11);
ASSERT_EQ(bpt.Get(4), 11);
bpt.Put(2, 15);
ASSERT_EQ(bpt.Get(2), 15);
ASSERT_EQ(bpt.Get(3), 7);
ASSERT_EQ(bpt.Get(1), 2);
ASSERT_EQ(bpt.Get(4), 11);
}
delete bpm;
delete dm;
dm = new DiskManager("/tmp/bpt2.db");
bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
ASSERT_EQ(bpt.Get(2), 15);
ASSERT_EQ(bpt.Get(3), 7);
ASSERT_EQ(bpt.Get(1), 2);
ASSERT_EQ(bpt.Get(4), 11);
}
delete bpm;
delete dm;
}
TEST(BasicTest, Put_Get_Remove) {
remove("/tmp/bpt3.db");
DiskManager *dm = new DiskManager("/tmp/bpt3.db");
BufferPoolManager *bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
bpt.Put(1, 2);
ASSERT_EQ(bpt.Get(1), 2);
bpt.Put(2, 5);
ASSERT_EQ(bpt.Get(2), 5);
bpt.Put(3, 7);
ASSERT_EQ(bpt.Get(3), 7);
bpt.Put(4, 11);
ASSERT_EQ(bpt.Get(4), 11);
bpt.Put(2, 15);
ASSERT_EQ(bpt.Get(2), 15);
ASSERT_EQ(bpt.Get(3), 7);
ASSERT_EQ(bpt.Get(1), 2);
bpt.Put(9, 11);
ASSERT_EQ(bpt.Get(4), 11);
bpt.Remove(2);
bpt.Remove(2);
ASSERT_EQ(bpt.Get(2), kInvalidValueIndex);
bpt.Remove(3);
ASSERT_EQ(bpt.Get(3), kInvalidValueIndex);
bpt.Remove(1);
ASSERT_EQ(bpt.Get(1), kInvalidValueIndex);
bpt.Remove(4);
bpt.Remove(4);
ASSERT_EQ(bpt.Get(4), kInvalidValueIndex);
}
delete bpm;
delete dm;
dm = new DiskManager("/tmp/bpt3.db");
bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
ASSERT_EQ(bpt.Get(2), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(3), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(1), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(4), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(9), 11);
}
delete bpm;
delete dm;
}