From 7efa6908100e5b6c68c2a932733fae9c752eb291 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 15 Jun 2023 14:25:37 +0530 Subject: [PATCH] Add `NumericTrie::remove`. --- include/numeric_range_trie_test.h | 6 ++- src/index.cpp | 27 +++++++++++- src/numeric_range_trie.cpp | 55 ++++++++++++++++++++--- test/numeric_range_trie_test.cpp | 72 +++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 7 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index 3b6f8f68..8b7bd22c 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -41,9 +41,11 @@ class NumericTrie { void insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + void remove(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level); - void search_geopoints(const std::vector& cell_ids, const char& max_index_level, + void search_geopoints(const std::vector& cell_ids, const char& max_level, std::vector& geo_result_ids); void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level); @@ -121,6 +123,8 @@ public: void insert(const int64_t& value, const uint32_t& seq_id); + void remove(const int64_t& value, const uint32_t& seq_id); + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id); void search_geopoints(const std::vector& cell_ids, std::vector& geo_result_ids); diff --git a/src/index.cpp b/src/index.cpp index a4e6f8a4..fd52d2c4 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -5423,6 +5423,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const std::vector{document[field_name].get()} : document[field_name].get>(); for(int32_t value: values) { + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(value, seq_id); + } + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(value, seq_id); if(search_field.facet) { @@ -5434,6 +5439,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const std::vector{document[field_name].get()} : document[field_name].get>(); for(int64_t value: values) { + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(value, seq_id); + } + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(value, seq_id); if(search_field.facet) { @@ -5448,8 +5458,14 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const document[field_name].get>(); for(float value: values) { - num_tree_t* num_tree = numerical_index.at(field_name); int64_t fintval = float_to_int64_t(value); + + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(fintval, seq_id); + } + + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(fintval, seq_id); if(search_field.facet) { remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id); @@ -5641,6 +5657,10 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec } else { num_tree_t* num_tree = new num_tree_t; numerical_index.emplace(new_field.name, num_tree); + + if (new_field.range_index) { + range_index.emplace(new_field.name, new NumericTrie(new_field.is_int32() ? 32 : 64)); + } } } @@ -5697,6 +5717,11 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec } else { delete numerical_index[del_field.name]; numerical_index.erase(del_field.name); + + if (del_field.range_index) { + delete range_index[del_field.name]; + range_index.erase(del_field.name); + } } if(del_field.is_sortable()) { diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 7ac88590..f70de113 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -19,6 +19,18 @@ void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) { } } +void NumericTrie::remove(const int64_t& value, const uint32_t& seq_id) { + if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { + return; + } + + if (value < 0) { + negative_trie->remove(std::abs(value), seq_id, max_level); + } else { + positive_trie->remove(value, seq_id, max_level); + } +} + void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id) { if (positive_trie == nullptr) { positive_trie = new NumericTrie::Node(); @@ -420,6 +432,34 @@ inline int get_geopoint_index(const uint64_t& cell_id, const char& level) { return (cell_id >> (8 * (8 - level))) & 0xFF; } +void NumericTrie::Node::remove(const int64_t& value, const uint32_t& id, const char& max_level) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level < max_level) { + root->seq_ids.remove_value(id); + + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + root->seq_ids.remove_value(id); + if (root->children != nullptr && root->children[index] != nullptr) { + auto& child = root->children[index]; + + child->seq_ids.remove_value(id); + if (child->seq_ids.getLength() == 0) { + delete child; + child = nullptr; + } + } +} + void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { if (level > max_level) { return; @@ -501,11 +541,11 @@ void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const c matches.insert(root); } -void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, const char& max_index_level, +void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, const char& max_level, std::vector& geo_result_ids) { std::set matches; for (const auto &cell_id: cell_ids) { - search_geopoints_helper(cell_id, max_index_level, matches); + search_geopoints_helper(cell_id, max_level, matches); } for (auto const& match: matches) { @@ -538,9 +578,14 @@ void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, co } root->seq_ids.remove_value(id); - if (root->children != nullptr || root->children[index] != nullptr) { - delete root->children[index]; - root->children[index] = nullptr; + if (root->children != nullptr && root->children[index] != nullptr) { + auto& child = root->children[index]; + + child->seq_ids.remove_value(id); + if (child->seq_ids.getLength() == 0) { + delete child; + child = nullptr; + } } } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index d2fc6e16..2412b5a5 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -601,6 +601,75 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { reset(ids, ids_length); } +TEST_F(NumericRangeTrieTest, Remove) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-0x202020, 32}, + {-32768, 5}, + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {0, 2}, + {0, 49}, + {1, 8}, + {256, 91}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91}, + {0x202020, 35}, + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_less_than(0, false, ids, ids_length); + + std::vector expected = {5, 8, 32, 35, 43}; + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->remove(-24576, 32); + trie->remove(-0x202020, 32); + + reset(ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); + + expected = {5, 8, 35, 43}; + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_equal_to(0, ids, ids_length); + + expected = {2, 49}; + ASSERT_EQ(2, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->remove(0, 2); + + reset(ids, ids_length); + trie->search_equal_to(0, ids, ids_length); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(49, ids[0]); + + reset(ids, ids_length); +} + TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); @@ -657,6 +726,9 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + + trie->remove(15, 0); + trie->remove(-15, 0); } TEST_F(NumericRangeTrieTest, Integration) {