finish simple case

This commit is contained in:
2024-04-27 13:46:33 +00:00
parent 3d7b616dc7
commit 87e19fb96c
4 changed files with 174 additions and 18 deletions

View File

@ -1,6 +1,7 @@
#ifndef BPT_HPP
#define BPT_HPP
#include <algorithm>
#include <cstring>
#include <shared_mutex>
#include <vector>
#include "bpt/bpt_page.hpp"
@ -9,11 +10,13 @@
/**
* @brief B+ Tree Indexer
* @warning The KeyType must can be stored byte by byte. As this is only the indexer, the type of value is always
* b_plus_tree_value_index_t.
* b_plus_tree_value_index_t. And also, this is only the indexer, the value is not stored in the indexer, the value is
* stored in the value file, and the BPlusTreeIndexer should not be used directly.
*/
template <typename KeyType, typename KeyComparator>
class BPlusTreeIndexer {
typedef BPlusTreePage<KeyType> PageType;
typedef ActualDataType<KeyType> _ActualDataType;
typedef std::pair<KeyType, default_numeric_index_t> key_index_pair_t;
typedef std::pair<KeyType, b_plus_tree_value_index_t> value_type;
@ -24,6 +27,7 @@ class BPlusTreeIndexer {
};
PositionSignType FindPosition(const KeyType &key) { // Finish Design
if (root_page_id == 0) {
// special case for the empty tree
return PositionSignType{.is_end = true};
}
BasicPageGuard current_page_guard(bpm->FetchPageBasic(root_page_id));
@ -36,7 +40,8 @@ class BPlusTreeIndexer {
key, comparer_for_key_index_pair) -
current_page_guard.As<PageType>()->data.p_data;
PositionSignType res;
while (res.path.back().first.template As<PageType>()->data.page_status != 0) {
res.path.push_back(std::make_pair(std::move(current_page_guard), nxt));
while ((res.path.back().first.template As<PageType>()->data.page_status & PageStatusType::LEAF) == 0) {
default_numeric_index_t nxt_page_id;
in_page_key_count_t internal_id = res.path.back().second;
if (internal_id < res.path.back().first.template As<PageType>()->data.key_count)
@ -55,9 +60,50 @@ class BPlusTreeIndexer {
return res;
}
void InsertEntryAt(PositionSignType &pos, const KeyType &key, b_plus_tree_value_index_t value) {
if (siz == 0) {
// special case for the first entry
BasicPageGuard new_page_guard = bpm->NewPageGuarded(&root_page_id);
new_page_guard.AsMut<PageType>()->data.page_status = PageStatusType::ROOT | PageStatusType::LEAF;
new_page_guard.AsMut<PageType>()->data.key_count = 1;
new_page_guard.AsMut<PageType>()->data.p_data[0] = std::make_pair(key, value);
new_page_guard.AsMut<PageType>()->data.p_n = 0;
++siz;
return;
}
auto &page_guard = pos.path.back().first;
if (page_guard.template As<PageType>()->data.key_count < _ActualDataType::kMaxKeyCount) {
// case 1: the page has enough space
memmove(page_guard.template AsMut<PageType>()->data.p_data + pos.path.back().second + 1,
page_guard.template As<PageType>()->data.p_data + pos.path.back().second,
(page_guard.template As<PageType>()->data.key_count - pos.path.back().second) * sizeof(value_type));
page_guard.template AsMut<PageType>()->data.p_data[pos.path.back().second] = std::make_pair(key, value);
page_guard.template AsMut<PageType>()->data.key_count++;
++siz;
return;
}
throw std::runtime_error("Not implemented yet: InsertEntryAt");
// TODO
}
void RemoveEntryAt(PositionSignType &pos) {
if (siz == 1) {
// special case for the last entry
bpm->DeletePage(root_page_id);
root_page_id = 0;
--siz;
return;
}
auto &page_guard = pos.path.back().first;
if (page_guard.template As<PageType>()->data.key_count > _ActualDataType::kMinNumberOfKeysForLeaf ||
(page_guard.template As<PageType>()->data.page_status & PageStatusType::ROOT) != 0) {
// case 1: the page has enough keys
memmove(page_guard.template AsMut<PageType>()->data.p_data + pos.path.back().second,
page_guard.template As<PageType>()->data.p_data + pos.path.back().second + 1,
(page_guard.template As<PageType>()->data.key_count - pos.path.back().second - 1) * sizeof(value_type));
page_guard.template AsMut<PageType>()->data.key_count--;
--siz;
return;
}
throw std::runtime_error("Not implemented yet: RemoveEntryAt");
// TODO
}
@ -71,13 +117,20 @@ class BPlusTreeIndexer {
friend class BPlusTreeIndexer;
public:
const KeyType &GetKey() const { return guard.As<PageType>()->data.p_data[internal_offset].first; }
const b_plus_tree_value_index_t &GetValue() { return guard.As<PageType>()->data.p_data[internal_offset].second; }
const KeyType &GetKey() const {
std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].first;
}
const b_plus_tree_value_index_t &GetValue() {
std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].second;
}
bool operator==(iterator &that) {
return domain == that.domain && guard.PageId() == that.guard.PageId() &&
internal_offset == that.internal_offset && is_end == that.is_end;
return domain == that.domain && is_end == that.is_end &&
(is_end || (guard.PageId() == that.guard.PageId() && internal_offset == that.internal_offset));
}
void SetValue(b_plus_tree_value_index_t new_value) {
std::unique_lock<std::shared_mutex> lock_guard(domain->latch);
guard.AsMut<PageType>()->data.p_data[internal_offset].second = new_value;
}
// only support ++it
@ -104,11 +157,17 @@ class BPlusTreeIndexer {
friend class BPlusTreeIndexer;
public:
const KeyType &GetKey() { return guard.As<PageType>()->data.p_data[internal_offset].first; }
const b_plus_tree_value_index_t &GetValue() { return guard.As<PageType>()->data.p_data[internal_offset].second; }
const KeyType &GetKey() {
std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].first;
}
const b_plus_tree_value_index_t &GetValue() {
std::shared_lock<std::shared_mutex> lock_guard(domain->latch);
return guard.As<PageType>()->data.p_data[internal_offset].second;
}
bool operator==(const_iterator &that) {
return domain == that.domain && guard.PageId() == that.guard.PageId() &&
internal_offset == that.internal_offset && is_end == that.is_end;
return domain == that.domain && is_end == that.is_end &&
(is_end || (guard.PageId() == that.guard.PageId() && internal_offset == that.internal_offset));
}
// only support ++it
const_iterator &operator++() {
@ -137,13 +196,14 @@ class BPlusTreeIndexer {
memcpy(&root_page_id, raw_data_memory, sizeof(page_id_t));
memcpy(&siz, raw_data_memory + sizeof(page_id_t), sizeof(bpt_size_t));
}
iterator end() { // Finish Design
~BPlusTreeIndexer() { Flush(); }
iterator end() { // Finish Design
iterator res;
res.domain = this;
res.is_end = true;
return res;
}
const_iterator end_const() { // Finish Design
const_iterator end_const() { // Finish Design
const_iterator res;
res.domain = this;
res.is_end = true;
@ -154,8 +214,9 @@ class BPlusTreeIndexer {
PositionSignType pos(std::move(FindPosition(key)));
iterator res;
res.domain = this;
res.guard = bpm->FetchPageWrite(pos.path.back().first.PageId());
res.is_end = pos.is_end;
if (res.is_end) return res;
res.guard = bpm->FetchPageWrite(pos.path.back().first.PageId());
res.internal_offset = pos.path.back().second;
return res;
}
@ -164,18 +225,21 @@ class BPlusTreeIndexer {
PositionSignType pos(std::move(FindPosition(key)));
const_iterator res;
res.domain = this;
res.guard = bpm->FetchPageRead(pos.path.back().first.PageId());
res.is_end = pos.is_end;
if (res.is_end) return res;
res.guard = bpm->FetchPageRead(pos.path.back().first.PageId());
res.internal_offset = pos.path.back().second;
return res;
}
b_plus_tree_value_index_t Get(const KeyType &key) { // Finish Design
b_plus_tree_value_index_t Get(const KeyType &key) { // Finish Design
std::shared_lock<std::shared_mutex> guard(latch);
auto it = lower_bound_const(key);
if (it == end_const()) return kInvalidValueIndex;
if (key_cmp(key, it.GetKey())) return kInvalidValueIndex;
return it.GetValue();
}
bool Put(const KeyType &key, b_plus_tree_value_index_t value) { // Finish Design
bool Put(const KeyType &key, b_plus_tree_value_index_t value) { // Finish Design
std::unique_lock<std::shared_mutex> guard(latch);
PositionSignType pos(std::move(FindPosition(key)));
if (!pos.is_end &&
!key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first)) {
@ -185,7 +249,8 @@ class BPlusTreeIndexer {
InsertEntryAt(pos, key, value);
return true;
}
bool Remove(const KeyType &key) { // Finish Design
bool Remove(const KeyType &key) { // Finish Design
std::unique_lock<std::shared_mutex> guard(latch);
PositionSignType pos(std::move(FindPosition(key)));
if (pos.is_end) return false;
if (key_cmp(key, pos.path.back().first.template As<PageType>()->data.p_data[pos.path.back().second].first))
@ -195,6 +260,7 @@ class BPlusTreeIndexer {
}
size_t Size() { return siz; } // Finish Design
void Flush() { // Finish Design
std::unique_lock<std::shared_mutex> guard(latch);
memcpy(raw_data_memory, &root_page_id, sizeof(page_id_t));
memcpy(raw_data_memory + sizeof(page_id_t), &siz, sizeof(bpt_size_t));
bpm->FlushAllPages();

View File

@ -6,10 +6,12 @@ template <typename KeyType, size_t kPageSize = 4096>
struct ActualDataType {
typedef std::pair<KeyType, default_numeric_index_t> value_type;
page_id_t p_n;
page_status_t page_status; // root(2) / internal(1) / leaf(0)
page_status_t page_status; // root(4) / internal(2) / leaf(1)
in_page_key_count_t key_count;
const static size_t kMaxKeyCount =
(kPageSize - sizeof(page_id_t) - sizeof(page_status_t) - sizeof(in_page_key_count_t)) / sizeof(value_type);
const static size_t kMinNumberOfKeysForInternal = (kMaxKeyCount) / 2;
const static size_t kMinNumberOfKeysForLeaf = (kMaxKeyCount + 1) / 2;
value_type p_data[kMaxKeyCount];
static_assert(kMaxKeyCount >= 2, "kMaxKeyCount must be greater than or equal to 2");
};

View File

@ -11,4 +11,9 @@ extern const b_plus_tree_value_index_t kInvalidValueIndex;
typedef uint8_t page_status_t;
typedef uint16_t in_page_key_count_t;
typedef uint64_t bpt_size_t;
enum PageStatusType {
LEAF = 1,
INTERNAL = 2,
ROOT = 4,
};
#endif

View File

@ -12,6 +12,7 @@ class FixLengthString {
};
} // namespace bpt_basic_test
TEST(BasicTest, Compile) { // This Test only test the compile of the code
return;
// test for long long, int, char, long double
BPlusTreePage<long long> page_long_long;
static_assert(sizeof(page_long_long) == 4096, "BPlusTreePage size mismatch");
@ -52,4 +53,86 @@ TEST(BasicTest, Compile) { // This Test only test the compile of the code
bpt.Remove(1);
delete bpm;
delete dm;
}
TEST(BasicTest, Put_and_Get) {
remove("/tmp/bpt2.db");
DiskManager *dm = new DiskManager("/tmp/bpt2.db");
BufferPoolManager *bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
bpt.Put(1, 2);
ASSERT_EQ(bpt.Get(1), 2);
bpt.Put(2, 5);
ASSERT_EQ(bpt.Get(2), 5);
bpt.Put(3, 7);
ASSERT_EQ(bpt.Get(3), 7);
bpt.Put(4, 11);
ASSERT_EQ(bpt.Get(4), 11);
bpt.Put(2, 15);
ASSERT_EQ(bpt.Get(2), 15);
ASSERT_EQ(bpt.Get(3), 7);
ASSERT_EQ(bpt.Get(1), 2);
ASSERT_EQ(bpt.Get(4), 11);
}
delete bpm;
delete dm;
dm = new DiskManager("/tmp/bpt2.db");
bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
ASSERT_EQ(bpt.Get(2), 15);
ASSERT_EQ(bpt.Get(3), 7);
ASSERT_EQ(bpt.Get(1), 2);
ASSERT_EQ(bpt.Get(4), 11);
}
delete bpm;
delete dm;
}
TEST(BasicTest, Put_Get_Remove) {
remove("/tmp/bpt3.db");
DiskManager *dm = new DiskManager("/tmp/bpt3.db");
BufferPoolManager *bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
bpt.Put(1, 2);
ASSERT_EQ(bpt.Get(1), 2);
bpt.Put(2, 5);
ASSERT_EQ(bpt.Get(2), 5);
bpt.Put(3, 7);
ASSERT_EQ(bpt.Get(3), 7);
bpt.Put(4, 11);
ASSERT_EQ(bpt.Get(4), 11);
bpt.Put(2, 15);
ASSERT_EQ(bpt.Get(2), 15);
ASSERT_EQ(bpt.Get(3), 7);
ASSERT_EQ(bpt.Get(1), 2);
bpt.Put(9, 11);
ASSERT_EQ(bpt.Get(4), 11);
bpt.Remove(2);
bpt.Remove(2);
ASSERT_EQ(bpt.Get(2), kInvalidValueIndex);
bpt.Remove(3);
ASSERT_EQ(bpt.Get(3), kInvalidValueIndex);
bpt.Remove(1);
ASSERT_EQ(bpt.Get(1), kInvalidValueIndex);
bpt.Remove(4);
bpt.Remove(4);
ASSERT_EQ(bpt.Get(4), kInvalidValueIndex);
}
delete bpm;
delete dm;
dm = new DiskManager("/tmp/bpt3.db");
bpm = new BufferPoolManager(20, 3, dm);
{
BPlusTreeIndexer<long long, std::less<long long>> bpt(bpm);
ASSERT_EQ(bpt.Get(2), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(3), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(1), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(4), kInvalidValueIndex);
ASSERT_EQ(bpt.Get(9), 11);
}
delete bpm;
delete dm;
}