Add != compatibility for numeric type. (#835)

* Add `!=` compatibility for numeric type.

* Add `numeric_not_equals_filter` method.

* Refactor `numeric_not_equals_filter` method.

* Support `!=` for numeric field multi-value filter_by.

* Ignore local settings file.

* Add comment.

* Add tests.

* Update negation test.

* Set `apply_not_equals` flag for string fields as well.

* Add `field: [!=value]` test.
This commit is contained in:
Harpreet Sangar 2022-12-29 20:57:10 +05:30 committed by GitHub
parent 60aafae585
commit 34b16cdb21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 195 additions and 85 deletions

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ cmake-build-release
.DS_Store
/bazel-*
typesense-server-data/
.clwb/.bazelproject
.vscode/settings.json

View File

@ -1,5 +0,0 @@
{
"clangd.arguments": [
"--compile-commands-dir=bazel-bin"
]
}

View File

@ -436,6 +436,10 @@ struct filter {
std::string field_name;
std::vector<std::string> values;
std::vector<NUM_COMPARATOR> comparators;
// Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the
// case of int and float fields. During filtering, all the results of matching the field against the values are
// aggregated and then this flag is checked if negation on the aggregated result is required.
bool apply_not_equals = false;
static const std::string RANGE_OPERATOR() {
return "..";
@ -473,6 +477,10 @@ struct filter {
num_comparator = GREATER_THAN_EQUALS;
}
else if(comp_and_value.compare(0, 2, "!=") == 0) {
num_comparator = NOT_EQUALS;
}
else if(comp_and_value.compare(0, 1, "<") == 0) {
num_comparator = LESS_THAN;
}
@ -491,7 +499,7 @@ struct filter {
if(num_comparator == LESS_THAN || num_comparator == GREATER_THAN) {
comp_and_value = comp_and_value.substr(1);
} else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS) {
} else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS || num_comparator == NOT_EQUALS) {
comp_and_value = comp_and_value.substr(2);
}

View File

@ -465,6 +465,11 @@ private:
const size_t num_search_fields,
std::vector<size_t>& popular_field_ids);
void numeric_not_equals_filter(num_tree_t* const num_tree,
const int64_t value,
uint32_t*& ids,
size_t& ids_len) const;
void do_filtering(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const root) const;

View File

@ -122,6 +122,40 @@ Option<bool> toPostfix(std::queue<std::string>& tokens, std::queue<std::string>&
return Option<bool>(true);
}
Option<bool> toMultiValueNumericFilter(std::string& raw_value, filter& filter_exp, const field& _field) {
std::vector<std::string> filter_values;
StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ",");
filter_exp = {_field.name, {}, {}};
for (std::string& filter_value: filter_values) {
Option<NUM_COMPARATOR> op_comparator = filter::extract_num_comparator(filter_value);
if (!op_comparator.ok()) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
}
if (op_comparator.get() == RANGE_INCLUSIVE) {
// split the value around range operator to extract bounds
std::vector<std::string> range_values;
StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR());
for (const std::string& range_value: range_values) {
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
if (!validate_op.ok()) {
return validate_op;
}
filter_exp.values.push_back(range_value);
filter_exp.comparators.push_back(op_comparator.get());
}
} else {
auto validate_op = filter::validate_numerical_filter_value(_field, filter_value);
if (!validate_op.ok()) {
return validate_op;
}
filter_exp.values.push_back(filter_value);
filter_exp.comparators.push_back(op_comparator.get());
}
}
return Option<bool>(true);
}
Option<bool> toFilter(const std::string expression,
filter& filter_exp,
const tsl::htrie_map<char, field>& search_schema,
@ -204,34 +238,9 @@ Option<bool> toFilter(const std::string expression,
if (_field.is_integer() || _field.is_float()) {
// could be a single value or a list
if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') {
std::vector<std::string> filter_values;
StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ",");
filter_exp = {field_name, {}, {}};
for (std::string& filter_value: filter_values) {
Option<NUM_COMPARATOR> op_comparator = filter::extract_num_comparator(filter_value);
if (!op_comparator.ok()) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
}
if (op_comparator.get() == RANGE_INCLUSIVE) {
// split the value around range operator to extract bounds
std::vector<std::string> range_values;
StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR());
for (const std::string& range_value: range_values) {
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
if (!validate_op.ok()) {
return validate_op;
}
filter_exp.values.push_back(range_value);
filter_exp.comparators.push_back(op_comparator.get());
}
} else {
auto validate_op = filter::validate_numerical_filter_value(_field, filter_value);
if (!validate_op.ok()) {
return validate_op;
}
filter_exp.values.push_back(filter_value);
filter_exp.comparators.push_back(op_comparator.get());
}
Option<bool> op = toMultiValueNumericFilter(raw_value, filter_exp, _field);
if (!op.ok()) {
return op;
}
} else {
Option<NUM_COMPARATOR> op_comparator = filter::extract_num_comparator(raw_value);
@ -251,6 +260,12 @@ Option<bool> toFilter(const std::string expression,
filter_exp.values.push_back(range_value);
filter_exp.comparators.push_back(op_comparator.get());
}
} else if (op_comparator.get() == NOT_EQUALS && raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') {
Option<bool> op = toMultiValueNumericFilter(raw_value, filter_exp, _field);
if (!op.ok()) {
return op;
}
filter_exp.apply_not_equals = true;
} else {
auto validate_op = filter::validate_numerical_filter_value(_field, raw_value);
if (!validate_op.ok()) {
@ -353,6 +368,8 @@ Option<bool> toFilter(const std::string expression,
} else {
filter_exp = {field_name, {raw_value.substr(filter_value_index)}, {str_comparator}};
}
filter_exp.apply_not_equals = (str_comparator == NOT_EQUALS);
} else {
return Option<bool>(400, "Error with filter field `" + _field.name +
"`: Unidentified field data type, see docs for supported data types.");

View File

@ -1548,6 +1548,35 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
}
}
void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
const int64_t value,
uint32_t*& ids,
size_t& ids_len) const {
uint32_t* to_exclude_ids = nullptr;
size_t to_exclude_ids_len = 0;
num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len);
auto all_ids = seq_ids->uncompress();
auto all_ids_size = seq_ids->num_ids();
uint32_t* to_include_ids = nullptr;
size_t to_include_ids_len = 0;
to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, to_exclude_ids,
to_exclude_ids_len, &to_include_ids);
delete[] all_ids;
delete[] to_exclude_ids;
uint32_t* out = nullptr;
ids_len = ArrayUtils::or_scalar(ids, ids_len,
to_include_ids, to_include_ids_len, &out);
delete[] ids;
delete[] to_include_ids;
ids = out;
}
void Index::do_filtering(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const root) const {
@ -1603,6 +1632,8 @@ void Index::do_filtering(uint32_t*& filter_ids,
int64_t range_end_value = (int64_t)std::stol(next_filter_value);
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
}
@ -1620,6 +1651,8 @@ void Index::do_filtering(uint32_t*& filter_ids,
int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
}
@ -1631,28 +1664,7 @@ void Index::do_filtering(uint32_t*& filter_ids,
for (const std::string& filter_value : a_filter.values) {
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
if (a_filter.comparators[value_index] == NOT_EQUALS) {
uint32_t* to_exclude_ids = nullptr;
size_t to_exclude_ids_len = 0;
num_tree->search(EQUALS, bool_int64, &to_exclude_ids, to_exclude_ids_len);
auto all_ids = seq_ids->uncompress();
auto all_ids_size = seq_ids->num_ids();
uint32_t* excluded_ids = nullptr;
size_t excluded_ids_len = 0;
excluded_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, to_exclude_ids,
to_exclude_ids_len, &excluded_ids);
delete[] all_ids;
delete[] to_exclude_ids;
uint32_t* out = nullptr;
result_ids_len = ArrayUtils::or_scalar(result_ids, result_ids_len,
excluded_ids, excluded_ids_len, &out);
delete[] result_ids;
result_ids = out;
delete[] excluded_ids;
numeric_not_equals_filter(num_tree, bool_int64, result_ids, result_ids_len);
} else {
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
}
@ -1869,39 +1881,27 @@ void Index::do_filtering(uint32_t*& filter_ids,
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
}
if (a_filter.comparators[0] == NOT_EQUALS) {
// exclude records from existing IDs (from previous filters or ALL records)
// "not equals" can only be applied to the entire array so we can do the exclusion operations once here
uint32_t* excluded_strt_ids = nullptr;
size_t excluded_strt_size = 0;
if (result_ids == nullptr) {
if (filter_ids == nullptr) {
result_ids = seq_ids->uncompress();
result_ids_len = seq_ids->num_ids();
} else {
result_ids = filter_ids;
result_ids_len = filter_ids_length;
}
}
excluded_strt_size = ArrayUtils::exclude_scalar(result_ids, result_ids_len, or_ids,
or_ids_size, &excluded_strt_ids);
if (filter_ids == nullptr) {
// means we had to uncompress `seq_ids` so need to free that
delete[] result_ids;
}
delete[] or_ids;
or_ids = excluded_strt_ids;
or_ids_size = excluded_strt_size;
}
result_ids = or_ids;
result_ids_len = or_ids_size;
}
if (a_filter.apply_not_equals) {
auto all_ids = seq_ids->uncompress();
auto all_ids_size = seq_ids->num_ids();
uint32_t* to_include_ids = nullptr;
size_t to_include_ids_len = 0;
to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, result_ids,
result_ids_len, &to_include_ids);
delete[] all_ids;
delete[] result_ids;
result_ids = to_include_ids;
result_ids_len = to_include_ids_len;
}
filter_ids = result_ids;
filter_ids_length = result_ids_len;

View File

@ -598,6 +598,18 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// not equals
results = coll_array_fields->search("Jeremy", query_fields, "age:!= 24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(4, results["hits"].size());
ids = {"3", "1", "4", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// multiple filters
results = coll_array_fields->search("Jeremy", query_fields, "years:<2005 && years:>1987", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(1, results["hits"].size());
@ -626,6 +638,42 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
results = coll_array_fields->search("Jeremy", query_fields, "age:= [21, 24, 63]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(3, results["hits"].size());
// individual comparators can still be applied.
results = coll_array_fields->search("Jeremy", query_fields, "age: [!=21, >30]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(4, results["hits"].size());
ids = {"3", "1", "4", "0"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_EQ(id, result_id);
}
// negate multiple search values (works like SQL's NOT IN) against a single int field
results = coll_array_fields->search("Jeremy", query_fields, "age:!= [21, 24, 63]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(2, results["hits"].size());
ids = {"1", "4"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_EQ(id, result_id);
}
// individual comparators can still be applied.
results = coll_array_fields->search("Jeremy", query_fields, "age: != [<30, >60]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(2, results["hits"].size());
ids = {"1", "4"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_EQ(id, result_id);
}
// multiple search values against an int32 array field - also use extra padding between symbols
results = coll_array_fields->search("Jeremy", query_fields, "years : [ 2015, 1985 , 1999]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(4, results["hits"].size());
@ -753,6 +801,17 @@ TEST_F(CollectionFilteringTest, FilterOnFloatFields) {
ASSERT_STREQ(id.c_str(), result_id.c_str()); //?
}
results = coll_array_fields->search("Jeremy", query_fields, "rating:!=0", facets, sort_fields_asc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(4, results["hits"].size());
ids = {"0", "4", "2", "1"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str()); //?
}
// Searching on a float field, sorted desc by rating
results = coll_array_fields->search("Jeremy", query_fields, "rating:>0.0", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(4, results["hits"].size());
@ -802,6 +861,30 @@ TEST_F(CollectionFilteringTest, FilterOnFloatFields) {
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// negate multiple search values (works like SQL's NOT IN operator) against a single float field
results = coll_array_fields->search("Jeremy", query_fields, "rating:!= [1.09, 7.812]", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(3, results["hits"].size());
ids = {"1", "4", "3"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_EQ(id, result_id);
}
// individual comparators can still be applied.
results = coll_array_fields->search("Jeremy", query_fields, "rating: != [<5.4, >9]", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(2, results["hits"].size());
ids = {"2", "4"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_EQ(id, result_id);
}
// multiple search values against a float array field - also use extra padding between symbols
results = coll_array_fields->search("Jeremy", query_fields, "top_3 : [ 5.431, 0.001 , 7.812, 11.992]", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(3, results["hits"].size());