Add NumericTrie::iterator_t.

This commit is contained in:
Harpreet Sangar 2023-06-07 14:42:46 +05:30
parent aa753fffc1
commit 0321396f98
3 changed files with 345 additions and 0 deletions

View File

@ -42,11 +42,19 @@ class NumericTrie {
void search_range(const int64_t& low, const int64_t& high, const char& max_level,
uint32_t*& ids, uint32_t& ids_length);
void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector<Node*>& matches);
void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
void search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
void search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
void search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
};
Node* negative_trie = nullptr;
@ -63,17 +71,63 @@ public:
delete positive_trie;
}
class iterator_t {
struct match_state {
uint32_t* ids = nullptr;
uint32_t ids_length = 0;
uint32_t index = 0;
explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {}
~match_state() {
delete [] ids;
}
};
std::vector<match_state*> matches;
void set_seq_id();
public:
explicit iterator_t(std::vector<Node*>& matches);
~iterator_t() {
for (auto& match: matches) {
delete match;
}
}
iterator_t& operator=(iterator_t&& obj) noexcept;
uint32_t seq_id = 0;
bool is_valid = true;
void next();
void skip_to(uint32_t id);
void reset();
};
void insert(const int64_t& value, const uint32_t& seq_id);
void search_range(const int64_t& low, const bool& low_inclusive,
const int64_t& high, const bool& high_inclusive,
uint32_t*& ids, uint32_t& ids_length);
iterator_t search_range(const int64_t& low, const bool& low_inclusive,
const int64_t& high, const bool& high_inclusive);
void search_less_than(const int64_t& value, const bool& inclusive,
uint32_t*& ids, uint32_t& ids_length);
iterator_t search_less_than(const int64_t& value, const bool& inclusive);
void search_greater_than(const int64_t& value, const bool& inclusive,
uint32_t*& ids, uint32_t& ids_length);
iterator_t search_greater_than(const int64_t& value, const bool& inclusive);
void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length);
iterator_t search_equal_to(const int64_t& value);
};

View File

@ -98,6 +98,47 @@ void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
}
}
NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
const int64_t& high, const bool& high_inclusive) {
std::vector<Node*> matches;
if (low > high) {
return NumericTrie::iterator_t(matches);
}
if (low < 0 && high >= 0) {
// Have to combine the results of >low from negative_trie and <high from positive_trie
if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ...
auto abs_low = std::abs(low);
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches);
}
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches);
}
} else if (low >= 0) {
// Search only in positive_trie
if (positive_trie == nullptr) {
return NumericTrie::iterator_t(matches);
}
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches);
} else {
// Search only in negative_trie
if (negative_trie == nullptr) {
return NumericTrie::iterator_t(matches);
}
auto abs_high = std::abs(high), abs_low = std::abs(low);
// Since we store absolute values, switching low and high would produce the correct result.
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
max_level, matches);
}
return NumericTrie::iterator_t(matches);
}
void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
if (positive_trie != nullptr) {
@ -167,6 +208,35 @@ void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusiv
}
}
NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) {
std::vector<Node*> matches;
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
if (positive_trie != nullptr) {
matches.push_back(positive_trie);
}
return NumericTrie::iterator_t(matches);
}
if (value >= 0) {
if (positive_trie != nullptr) {
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches);
}
} else {
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
if (negative_trie != nullptr) {
auto abs_low = std::abs(value);
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches);
}
if (positive_trie != nullptr) {
matches.push_back(positive_trie);
}
}
return NumericTrie::iterator_t(matches);
}
void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
if (negative_trie != nullptr) {
@ -237,6 +307,35 @@ void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive,
}
}
NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) {
std::vector<Node*> matches;
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
if (negative_trie != nullptr) {
matches.push_back(negative_trie);
}
return NumericTrie::iterator_t(matches);
}
if (value < 0) {
if (negative_trie != nullptr) {
auto abs_low = std::abs(value);
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches);
}
} else {
// Have to combine the results of <value from positive_trie and all the ids in negative_trie
if (positive_trie != nullptr) {
positive_trie->search_less_than(inclusive ? value : value - 1, max_level, matches);
}
if (negative_trie != nullptr) {
matches.push_back(negative_trie);
}
}
return NumericTrie::iterator_t(matches);
}
void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) {
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
return;
@ -259,6 +358,17 @@ void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t
ids = out;
}
NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) {
std::vector<Node*> matches;
if (value < 0 && negative_trie != nullptr) {
negative_trie->search_equal_to(std::abs(value), max_level, matches);
} else if (value >= 0 && positive_trie != nullptr) {
positive_trie->search_equal_to(value, max_level, matches);
}
return NumericTrie::iterator_t(matches);
}
void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) {
char level = 0;
return insert_helper(value, seq_id, level, max_level);
@ -331,6 +441,11 @@ void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_l
ids = out;
}
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
char level = 0;
search_less_than_helper(value, level, max_level, matches);
}
void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level,
std::vector<Node*>& matches) {
if (level == max_level) {
@ -383,6 +498,15 @@ void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, co
ids = out;
}
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
std::vector<Node*>& matches) {
if (low > high) {
return;
}
search_range_helper(low, high, max_level, matches);
}
void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
std::vector<Node*>& matches) {
// Segregating the nodes into matching low, in-between, and matching high.
@ -460,6 +584,11 @@ void NumericTrie::Node::search_greater_than(const int64_t& value, const char& ma
ids = out;
}
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
char level = 0;
search_greater_than_helper(value, level, max_level, matches);
}
void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
std::vector<Node*>& matches) {
if (level == max_level) {
@ -500,3 +629,100 @@ void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_le
root->get_all_ids(ids, ids_length);
}
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
char level = 1;
Node* root = this;
auto index = get_index(value, level, max_level);
while (level <= max_level) {
if (root->children == nullptr || root->children[index] == nullptr) {
return;
}
root = root->children[index];
index = get_index(value, ++level, max_level);
}
matches.push_back(root);
}
void NumericTrie::iterator_t::reset() {
for (auto& match: matches) {
match->index = 0;
}
is_valid = true;
set_seq_id();
}
void NumericTrie::iterator_t::skip_to(uint32_t id) {
for (auto& match: matches) {
ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id);
}
set_seq_id();
}
void NumericTrie::iterator_t::next() {
// Advance all the matches at seq_id.
for (auto& match: matches) {
if (match->index < match->ids_length && match->ids[match->index] == seq_id) {
match->index++;
}
}
set_seq_id();
}
NumericTrie::iterator_t::iterator_t(std::vector<Node*>& node_matches) {
for (auto const& node_match: node_matches) {
uint32_t* ids = nullptr;
uint32_t ids_length;
node_match->get_all_ids(ids, ids_length);
if (ids_length > 0) {
matches.emplace_back(new match_state(ids, ids_length));
}
}
set_seq_id();
}
void NumericTrie::iterator_t::set_seq_id() {
// Find the lowest id of all the matches and update the seq_id.
bool one_is_valid = false;
uint32_t lowest_id = UINT32_MAX;
for (auto& match: matches) {
if (match->index < match->ids_length) {
one_is_valid = true;
if (match->ids[match->index] < lowest_id) {
lowest_id = match->ids[match->index];
}
}
}
if (one_is_valid) {
seq_id = lowest_id;
}
is_valid = one_is_valid;
}
NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept {
if (&obj == this)
return *this;
for (auto& match: matches) {
delete match;
}
matches.clear();
matches = std::move(obj.matches);
seq_id = obj.seq_id;
is_valid = obj.is_valid;
return *this;
}

View File

@ -436,6 +436,71 @@ TEST_F(NumericRangeTrieTest, SearchEqualTo) {
ASSERT_EQ(0, ids_length);
}
TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) {
auto trie = new NumericTrie();
std::unique_ptr<NumericTrie> trie_guard(trie);
std::vector<std::pair<int32_t, uint32_t>> pairs = {
{-8192, 8},
{-16384, 32},
{-24576, 35},
{-32769, 41},
{-32768, 43},
{-32767, 45},
{8192, 49},
{16384, 56},
{24576, 58},
{24576, 60},
{32768, 91}
};
for (auto const& pair: pairs) {
trie->insert(pair.first, pair.second);
}
uint32_t* ids = nullptr;
uint32_t ids_length = 0;
auto iterator = trie->search_equal_to(0);
ASSERT_EQ(false, iterator.is_valid);
iterator = trie->search_equal_to(0x202020);
ASSERT_EQ(false, iterator.is_valid);
iterator = trie->search_equal_to(-32768);
ASSERT_EQ(true, iterator.is_valid);
ASSERT_EQ(43, iterator.seq_id);
iterator.next();
ASSERT_EQ(false, iterator.is_valid);
iterator = trie->search_equal_to(24576);
ASSERT_EQ(true, iterator.is_valid);
ASSERT_EQ(58, iterator.seq_id);
iterator.next();
ASSERT_EQ(true, iterator.is_valid);
ASSERT_EQ(60, iterator.seq_id);
iterator.next();
ASSERT_EQ(false, iterator.is_valid);
iterator.reset();
ASSERT_EQ(true, iterator.is_valid);
ASSERT_EQ(58, iterator.seq_id);
iterator.skip_to(4);
ASSERT_EQ(true, iterator.is_valid);
ASSERT_EQ(58, iterator.seq_id);
iterator.skip_to(59);
ASSERT_EQ(true, iterator.is_valid);
ASSERT_EQ(60, iterator.seq_id);
iterator.skip_to(66);
ASSERT_EQ(false, iterator.is_valid);
}
TEST_F(NumericRangeTrieTest, MultivalueData) {
auto trie = new NumericTrie();
std::unique_ptr<NumericTrie> trie_guard(trie);