mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 14:55:26 +08:00
Add support for int64
and float
fields in NumericTrie
.
This commit is contained in:
parent
dd31e841b9
commit
43d235bbd0
@ -3,21 +3,25 @@
|
||||
#include <map>
|
||||
#include "sorted_array.h"
|
||||
|
||||
constexpr char MAX_LEVEL = 4;
|
||||
constexpr short EXPANSE = 256;
|
||||
|
||||
class NumericTrie {
|
||||
char max_level = 4;
|
||||
|
||||
class Node {
|
||||
Node** children = nullptr;
|
||||
sorted_array seq_ids;
|
||||
|
||||
void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level);
|
||||
void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level);
|
||||
|
||||
void search_range_helper(const int32_t& low,const int32_t& high, std::vector<Node*>& matches);
|
||||
void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
void search_less_than_helper(const int32_t& value, char& level, std::vector<Node*>& matches);
|
||||
void search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
void search_greater_than_helper(const int32_t& value, char& level, std::vector<Node*>& matches);
|
||||
void search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
public:
|
||||
|
||||
@ -31,18 +35,18 @@ class NumericTrie {
|
||||
delete [] children;
|
||||
}
|
||||
|
||||
void insert(const int32_t& value, const uint32_t& seq_id);
|
||||
void insert(const int64_t& value, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void get_all_ids(uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_range(const int32_t& low, const int32_t& high,
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
};
|
||||
|
||||
Node* negative_trie = nullptr;
|
||||
@ -50,22 +54,26 @@ class NumericTrie {
|
||||
|
||||
public:
|
||||
|
||||
explicit NumericTrie(char num_bits = 32) {
|
||||
max_level = num_bits / 8;
|
||||
}
|
||||
|
||||
~NumericTrie() {
|
||||
delete negative_trie;
|
||||
delete positive_trie;
|
||||
}
|
||||
|
||||
void insert(const int32_t& value, const uint32_t& seq_id);
|
||||
void insert(const int64_t& value, const uint32_t& seq_id);
|
||||
|
||||
void search_range(const int32_t& low, const bool& low_inclusive,
|
||||
const int32_t& high, const bool& high_inclusive,
|
||||
void search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_less_than(const int32_t& value, const bool& inclusive,
|
||||
void search_less_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater_than(const int32_t& value, const bool& inclusive,
|
||||
void search_greater_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
};
|
||||
|
@ -646,7 +646,7 @@ void filter_result_iterator_t::init() {
|
||||
field f = index->search_schema.at(a_filter.field_name);
|
||||
|
||||
if (f.is_integer()) {
|
||||
if (f.is_int32() && f.range_index) {
|
||||
if (f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
@ -718,28 +718,64 @@ void filter_result_iterator_t::init() {
|
||||
is_filter_result_initialized = true;
|
||||
return;
|
||||
} else if (f.is_float()) {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
if (f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, float_int64,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == EQUALS) {
|
||||
trie->search_equal_to(float_int64, filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
uint32_t to_exclude_ids_len = 0;
|
||||
trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = index->seq_ids->uncompress();
|
||||
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
|
||||
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
|
||||
trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
|
||||
trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
|
||||
filter_result.count = result_size;
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, float_int64,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
|
||||
}
|
||||
|
||||
filter_result.count = result_size;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
|
@ -90,7 +90,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
numerical_index.emplace(a_field.name, num_tree);
|
||||
|
||||
if (a_field.range_index) {
|
||||
auto trie = new NumericTrie();
|
||||
auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64);
|
||||
range_index.emplace(a_field.name, trie);
|
||||
}
|
||||
}
|
||||
@ -767,6 +767,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::INT64) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
int64_t value = record.doc[afield.name].get<int64_t>();
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -776,6 +785,16 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::FLOAT) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
float fvalue = record.doc[afield.name].get<float>();
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -928,23 +947,30 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
if(afield.type == field_types::INT32_ARRAY) {
|
||||
const int32_t value = arr_value;
|
||||
num_tree->insert(value, seq_id);
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
|
||||
num_tree->insert(value, seq_id);
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::INT64_ARRAY) {
|
||||
const int64_t value = arr_value;
|
||||
num_tree->insert(value, seq_id);
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::FLOAT_ARRAY) {
|
||||
const float fvalue = arr_value;
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
num_tree->insert(value, seq_id);
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::BOOL_ARRAY) {
|
||||
|
@ -2,24 +2,24 @@
|
||||
#include "numeric_range_trie_test.h"
|
||||
#include "array_utils.h"
|
||||
|
||||
void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) {
|
||||
void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) {
|
||||
if (value < 0) {
|
||||
if (negative_trie == nullptr) {
|
||||
negative_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
negative_trie->insert(std::abs(value), seq_id);
|
||||
negative_trie->insert(std::abs(value), seq_id, max_level);
|
||||
} else {
|
||||
if (positive_trie == nullptr) {
|
||||
positive_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
positive_trie->insert(value, seq_id);
|
||||
positive_trie->insert(value, seq_id, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
const int32_t& high, const bool& high_inclusive,
|
||||
void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low > high) {
|
||||
return;
|
||||
@ -34,7 +34,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
auto abs_low = std::abs(low);
|
||||
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length);
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
@ -47,7 +48,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, positive_ids, positive_ids_length);
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
@ -64,7 +66,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1,
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
@ -84,6 +86,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
// Since we store absolute values, switching low and high would produce the correct result.
|
||||
auto abs_high = std::abs(high), abs_low = std::abs(low);
|
||||
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
|
||||
max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
@ -95,7 +98,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
uint32_t* positive_ids = nullptr;
|
||||
@ -119,7 +122,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length);
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
@ -136,7 +139,8 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv
|
||||
auto abs_low = std::abs(value);
|
||||
|
||||
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length);
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
@ -163,7 +167,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
uint32_t* negative_ids = nullptr;
|
||||
@ -190,7 +194,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive,
|
||||
auto abs_low = std::abs(value);
|
||||
|
||||
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length);
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
@ -204,7 +209,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive,
|
||||
if (positive_trie != nullptr) {
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, positive_ids, positive_ids_length);
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
@ -231,7 +237,7 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive,
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
|
||||
return;
|
||||
}
|
||||
@ -240,9 +246,9 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t
|
||||
uint32_t equal_ids_length = 0;
|
||||
|
||||
if (value < 0) {
|
||||
negative_trie->search_equal_to(std::abs(value), equal_ids, equal_ids_length);
|
||||
negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length);
|
||||
} else {
|
||||
positive_trie->search_equal_to(value, equal_ids, equal_ids_length);
|
||||
positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length);
|
||||
}
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
@ -253,12 +259,12 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) {
|
||||
void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) {
|
||||
char level = 0;
|
||||
return insert_helper(value, seq_id, level);
|
||||
return insert_helper(value, seq_id, level, max_level);
|
||||
}
|
||||
|
||||
inline int get_index(const int32_t& value, char& level) {
|
||||
inline int get_index(const int64_t& value, const char& level, const char& max_level) {
|
||||
// Values are index considering higher order of the bytes first.
|
||||
// 0x01020408 (16909320) would be indexed in the trie as follows:
|
||||
// Level Index
|
||||
@ -266,11 +272,11 @@ inline int get_index(const int32_t& value, char& level) {
|
||||
// 2 2
|
||||
// 3 4
|
||||
// 4 8
|
||||
return (value >> (8 * (MAX_LEVEL - level))) & 0xFF;
|
||||
return (value >> (8 * (max_level - level))) & 0xFF;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) {
|
||||
if (level > MAX_LEVEL) {
|
||||
void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) {
|
||||
if (level > max_level) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -279,17 +285,17 @@ void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
if (++level <= MAX_LEVEL) {
|
||||
if (++level <= max_level) {
|
||||
if (children == nullptr) {
|
||||
children = new NumericTrie::Node* [EXPANSE]{nullptr};
|
||||
}
|
||||
|
||||
auto index = get_index(value, level);
|
||||
auto index = get_index(value, level, max_level);
|
||||
if (children[index] == nullptr) {
|
||||
children[index] = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
return children[index]->insert_helper(value, seq_id, level);
|
||||
return children[index]->insert_helper(value, seq_id, level, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
@ -298,10 +304,11 @@ void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) {
|
||||
ids_length = seq_ids.getLength();
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_less_than_helper(value, level, matches);
|
||||
search_less_than_helper(value, level, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
@ -324,17 +331,18 @@ void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, u
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& level, std::vector<NumericTrie::Node*>& matches) {
|
||||
if (level == MAX_LEVEL) {
|
||||
void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
} else if (level > MAX_LEVEL || children == nullptr) {
|
||||
} else if (level > max_level || children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level);
|
||||
auto index = get_index(value, ++level, max_level);
|
||||
if (children[index] != nullptr) {
|
||||
children[index]->search_less_than_helper(value, level, matches);
|
||||
children[index]->search_less_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
while (--index >= 0) {
|
||||
@ -346,12 +354,13 @@ void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& leve
|
||||
--level;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_range_helper(low, high, matches);
|
||||
search_range_helper(low, high, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
@ -374,24 +383,24 @@ void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, ui
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& high,
|
||||
std::vector<NumericTrie::Node*>& matches) {
|
||||
void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
// Segregating the nodes into matching low, in-between, and matching high.
|
||||
|
||||
NumericTrie::Node* root = this;
|
||||
char level = 1;
|
||||
auto low_index = get_index(low, level), high_index = get_index(high, level);
|
||||
auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level);
|
||||
|
||||
// Keep updating the root while the range is contained within a single child node.
|
||||
while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) {
|
||||
while (root->children != nullptr && low_index == high_index && level < max_level) {
|
||||
if (root->children[low_index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[low_index];
|
||||
level++;
|
||||
low_index = get_index(low, level);
|
||||
high_index = get_index(high, level);
|
||||
low_index = get_index(low, level, max_level);
|
||||
high_index = get_index(high, level, max_level);
|
||||
}
|
||||
|
||||
if (root->children == nullptr) {
|
||||
@ -405,7 +414,7 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h
|
||||
|
||||
if (root->children[low_index] != nullptr) {
|
||||
// Collect all the sub-nodes that are greater than low.
|
||||
root->children[low_index]->search_greater_than_helper(low, level, matches);
|
||||
root->children[low_index]->search_greater_than_helper(low, level, max_level, matches);
|
||||
}
|
||||
|
||||
auto index = low_index + 1;
|
||||
@ -420,14 +429,15 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h
|
||||
|
||||
if (index < EXPANSE && index == high_index && root->children[index] != nullptr) {
|
||||
// Collect all the sub-nodes that are lesser than high.
|
||||
root->children[index]->search_less_than_helper(high, level, matches);
|
||||
root->children[index]->search_less_than_helper(high, level, max_level, matches);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_greater_than_helper(value, level, matches);
|
||||
search_greater_than_helper(value, level, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
@ -450,17 +460,18 @@ void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& level, std::vector<NumericTrie::Node*>& matches) {
|
||||
if (level == MAX_LEVEL) {
|
||||
void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
} else if (level > MAX_LEVEL || children == nullptr) {
|
||||
} else if (level > max_level || children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level);
|
||||
auto index = get_index(value, ++level, max_level);
|
||||
if (children[index] != nullptr) {
|
||||
children[index]->search_greater_than_helper(value, level, matches);
|
||||
children[index]->search_greater_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
while (++index < EXPANSE) {
|
||||
@ -472,18 +483,19 @@ void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& l
|
||||
--level;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level);
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level <= MAX_LEVEL) {
|
||||
while (level <= max_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level);
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
root->get_all_ids(ids, ids_length);
|
||||
|
@ -604,7 +604,8 @@ TEST_F(NumericRangeTrieTest, Integration) {
|
||||
field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(),
|
||||
true), // Setting range index true.
|
||||
field("years", field_types::INT32_ARRAY, false),
|
||||
field("timestamps", field_types::INT64_ARRAY, false),
|
||||
field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "",
|
||||
nlohmann::json(), true),
|
||||
field("tags", field_types::STRING_ARRAY, true)
|
||||
};
|
||||
|
||||
@ -626,32 +627,18 @@ TEST_F(NumericRangeTrieTest, Integration) {
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
auto add_op = coll_array_fields->add(json_line);
|
||||
LOG(INFO) << add_op.error();
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
// Plain search with no filters - results should be sorted by rank fields
|
||||
query_fields = {"name"};
|
||||
std::vector<std::string> facets;
|
||||
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"3", "1", "4", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// Searching on an int32 field
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
ids = {"3", "1", "4"};
|
||||
std::vector<std::string> ids = {"3", "1", "4"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
@ -659,4 +646,20 @@ TEST_F(NumericRangeTrieTest, Integration) {
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// searching on an int64 array field - also ensure that padded space causes no issues
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"1", "4", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user