Add NumericTrieNode::search_range.

This commit is contained in:
Harpreet Sangar 2023-05-30 18:04:55 +05:30
parent a10e5e532a
commit 92bbe8de9f
3 changed files with 156 additions and 10 deletions

View File

@ -10,7 +10,9 @@ class NumericTrieNode {
NumericTrieNode** children = nullptr;
sorted_array seq_ids;
void insert(const int32_t& value, const uint32_t& seq_id, char& level);
void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level);
void search_range_helper(const int32_t& low,const int32_t& high, std::vector<NumericTrieNode*>& matches);
void search_lesser_helper(const int32_t& value, char& level, std::vector<NumericTrieNode*>& matches);
@ -32,7 +34,7 @@ public:
void get_all_ids(uint32_t*& ids, uint32_t& ids_length);
void search_range(const int32_t& low,const int32_t& high,
void search_range(const int32_t& low, const int32_t& high,
uint32_t*& ids, uint32_t& ids_length);
void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);

View File

@ -20,7 +20,7 @@ void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) {
void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
const int32_t& high, const bool& high_inclusive,
uint32_t*& ids, uint32_t& ids_length) {
if (low >= high) {
if (low > high) {
return;
}
@ -48,8 +48,30 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
return;
} else if (low >= 0) {
// Search only in positive_trie
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
if (positive_trie != nullptr) {
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1,
positive_ids, positive_ids_length);
}
ids = positive_ids;
ids_length = positive_ids_length;
} else {
// Search only in negative_trie
uint32_t* negative_ids = nullptr;
uint32_t negative_ids_length = 0;
if (negative_trie != nullptr) {
// Since we store absolute values, switching low and high would produce the correct result.
auto abs_high = std::abs(high), abs_low = std::abs(low);
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
negative_ids, negative_ids_length);
}
ids = negative_ids;
ids_length = negative_ids_length;
}
}
@ -139,7 +161,7 @@ void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uin
void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) {
char level = 0;
return insert(value, seq_id, level);
return insert_helper(value, seq_id, level);
}
inline int get_index(const int32_t& value, char& level) {
@ -153,7 +175,7 @@ inline int get_index(const int32_t& value, char& level) {
return (value >> (8 * (MAX_LEVEL - level))) & 0xFF;
}
void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& level) {
void NumericTrieNode::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) {
if (level > MAX_LEVEL) {
return;
}
@ -173,7 +195,7 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char&
children[index] = new NumericTrieNode();
}
return children[index]->insert(value, seq_id, level);
return children[index]->insert_helper(value, seq_id, level);
}
}
@ -221,7 +243,71 @@ void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, st
}
void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) {
if (low > high) {
return;
}
std::vector<NumericTrieNode*> matches;
search_range_helper(low, high, matches);
for (auto const& match: matches) {
uint32_t* out = nullptr;
auto const& m_seq_ids = match->seq_ids.uncompress();
ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out);
delete [] m_seq_ids;
delete [] ids;
ids = out;
}
}
void NumericTrieNode::search_range_helper(const int32_t& low, const int32_t& high,
std::vector<NumericTrieNode*>& matches) {
// Segregating the nodes into matching low, in-between, and matching high.
NumericTrieNode* root = this;
char level = 1;
auto low_index = get_index(low, level), high_index = get_index(high, level);
// Keep updating the root while the range is contained within a single child node.
while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) {
if (root->children[low_index] == nullptr) {
return;
}
root = root->children[low_index];
level++;
low_index = get_index(low, level);
high_index = get_index(high, level);
}
if (root->children == nullptr) {
return;
} else if (low_index == high_index) { // low and high are equal
if (root->children[low_index] != nullptr) {
matches.push_back(root->children[low_index]);
}
return;
}
if (root->children[low_index] != nullptr) {
// Collect all the sub-nodes that are greater than low.
root->children[low_index]->search_greater_helper(low, level, matches);
}
auto index = low_index + 1;
// All the nodes in-between low and high are a match by default.
while (index < std::min(high_index, (int)EXPANSE)) {
if (root->children[index] != nullptr) {
matches.push_back(root->children[index]);
}
index++;
}
if (index < EXPANSE && index == high_index && root->children[index] != nullptr) {
// Collect all the sub-nodes that are lesser than high.
root->children[index]->search_lesser_helper(high, level, matches);
}
}
void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {

View File

@ -92,6 +92,14 @@ TEST_F(NumericRangeTrieTest, SearchRange) {
ASSERT_EQ(pairs[i].second, ids[i]);
}
trie->search_range(-134217728, true, 134217728, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(pairs.size(), ids_length);
for (uint32_t i = 0; i < pairs.size(); i++) {
ASSERT_EQ(pairs[i].second, ids[i]);
}
trie->search_range(-1, true, 32768, true, ids, ids_length);
ids_guard.reset(ids);
@ -113,6 +121,56 @@ TEST_F(NumericRangeTrieTest, SearchRange) {
trie->search_range(-1, false, 0, false, ids, ids_length);
ASSERT_EQ(0, ids_length);
trie->search_range(8192, true, 32768, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
trie->search_range(8192, true, 0x2000000, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
trie->search_range(16384, true, 16384, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(1, ids_length);
ASSERT_EQ(56, ids[0]);
trie->search_range(16384, true, 16384, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
trie->search_range(16384, false, 16384, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
trie->search_range(16383, true, 16383, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
trie->search_range(8193, true, 16383, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
trie->search_range(-32768, true, -8192, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 0; i < ids_length; i++) {
ASSERT_EQ(pairs[i].second, ids[i]);
}
}
TEST_F(NumericRangeTrieTest, SearchGreater) {
@ -194,12 +252,12 @@ TEST_F(NumericRangeTrieTest, SearchGreater) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
trie->search_greater(100000, false, ids, ids_length);
trie->search_greater(1000000, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
trie->search_greater(-100000, false, ids, ids_length);
trie->search_greater(-1000000, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(8, ids_length);
@ -285,12 +343,12 @@ TEST_F(NumericRangeTrieTest, SearchLesser) {
ASSERT_EQ(pairs[i].second, ids[i]);
}
trie->search_lesser(-100000, false, ids, ids_length);
trie->search_lesser(-1000000, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
trie->search_lesser(100000, true, ids, ids_length);
trie->search_lesser(1000000, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(8, ids_length);