diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index b3d4436c..2144c8a9 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -10,7 +10,9 @@ class NumericTrieNode { NumericTrieNode** children = nullptr; sorted_array seq_ids; - void insert(const int32_t& value, const uint32_t& seq_id, char& level); + void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level); + + void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); @@ -32,7 +34,7 @@ public: void get_all_ids(uint32_t*& ids, uint32_t& ids_length); - void search_range(const int32_t& low,const int32_t& high, + void search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length); void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index be0c7f2c..f975df7b 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -20,7 +20,7 @@ void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, const int32_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length) { - if (low >= high) { + if (low > high) { return; } @@ -48,8 +48,30 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, return; } else if (low >= 0) { // Search only in positive_trie + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + if (positive_trie != nullptr) { + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, + positive_ids, positive_ids_length); + } + + ids = positive_ids; + ids_length = positive_ids_length; } else { // Search only in negative_trie + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + if (negative_trie != nullptr) { + // Since we store absolute values, switching low and high would produce the correct result. + auto abs_high = std::abs(high), abs_low = std::abs(low); + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + negative_ids, negative_ids_length); + } + + ids = negative_ids; + ids_length = negative_ids_length; } } @@ -139,7 +161,7 @@ void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uin void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { char level = 0; - return insert(value, seq_id, level); + return insert_helper(value, seq_id, level); } inline int get_index(const int32_t& value, char& level) { @@ -153,7 +175,7 @@ inline int get_index(const int32_t& value, char& level) { return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; } -void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& level) { +void NumericTrieNode::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) { if (level > MAX_LEVEL) { return; } @@ -173,7 +195,7 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& children[index] = new NumericTrieNode(); } - return children[index]->insert(value, seq_id, level); + return children[index]->insert_helper(value, seq_id, level); } } @@ -221,7 +243,71 @@ void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, st } void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { + if (low > high) { + return; + } + std::vector matches; + search_range_helper(low, high, matches); + for (auto const& match: matches) { + uint32_t* out = nullptr; + auto const& m_seq_ids = match->seq_ids.uncompress(); + ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + + delete [] m_seq_ids; + delete [] ids; + ids = out; + } +} + +void NumericTrieNode::search_range_helper(const int32_t& low, const int32_t& high, + std::vector& matches) { + // Segregating the nodes into matching low, in-between, and matching high. + + NumericTrieNode* root = this; + char level = 1; + auto low_index = get_index(low, level), high_index = get_index(high, level); + + // Keep updating the root while the range is contained within a single child node. + while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) { + if (root->children[low_index] == nullptr) { + return; + } + + root = root->children[low_index]; + level++; + low_index = get_index(low, level); + high_index = get_index(high, level); + } + + if (root->children == nullptr) { + return; + } else if (low_index == high_index) { // low and high are equal + if (root->children[low_index] != nullptr) { + matches.push_back(root->children[low_index]); + } + return; + } + + if (root->children[low_index] != nullptr) { + // Collect all the sub-nodes that are greater than low. + root->children[low_index]->search_greater_helper(low, level, matches); + } + + auto index = low_index + 1; + // All the nodes in-between low and high are a match by default. + while (index < std::min(high_index, (int)EXPANSE)) { + if (root->children[index] != nullptr) { + matches.push_back(root->children[index]); + } + + index++; + } + + if (index < EXPANSE && index == high_index && root->children[index] != nullptr) { + // Collect all the sub-nodes that are lesser than high. + root->children[index]->search_lesser_helper(high, level, matches); + } } void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index ff297aec..875ed544 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -92,6 +92,14 @@ TEST_F(NumericRangeTrieTest, SearchRange) { ASSERT_EQ(pairs[i].second, ids[i]); } + trie->search_range(-134217728, true, 134217728, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + trie->search_range(-1, true, 32768, true, ids, ids_length); ids_guard.reset(ids); @@ -113,6 +121,56 @@ TEST_F(NumericRangeTrieTest, SearchRange) { trie->search_range(-1, false, 0, false, ids, ids_length); ASSERT_EQ(0, ids_length); + + trie->search_range(8192, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_range(8192, true, 0x2000000, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_range(16384, true, 16384, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(56, ids[0]); + + trie->search_range(16384, true, 16384, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(16384, false, 16384, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(16383, true, 16383, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(8193, true, 16383, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(-32768, true, -8192, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } } TEST_F(NumericRangeTrieTest, SearchGreater) { @@ -194,12 +252,12 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(100000, false, ids, ids_length); + trie->search_greater(1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_greater(-100000, false, ids, ids_length); + trie->search_greater(-1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(8, ids_length); @@ -285,12 +343,12 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(-100000, false, ids, ids_length); + trie->search_lesser(-1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_lesser(100000, true, ids, ids_length); + trie->search_lesser(1000000, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(8, ids_length);