mirror of
https://github.com/typesense/typesense.git
synced 2025-05-23 23:30:42 +08:00
Merge pull request #1036 from happy-san/v0.26-filter
Add `NumericTrie`.
This commit is contained in:
commit
0c5e8c5a9a
@ -54,6 +54,7 @@ namespace fields {
|
||||
static const std::string from = "from";
|
||||
static const std::string embed_from = "embed_from";
|
||||
static const std::string model_name = "model_name";
|
||||
static const std::string range_index = "range_index";
|
||||
|
||||
// Some models require additional parameters to be passed to the model during indexing/querying
|
||||
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying
|
||||
@ -93,13 +94,17 @@ struct field {
|
||||
|
||||
std::string reference; // Foo.bar (reference to bar field in Foo collection).
|
||||
|
||||
bool range_index;
|
||||
|
||||
field() {}
|
||||
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) :
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
|
||||
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
|
||||
embed(embed), range_index(range_index) {
|
||||
|
||||
set_computed_defaults(sort, infix);
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ struct filter {
|
||||
std::string field_name;
|
||||
std::vector<std::string> values;
|
||||
std::vector<NUM_COMPARATOR> comparators;
|
||||
// Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the
|
||||
// Would be set when `field: != ...` is encountered with id/string field or `field: != [ ... ]` is encountered in the
|
||||
// case of int and float fields. During filtering, all the results of matching the field against the values are
|
||||
// aggregated and then this flag is checked if negation on the aggregated result is required.
|
||||
bool apply_not_equals = false;
|
||||
|
@ -160,6 +160,9 @@ public:
|
||||
/// Returns the status of the initialization of iterator tree.
|
||||
Option<bool> init_status();
|
||||
|
||||
/// Recursively computes the result of each node and stores the final result in the root node.
|
||||
void compute_result();
|
||||
|
||||
/// Returns a tri-state:
|
||||
/// 0: id is not valid
|
||||
/// 1: id is valid
|
||||
|
@ -30,6 +30,7 @@
|
||||
#include "vector_query_ops.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
#include "filter.h"
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
static constexpr size_t ARRAY_FACET_DIM = 4;
|
||||
using facet_map_t = spp::sparse_hash_map<uint32_t, facet_hash_values_t>;
|
||||
@ -302,7 +303,9 @@ private:
|
||||
|
||||
spp::sparse_hash_map<std::string, num_tree_t*> numerical_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::vector<uint32_t>>*> geopoint_index;
|
||||
spp::sparse_hash_map<std::string, NumericTrie*> range_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, NumericTrie*> geo_range_index;
|
||||
|
||||
// geo_array_field => (seq_id => values) used for exact filtering of geo array records
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<uint32_t, int64_t*>*> geo_array_index;
|
||||
|
154
include/numeric_range_trie_test.h
Normal file
154
include/numeric_range_trie_test.h
Normal file
@ -0,0 +1,154 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include "sorted_array.h"
|
||||
|
||||
constexpr short EXPANSE = 256;
|
||||
|
||||
class NumericTrie {
|
||||
char max_level = 4;
|
||||
|
||||
class Node {
|
||||
Node** children = nullptr;
|
||||
sorted_array seq_ids;
|
||||
|
||||
void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level);
|
||||
|
||||
void insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, const char& max_level);
|
||||
|
||||
void search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level, std::set<Node*>& matches);
|
||||
|
||||
void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
void search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
void search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches);
|
||||
|
||||
public:
|
||||
|
||||
~Node() {
|
||||
if (children != nullptr) {
|
||||
for (auto i = 0; i < EXPANSE; i++) {
|
||||
delete children[i];
|
||||
}
|
||||
}
|
||||
|
||||
delete [] children;
|
||||
}
|
||||
|
||||
void insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void remove(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level);
|
||||
|
||||
void search_geopoints(const std::vector<uint64_t>& cell_ids, const char& max_level,
|
||||
std::vector<uint32_t>& geo_result_ids);
|
||||
|
||||
void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level);
|
||||
|
||||
void get_all_ids(uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
|
||||
void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
void search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches);
|
||||
};
|
||||
|
||||
Node* negative_trie = nullptr;
|
||||
Node* positive_trie = nullptr;
|
||||
|
||||
public:
|
||||
|
||||
explicit NumericTrie(char num_bits = 32) {
|
||||
max_level = num_bits / 8;
|
||||
}
|
||||
|
||||
~NumericTrie() {
|
||||
delete negative_trie;
|
||||
delete positive_trie;
|
||||
}
|
||||
|
||||
class iterator_t {
|
||||
struct match_state {
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
uint32_t index = 0;
|
||||
|
||||
explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {}
|
||||
|
||||
~match_state() {
|
||||
delete [] ids;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<match_state*> matches;
|
||||
|
||||
void set_seq_id();
|
||||
|
||||
public:
|
||||
|
||||
explicit iterator_t(std::vector<Node*>& matches);
|
||||
|
||||
~iterator_t() {
|
||||
for (auto& match: matches) {
|
||||
delete match;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_t& operator=(iterator_t&& obj) noexcept;
|
||||
|
||||
uint32_t seq_id = 0;
|
||||
bool is_valid = true;
|
||||
|
||||
void next();
|
||||
void skip_to(uint32_t id);
|
||||
void reset();
|
||||
};
|
||||
|
||||
void insert(const int64_t& value, const uint32_t& seq_id);
|
||||
|
||||
void remove(const int64_t& value, const uint32_t& seq_id);
|
||||
|
||||
void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id);
|
||||
|
||||
void search_geopoints(const std::vector<uint64_t>& cell_ids, std::vector<uint32_t>& geo_result_ids);
|
||||
|
||||
void delete_geopoint(const uint64_t& cell_id, uint32_t id);
|
||||
|
||||
void search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive);
|
||||
|
||||
void search_less_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_less_than(const int64_t& value, const bool& inclusive);
|
||||
|
||||
void search_greater_than(const int64_t& value, const bool& inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_greater_than(const int64_t& value, const bool& inclusive);
|
||||
|
||||
void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length);
|
||||
|
||||
iterator_t search_equal_to(const int64_t& value);
|
||||
};
|
@ -75,6 +75,24 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::reference] = "";
|
||||
}
|
||||
|
||||
if (field_json.count(fields::range_index) != 0) {
|
||||
if (!field_json.at(fields::range_index).is_boolean()) {
|
||||
return Option<bool>(400, std::string("The `range_index` property of the field `") +
|
||||
field_json[fields::name].get<std::string>() +
|
||||
std::string("` should be a boolean."));
|
||||
}
|
||||
|
||||
auto const& type = field_json["type"];
|
||||
if (field_json[fields::range_index] &&
|
||||
type != field_types::INT32 && type != field_types::INT32_ARRAY &&
|
||||
type != field_types::INT64 && type != field_types::INT64_ARRAY &&
|
||||
type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, std::string("The `range_index` property is only allowed for the numerical fields`"));
|
||||
}
|
||||
} else {
|
||||
field_json[fields::range_index] = false;
|
||||
}
|
||||
|
||||
if(field_json["name"] == ".*") {
|
||||
if(field_json.count(fields::facet) == 0) {
|
||||
field_json[fields::facet] = false;
|
||||
@ -297,7 +315,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
|
||||
field_json[fields::reference], field_json[fields::embed])
|
||||
field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index])
|
||||
);
|
||||
|
||||
if (!field_json[fields::reference].get<std::string>().empty()) {
|
||||
|
@ -422,7 +422,10 @@ Option<bool> toFilter(const std::string expression,
|
||||
id_comparator = EQUALS;
|
||||
while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' ');
|
||||
} else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') {
|
||||
return Option<bool>(400, "Not equals filtering is not supported on the `id` field.");
|
||||
id_comparator = NOT_EQUALS;
|
||||
filter_exp.apply_not_equals = true;
|
||||
filter_value_index++;
|
||||
while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' ');
|
||||
}
|
||||
if (filter_value_index != 0) {
|
||||
raw_value = raw_value.substr(filter_value_index);
|
||||
|
@ -401,6 +401,22 @@ void filter_result_iterator_t::next() {
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
if (++result_index >= filter_result.count) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
reference.clear();
|
||||
for (auto const& item: filter_result.reference_filter_results) {
|
||||
reference[item.first] = item.second[result_index];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Advance the subtrees and then apply operators to arrive at the next valid doc.
|
||||
if (filter_node->filter_operator == AND) {
|
||||
@ -423,21 +439,6 @@ void filter_result_iterator_t::next() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_filter_result_initialized) {
|
||||
if (++result_index >= filter_result.count) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
reference.clear();
|
||||
for (auto const& item: filter_result.reference_filter_results) {
|
||||
reference[item.first] = item.second[result_index];
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const filter a_filter = filter_node->filter_exp;
|
||||
|
||||
if (!index->field_is_indexed(a_filter.field_name)) {
|
||||
@ -619,11 +620,6 @@ void filter_result_iterator_t::init() {
|
||||
}
|
||||
|
||||
if (a_filter.field_name == "id") {
|
||||
if (a_filter.values.empty()) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// we handle `ids` separately
|
||||
std::vector<uint32_t> result_ids;
|
||||
for (const auto& id_str : a_filter.values) {
|
||||
@ -636,6 +632,16 @@ void filter_result_iterator_t::init() {
|
||||
filter_result.docs = new uint32_t[result_ids.size()];
|
||||
std::copy(result_ids.begin(), result_ids.end(), filter_result.docs);
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
|
||||
if (filter_result.count == 0) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
is_filter_result_initialized = true;
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
@ -650,27 +656,62 @@ void filter_result_iterator_t::init() {
|
||||
field f = index->search_schema.at(a_filter.field_name);
|
||||
|
||||
if (f.is_integer()) {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
if (f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t)std::stol(filter_value);
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
auto const& value = (int32_t)std::stoi(filter_value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const& range_end_value = (int32_t)std::stoi(next_filter_value);
|
||||
trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == EQUALS) {
|
||||
trie->search_equal_to(value, filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
uint32_t to_exclude_ids_len = 0;
|
||||
trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = index->seq_ids->uncompress();
|
||||
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
|
||||
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
|
||||
trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
|
||||
trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
|
||||
filter_result.count = result_size;
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t)std::stol(filter_value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
|
||||
}
|
||||
|
||||
filter_result.count = result_size;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
@ -688,28 +729,64 @@ void filter_result_iterator_t::init() {
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
return;
|
||||
} else if (f.is_float()) {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
if (f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, float_int64,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == EQUALS) {
|
||||
trie->search_equal_to(float_int64, filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
uint32_t to_exclude_ids_len = 0;
|
||||
trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = index->seq_ids->uncompress();
|
||||
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
|
||||
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
|
||||
trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
|
||||
trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
|
||||
filter_result.count = result_size;
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
float value = (float)std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = Index::float_to_int64_t(value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, float_int64,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
|
||||
}
|
||||
|
||||
filter_result.count = result_size;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
@ -821,17 +898,15 @@ void filter_result_iterator_t::init() {
|
||||
S2RegionTermIndexer::Options options;
|
||||
options.set_index_contains_points_only(true);
|
||||
S2RegionTermIndexer indexer(options);
|
||||
auto const& geo_range_index = index->geo_range_index.at(a_filter.field_name);
|
||||
|
||||
std::vector<uint64_t> cell_ids;
|
||||
for (const auto& term : indexer.GetQueryTerms(*query_region, "")) {
|
||||
auto geo_index = index->geopoint_index.at(a_filter.field_name);
|
||||
const auto& ids_it = geo_index->find(term);
|
||||
if(ids_it != geo_index->end()) {
|
||||
geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end());
|
||||
}
|
||||
auto cell = S2CellId::FromToken(term);
|
||||
cell_ids.push_back(cell.id());
|
||||
}
|
||||
|
||||
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
|
||||
geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end());
|
||||
geo_range_index->search_geopoints(cell_ids, geo_result_ids);
|
||||
|
||||
// Skip exact filtering step if query radius is greater than the threshold.
|
||||
if (fi < a_filter.params.size() &&
|
||||
@ -955,20 +1030,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Skip the subtrees to id and then apply operators to arrive at the next valid doc.
|
||||
left_it->skip_to(id);
|
||||
right_it->skip_to(id);
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id);
|
||||
|
||||
@ -986,6 +1048,20 @@ void filter_result_iterator_t::skip_to(uint32_t id) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Skip the subtrees to id and then apply operators to arrive at the next valid doc.
|
||||
left_it->skip_to(id);
|
||||
right_it->skip_to(id);
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const filter a_filter = filter_node->filter_exp;
|
||||
|
||||
if (!index->field_is_indexed(a_filter.field_name)) {
|
||||
@ -1068,6 +1144,12 @@ int filter_result_iterator_t::valid(uint32_t id) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
skip_to(id);
|
||||
return is_valid ? (seq_id == id ? 1 : 0) : -1;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
auto left_valid = left_it->valid(id), right_valid = right_it->valid(id);
|
||||
|
||||
@ -1181,21 +1263,7 @@ void filter_result_iterator_t::reset() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Reset the subtrees then apply operators to arrive at the first valid doc.
|
||||
left_it->reset();
|
||||
right_it->reset();
|
||||
is_valid = true;
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to traverse iterator tree if there's only one filter or compute_result() has been called.
|
||||
if (is_filter_result_initialized) {
|
||||
if (filter_result.count == 0) {
|
||||
is_valid = false;
|
||||
@ -1214,6 +1282,21 @@ void filter_result_iterator_t::reset() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (filter_node->isOperator) {
|
||||
// Reset the subtrees then apply operators to arrive at the first valid doc.
|
||||
left_it->reset();
|
||||
right_it->reset();
|
||||
is_valid = true;
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
and_filter_iterators();
|
||||
} else {
|
||||
or_filter_iterators();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const filter a_filter = filter_node->filter_exp;
|
||||
|
||||
if (!index->field_is_indexed(a_filter.field_name)) {
|
||||
@ -1459,3 +1542,136 @@ void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_
|
||||
root_iterator->seq_id = left_it->seq_id;
|
||||
filter_result_iterator = root_iterator;
|
||||
}
|
||||
|
||||
void filter_result_iterator_t::compute_result() {
|
||||
if (filter_node->isOperator) {
|
||||
left_it->compute_result();
|
||||
right_it->compute_result();
|
||||
|
||||
if (filter_node->filter_operator == AND) {
|
||||
filter_result_t::and_filter_results(left_it->filter_result, right_it->filter_result, filter_result);
|
||||
} else {
|
||||
filter_result_t::or_filter_results(left_it->filter_result, right_it->filter_result, filter_result);
|
||||
}
|
||||
|
||||
seq_id = filter_result.docs[result_index];
|
||||
is_filter_result_initialized = true;
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
return;
|
||||
}
|
||||
|
||||
// Only string field filter needs to be evaluated.
|
||||
if (is_filter_result_initialized || index->search_index.count(filter_node->filter_exp.field_name) == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto const& a_filter = filter_node->filter_exp;
|
||||
auto const& f = index->search_schema.at(a_filter.field_name);
|
||||
art_tree* t = index->search_index.at(a_filter.field_name);
|
||||
|
||||
uint32_t* or_ids = nullptr;
|
||||
size_t or_ids_size = 0;
|
||||
|
||||
// aggregates IDs across array of filter values and reduces excessive ORing
|
||||
std::vector<uint32_t> f_id_buff;
|
||||
|
||||
for (const std::string& filter_value : a_filter.values) {
|
||||
std::vector<void*> posting_lists;
|
||||
|
||||
// there could be multiple tokens in a filter value, which we have to treat as ANDs
|
||||
// e.g. country: South Africa
|
||||
Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators);
|
||||
|
||||
std::string str_token;
|
||||
size_t token_index = 0;
|
||||
std::vector<std::string> str_tokens;
|
||||
|
||||
while (tokenizer.next(str_token, token_index)) {
|
||||
str_tokens.push_back(str_token);
|
||||
|
||||
art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(),
|
||||
str_token.length()+1);
|
||||
if (leaf == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
posting_lists.push_back(leaf->values);
|
||||
}
|
||||
|
||||
if (posting_lists.size() != str_tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) {
|
||||
// needs intersection + exact matching (unlike CONTAINS)
|
||||
std::vector<uint32_t> result_id_vec;
|
||||
posting_t::intersect(posting_lists, result_id_vec);
|
||||
|
||||
if (result_id_vec.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// need to do exact match
|
||||
uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()];
|
||||
size_t exact_str_ids_size = 0;
|
||||
std::unique_ptr<uint32_t[]> exact_str_ids_guard(exact_str_ids);
|
||||
|
||||
posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(),
|
||||
exact_str_ids, exact_str_ids_size);
|
||||
|
||||
if (exact_str_ids_size == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t ei = 0; ei < exact_str_ids_size; ei++) {
|
||||
f_id_buff.push_back(exact_str_ids[ei]);
|
||||
}
|
||||
} else {
|
||||
// CONTAINS
|
||||
size_t before_size = f_id_buff.size();
|
||||
posting_t::intersect(posting_lists, f_id_buff);
|
||||
if (f_id_buff.size() == before_size) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) {
|
||||
gfx::timsort(f_id_buff.begin(), f_id_buff.end());
|
||||
f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out);
|
||||
delete[] or_ids;
|
||||
or_ids = out;
|
||||
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
|
||||
}
|
||||
}
|
||||
|
||||
if (!f_id_buff.empty()) {
|
||||
gfx::timsort(f_id_buff.begin(), f_id_buff.end());
|
||||
f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out);
|
||||
delete[] or_ids;
|
||||
or_ids = out;
|
||||
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
|
||||
}
|
||||
|
||||
filter_result.docs = or_ids;
|
||||
filter_result.count = or_ids_size;
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), filter_result.docs, filter_result.count);
|
||||
}
|
||||
|
||||
if (filter_result.count == 0) {
|
||||
is_valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
result_index = 0;
|
||||
seq_id = filter_result.docs[result_index];
|
||||
is_filter_result_initialized = true;
|
||||
approx_filter_ids_length = filter_result.count;
|
||||
}
|
||||
|
130
src/index.cpp
130
src/index.cpp
@ -78,8 +78,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
art_tree_init(t);
|
||||
search_index.emplace(a_field.name, t);
|
||||
} else if(a_field.is_geopoint()) {
|
||||
auto field_geo_index = new spp::sparse_hash_map<std::string, std::vector<uint32_t>>();
|
||||
geopoint_index.emplace(a_field.name, field_geo_index);
|
||||
geo_range_index.emplace(a_field.name, new NumericTrie());
|
||||
|
||||
if(!a_field.is_single_geopoint()) {
|
||||
spp::sparse_hash_map<uint32_t, int64_t*> * doc_to_geos = new spp::sparse_hash_map<uint32_t, int64_t*>();
|
||||
@ -88,6 +87,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
} else {
|
||||
num_tree_t* num_tree = new num_tree_t;
|
||||
numerical_index.emplace(a_field.name, num_tree);
|
||||
|
||||
if (a_field.range_index) {
|
||||
auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64);
|
||||
range_index.emplace(a_field.name, trie);
|
||||
}
|
||||
}
|
||||
|
||||
if(a_field.sort) {
|
||||
@ -136,12 +140,12 @@ Index::~Index() {
|
||||
|
||||
search_index.clear();
|
||||
|
||||
for(auto & name_index: geopoint_index) {
|
||||
for(auto & name_index: geo_range_index) {
|
||||
delete name_index.second;
|
||||
name_index.second = nullptr;
|
||||
}
|
||||
|
||||
geopoint_index.clear();
|
||||
geo_range_index.clear();
|
||||
|
||||
for(auto& name_index: geo_array_index) {
|
||||
for(auto& kv: *name_index.second) {
|
||||
@ -161,6 +165,13 @@ Index::~Index() {
|
||||
|
||||
numerical_index.clear();
|
||||
|
||||
for(auto & name_tree: range_index) {
|
||||
delete name_tree.second;
|
||||
name_tree.second = nullptr;
|
||||
}
|
||||
|
||||
range_index.clear();
|
||||
|
||||
for(auto & name_map: sort_index) {
|
||||
delete name_map.second;
|
||||
name_map.second = nullptr;
|
||||
@ -738,6 +749,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
if(!afield.is_string()) {
|
||||
if (afield.type == field_types::INT32) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
int32_t value = record.doc[afield.name].get<int32_t>();
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -747,6 +767,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::INT64) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
int64_t value = record.doc[afield.name].get<int64_t>();
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -756,6 +785,16 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::FLOAT) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
float fvalue = record.doc[afield.name].get<float>();
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -771,10 +810,10 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
num_tree->insert(value, seq_id);
|
||||
});
|
||||
} else if(afield.type == field_types::GEOPOINT || afield.type == field_types::GEOPOINT_ARRAY) {
|
||||
auto geo_index = geopoint_index.at(afield.name);
|
||||
auto geopoint_range_index = geo_range_index.at(afield.name);
|
||||
|
||||
iterate_and_index_numerical_field(iter_batch, afield,
|
||||
[&afield, &geo_array_index=geo_array_index, geo_index](const index_record& record, uint32_t seq_id) {
|
||||
[&afield, &geo_array_index=geo_array_index, geopoint_range_index](const index_record& record, uint32_t seq_id) {
|
||||
// nested geopoint value inside an array of object will be a simple array so must be treated as geopoint
|
||||
bool nested_obj_arr_geopoint = (afield.nested && afield.type == field_types::GEOPOINT_ARRAY &&
|
||||
!record.doc[afield.name].empty() && record.doc[afield.name][0].is_number());
|
||||
@ -788,9 +827,8 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
S2RegionTermIndexer indexer(options);
|
||||
S2Point point = S2LatLng::FromDegrees(latlongs[li], latlongs[li+1]).ToPoint();
|
||||
|
||||
for(const auto& term: indexer.GetIndexTerms(point, "")) {
|
||||
(*geo_index)[term].push_back(seq_id);
|
||||
}
|
||||
auto cell = S2CellId(point);
|
||||
geopoint_range_index->insert_geopoint(cell.id(), seq_id);
|
||||
}
|
||||
|
||||
if(nested_obj_arr_geopoint) {
|
||||
@ -818,9 +856,9 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
for(size_t li = 0; li < latlongs.size(); li++) {
|
||||
auto& latlong = latlongs[li];
|
||||
S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint();
|
||||
for(const auto& term: indexer.GetIndexTerms(point, "")) {
|
||||
(*geo_index)[term].push_back(seq_id);
|
||||
}
|
||||
|
||||
auto cell = S2CellId(point);
|
||||
geopoint_range_index->insert_geopoint(cell.id(), seq_id);
|
||||
|
||||
int64_t packed_latlong = GeoPoint::pack_lat_lng(latlong[0], latlong[1]);
|
||||
packed_latlongs[li + 1] = packed_latlong;
|
||||
@ -900,7 +938,8 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
// all other numerical arrays
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr;
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) {
|
||||
const auto& arr_value = record.doc[afield.name][arr_i];
|
||||
@ -908,17 +947,29 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
if(afield.type == field_types::INT32_ARRAY) {
|
||||
const int32_t value = arr_value;
|
||||
num_tree->insert(value, seq_id);
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::INT64_ARRAY) {
|
||||
const int64_t value = arr_value;
|
||||
num_tree->insert(value, seq_id);
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::FLOAT_ARRAY) {
|
||||
const float fvalue = arr_value;
|
||||
int64_t value = float_to_int64_t(fvalue);
|
||||
num_tree->insert(value, seq_id);
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
else if(afield.type == field_types::BOOL_ARRAY) {
|
||||
@ -1537,7 +1588,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
|
||||
bool Index::field_is_indexed(const std::string& field_name) const {
|
||||
return search_index.count(field_name) != 0 ||
|
||||
numerical_index.count(field_name) != 0 ||
|
||||
geopoint_index.count(field_name) != 0;
|
||||
geo_range_index.count(field_name) != 0;
|
||||
}
|
||||
|
||||
void Index::aproximate_numerical_match(num_tree_t* const num_tree,
|
||||
@ -4495,7 +4546,9 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
|
||||
const std::vector<size_t>& geopoint_indices) const {
|
||||
|
||||
filter_result_iterator->compute_result();
|
||||
auto const& approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length;
|
||||
|
||||
uint32_t token_bits = 0;
|
||||
const bool check_for_circuit_break = (approx_filter_ids_length > 1000000);
|
||||
|
||||
@ -5368,6 +5421,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
std::vector<int32_t>{document[field_name].get<int32_t>()} :
|
||||
document[field_name].get<std::vector<int32_t>>();
|
||||
for(int32_t value: values) {
|
||||
if (search_field.range_index) {
|
||||
auto const& trie = range_index.at(search_field.name);
|
||||
trie->remove(value, seq_id);
|
||||
}
|
||||
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
num_tree->remove(value, seq_id);
|
||||
if(search_field.facet) {
|
||||
@ -5379,6 +5437,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
std::vector<int64_t>{document[field_name].get<int64_t>()} :
|
||||
document[field_name].get<std::vector<int64_t>>();
|
||||
for(int64_t value: values) {
|
||||
if (search_field.range_index) {
|
||||
auto const& trie = range_index.at(search_field.name);
|
||||
trie->remove(value, seq_id);
|
||||
}
|
||||
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
num_tree->remove(value, seq_id);
|
||||
if(search_field.facet) {
|
||||
@ -5393,8 +5456,14 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
document[field_name].get<std::vector<float>>();
|
||||
|
||||
for(float value: values) {
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
int64_t fintval = float_to_int64_t(value);
|
||||
|
||||
if (search_field.range_index) {
|
||||
auto const& trie = range_index.at(search_field.name);
|
||||
trie->remove(fintval, seq_id);
|
||||
}
|
||||
|
||||
num_tree_t* num_tree = numerical_index.at(field_name);
|
||||
num_tree->remove(fintval, seq_id);
|
||||
if(search_field.facet) {
|
||||
remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id);
|
||||
@ -5414,7 +5483,7 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
}
|
||||
}
|
||||
} else if(search_field.is_geopoint()) {
|
||||
auto geo_index = geopoint_index[field_name];
|
||||
auto geopoint_range_index = geo_range_index[field_name];
|
||||
S2RegionTermIndexer::Options options;
|
||||
options.set_index_contains_points_only(true);
|
||||
S2RegionTermIndexer indexer(options);
|
||||
@ -5425,17 +5494,8 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const
|
||||
|
||||
for(const std::vector<double>& latlong: latlongs) {
|
||||
S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint();
|
||||
for(const auto& term: indexer.GetIndexTerms(point, "")) {
|
||||
auto term_it = geo_index->find(term);
|
||||
if(term_it == geo_index->end()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<uint32_t>& ids = term_it->second;
|
||||
ids.erase(std::remove(ids.begin(), ids.end(), seq_id), ids.end());
|
||||
if(ids.empty()) {
|
||||
geo_index->erase(term);
|
||||
}
|
||||
}
|
||||
auto cell = S2CellId(point);
|
||||
geopoint_range_index->delete_geopoint(cell.id(), seq_id);
|
||||
}
|
||||
|
||||
if(!search_field.is_single_geopoint()) {
|
||||
@ -5587,8 +5647,7 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
art_tree_init(t);
|
||||
search_index.emplace(new_field.name, t);
|
||||
} else if(new_field.is_geopoint()) {
|
||||
auto field_geo_index = new spp::sparse_hash_map<std::string, std::vector<uint32_t>>();
|
||||
geopoint_index.emplace(new_field.name, field_geo_index);
|
||||
geo_range_index.emplace(new_field.name, new NumericTrie());
|
||||
if(!new_field.is_single_geopoint()) {
|
||||
auto geo_array_map = new spp::sparse_hash_map<uint32_t, int64_t*>();
|
||||
geo_array_index.emplace(new_field.name, geo_array_map);
|
||||
@ -5596,6 +5655,10 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
} else {
|
||||
num_tree_t* num_tree = new num_tree_t;
|
||||
numerical_index.emplace(new_field.name, num_tree);
|
||||
|
||||
if (new_field.range_index) {
|
||||
range_index.emplace(new_field.name, new NumericTrie(new_field.is_int32() ? 32 : 64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5638,8 +5701,8 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
delete search_index[del_field.name];
|
||||
search_index.erase(del_field.name);
|
||||
} else if(del_field.is_geopoint()) {
|
||||
delete geopoint_index[del_field.name];
|
||||
geopoint_index.erase(del_field.name);
|
||||
delete geo_range_index[del_field.name];
|
||||
geo_range_index.erase(del_field.name);
|
||||
|
||||
if(!del_field.is_single_geopoint()) {
|
||||
spp::sparse_hash_map<uint32_t, int64_t*>* geo_array_map = geo_array_index[del_field.name];
|
||||
@ -5652,6 +5715,11 @@ void Index::refresh_schemas(const std::vector<field>& new_fields, const std::vec
|
||||
} else {
|
||||
delete numerical_index[del_field.name];
|
||||
numerical_index.erase(del_field.name);
|
||||
|
||||
if (del_field.range_index) {
|
||||
delete range_index[del_field.name];
|
||||
range_index.erase(del_field.name);
|
||||
}
|
||||
}
|
||||
|
||||
if(del_field.is_sortable()) {
|
||||
|
908
src/numeric_range_trie.cpp
Normal file
908
src/numeric_range_trie.cpp
Normal file
@ -0,0 +1,908 @@
|
||||
#include <timsort.hpp>
|
||||
#include <set>
|
||||
#include "numeric_range_trie_test.h"
|
||||
#include "array_utils.h"
|
||||
|
||||
void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) {
|
||||
if (value < 0) {
|
||||
if (negative_trie == nullptr) {
|
||||
negative_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
negative_trie->insert(std::abs(value), seq_id, max_level);
|
||||
} else {
|
||||
if (positive_trie == nullptr) {
|
||||
positive_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
positive_trie->insert(value, seq_id, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::remove(const int64_t& value, const uint32_t& seq_id) {
|
||||
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
negative_trie->remove(std::abs(value), seq_id, max_level);
|
||||
} else {
|
||||
positive_trie->remove(value, seq_id, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id) {
|
||||
if (positive_trie == nullptr) {
|
||||
positive_trie = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
positive_trie->insert_geopoint(cell_id, seq_id, max_level);
|
||||
}
|
||||
|
||||
void NumericTrie::search_geopoints(const std::vector<uint64_t>& cell_ids, std::vector<uint32_t>& geo_result_ids) {
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
positive_trie->search_geopoints(cell_ids, max_level, geo_result_ids);
|
||||
}
|
||||
|
||||
void NumericTrie::delete_geopoint(const uint64_t& cell_id, uint32_t id) {
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
positive_trie->delete_geopoint(cell_id, id, max_level);
|
||||
}
|
||||
|
||||
void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (low < 0 && high >= 0) {
|
||||
// Have to combine the results of >low from negative_trie and <high from positive_trie
|
||||
|
||||
if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ...
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
auto abs_low = std::abs(low);
|
||||
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
} else if (low >= 0) {
|
||||
// Search only in positive_trie
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
} else {
|
||||
// Search only in negative_trie
|
||||
if (negative_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
// Since we store absolute values, switching low and high would produce the correct result.
|
||||
auto abs_high = std::abs(high), abs_low = std::abs(low);
|
||||
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
|
||||
max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
|
||||
const int64_t& high, const bool& high_inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
if (low > high) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (low < 0 && high >= 0) {
|
||||
// Have to combine the results of >low from negative_trie and <high from positive_trie
|
||||
|
||||
if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ...
|
||||
auto abs_low = std::abs(low);
|
||||
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
|
||||
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches);
|
||||
}
|
||||
|
||||
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
|
||||
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches);
|
||||
}
|
||||
} else if (low >= 0) {
|
||||
// Search only in positive_trie
|
||||
if (positive_trie == nullptr) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches);
|
||||
} else {
|
||||
// Search only in negative_trie
|
||||
if (negative_trie == nullptr) {
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
auto abs_high = std::abs(high), abs_low = std::abs(low);
|
||||
// Since we store absolute values, switching low and high would produce the correct result.
|
||||
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
|
||||
max_level, matches);
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->get_all_ids(positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (value >= 0) {
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
} else {
|
||||
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
|
||||
|
||||
if (negative_trie != nullptr) {
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
auto abs_low = std::abs(value);
|
||||
|
||||
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
if (positive_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->get_all_ids(positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
|
||||
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
|
||||
if (positive_trie != nullptr) {
|
||||
matches.push_back(positive_trie);
|
||||
}
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (value >= 0) {
|
||||
if (positive_trie != nullptr) {
|
||||
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches);
|
||||
}
|
||||
} else {
|
||||
// Have to combine the results of >value from negative_trie and all the ids in positive_trie
|
||||
if (negative_trie != nullptr) {
|
||||
auto abs_low = std::abs(value);
|
||||
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
|
||||
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches);
|
||||
}
|
||||
if (positive_trie != nullptr) {
|
||||
matches.push_back(positive_trie);
|
||||
}
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
negative_trie->get_all_ids(negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
if (negative_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
auto abs_low = std::abs(value);
|
||||
|
||||
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level,
|
||||
negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
} else {
|
||||
// Have to combine the results of <value from positive_trie and all the ids in negative_trie
|
||||
|
||||
if (positive_trie != nullptr) {
|
||||
uint32_t* positive_ids = nullptr;
|
||||
uint32_t positive_ids_length = 0;
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, max_level,
|
||||
positive_ids, positive_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] positive_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
if (negative_trie == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* negative_ids = nullptr;
|
||||
uint32_t negative_ids_length = 0;
|
||||
negative_trie->get_all_ids(negative_ids, negative_ids_length);
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] negative_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) {
|
||||
std::vector<Node*> matches;
|
||||
|
||||
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
|
||||
if (negative_trie != nullptr) {
|
||||
matches.push_back(negative_trie);
|
||||
}
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
if (value < 0) {
|
||||
if (negative_trie != nullptr) {
|
||||
auto abs_low = std::abs(value);
|
||||
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
|
||||
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches);
|
||||
}
|
||||
} else {
|
||||
// Have to combine the results of <value from positive_trie and all the ids in negative_trie
|
||||
if (positive_trie != nullptr) {
|
||||
positive_trie->search_less_than(inclusive ? value : value - 1, max_level, matches);
|
||||
}
|
||||
if (negative_trie != nullptr) {
|
||||
matches.push_back(negative_trie);
|
||||
}
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) {
|
||||
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t* equal_ids = nullptr;
|
||||
uint32_t equal_ids_length = 0;
|
||||
|
||||
if (value < 0) {
|
||||
negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length);
|
||||
} else {
|
||||
positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length);
|
||||
}
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(equal_ids, equal_ids_length, ids, ids_length, &out);
|
||||
|
||||
delete [] equal_ids;
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) {
|
||||
std::vector<Node*> matches;
|
||||
if (value < 0 && negative_trie != nullptr) {
|
||||
negative_trie->search_equal_to(std::abs(value), max_level, matches);
|
||||
} else if (value >= 0 && positive_trie != nullptr) {
|
||||
positive_trie->search_equal_to(value, max_level, matches);
|
||||
}
|
||||
|
||||
return NumericTrie::iterator_t(matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level) {
|
||||
char level = 0;
|
||||
return insert_helper(cell_id, seq_id, level, max_level);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level) {
|
||||
char level = 0;
|
||||
return insert_geopoint_helper(cell_id, seq_id, level, max_level);
|
||||
}
|
||||
|
||||
inline int get_index(const int64_t& value, const char& level, const char& max_level) {
|
||||
// Values are index considering higher order of the bytes first.
|
||||
// 0x01020408 (16909320) would be indexed in the trie as follows:
|
||||
// Level Index
|
||||
// 1 1
|
||||
// 2 2
|
||||
// 3 4
|
||||
// 4 8
|
||||
return (value >> (8 * (max_level - level))) & 0xFF;
|
||||
}
|
||||
|
||||
inline int get_geopoint_index(const uint64_t& cell_id, const char& level) {
|
||||
// Doing 8-level since cell_id is a 64 bit number.
|
||||
return (cell_id >> (8 * (8 - level))) & 0xFF;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::remove(const int64_t& value, const uint32_t& id, const char& max_level) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level < max_level) {
|
||||
root->seq_ids.remove_value(id);
|
||||
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
root->seq_ids.remove_value(id);
|
||||
if (root->children != nullptr && root->children[index] != nullptr) {
|
||||
auto& child = root->children[index];
|
||||
|
||||
child->seq_ids.remove_value(id);
|
||||
if (child->seq_ids.getLength() == 0) {
|
||||
delete child;
|
||||
child = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) {
|
||||
if (level > max_level) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Root node contains all the sequence ids present in the tree.
|
||||
if (!seq_ids.contains(seq_id)) {
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
if (++level <= max_level) {
|
||||
if (children == nullptr) {
|
||||
children = new NumericTrie::Node* [EXPANSE]{nullptr};
|
||||
}
|
||||
|
||||
auto index = get_index(value, level, max_level);
|
||||
if (children[index] == nullptr) {
|
||||
children[index] = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
return children[index]->insert_helper(value, seq_id, level, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level,
|
||||
const char& max_level) {
|
||||
if (level > max_level) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Root node contains all the sequence ids present in the tree.
|
||||
if (!seq_ids.contains(seq_id)) {
|
||||
seq_ids.append(seq_id);
|
||||
}
|
||||
|
||||
if (++level <= max_level) {
|
||||
if (children == nullptr) {
|
||||
children = new NumericTrie::Node* [EXPANSE]{nullptr};
|
||||
}
|
||||
|
||||
auto index = get_geopoint_index(cell_id, level);
|
||||
if (children[index] == nullptr) {
|
||||
children[index] = new NumericTrie::Node();
|
||||
}
|
||||
|
||||
return children[index]->insert_geopoint_helper(cell_id, seq_id, level, max_level);
|
||||
}
|
||||
}
|
||||
|
||||
char get_max_search_level(const uint64_t& cell_id, const char& max_level) {
|
||||
// For cell id 0x47E66C3000000000, we only have to prefix match the top four bytes since rest of the bytes are 0.
|
||||
// So the max search level would be 4 in this case.
|
||||
|
||||
auto mask = (uint64_t) 0xFF << (8 * (8 - max_level)); // We're only indexing top 8-max_level bytes.
|
||||
char i = max_level;
|
||||
while (((cell_id & mask) == 0) && --i > 0) {
|
||||
mask <<= 8;
|
||||
}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level,
|
||||
std::set<Node*>& matches) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_geopoint_index(cell_id, level);
|
||||
auto max_search_level = get_max_search_level(cell_id, max_index_level);
|
||||
|
||||
while (level < max_search_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_geopoint_index(cell_id, ++level);
|
||||
}
|
||||
|
||||
matches.insert(root);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_geopoints(const std::vector<uint64_t>& cell_ids, const char& max_level,
|
||||
std::vector<uint32_t>& geo_result_ids) {
|
||||
std::set<Node*> matches;
|
||||
for (const auto &cell_id: cell_ids) {
|
||||
search_geopoints_helper(cell_id, max_level, matches);
|
||||
}
|
||||
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
geo_result_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
|
||||
geo_result_ids.erase(unique(geo_result_ids.begin(), geo_result_ids.end()), geo_result_ids.end());
|
||||
}
|
||||
|
||||
void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_geopoint_index(cell_id, level);
|
||||
|
||||
while (level < max_level) {
|
||||
root->seq_ids.remove_value(id);
|
||||
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_geopoint_index(cell_id, ++level);
|
||||
}
|
||||
|
||||
root->seq_ids.remove_value(id);
|
||||
if (root->children != nullptr && root->children[index] != nullptr) {
|
||||
auto& child = root->children[index];
|
||||
|
||||
child->seq_ids.remove_value(id);
|
||||
if (child->seq_ids.getLength() == 0) {
|
||||
delete child;
|
||||
child = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) {
|
||||
ids = seq_ids.uncompress();
|
||||
ids_length = seq_ids.getLength();
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_less_than_helper(value, level, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
consolidated_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
ids, ids_length, &out);
|
||||
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 0;
|
||||
search_less_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
} else if (level > max_level || children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level, max_level);
|
||||
if (children[index] != nullptr) {
|
||||
children[index]->search_less_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
while (--index >= 0) {
|
||||
if (children[index] != nullptr) {
|
||||
matches.push_back(children[index]);
|
||||
}
|
||||
}
|
||||
|
||||
--level;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_range_helper(low, high, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
consolidated_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
ids, ids_length, &out);
|
||||
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (low > high) {
|
||||
return;
|
||||
}
|
||||
|
||||
search_range_helper(low, high, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
// Segregating the nodes into matching low, in-between, and matching high.
|
||||
|
||||
NumericTrie::Node* root = this;
|
||||
char level = 1;
|
||||
auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level);
|
||||
|
||||
// Keep updating the root while the range is contained within a single child node.
|
||||
while (root->children != nullptr && low_index == high_index && level < max_level) {
|
||||
if (root->children[low_index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[low_index];
|
||||
level++;
|
||||
low_index = get_index(low, level, max_level);
|
||||
high_index = get_index(high, level, max_level);
|
||||
}
|
||||
|
||||
if (root->children == nullptr) {
|
||||
return;
|
||||
} else if (low_index == high_index) { // low and high are equal
|
||||
if (root->children[low_index] != nullptr) {
|
||||
matches.push_back(root->children[low_index]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (root->children[low_index] != nullptr) {
|
||||
// Collect all the sub-nodes that are greater than low.
|
||||
root->children[low_index]->search_greater_than_helper(low, level, max_level, matches);
|
||||
}
|
||||
|
||||
auto index = low_index + 1;
|
||||
// All the nodes in-between low and high are a match by default.
|
||||
while (index < std::min(high_index, (int)EXPANSE)) {
|
||||
if (root->children[index] != nullptr) {
|
||||
matches.push_back(root->children[index]);
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
|
||||
if (index < EXPANSE && index == high_index && root->children[index] != nullptr) {
|
||||
// Collect all the sub-nodes that are lesser than high.
|
||||
root->children[index]->search_less_than_helper(high, level, max_level, matches);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 0;
|
||||
std::vector<NumericTrie::Node*> matches;
|
||||
search_greater_than_helper(value, level, max_level, matches);
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
for (auto const& match: matches) {
|
||||
auto const& m_seq_ids = match->seq_ids.uncompress();
|
||||
for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) {
|
||||
consolidated_ids.push_back(m_seq_ids[i]);
|
||||
}
|
||||
|
||||
delete [] m_seq_ids;
|
||||
}
|
||||
|
||||
gfx::timsort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end());
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
ids, ids_length, &out);
|
||||
|
||||
delete [] ids;
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 0;
|
||||
search_greater_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
|
||||
std::vector<Node*>& matches) {
|
||||
if (level == max_level) {
|
||||
matches.push_back(this);
|
||||
return;
|
||||
} else if (level > max_level || children == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto index = get_index(value, ++level, max_level);
|
||||
if (children[index] != nullptr) {
|
||||
children[index]->search_greater_than_helper(value, level, max_level, matches);
|
||||
}
|
||||
|
||||
while (++index < EXPANSE) {
|
||||
if (children[index] != nullptr) {
|
||||
matches.push_back(children[index]);
|
||||
}
|
||||
}
|
||||
|
||||
--level;
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level,
|
||||
uint32_t*& ids, uint32_t& ids_length) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level <= max_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
root->get_all_ids(ids, ids_length);
|
||||
}
|
||||
|
||||
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector<Node*>& matches) {
|
||||
char level = 1;
|
||||
Node* root = this;
|
||||
auto index = get_index(value, level, max_level);
|
||||
|
||||
while (level <= max_level) {
|
||||
if (root->children == nullptr || root->children[index] == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
root = root->children[index];
|
||||
index = get_index(value, ++level, max_level);
|
||||
}
|
||||
|
||||
matches.push_back(root);
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::reset() {
|
||||
for (auto& match: matches) {
|
||||
match->index = 0;
|
||||
}
|
||||
|
||||
is_valid = true;
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::skip_to(uint32_t id) {
|
||||
for (auto& match: matches) {
|
||||
ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id);
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::next() {
|
||||
// Advance all the matches at seq_id.
|
||||
for (auto& match: matches) {
|
||||
if (match->index < match->ids_length && match->ids[match->index] == seq_id) {
|
||||
match->index++;
|
||||
}
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t::iterator_t(std::vector<Node*>& node_matches) {
|
||||
for (auto const& node_match: node_matches) {
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length;
|
||||
node_match->get_all_ids(ids, ids_length);
|
||||
if (ids_length > 0) {
|
||||
matches.emplace_back(new match_state(ids, ids_length));
|
||||
}
|
||||
}
|
||||
|
||||
set_seq_id();
|
||||
}
|
||||
|
||||
void NumericTrie::iterator_t::set_seq_id() {
|
||||
// Find the lowest id of all the matches and update the seq_id.
|
||||
bool one_is_valid = false;
|
||||
uint32_t lowest_id = UINT32_MAX;
|
||||
|
||||
for (auto& match: matches) {
|
||||
if (match->index < match->ids_length) {
|
||||
one_is_valid = true;
|
||||
|
||||
if (match->ids[match->index] < lowest_id) {
|
||||
lowest_id = match->ids[match->index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (one_is_valid) {
|
||||
seq_id = lowest_id;
|
||||
}
|
||||
|
||||
is_valid = one_is_valid;
|
||||
}
|
||||
|
||||
NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept {
|
||||
if (&obj == this)
|
||||
return *this;
|
||||
|
||||
for (auto& match: matches) {
|
||||
delete match;
|
||||
}
|
||||
matches.clear();
|
||||
|
||||
matches = std::move(obj.matches);
|
||||
seq_id = obj.seq_id;
|
||||
is_valid = obj.is_valid;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
@ -1231,6 +1231,16 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) {
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_STREQ("123", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "id: != 123",
|
||||
{}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get();
|
||||
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_STREQ("125", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("127", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("129", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// single ID with backtick
|
||||
|
||||
results = coll1->search("*",
|
||||
@ -1283,6 +1293,14 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) {
|
||||
ASSERT_STREQ("125", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("127", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "id:!= [123,125] && num_employees: <300",
|
||||
{}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get();
|
||||
|
||||
ASSERT_EQ(1, results["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
ASSERT_STREQ("127", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// empty id list not allowed
|
||||
auto res_op = coll1->search("*", {}, "id:=", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true});
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
@ -1296,13 +1314,6 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) {
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Error with filter field `id`: Filter value cannot be empty.", res_op.error());
|
||||
|
||||
// not equals is not supported yet
|
||||
res_op = coll1->search("*",
|
||||
{}, "id:!= [123,125] && num_employees: <300",
|
||||
{}, sort_fields, {0}, 10, 1, FREQUENCY, {true});
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Not equals filtering is not supported on the `id` field.", res_op.error());
|
||||
|
||||
// when no IDs exist
|
||||
results = coll1->search("*",
|
||||
{}, "id: [1000] && num_employees: <300",
|
||||
@ -1397,9 +1408,10 @@ TEST_F(CollectionFilteringTest, NumericalFilteringWithArray) {
|
||||
TEST_F(CollectionFilteringTest, NegationOperatorBasics) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("artist", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
std::vector<field> fields = {
|
||||
field("title", field_types::STRING, false),
|
||||
field("artist", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
|
802
test/numeric_range_trie_test.cpp
Normal file
802
test/numeric_range_trie_test.cpp
Normal file
@ -0,0 +1,802 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
class NumericRangeTrieTest : public ::testing::Test {
|
||||
protected:
|
||||
Store *store;
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
std::atomic<bool> quit = false;
|
||||
|
||||
std::vector<std::string> query_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
void setupCollection() {
|
||||
std::string state_dir_path = "/tmp/typesense_test/collection_filtering";
|
||||
LOG(INFO) << "Truncating and creating: " << state_dir_path;
|
||||
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
|
||||
|
||||
store = new Store(state_dir_path);
|
||||
collectionManager.init(store, 1.0, "auth_key", quit);
|
||||
collectionManager.load(8, 1000);
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
collectionManager.dispose();
|
||||
delete store;
|
||||
}
|
||||
};
|
||||
|
||||
void reset(uint32_t*& ids, uint32_t& ids_length) {
|
||||
delete [] ids;
|
||||
ids = nullptr;
|
||||
ids_length = 0;
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchRange) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32768, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_range(32768, true, -32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 32768, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size() - 1, ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size() - 1; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 134217728, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 0, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, false, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size() - 1, ids_length);
|
||||
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
|
||||
if (i == 3) continue; // id for -32768 would not be present
|
||||
ASSERT_EQ(pairs[i].second, ids[j++]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-134217728, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-134217728, true, 134217728, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(pairs.size(), ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, false, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, true, 0, true, ids, ids_length);
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-1, false, 0, false, ids, ids_length);
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(8192, true, 32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(8192, true, 0x2000000, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16384, true, 16384, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(56, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16384, true, 16384, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16384, false, 16384, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(16383, true, 16383, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(8193, true, 16383, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, -8192, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchGreaterThan) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32768, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_greater_than(0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-1, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-1, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-24576, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
|
||||
if (i == 3) continue; // id for -32768 would not be present
|
||||
ASSERT_EQ(pairs[i].second, ids[j++]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-32768, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
for (uint32_t i = 0, j = 0; i < pairs.size(); i++) {
|
||||
if (i == 3) continue; // id for -32768 would not be present
|
||||
ASSERT_EQ(pairs[i].second, ids[j++]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(8192, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(8192, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(3, ids_length);
|
||||
for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(1000000, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-1000000, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(8, ids_length);
|
||||
for (uint32_t i = 0; i < pairs.size(); i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchLessThan) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-32768, 8},
|
||||
{-24576, 32},
|
||||
{-16384, 35},
|
||||
{-8192, 43},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_less_than(0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[j]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-1, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-16384, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(3, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-16384, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(2, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(8192, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(8192, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-1000000, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(1000000, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(8, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(pairs[i].second, ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, SearchEqualTo) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32769, 41},
|
||||
{-32768, 43},
|
||||
{-32767, 45},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_equal_to(0, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(-32768, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(43, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(24576, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(58, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(0x202020, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-8192, 8},
|
||||
{-16384, 32},
|
||||
{-24576, 35},
|
||||
{-32769, 41},
|
||||
{-32768, 43},
|
||||
{-32767, 45},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{24576, 60},
|
||||
{32768, 91}
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
auto iterator = trie->search_equal_to(0);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(0x202020);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(-32768);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(43, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
iterator = trie->search_equal_to(24576);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(60, iterator.seq_id);
|
||||
|
||||
iterator.next();
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
|
||||
|
||||
iterator.reset();
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(4);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(58, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(59);
|
||||
ASSERT_EQ(true, iterator.is_valid);
|
||||
ASSERT_EQ(60, iterator.seq_id);
|
||||
|
||||
iterator.skip_to(66);
|
||||
ASSERT_EQ(false, iterator.is_valid);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, MultivalueData) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-0x202020, 32},
|
||||
{-32768, 5},
|
||||
{-32768, 8},
|
||||
{-24576, 32},
|
||||
{-16384, 35},
|
||||
{-8192, 43},
|
||||
{0, 43},
|
||||
{0, 49},
|
||||
{1, 8},
|
||||
{256, 91},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91},
|
||||
{0x202020, 35},
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
std::vector<uint32_t> expected = {5, 8, 32, 35, 43};
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(-16380, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(4, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(16384, false, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35, 43, 49, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(7, ids_length);
|
||||
|
||||
expected = {8, 35, 43, 49, 56, 58, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(256, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
|
||||
expected = {35, 49, 56, 58, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_greater_than(-32768, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(9, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35, 43, 49, 56, 58, 91};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_range(-32768, true, 0, true, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(6, ids_length);
|
||||
|
||||
expected = {5, 8, 32, 35, 43, 49};
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, Remove) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
std::vector<std::pair<int32_t, uint32_t>> pairs = {
|
||||
{-0x202020, 32},
|
||||
{-32768, 5},
|
||||
{-32768, 8},
|
||||
{-24576, 32},
|
||||
{-16384, 35},
|
||||
{-8192, 43},
|
||||
{0, 2},
|
||||
{0, 49},
|
||||
{1, 8},
|
||||
{256, 91},
|
||||
{8192, 49},
|
||||
{16384, 56},
|
||||
{24576, 58},
|
||||
{32768, 91},
|
||||
{0x202020, 35},
|
||||
};
|
||||
|
||||
for (auto const& pair: pairs) {
|
||||
trie->insert(pair.first, pair.second);
|
||||
}
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
std::vector<uint32_t> expected = {5, 8, 32, 35, 43};
|
||||
|
||||
ASSERT_EQ(5, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
trie->remove(-24576, 32);
|
||||
trie->remove(-0x202020, 32);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
|
||||
expected = {5, 8, 35, 43};
|
||||
ASSERT_EQ(4, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(0, ids, ids_length);
|
||||
|
||||
expected = {2, 49};
|
||||
ASSERT_EQ(2, ids_length);
|
||||
for (uint32_t i = 0; i < ids_length; i++) {
|
||||
ASSERT_EQ(expected[i], ids[i]);
|
||||
}
|
||||
|
||||
trie->remove(0, 2);
|
||||
|
||||
reset(ids, ids_length);
|
||||
trie->search_equal_to(0, ids, ids_length);
|
||||
|
||||
ASSERT_EQ(1, ids_length);
|
||||
ASSERT_EQ(49, ids[0]);
|
||||
|
||||
reset(ids, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, EmptyTrieOperations) {
|
||||
auto trie = new NumericTrie();
|
||||
std::unique_ptr<NumericTrie> trie_guard(trie);
|
||||
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t ids_length = 0;
|
||||
|
||||
trie->search_range(-32768, true, 32768, true, ids, ids_length);
|
||||
std::unique_ptr<uint32_t[]> ids_guard(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_range(-32768, true, -1, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_range(1, true, 32768, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_greater_than(0, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_greater_than(15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_greater_than(-15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_less_than(0, false, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_less_than(-15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_less_than(15, true, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->search_equal_to(15, ids, ids_length);
|
||||
ids_guard.reset(ids);
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
|
||||
trie->remove(15, 0);
|
||||
trie->remove(-15, 0);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, Integration) {
|
||||
Collection *coll_array_fields;
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
|
||||
std::vector<field> fields = {
|
||||
field("name", field_types::STRING, false),
|
||||
field("rating", field_types::FLOAT, false),
|
||||
field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(),
|
||||
true), // Setting range index true.
|
||||
field("years", field_types::INT32_ARRAY, false),
|
||||
field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "",
|
||||
nlohmann::json(), true),
|
||||
field("tags", field_types::STRING_ARRAY, true)
|
||||
};
|
||||
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields").get();
|
||||
if(coll_array_fields == nullptr) {
|
||||
// ensure that default_sorting_field is a non-array numerical field
|
||||
auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years");
|
||||
ASSERT_EQ(false, coll_op.ok());
|
||||
ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str());
|
||||
|
||||
// let's try again properly
|
||||
coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age");
|
||||
coll_array_fields = coll_op.get();
|
||||
}
|
||||
|
||||
std::string json_line;
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
auto add_op = coll_array_fields->add(json_line);
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
query_fields = {"name"};
|
||||
std::vector<std::string> facets;
|
||||
// Searching on an int32 field
|
||||
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"3", "1", "4"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// searching on an int64 array field - also ensure that padded space causes no issues
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"1", "4", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user