diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index ce1da83d..b3d4436c 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -13,6 +13,9 @@ class NumericTrieNode { void insert(const int32_t& value, const uint32_t& seq_id, char& level); void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); + + void search_greater_helper(const int32_t& value, char& level, std::vector& matches); + public: ~NumericTrieNode() { @@ -27,6 +30,8 @@ public: void insert(const int32_t& value, const uint32_t& seq_id); + void get_all_ids(uint32_t*& ids, uint32_t& ids_length); + void search_range(const int32_t& low,const int32_t& high, uint32_t*& ids, uint32_t& ids_length); diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index d7476a28..50cf3315 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -39,6 +39,41 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + delete [] negative_ids; + delete [] positive_ids; + return; + } else if (low >= 0) { + // Search only in positive_trie + } else { + // Search only in negative_trie + } +} + +void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { + if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) + positive_trie->get_all_ids(ids, ids_length); + return; + } + + if (value >= 0) { + uint32_t* positive_ids = nullptr; + positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, ids_length); + ids = positive_ids; + } else { + // Have to combine the results of >value from negative_trie and all the ids in positive_trie + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + auto abs_low = std::abs(value); + // Since we store absolute values, search_lesser would yield result for >low from negative_trie. + negative_trie->search_lesser(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->get_all_ids(positive_ids, positive_ids_length); + + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + delete [] negative_ids; delete [] positive_ids; return; @@ -51,6 +86,13 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { } inline int get_index(const int32_t& value, char& level) { + // Values are index considering higher order of the bytes first. + // 0x01020408 (16909320) would be indexed in the trie as follows: + // Level Index + // 1 1 + // 2 2 + // 3 4 + // 4 8 return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; } @@ -59,6 +101,7 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& return; } + // Root node contains all the sequence ids present in the tree. if (!seq_ids.contains(seq_id)) { seq_ids.append(seq_id); } @@ -77,6 +120,11 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& } } +void NumericTrieNode::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { + ids = seq_ids.uncompress(); + ids_length = seq_ids.getLength(); +} + void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; @@ -114,3 +162,45 @@ void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, st --level; } + +void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { + +} + +void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { + char level = 0; + std::vector matches; + search_greater_helper(value, level, matches); + + for (auto const& match: matches) { + uint32_t* out = nullptr; + auto const& m_seq_ids = match->seq_ids.uncompress(); + ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + + delete [] m_seq_ids; + delete [] ids; + ids = out; + } +} + +void NumericTrieNode::search_greater_helper(const int32_t& value, char& level, std::vector& matches) { + if (level == MAX_LEVEL) { + matches.push_back(this); + return; + } else if (level > MAX_LEVEL || children == nullptr) { + return; + } + + auto index = get_index(value, ++level); + if (children[index] != nullptr) { + children[index]->search_greater_helper(value, level, matches); + } + + while (++index < EXPANSE) { + if (children[index] != nullptr) { + matches.push_back(children[index]); + } + } + + --level; +} diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 526cdb32..690209a8 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -109,3 +109,83 @@ TEST_F(NumericRangeTrieTest, SearchRange) { trie->search_range(-1, false, 0, false, ids, ids_length); ASSERT_EQ(0, ids_length); } + +TEST_F(NumericRangeTrieTest, SearchGreater) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32768, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_greater(0, true, ids, ids_length); + std::unique_ptr ids_guard(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(-1, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(-1, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(-24576, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(7, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + trie->search_greater(-32768, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(7, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + trie->search_greater(8192, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(8192, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(3, ids_length); + for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } +}