diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index cfa9e7e4..3c4e7cd1 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -78,42 +78,68 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { - positive_trie->get_all_ids(ids, ids_length); + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->get_all_ids(positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; } return; } if (value >= 0) { - uint32_t* positive_ids = nullptr; - uint32_t positive_ids_length = 0; - if (positive_trie != nullptr) { - positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); + if (positive_trie == nullptr) { + return; } - ids_length = positive_ids_length; - ids = positive_ids; + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; } else { // Have to combine the results of >value from negative_trie and all the ids in positive_trie - uint32_t* negative_ids = nullptr; - uint32_t negative_ids_length = 0; - // Since we store absolute values, search_lesser would yield result for >value from negative_trie. if (negative_trie != nullptr) { + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; auto abs_low = std::abs(value); + + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } + + if (positive_trie == nullptr) { + return; } uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - if (positive_trie != nullptr) { - positive_trie->get_all_ids(positive_ids, positive_ids_length); - } + positive_trie->get_all_ids(positive_ids, positive_ids_length); - ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); - delete [] negative_ids; delete [] positive_ids; - return; + delete [] ids; + ids = out; } } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index baad5ef7..3f591dfc 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -9,6 +9,12 @@ protected: virtual void TearDown() {} }; +void reset(uint32_t*& ids, uint32_t& ids_length) { + delete [] ids; + ids = nullptr; + ids_length = 0; +} + TEST_F(NumericRangeTrieTest, SearchRange) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); @@ -195,31 +201,30 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) { uint32_t ids_length = 0; trie->search_greater_than(0, true, ids, ids_length); - std::unique_ptr ids_guard(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(-1, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(-1, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(-24576, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(7, ids_length); for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { @@ -227,8 +232,8 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) { ASSERT_EQ(pairs[i].second, ids[j++]); } + reset(ids, ids_length); trie->search_greater_than(-32768, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(7, ids_length); for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { @@ -236,34 +241,36 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) { ASSERT_EQ(pairs[i].second, ids[j++]); } + reset(ids, ids_length); trie->search_greater_than(8192, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(8192, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(3, ids_length); for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(1000000, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_greater_than(-1000000, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(8, ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + + reset(ids, ids_length); } TEST_F(NumericRangeTrieTest, SearchLessThan) {