mirror of
https://github.com/typesense/typesense.git
synced 2025-05-23 15:23:40 +08:00
Add NumericTrie::iterator_t
.
This commit is contained in:
parent
aa753fffc1
commit
0321396f98
@ -42,11 +42,19 @@ class NumericTrie {
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
};
|
||||
|
||||
Node* negative_trie = nullptr;
|
||||
@ -63,17 +71,63 @@ public:
|
||||
delete positive_trie;
|
||||
}
|
||||
|
||||
class iterator_t {
|
||||
struct match_state {
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
uint32_t index = 0;
|
||||
|
||||
explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {}
|
||||
|
||||
~match_state() {
|
||||
delete [] ids;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<match_state*> matches;
|
||||
|
||||
void set_seq_id();
|
||||
|
||||
public:
|
||||
|
||||
explicit iterator_t(std::vector<Node*>& matches);
|
||||
|
||||
~iterator_t() {
|
||||
for (auto& match: matches) {
|
||||
delete match;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_t& operator=(iterator_t&& obj) noexcept;
|
||||
|
||||
uint32_t seq_id = 0;
|
||||
bool is_valid = true;
|
||||
|
||||
void next();
|
||||
void skip_to(uint32_t id);
|
||||
void reset();
|
||||
};
|
||||
|
||||
void insert(const int64_t& value, const uint32_t& seq_id);
|
||||
|
||||
void search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive);
|
||||
|
||||
void search_less_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_less_than(const int64_t& value, const bool& inclusive);
|
||||
|
||||
void search_greater_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_greater_than(const int64_t& value, const bool& inclusive);
|
||||
|
||||
void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_equal_to(const int64_t& value);
|
||||
};
|
||||
|
@ -98,6 +98,47 @@ void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
if (low > high) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (low < 0 && high >= 0) {
|
||||
// Have to combine the results of >low from negative_trie and <high from positive_trie
|
||||
|
||||
if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ...
|
||||
auto abs_low = std::abs(low);
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches);
|
||||
}
|
||||
|
||||
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches);
|
||||
}
|
||||
} else if (low >= 0) {
|
||||
// Search only in positive_trie
|
||||
if (positive_trie == nullptr) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches);
|
||||
} else {
|
||||
// Search only in negative_trie
|
||||
if (negative_trie == nullptr) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
auto abs_high = std::abs(high), abs_low = std::abs(low);
|
||||
// Since we store absolute values, switching low and high would produce the correct result.
|
||||
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
|
||||
max_level, matches);
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
@ -167,6 +208,35 @@ void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusiv
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
matches.push_back(positive_trie);
|
||||
}
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (value >= 0) {
|
||||
if (positive_trie != nullptr) {
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches);
|
||||
}
|
||||
} else {
|
||||
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
|
||||
if (negative_trie != nullptr) {
|
||||
auto abs_low = std::abs(value);
|
||||
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches);
|
||||
}
|
||||
if (positive_trie != nullptr) {
|
||||
matches.push_back(positive_trie);
|
||||
}
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
@ -237,6 +307,35 @@ void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive,
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
matches.push_back(negative_trie);
|
||||
}
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
if (negative_trie != nullptr) {
|
||||
auto abs_low = std::abs(value);
|
||||
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches);
|
||||
}
|
||||
} else {
|
||||
// Have to combine the results of <value from positive_trie and all the ids in negative_trie
|
||||
if (positive_trie != nullptr) {
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, max_level, matches);
|
||||
}
|
||||
if (negative_trie != nullptr) {
|
||||
matches.push_back(negative_trie);
|
||||
}
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
|
||||
return;
|
||||
@ -259,6 +358,17 @@ void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t
|
||||
ids = out;
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) {
|
||||
std::vector<Node*> matches;
|
||||
if (value < 0 && negative_trie != nullptr) {
|
||||
negative_trie->search_equal_to(std::abs(value), max_level, matches);
|
||||
} else if (value >= 0 && positive_trie != nullptr) {
|
||||
positive_trie->search_equal_to(value, max_level, matches);
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) {
|
||||
char level = 0;
|
||||
return insert_helper(value, seq_id, level, max_level);
|
||||
@ -331,6 +441,11 @@ void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_l
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 0;
|
||||
search_less_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
@ -383,6 +498,15 @@ void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, co
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
|
||||
search_range_helper(low, high, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
// Segregating the nodes into matching low, in-between, and matching high.
|
||||
@ -460,6 +584,11 @@ void NumericTrie::Node::search_greater_than(const int64_t& value, const char& ma
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 0;
|
||||
search_greater_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
@ -500,3 +629,100 @@ void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_le
|
||||
|
||||
root->get_all_ids(ids, ids_length);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level <= max_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
matches.push_back(root);
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::reset() {
|
||||
for (auto& match: matches) {
|
||||
match->index = 0;
|
||||
}
|
||||
|
||||
is_valid = true;
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::skip_to(uint32_t id) {
|
||||
for (auto& match: matches) {
|
||||
ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id);
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::next() {
|
||||
// Advance all the matches at seq_id.
|
||||
for (auto& match: matches) {
|
||||
if (match->index < match->ids_length && match->ids[match->index] == seq_id) {
|
||||
match->index++;
|
||||
}
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t::iterator_t(std::vector<Node*>& node_matches) {
|
||||
for (auto const& node_match: node_matches) {
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length;
|
||||
node_match->get_all_ids(ids, ids_length);
|
||||
if (ids_length > 0) {
|
||||
matches.emplace_back(new match_state(ids, ids_length));
|
||||
}
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::set_seq_id() {
|
||||
// Find the lowest id of all the matches and update the seq_id.
|
||||
bool one_is_valid = false;
|
||||
uint32_t lowest_id = UINT32_MAX;
|
||||
|
||||
for (auto& match: matches) {
|
||||
if (match->index < match->ids_length) {
|
||||
one_is_valid = true;
|
||||
|
||||
if (match->ids[match->index] < lowest_id) {
|
||||
lowest_id = match->ids[match->index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (one_is_valid) {
|
||||
seq_id = lowest_id;
|
||||
}
|
||||
|
||||
is_valid = one_is_valid;
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept {
|
||||
if (&obj == this)
|
||||
return *this;
|
||||
|
||||
for (auto& match: matches) {
|
||||
delete match;
|
||||
}
|
||||
matches.clear();
|
||||
|
||||
matches = std::move(obj.matches);
|
||||
seq_id = obj.seq_id;
|
||||
is_valid = obj.is_valid;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -436,6 +436,71 @@ TEST_F(NumericRangeTrieTest, SearchEqualTo) {
|
||||
ASSERT_EQ(0, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32769, 41},
|
||||
{-32768, 43},
|
||||
{-32767, 45},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{24576, 60},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
auto iterator = trie->search_equal_to(0);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(0x202020);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(-32768);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(43, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(24576);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(60, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
|
||||
iterator.reset();
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(4);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(59);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(60, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(66);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, MultivalueData) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
|
Loading…
x
Reference in New Issue
Block a user