diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index f5d6add2..a8422524 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -42,11 +42,19 @@ class NumericTrie { void search_range(const int64_t& low, const int64_t& high, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector& matches); + void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + void search_less_than(const int64_t& value, const char& max_level, std::vector& matches); + void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + void search_greater_than(const int64_t& value, const char& max_level, std::vector& matches); + void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + + void search_equal_to(const int64_t& value, const char& max_level, std::vector& matches); }; Node* negative_trie = nullptr; @@ -63,17 +71,63 @@ public: delete positive_trie; } + class iterator_t { + struct match_state { + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + uint32_t index = 0; + + explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {} + + ~match_state() { + delete [] ids; + } + }; + + std::vector matches; + + void set_seq_id(); + + public: + + explicit iterator_t(std::vector& matches); + + ~iterator_t() { + for (auto& match: matches) { + delete match; + } + } + + iterator_t& operator=(iterator_t&& obj) noexcept; + + uint32_t seq_id = 0; + bool is_valid = true; + + void next(); + void skip_to(uint32_t id); + void reset(); + }; + void insert(const int64_t& value, const uint32_t& seq_id); void search_range(const int64_t& low, const bool& low_inclusive, const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length); + iterator_t search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive); + void search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); + iterator_t search_less_than(const int64_t& value, const bool& inclusive); + void search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); + iterator_t search_greater_than(const int64_t& value, const bool& inclusive); + void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length); + + iterator_t search_equal_to(const int64_t& value); }; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 3076f873..9d9f4aa0 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -98,6 +98,47 @@ void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, } } +NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive) { + std::vector matches; + if (low > high) { + return NumericTrie::iterator_t(matches); + } + + if (low < 0 && high >= 0) { + // Have to combine the results of >low from negative_trie and low from negative_trie. + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches); + } + + if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) + positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches); + } + } else if (low >= 0) { + // Search only in positive_trie + if (positive_trie == nullptr) { + return NumericTrie::iterator_t(matches); + } + + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches); + } else { + // Search only in negative_trie + if (negative_trie == nullptr) { + return NumericTrie::iterator_t(matches); + } + + auto abs_high = std::abs(high), abs_low = std::abs(low); + // Since we store absolute values, switching low and high would produce the correct result. + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + max_level, matches); + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { @@ -167,6 +208,35 @@ void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusiv } } +NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) { + std::vector matches; + + if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) + if (positive_trie != nullptr) { + matches.push_back(positive_trie); + } + return NumericTrie::iterator_t(matches); + } + + if (value >= 0) { + if (positive_trie != nullptr) { + positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches); + } + } else { + // Have to combine the results of >value from negative_trie and all the ids in positive_trie + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches); + } + if (positive_trie != nullptr) { + matches.push_back(positive_trie); + } + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] if (negative_trie != nullptr) { @@ -237,6 +307,35 @@ void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, } } +NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) { + std::vector matches; + + if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] + if (negative_trie != nullptr) { + matches.push_back(negative_trie); + } + return NumericTrie::iterator_t(matches); + } + + if (value < 0) { + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches); + } + } else { + // Have to combine the results of search_less_than(inclusive ? value : value - 1, max_level, matches); + } + if (negative_trie != nullptr) { + matches.push_back(negative_trie); + } + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) { if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { return; @@ -259,6 +358,17 @@ void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t ids = out; } +NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) { + std::vector matches; + if (value < 0 && negative_trie != nullptr) { + negative_trie->search_equal_to(std::abs(value), max_level, matches); + } else if (value >= 0 && positive_trie != nullptr) { + positive_trie->search_equal_to(value, max_level, matches); + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) { char level = 0; return insert_helper(value, seq_id, level, max_level); @@ -331,6 +441,11 @@ void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_l ids = out; } +void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 0; + search_less_than_helper(value, level, max_level, matches); +} + void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level, std::vector& matches) { if (level == max_level) { @@ -383,6 +498,15 @@ void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, co ids = out; } +void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level, + std::vector& matches) { + if (low > high) { + return; + } + + search_range_helper(low, high, max_level, matches); +} + void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, std::vector& matches) { // Segregating the nodes into matching low, in-between, and matching high. @@ -460,6 +584,11 @@ void NumericTrie::Node::search_greater_than(const int64_t& value, const char& ma ids = out; } +void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 0; + search_greater_than_helper(value, level, max_level, matches); +} + void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level, std::vector& matches) { if (level == max_level) { @@ -500,3 +629,100 @@ void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_le root->get_all_ids(ids, ids_length); } + +void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level <= max_level) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + matches.push_back(root); +} + +void NumericTrie::iterator_t::reset() { + for (auto& match: matches) { + match->index = 0; + } + + is_valid = true; + set_seq_id(); +} + +void NumericTrie::iterator_t::skip_to(uint32_t id) { + for (auto& match: matches) { + ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id); + } + + set_seq_id(); +} + +void NumericTrie::iterator_t::next() { + // Advance all the matches at seq_id. + for (auto& match: matches) { + if (match->index < match->ids_length && match->ids[match->index] == seq_id) { + match->index++; + } + } + + set_seq_id(); +} + +NumericTrie::iterator_t::iterator_t(std::vector& node_matches) { + for (auto const& node_match: node_matches) { + uint32_t* ids = nullptr; + uint32_t ids_length; + node_match->get_all_ids(ids, ids_length); + if (ids_length > 0) { + matches.emplace_back(new match_state(ids, ids_length)); + } + } + + set_seq_id(); +} + +void NumericTrie::iterator_t::set_seq_id() { + // Find the lowest id of all the matches and update the seq_id. + bool one_is_valid = false; + uint32_t lowest_id = UINT32_MAX; + + for (auto& match: matches) { + if (match->index < match->ids_length) { + one_is_valid = true; + + if (match->ids[match->index] < lowest_id) { + lowest_id = match->ids[match->index]; + } + } + } + + if (one_is_valid) { + seq_id = lowest_id; + } + + is_valid = one_is_valid; +} + +NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept { + if (&obj == this) + return *this; + + for (auto& match: matches) { + delete match; + } + matches.clear(); + + matches = std::move(obj.matches); + seq_id = obj.seq_id; + is_valid = obj.is_valid; + + return *this; +} + diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 29dff68a..d2fc6e16 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -436,6 +436,71 @@ TEST_F(NumericRangeTrieTest, SearchEqualTo) { ASSERT_EQ(0, ids_length); } +TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32769, 41}, + {-32768, 43}, + {-32767, 45}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {24576, 60}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + auto iterator = trie->search_equal_to(0); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(0x202020); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(-32768); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(43, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(24576); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(60, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(false, iterator.is_valid); + + + iterator.reset(); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.skip_to(4); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.skip_to(59); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(60, iterator.seq_id); + + iterator.skip_to(66); + ASSERT_EQ(false, iterator.is_valid); +} + TEST_F(NumericRangeTrieTest, MultivalueData) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie);