mirror of
https://github.com/typesense/typesense.git
synced 2025-05-24 15:50:42 +08:00
Add NumericTrie
.
This commit is contained in:
parent
fedf8f4ec1
commit
c9180a0541
60
include/numeric_range_trie_test.h
Normal file
60
include/numeric_range_trie_test.h
Normal file
@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include "sorted_array.h"
|
||||
|
||||
constexpr char MAX_LEVEL = 4;
|
||||
constexpr short EXPANSE = 256;
|
||||
|
||||
class NumericTrieNode {
|
||||
NumericTrieNode** children = nullptr;
|
||||
sorted_array seq_ids;
|
||||
|
||||
void insert(const int32_t& value, const uint32_t& seq_id, char& level);
|
||||
|
||||
void search_lesser_helper(const int32_t& value, char& level, std::vector<NumericTrieNode*>& matches);
|
||||
public:
|
||||
|
||||
~NumericTrieNode() {
|
||||
if (children != nullptr) {
|
||||
for (auto i = 0; i < EXPANSE; i++) {
|
||||
delete children[i];
|
||||
}
|
||||
}
|
||||
|
||||
delete [] children;
|
||||
}
|
||||
|
||||
void insert(const int32_t& value, const uint32_t& seq_id);
|
||||
|
||||
void search_range(const int32_t& low,const int32_t& high,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
};
|
||||
|
||||
class NumericTrie {
|
||||
NumericTrieNode* negative_trie = nullptr;
|
||||
NumericTrieNode* positive_trie = nullptr;
|
||||
|
||||
public:
|
||||
|
||||
~NumericTrie() {
|
||||
delete negative_trie;
|
||||
delete positive_trie;
|
||||
}
|
||||
|
||||
void insert(const int32_t& value, const uint32_t& seq_id);
|
||||
|
||||
void search_range(const int32_t& low, const bool& low_inclusive,
|
||||
const int32_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_lesser(const int32_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater(const int32_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
};
|
122
src/numeric_range_trie.cpp
Normal file
122
src/numeric_range_trie.cpp
Normal file
@ -0,0 +1,122 @@
|
||||
#include "numeric_range_trie_test.h"
|
||||
#include "array_utils.h"
|
||||
|
||||
void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) {
|
||||
if (value < 0) {
|
||||
if (negative_trie == nullptr) {
|
||||
negative_trie = new NumericTrieNode();
|
||||
}
|
||||
|
||||
negative_trie->insert(std::abs(value), seq_id);
|
||||
} else {
|
||||
if (positive_trie == nullptr) {
|
||||
positive_trie = new NumericTrieNode();
|
||||
}
|
||||
|
||||
positive_trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
const int32_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low < 0 && high >= 0) {
|
||||
// Have to combine the results of >low from negative_trie and <high from positive_trie
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
if (!(low == -1 && !low_inclusive)) {
|
||||
auto abs_low = std::abs(low);
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_lesser(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length);
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
if (!(high == 0 && !high_inclusive)) {
|
||||
positive_trie->search_lesser(high_inclusive ? high : high - 1, positive_ids, positive_ids_length);
|
||||
}
|
||||
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] positive_ids;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) {
|
||||
char level = 0;
|
||||
return insert(value, seq_id, level);
|
||||
}
|
||||
|
||||
inline int get_index(const int32_t& value, char& level) {
|
||||
return (value >> (8 * (MAX_LEVEL - level))) & 0xFF;
|
||||
}
|
||||
|
||||
void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& level) {
|
||||
if (level > MAX_LEVEL) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!seq_ids.contains(seq_id)) {
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
if (++level <= MAX_LEVEL) {
|
||||
if (children == nullptr) {
|
||||
children = new NumericTrieNode* [EXPANSE]{nullptr};
|
||||
}
|
||||
|
||||
auto index = get_index(value, level);
|
||||
if (children[index] == nullptr) {
|
||||
children[index] = new NumericTrieNode();
|
||||
}
|
||||
|
||||
return children[index]->insert(value, seq_id, level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrieNode*> matches;
|
||||
search_lesser_helper(value, level, matches);
|
||||
|
||||
for (auto const& match: matches) {
|
||||
uint32_t* out = nullptr;
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out);
|
||||
|
||||
delete [] m_seq_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, std::vector<NumericTrieNode*>& matches) {
|
||||
if (level > MAX_LEVEL) {
|
||||
return;
|
||||
} else if (level == MAX_LEVEL) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
}
|
||||
|
||||
if (children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level);
|
||||
if (children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
children[index]->search_lesser_helper(value, level, matches);
|
||||
|
||||
while (--index >= 0) {
|
||||
if (children[index] != nullptr) {
|
||||
matches.push_back(children[index]);
|
||||
}
|
||||
}
|
||||
|
||||
--level;
|
||||
}
|
41
test/numeric_range_trie_test.cpp
Normal file
41
test/numeric_range_trie_test.cpp
Normal file
@ -0,0 +1,41 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
class NumericRangeTrieTest : public ::testing::Test {
|
||||
protected:
|
||||
|
||||
virtual void SetUp() {}
|
||||
|
||||
virtual void TearDown() {}
|
||||
};
|
||||
|
||||
TEST_F(NumericRangeTrieTest, Insert) {
|
||||
auto trie = new NumericTrie();
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32768, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_range(-32768, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
delete [] ids;
|
||||
delete trie;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user