mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 04:32:38 +08:00
Add !=
compatibility for numeric type. (#835)
* Add `!=` compatibility for numeric type. * Add `numeric_not_equals_filter` method. * Refactor `numeric_not_equals_filter` method. * Support `!=` for numeric field multi-value filter_by. * Ignore local settings file. * Add comment. * Add tests. * Update negation test. * Set `apply_not_equals` flag for string fields as well. * Add `field: [!=value]` test.
This commit is contained in:
parent
60aafae585
commit
34b16cdb21
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,3 +11,5 @@ cmake-build-release
|
||||
.DS_Store
|
||||
/bazel-*
|
||||
typesense-server-data/
|
||||
.clwb/.bazelproject
|
||||
.vscode/settings.json
|
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
@ -1,5 +0,0 @@
|
||||
{
|
||||
"clangd.arguments": [
|
||||
"--compile-commands-dir=bazel-bin"
|
||||
]
|
||||
}
|
@ -436,6 +436,10 @@ struct filter {
|
||||
std::string field_name;
|
||||
std::vector<std::string> values;
|
||||
std::vector<NUM_COMPARATOR> comparators;
|
||||
// Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the
|
||||
// case of int and float fields. During filtering, all the results of matching the field against the values are
|
||||
// aggregated and then this flag is checked if negation on the aggregated result is required.
|
||||
bool apply_not_equals = false;
|
||||
|
||||
static const std::string RANGE_OPERATOR() {
|
||||
return "..";
|
||||
@ -473,6 +477,10 @@ struct filter {
|
||||
num_comparator = GREATER_THAN_EQUALS;
|
||||
}
|
||||
|
||||
else if(comp_and_value.compare(0, 2, "!=") == 0) {
|
||||
num_comparator = NOT_EQUALS;
|
||||
}
|
||||
|
||||
else if(comp_and_value.compare(0, 1, "<") == 0) {
|
||||
num_comparator = LESS_THAN;
|
||||
}
|
||||
@ -491,7 +499,7 @@ struct filter {
|
||||
|
||||
if(num_comparator == LESS_THAN || num_comparator == GREATER_THAN) {
|
||||
comp_and_value = comp_and_value.substr(1);
|
||||
} else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS) {
|
||||
} else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS || num_comparator == NOT_EQUALS) {
|
||||
comp_and_value = comp_and_value.substr(2);
|
||||
}
|
||||
|
||||
|
@ -465,6 +465,11 @@ private:
|
||||
const size_t num_search_fields,
|
||||
std::vector<size_t>& popular_field_ids);
|
||||
|
||||
void numeric_not_equals_filter(num_tree_t* const num_tree,
|
||||
const int64_t value,
|
||||
uint32_t*& ids,
|
||||
size_t& ids_len) const;
|
||||
|
||||
void do_filtering(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const root) const;
|
||||
|
@ -122,6 +122,40 @@ Option<bool> toPostfix(std::queue<std::string>& tokens, std::queue<std::string>&
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> toMultiValueNumericFilter(std::string& raw_value, filter& filter_exp, const field& _field) {
|
||||
std::vector<std::string> filter_values;
|
||||
StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ",");
|
||||
filter_exp = {_field.name, {}, {}};
|
||||
for (std::string& filter_value: filter_values) {
|
||||
Option<NUM_COMPARATOR> op_comparator = filter::extract_num_comparator(filter_value);
|
||||
if (!op_comparator.ok()) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
|
||||
}
|
||||
if (op_comparator.get() == RANGE_INCLUSIVE) {
|
||||
// split the value around range operator to extract bounds
|
||||
std::vector<std::string> range_values;
|
||||
StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR());
|
||||
for (const std::string& range_value: range_values) {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
|
||||
if (!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
filter_exp.values.push_back(range_value);
|
||||
filter_exp.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
} else {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, filter_value);
|
||||
if (!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
filter_exp.values.push_back(filter_value);
|
||||
filter_exp.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> toFilter(const std::string expression,
|
||||
filter& filter_exp,
|
||||
const tsl::htrie_map<char, field>& search_schema,
|
||||
@ -204,34 +238,9 @@ Option<bool> toFilter(const std::string expression,
|
||||
if (_field.is_integer() || _field.is_float()) {
|
||||
// could be a single value or a list
|
||||
if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') {
|
||||
std::vector<std::string> filter_values;
|
||||
StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ",");
|
||||
filter_exp = {field_name, {}, {}};
|
||||
for (std::string& filter_value: filter_values) {
|
||||
Option<NUM_COMPARATOR> op_comparator = filter::extract_num_comparator(filter_value);
|
||||
if (!op_comparator.ok()) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
|
||||
}
|
||||
if (op_comparator.get() == RANGE_INCLUSIVE) {
|
||||
// split the value around range operator to extract bounds
|
||||
std::vector<std::string> range_values;
|
||||
StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR());
|
||||
for (const std::string& range_value: range_values) {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
|
||||
if (!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
filter_exp.values.push_back(range_value);
|
||||
filter_exp.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
} else {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, filter_value);
|
||||
if (!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
filter_exp.values.push_back(filter_value);
|
||||
filter_exp.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
Option<bool> op = toMultiValueNumericFilter(raw_value, filter_exp, _field);
|
||||
if (!op.ok()) {
|
||||
return op;
|
||||
}
|
||||
} else {
|
||||
Option<NUM_COMPARATOR> op_comparator = filter::extract_num_comparator(raw_value);
|
||||
@ -251,6 +260,12 @@ Option<bool> toFilter(const std::string expression,
|
||||
filter_exp.values.push_back(range_value);
|
||||
filter_exp.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
} else if (op_comparator.get() == NOT_EQUALS && raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') {
|
||||
Option<bool> op = toMultiValueNumericFilter(raw_value, filter_exp, _field);
|
||||
if (!op.ok()) {
|
||||
return op;
|
||||
}
|
||||
filter_exp.apply_not_equals = true;
|
||||
} else {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, raw_value);
|
||||
if (!validate_op.ok()) {
|
||||
@ -353,6 +368,8 @@ Option<bool> toFilter(const std::string expression,
|
||||
} else {
|
||||
filter_exp = {field_name, {raw_value.substr(filter_value_index)}, {str_comparator}};
|
||||
}
|
||||
|
||||
filter_exp.apply_not_equals = (str_comparator == NOT_EQUALS);
|
||||
} else {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name +
|
||||
"`: Unidentified field data type, see docs for supported data types.");
|
||||
|
102
src/index.cpp
102
src/index.cpp
@ -1548,6 +1548,35 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
|
||||
}
|
||||
}
|
||||
|
||||
void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
|
||||
const int64_t value,
|
||||
uint32_t*& ids,
|
||||
size_t& ids_len) const {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
size_t to_exclude_ids_len = 0;
|
||||
num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = seq_ids->uncompress();
|
||||
auto all_ids_size = seq_ids->num_ids();
|
||||
|
||||
uint32_t* to_include_ids = nullptr;
|
||||
size_t to_include_ids_len = 0;
|
||||
|
||||
to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, to_exclude_ids,
|
||||
to_exclude_ids_len, &to_include_ids);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
ids_len = ArrayUtils::or_scalar(ids, ids_len,
|
||||
to_include_ids, to_include_ids_len, &out);
|
||||
delete[] ids;
|
||||
delete[] to_include_ids;
|
||||
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void Index::do_filtering(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const root) const {
|
||||
@ -1603,6 +1632,8 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
int64_t range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
|
||||
}
|
||||
@ -1620,6 +1651,8 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value, result_ids, result_ids_len);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
|
||||
}
|
||||
@ -1631,28 +1664,7 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
for (const std::string& filter_value : a_filter.values) {
|
||||
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
|
||||
if (a_filter.comparators[value_index] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
size_t to_exclude_ids_len = 0;
|
||||
num_tree->search(EQUALS, bool_int64, &to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = seq_ids->uncompress();
|
||||
auto all_ids_size = seq_ids->num_ids();
|
||||
|
||||
uint32_t* excluded_ids = nullptr;
|
||||
size_t excluded_ids_len = 0;
|
||||
|
||||
excluded_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, to_exclude_ids,
|
||||
to_exclude_ids_len, &excluded_ids);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
|
||||
uint32_t* out = nullptr;
|
||||
result_ids_len = ArrayUtils::or_scalar(result_ids, result_ids_len,
|
||||
excluded_ids, excluded_ids_len, &out);
|
||||
delete[] result_ids;
|
||||
result_ids = out;
|
||||
delete[] excluded_ids;
|
||||
numeric_not_equals_filter(num_tree, bool_int64, result_ids, result_ids_len);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
|
||||
}
|
||||
@ -1869,39 +1881,27 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
|
||||
}
|
||||
|
||||
if (a_filter.comparators[0] == NOT_EQUALS) {
|
||||
// exclude records from existing IDs (from previous filters or ALL records)
|
||||
// "not equals" can only be applied to the entire array so we can do the exclusion operations once here
|
||||
uint32_t* excluded_strt_ids = nullptr;
|
||||
size_t excluded_strt_size = 0;
|
||||
|
||||
if (result_ids == nullptr) {
|
||||
if (filter_ids == nullptr) {
|
||||
result_ids = seq_ids->uncompress();
|
||||
result_ids_len = seq_ids->num_ids();
|
||||
} else {
|
||||
result_ids = filter_ids;
|
||||
result_ids_len = filter_ids_length;
|
||||
}
|
||||
}
|
||||
|
||||
excluded_strt_size = ArrayUtils::exclude_scalar(result_ids, result_ids_len, or_ids,
|
||||
or_ids_size, &excluded_strt_ids);
|
||||
|
||||
if (filter_ids == nullptr) {
|
||||
// means we had to uncompress `seq_ids` so need to free that
|
||||
delete[] result_ids;
|
||||
}
|
||||
|
||||
delete[] or_ids;
|
||||
or_ids = excluded_strt_ids;
|
||||
or_ids_size = excluded_strt_size;
|
||||
}
|
||||
|
||||
result_ids = or_ids;
|
||||
result_ids_len = or_ids_size;
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
auto all_ids = seq_ids->uncompress();
|
||||
auto all_ids_size = seq_ids->num_ids();
|
||||
|
||||
uint32_t* to_include_ids = nullptr;
|
||||
size_t to_include_ids_len = 0;
|
||||
|
||||
to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, result_ids,
|
||||
result_ids_len, &to_include_ids);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] result_ids;
|
||||
|
||||
result_ids = to_include_ids;
|
||||
result_ids_len = to_include_ids_len;
|
||||
}
|
||||
|
||||
filter_ids = result_ids;
|
||||
filter_ids_length = result_ids_len;
|
||||
|
||||
|
@ -598,6 +598,18 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// not equals
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age:!= 24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"3", "1", "4", "2"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// multiple filters
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "years:<2005 && years:>1987", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
@ -626,6 +638,42 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age:= [21, 24, 63]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
// individual comparators can still be applied.
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age: [!=21, >30]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"3", "1", "4", "0"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_EQ(id, result_id);
|
||||
}
|
||||
|
||||
// negate multiple search values (works like SQL's NOT IN) against a single int field
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age:!= [21, 24, 63]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
ids = {"1", "4"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_EQ(id, result_id);
|
||||
}
|
||||
|
||||
// individual comparators can still be applied.
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age: != [<30, >60]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
ids = {"1", "4"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_EQ(id, result_id);
|
||||
}
|
||||
|
||||
// multiple search values against an int32 array field - also use extra padding between symbols
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "years : [ 2015, 1985 , 1999]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
@ -753,6 +801,17 @@ TEST_F(CollectionFilteringTest, FilterOnFloatFields) {
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str()); //?
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating:!=0", facets, sort_fields_asc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"0", "4", "2", "1"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str()); //?
|
||||
}
|
||||
|
||||
// Searching on a float field, sorted desc by rating
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating:>0.0", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
@ -802,6 +861,30 @@ TEST_F(CollectionFilteringTest, FilterOnFloatFields) {
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// negate multiple search values (works like SQL's NOT IN operator) against a single float field
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating:!= [1.09, 7.812]", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
ids = {"1", "4", "3"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_EQ(id, result_id);
|
||||
}
|
||||
|
||||
// individual comparators can still be applied.
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating: != [<5.4, >9]", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
ids = {"2", "4"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_EQ(id, result_id);
|
||||
}
|
||||
|
||||
// multiple search values against a float array field - also use extra padding between symbols
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "top_3 : [ 5.431, 0.001 , 7.812, 11.992]", facets, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
Loading…
x
Reference in New Issue
Block a user