Refactor NumericTrie::search_greater_than.

This commit is contained in:
Harpreet Sangar 2023-06-01 11:03:21 +05:30
parent 6348fbcf03
commit 4635e5cebf
2 changed files with 58 additions and 25 deletions

View File

@ -78,42 +78,68 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
if (positive_trie != nullptr) {
positive_trie->get_all_ids(ids, ids_length);
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
positive_trie->get_all_ids(positive_ids, positive_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
delete [] positive_ids;
delete [] ids;
ids = out;
}
return;
}
if (value >= 0) {
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
if (positive_trie != nullptr) {
positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length);
if (positive_trie == nullptr) {
return;
}
ids_length = positive_ids_length;
ids = positive_ids;
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
delete [] positive_ids;
delete [] ids;
ids = out;
} else {
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
uint32_t* negative_ids = nullptr;
uint32_t negative_ids_length = 0;
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
if (negative_trie != nullptr) {
uint32_t* negative_ids = nullptr;
uint32_t negative_ids_length = 0;
auto abs_low = std::abs(value);
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
delete [] negative_ids;
delete [] ids;
ids = out;
}
if (positive_trie == nullptr) {
return;
}
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
if (positive_trie != nullptr) {
positive_trie->get_all_ids(positive_ids, positive_ids_length);
}
positive_trie->get_all_ids(positive_ids, positive_ids_length);
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
delete [] negative_ids;
delete [] positive_ids;
return;
delete [] ids;
ids = out;
}
}

View File

@ -9,6 +9,12 @@ protected:
virtual void TearDown() {}
};
void reset(uint32_t*& ids, uint32_t& ids_length) {
delete [] ids;
ids = nullptr;
ids_length = 0;
}
TEST_F(NumericRangeTrieTest, SearchRange) {
auto trie = new NumericTrie();
std::unique_ptr<NumericTrie> trie_guard(trie);
@ -195,31 +201,30 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) {
uint32_t ids_length = 0;
trie->search_greater_than(0, true, ids, ids_length);
std::unique_ptr<uint32_t[]> ids_guard(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
reset(ids, ids_length);
trie->search_greater_than(-1, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
reset(ids, ids_length);
trie->search_greater_than(-1, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
reset(ids, ids_length);
trie->search_greater_than(-24576, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(7, ids_length);
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
@ -227,8 +232,8 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) {
ASSERT_EQ(pairs[i].second, ids[j++]);
}
reset(ids, ids_length);
trie->search_greater_than(-32768, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(7, ids_length);
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
@ -236,34 +241,36 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) {
ASSERT_EQ(pairs[i].second, ids[j++]);
}
reset(ids, ids_length);
trie->search_greater_than(8192, true, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(4, ids_length);
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
reset(ids, ids_length);
trie->search_greater_than(8192, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(3, ids_length);
for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) {
ASSERT_EQ(pairs[i].second, ids[j]);
}
reset(ids, ids_length);
trie->search_greater_than(1000000, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(0, ids_length);
reset(ids, ids_length);
trie->search_greater_than(-1000000, false, ids, ids_length);
ids_guard.reset(ids);
ASSERT_EQ(8, ids_length);
for (uint32_t i = 0; i < pairs.size(); i++) {
ASSERT_EQ(pairs[i].second, ids[i]);
}
reset(ids, ids_length);
}
TEST_F(NumericRangeTrieTest, SearchLessThan) {