Add support for int64 and float fields in NumericTrie.

This commit is contained in:
Harpreet Sangar 2023-06-02 19:14:11 +05:30
parent dd31e841b9
commit 43d235bbd0
5 changed files with 194 additions and 109 deletions

View File

@ -3,21 +3,25 @@
#include <map>
#include "sorted_array.h"
constexpr char MAX_LEVEL = 4;
constexpr short EXPANSE = 256;
class NumericTrie {
char max_level = 4;
class Node {
Node** children = nullptr;
sorted_array seq_ids;
void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level);
void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level);
void search_range_helper(const int32_t& low,const int32_t& high, std::vector<Node*>& matches);
void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
std::vector<Node*>& matches);
void search_less_than_helper(const int32_t& value, char& level, std::vector<Node*>& matches);
void search_less_than_helper(const int64_t& value, char& level, const char& max_level,
std::vector<Node*>& matches);
void search_greater_than_helper(const int32_t& value, char& level, std::vector<Node*>& matches);
void search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
std::vector<Node*>& matches);
public:
@ -31,18 +35,18 @@ class NumericTrie {
delete [] children;
}
void insert(const int32_t& value, const uint32_t& seq_id);
void insert(const int64_t& value, const uint32_t& seq_id, const char& max_level);
void get_all_ids(uint32_t*& ids, uint32_t& ids_length);
void search_range(const int32_t& low, const int32_t& high,
void search_range(const int64_t& low, const int64_t& high, const char& max_level,
uint32_t*& ids, uint32_t& ids_length);
void search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
void search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length);
};
Node* negative_trie = nullptr;
@ -50,22 +54,26 @@ class NumericTrie {
public:
explicit NumericTrie(char num_bits = 32) {
max_level = num_bits / 8;
}
~NumericTrie() {
delete negative_trie;
delete positive_trie;
}
void insert(const int32_t& value, const uint32_t& seq_id);
void insert(const int64_t& value, const uint32_t& seq_id);
void search_range(const int32_t& low, const bool& low_inclusive,
const int32_t& high, const bool& high_inclusive,
void search_range(const int64_t& low, const bool& low_inclusive,
const int64_t& high, const bool& high_inclusive,
uint32_t*& ids, uint32_t& ids_length);
void search_less_than(const int32_t& value, const bool& inclusive,
void search_less_than(const int64_t& value, const bool& inclusive,
uint32_t*& ids, uint32_t& ids_length);
void search_greater_than(const int32_t& value, const bool& inclusive,
void search_greater_than(const int64_t& value, const bool& inclusive,
uint32_t*& ids, uint32_t& ids_length);
void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length);
void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length);
};

View File

@ -646,7 +646,7 @@ void filter_result_iterator_t::init() {
field f = index->search_schema.at(a_filter.field_name);
if (f.is_integer()) {
if (f.is_int32() && f.range_index) {
if (f.range_index) {
auto const& trie = index->range_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
@ -718,28 +718,64 @@ void filter_result_iterator_t::init() {
is_filter_result_initialized = true;
return;
} else if (f.is_float()) {
auto num_tree = index->numerical_index.at(a_filter.field_name);
if (f.range_index) {
auto const& trie = index->range_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
float value = (float)std::atof(filter_value.c_str());
int64_t float_int64 = Index::float_to_int64_t(value);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
float value = (float)std::atof(filter_value.c_str());
int64_t float_int64 = Index::float_to_int64_t(value);
size_t result_size = filter_result.count;
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, float_int64,
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
filter_result.docs, result_size);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count);
fi++;
} else if (a_filter.comparators[fi] == EQUALS) {
trie->search_equal_to(float_int64, filter_result.docs, filter_result.count);
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
uint32_t* to_exclude_ids = nullptr;
uint32_t to_exclude_ids_len = 0;
trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len);
auto all_ids = index->seq_ids->uncompress();
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
delete[] all_ids;
delete[] to_exclude_ids;
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
filter_result.docs, filter_result.count);
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS,
filter_result.docs, filter_result.count);
}
}
} else {
auto num_tree = index->numerical_index.at(a_filter.field_name);
filter_result.count = result_size;
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
float value = (float)std::atof(filter_value.c_str());
int64_t float_int64 = Index::float_to_int64_t(value);
size_t result_size = filter_result.count;
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str()));
num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size);
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, float_int64,
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
filter_result.docs, result_size);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size);
}
filter_result.count = result_size;
}
}
if (a_filter.apply_not_equals) {

View File

@ -90,7 +90,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
numerical_index.emplace(a_field.name, num_tree);
if (a_field.range_index) {
auto trie = new NumericTrie();
auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64);
range_index.emplace(a_field.name, trie);
}
}
@ -767,6 +767,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
}
else if(afield.type == field_types::INT64) {
if (afield.range_index) {
auto const& trie = range_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
(const index_record& record, uint32_t seq_id) {
int64_t value = record.doc[afield.name].get<int64_t>();
trie->insert(value, seq_id);
});
}
auto num_tree = numerical_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
(const index_record& record, uint32_t seq_id) {
@ -776,6 +785,16 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
}
else if(afield.type == field_types::FLOAT) {
if (afield.range_index) {
auto const& trie = range_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
(const index_record& record, uint32_t seq_id) {
float fvalue = record.doc[afield.name].get<float>();
int64_t value = float_to_int64_t(fvalue);
trie->insert(value, seq_id);
});
}
auto num_tree = numerical_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
(const index_record& record, uint32_t seq_id) {
@ -928,23 +947,30 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
if(afield.type == field_types::INT32_ARRAY) {
const int32_t value = arr_value;
num_tree->insert(value, seq_id);
if (afield.range_index) {
trie->insert(value, seq_id);
}
num_tree->insert(value, seq_id);
}
else if(afield.type == field_types::INT64_ARRAY) {
const int64_t value = arr_value;
num_tree->insert(value, seq_id);
if (afield.range_index) {
trie->insert(value, seq_id);
}
}
else if(afield.type == field_types::FLOAT_ARRAY) {
const float fvalue = arr_value;
int64_t value = float_to_int64_t(fvalue);
num_tree->insert(value, seq_id);
if (afield.range_index) {
trie->insert(value, seq_id);
}
}
else if(afield.type == field_types::BOOL_ARRAY) {

View File

@ -2,24 +2,24 @@
#include "numeric_range_trie_test.h"
#include "array_utils.h"
void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) {
void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) {
if (value < 0) {
if (negative_trie == nullptr) {
negative_trie = new NumericTrie::Node();
}
negative_trie->insert(std::abs(value), seq_id);
negative_trie->insert(std::abs(value), seq_id, max_level);
} else {
if (positive_trie == nullptr) {
positive_trie = new NumericTrie::Node();
}
positive_trie->insert(value, seq_id);
positive_trie->insert(value, seq_id, max_level);
}
}
void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
const int32_t& high, const bool& high_inclusive,
void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive,
const int64_t& high, const bool& high_inclusive,
uint32_t*& ids, uint32_t& ids_length) {
if (low > high) {
return;
@ -34,7 +34,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
auto abs_low = std::abs(low);
// Since we store absolute values, search_lesser would yield result for >low from negative_trie.
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length);
negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level,
negative_ids, negative_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
@ -47,7 +48,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0)
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
positive_trie->search_less_than(high_inclusive ? high : high - 1, positive_ids, positive_ids_length);
positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level,
positive_ids, positive_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
@ -64,7 +66,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1,
positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level,
positive_ids, positive_ids_length);
uint32_t* out = nullptr;
@ -84,6 +86,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
// Since we store absolute values, switching low and high would produce the correct result.
auto abs_high = std::abs(high), abs_low = std::abs(low);
negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1,
max_level,
negative_ids, negative_ids_length);
uint32_t* out = nullptr;
@ -95,7 +98,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive,
}
}
void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞)
if (positive_trie != nullptr) {
uint32_t* positive_ids = nullptr;
@ -119,7 +122,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length);
positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
@ -136,7 +139,8 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv
auto abs_low = std::abs(value);
// Since we store absolute values, search_lesser would yield result for >value from negative_trie.
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length);
negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level,
negative_ids, negative_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
@ -163,7 +167,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv
}
}
void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) {
if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1]
if (negative_trie != nullptr) {
uint32_t* negative_ids = nullptr;
@ -190,7 +194,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive,
auto abs_low = std::abs(value);
// Since we store absolute values, search_greater would yield result for <value from negative_trie.
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length);
negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level,
negative_ids, negative_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out);
@ -204,7 +209,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive,
if (positive_trie != nullptr) {
uint32_t* positive_ids = nullptr;
uint32_t positive_ids_length = 0;
positive_trie->search_less_than(inclusive ? value : value - 1, positive_ids, positive_ids_length);
positive_trie->search_less_than(inclusive ? value : value - 1, max_level,
positive_ids, positive_ids_length);
uint32_t* out = nullptr;
ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out);
@ -231,7 +237,7 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive,
}
}
void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) {
if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) {
return;
}
@ -240,9 +246,9 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t
uint32_t equal_ids_length = 0;
if (value < 0) {
negative_trie->search_equal_to(std::abs(value), equal_ids, equal_ids_length);
negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length);
} else {
positive_trie->search_equal_to(value, equal_ids, equal_ids_length);
positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length);
}
uint32_t* out = nullptr;
@ -253,12 +259,12 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t
ids = out;
}
void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) {
void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) {
char level = 0;
return insert_helper(value, seq_id, level);
return insert_helper(value, seq_id, level, max_level);
}
inline int get_index(const int32_t& value, char& level) {
inline int get_index(const int64_t& value, const char& level, const char& max_level) {
// Values are index considering higher order of the bytes first.
// 0x01020408 (16909320) would be indexed in the trie as follows:
// Level Index
@ -266,11 +272,11 @@ inline int get_index(const int32_t& value, char& level) {
// 2 2
// 3 4
// 4 8
return (value >> (8 * (MAX_LEVEL - level))) & 0xFF;
return (value >> (8 * (max_level - level))) & 0xFF;
}
void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) {
if (level > MAX_LEVEL) {
void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) {
if (level > max_level) {
return;
}
@ -279,17 +285,17 @@ void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_
seq_ids.append(seq_id);
}
if (++level <= MAX_LEVEL) {
if (++level <= max_level) {
if (children == nullptr) {
children = new NumericTrie::Node* [EXPANSE]{nullptr};
}
auto index = get_index(value, level);
auto index = get_index(value, level, max_level);
if (children[index] == nullptr) {
children[index] = new NumericTrie::Node();
}
return children[index]->insert_helper(value, seq_id, level);
return children[index]->insert_helper(value, seq_id, level, max_level);
}
}
@ -298,10 +304,11 @@ void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) {
ids_length = seq_ids.getLength();
}
void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level,
uint32_t*& ids, uint32_t& ids_length) {
char level = 0;
std::vector<NumericTrie::Node*> matches;
search_less_than_helper(value, level, matches);
search_less_than_helper(value, level, max_level, matches);
std::vector<uint32_t> consolidated_ids;
for (auto const& match: matches) {
@ -324,17 +331,18 @@ void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, u
ids = out;
}
void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& level, std::vector<NumericTrie::Node*>& matches) {
if (level == MAX_LEVEL) {
void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level,
std::vector<Node*>& matches) {
if (level == max_level) {
matches.push_back(this);
return;
} else if (level > MAX_LEVEL || children == nullptr) {
} else if (level > max_level || children == nullptr) {
return;
}
auto index = get_index(value, ++level);
auto index = get_index(value, ++level, max_level);
if (children[index] != nullptr) {
children[index]->search_less_than_helper(value, level, matches);
children[index]->search_less_than_helper(value, level, max_level, matches);
}
while (--index >= 0) {
@ -346,12 +354,13 @@ void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& leve
--level;
}
void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level,
uint32_t*& ids, uint32_t& ids_length) {
if (low > high) {
return;
}
std::vector<NumericTrie::Node*> matches;
search_range_helper(low, high, matches);
search_range_helper(low, high, max_level, matches);
std::vector<uint32_t> consolidated_ids;
for (auto const& match: matches) {
@ -374,24 +383,24 @@ void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, ui
ids = out;
}
void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& high,
std::vector<NumericTrie::Node*>& matches) {
void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level,
std::vector<Node*>& matches) {
// Segregating the nodes into matching low, in-between, and matching high.
NumericTrie::Node* root = this;
char level = 1;
auto low_index = get_index(low, level), high_index = get_index(high, level);
auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level);
// Keep updating the root while the range is contained within a single child node.
while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) {
while (root->children != nullptr && low_index == high_index && level < max_level) {
if (root->children[low_index] == nullptr) {
return;
}
root = root->children[low_index];
level++;
low_index = get_index(low, level);
high_index = get_index(high, level);
low_index = get_index(low, level, max_level);
high_index = get_index(high, level, max_level);
}
if (root->children == nullptr) {
@ -405,7 +414,7 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h
if (root->children[low_index] != nullptr) {
// Collect all the sub-nodes that are greater than low.
root->children[low_index]->search_greater_than_helper(low, level, matches);
root->children[low_index]->search_greater_than_helper(low, level, max_level, matches);
}
auto index = low_index + 1;
@ -420,14 +429,15 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h
if (index < EXPANSE && index == high_index && root->children[index] != nullptr) {
// Collect all the sub-nodes that are lesser than high.
root->children[index]->search_less_than_helper(high, level, matches);
root->children[index]->search_less_than_helper(high, level, max_level, matches);
}
}
void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level,
uint32_t*& ids, uint32_t& ids_length) {
char level = 0;
std::vector<NumericTrie::Node*> matches;
search_greater_than_helper(value, level, matches);
search_greater_than_helper(value, level, max_level, matches);
std::vector<uint32_t> consolidated_ids;
for (auto const& match: matches) {
@ -450,17 +460,18 @@ void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids
ids = out;
}
void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& level, std::vector<NumericTrie::Node*>& matches) {
if (level == MAX_LEVEL) {
void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level,
std::vector<Node*>& matches) {
if (level == max_level) {
matches.push_back(this);
return;
} else if (level > MAX_LEVEL || children == nullptr) {
} else if (level > max_level || children == nullptr) {
return;
}
auto index = get_index(value, ++level);
auto index = get_index(value, ++level, max_level);
if (children[index] != nullptr) {
children[index]->search_greater_than_helper(value, level, matches);
children[index]->search_greater_than_helper(value, level, max_level, matches);
}
while (++index < EXPANSE) {
@ -472,18 +483,19 @@ void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& l
--level;
}
void NumericTrie::Node::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) {
void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level,
uint32_t*& ids, uint32_t& ids_length) {
char level = 1;
Node* root = this;
auto index = get_index(value, level);
auto index = get_index(value, level, max_level);
while (level <= MAX_LEVEL) {
while (level <= max_level) {
if (root->children == nullptr || root->children[index] == nullptr) {
return;
}
root = root->children[index];
index = get_index(value, ++level);
index = get_index(value, ++level, max_level);
}
root->get_all_ids(ids, ids_length);

View File

@ -604,7 +604,8 @@ TEST_F(NumericRangeTrieTest, Integration) {
field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(),
true), // Setting range index true.
field("years", field_types::INT32_ARRAY, false),
field("timestamps", field_types::INT64_ARRAY, false),
field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "",
nlohmann::json(), true),
field("tags", field_types::STRING_ARRAY, true)
};
@ -626,32 +627,18 @@ TEST_F(NumericRangeTrieTest, Integration) {
while (std::getline(infile, json_line)) {
auto add_op = coll_array_fields->add(json_line);
LOG(INFO) << add_op.error();
ASSERT_TRUE(add_op.ok());
}
infile.close();
// Plain search with no filters - results should be sorted by rank fields
query_fields = {"name"};
std::vector<std::string> facets;
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> ids = {"3", "1", "4", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// Searching on an int32 field
results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(3, results["hits"].size());
ids = {"3", "1", "4"};
std::vector<std::string> ids = {"3", "1", "4"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
@ -659,4 +646,20 @@ TEST_F(NumericRangeTrieTest, Integration) {
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// searching on an int64 array field - also ensure that padded space causes no issues
results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(4, results["hits"].size());
ids = {"1", "4", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(3, results["hits"].size());
}