diff --git a/src/collection.cpp b/src/collection.cpp index 84b7bc70..e4b16d7b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2149,6 +2149,27 @@ Option Collection::parse_filter_query(const std::string& simple_filter_que } } } else if(_field.is_bool()) { + NUM_COMPARATOR bool_comparator = EQUALS; + size_t filter_value_index = 0; + + if(raw_value[0] == '=') { + bool_comparator = EQUALS; + while(++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if(raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + bool_comparator = NOT_EQUALS; + filter_value_index++; + while(++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } + + if(filter_value_index != 0) { + raw_value = raw_value.substr(filter_value_index); + } + + if(filter_value_index == raw_value.size()) { + return Option(400, "Error with filter field `" + _field.name + + "`: Filter value cannot be empty."); + } + if(raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { std::vector filter_values; StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); @@ -2162,14 +2183,15 @@ Option Collection::parse_filter_query(const std::string& simple_filter_que filter_value = (filter_value == "true") ? "1" : "0"; f.values.push_back(filter_value); - f.comparators.push_back(EQUALS); + f.comparators.push_back(bool_comparator); } } else { if(raw_value != "true" && raw_value != "false") { return Option(400, "Value of filter field `" + _field.name + "` must be `true` or `false`."); } + std::string bool_value = (raw_value == "true") ? "1" : "0"; - f = {field_name, {bool_value}, {EQUALS}}; + f = {field_name, {bool_value}, {bool_comparator}}; } } else if(_field.is_geopoint()) { @@ -2224,8 +2246,8 @@ Option Collection::parse_filter_query(const std::string& simple_filter_que // string filter should be evaluated in strict "equals" mode str_comparator = EQUALS; - while(raw_value[++filter_value_index] == ' '); - } else if(raw_value[0] == '-') { + while(++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if(raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { if(!_field.facet) { // EXCLUDE filtering on string is possible only on facet fields return Option(400, "To perform exclude filtering, filter field `" + @@ -2233,7 +2255,13 @@ Option Collection::parse_filter_query(const std::string& simple_filter_que } str_comparator = NOT_EQUALS; - while(raw_value[++filter_value_index] == ' '); + filter_value_index++; + while(++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } + + if(filter_value_index == raw_value.size()) { + return Option(400, "Error with filter field `" + _field.name + + "`: Filter value cannot be empty."); } if(raw_value[filter_value_index] == '[' && raw_value[raw_value.size() - 1] == ']') { diff --git a/src/index.cpp b/src/index.cpp index 5d6295ba..09f1b381 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1086,7 +1086,33 @@ uint32_t Index::do_filtering(uint32_t** filter_ids_out, const std::vectorsearch(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len); + if(a_filter.comparators[value_index] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + size_t to_exclude_ids_len = 0; + num_tree->search(EQUALS, bool_int64, &to_exclude_ids, to_exclude_ids_len); + + auto all_ids = seq_ids.uncompress(); + auto all_ids_size = seq_ids.getLength(); + + uint32_t* excluded_ids = nullptr; + size_t excluded_ids_len = 0; + + excluded_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, to_exclude_ids, + to_exclude_ids_len, &excluded_ids); + + delete [] all_ids; + delete [] to_exclude_ids; + + uint32_t *out = nullptr; + result_ids_len = ArrayUtils::or_scalar(result_ids, result_ids_len, + excluded_ids, excluded_ids_len, &out); + delete [] result_ids; + result_ids = out; + delete [] excluded_ids; + } else { + num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len); + } + value_index++; } @@ -1200,7 +1226,9 @@ uint32_t Index::do_filtering(uint32_t** filter_ids_out, const std::vectorvalues); } - if(posting_lists.size() != str_tokens.size()) { + // For NOT_EQUALS alone, it is okay for none of the results to match prior to negation + // e.g. field:- [RANDOM_NON_EXISTING_STRING] + if(a_filter.comparators[0] != NOT_EQUALS && posting_lists.size() != str_tokens.size()) { continue; } diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index f03d4677..5ad85c4a 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -129,6 +129,11 @@ TEST_F(CollectionFilteringTest, FilterOnTextFields) { results = coll_array_fields->search("Jeremy", query_fields, "tags:>BRONZE", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(2, results["hits"].size()); + // bad filter value (empty) + auto res_op = coll_array_fields->search("Jeremy", query_fields, "tags:=", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `tags`: Filter value cannot be empty.", res_op.error()); + collectionManager.drop_collection("coll_array_fields"); } @@ -319,7 +324,7 @@ TEST_F(CollectionFilteringTest, HandleBadlyFormedFilterQuery) { std::vector fields = {field("name", field_types::STRING, false), field("age", field_types::INT32, false), field("years", field_types::INT32_ARRAY, false), field("timestamps", field_types::INT64_ARRAY, false), - field("tags", field_types::STRING_ARRAY, false)}; + field("tags", field_types::STRING_ARRAY, true)}; std::vector sort_fields = { sort_by("age", "DESC") }; @@ -363,6 +368,20 @@ TEST_F(CollectionFilteringTest, HandleBadlyFormedFilterQuery) { results = coll_array_fields->search("Jeremy", query_fields, "age: '21'", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(0, results["hits"].size()); + // empty value for a numerical filter field + auto res_op = coll_array_fields->search("Jeremy", query_fields, "age:", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `age`: Numerical field has an invalid comparator.", res_op.error()); + + // empty value for string filter field + res_op = coll_array_fields->search("Jeremy", query_fields, "tags:", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `tags`: Filter value cannot be empty.", res_op.error()); + + res_op = coll_array_fields->search("Jeremy", query_fields, "tags:= ", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `tags`: Filter value cannot be empty.", res_op.error()); + collectionManager.drop_collection("coll_array_fields"); } @@ -371,10 +390,10 @@ TEST_F(CollectionFilteringTest, FilterAndQueryFieldRestrictions) { std::ifstream infile(std::string(ROOT_DIR)+"test/multi_field_documents.jsonl"); std::vector fields = { - field("title", field_types::STRING, false), - field("starring", field_types::STRING, false), - field("cast", field_types::STRING_ARRAY, true), - field("points", field_types::INT32, false) + field("title", field_types::STRING, false), + field("starring", field_types::STRING, false), + field("cast", field_types::STRING_ARRAY, true), + field("points", field_types::INT32, false) }; coll_mul_fields = collectionManager.get_collection("coll_mul_fields").get(); @@ -1302,7 +1321,7 @@ TEST_F(CollectionFilteringTest, NegationOperatorBasics) { coll1 = collectionManager.get_collection("coll1").get(); if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + coll1 = collectionManager.create_collection("coll1", 2, fields, "points").get(); } std::vector> records = { @@ -1323,7 +1342,7 @@ TEST_F(CollectionFilteringTest, NegationOperatorBasics) { ASSERT_TRUE(coll1->add(doc.dump()).ok()); } - auto results = coll1->search("*", {"artist"}, "artist:- Michael Jackson", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); + auto results = coll1->search("*", {"artist"}, "artist:!=Michael Jackson", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); ASSERT_EQ(3, results["found"].get()); @@ -1331,17 +1350,26 @@ TEST_F(CollectionFilteringTest, NegationOperatorBasics) { ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get().c_str()); ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); - results = coll1->search("*", {"artist"}, "artist:- Michael Jackson && points: >0", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); + results = coll1->search("*", {"artist"}, "artist:!= Michael Jackson && points: >0", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); ASSERT_EQ(2, results["found"].get()); ASSERT_STREQ("3", results["hits"][0]["document"]["id"].get().c_str()); ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get().c_str()); // negation operation on multiple values - results = coll1->search("*", {"artist"}, "artist:- [Michael Jackson, Taylor Swift]", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); + results = coll1->search("*", {"artist"}, "artist:!= [Michael Jackson, Taylor Swift]", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); ASSERT_EQ(1, results["found"].get()); ASSERT_STREQ("3", results["hits"][0]["document"]["id"].get().c_str()); + // empty value (bad filtering) + auto res_op = coll1->search("*", {"artist"}, "artist:!=", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `artist`: Filter value cannot be empty.", res_op.error()); + + res_op = coll1->search("*", {"artist"}, "artist:!= ", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `artist`: Filter value cannot be empty.", res_op.error()); + collectionManager.drop_collection("coll1"); } @@ -1436,3 +1464,137 @@ TEST_F(CollectionFilteringTest, NumericalRangeFilter) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionFilteringTest, QueryBoolFields) { + Collection *coll_bool; + + std::ifstream infile(std::string(ROOT_DIR)+"test/bool_documents.jsonl"); + std::vector fields = { + field("popular", field_types::BOOL, false), + field("title", field_types::STRING, false), + field("rating", field_types::FLOAT, false), + field("bool_array", field_types::BOOL_ARRAY, false), + }; + + std::vector sort_fields = { sort_by("popular", "DESC"), sort_by("rating", "DESC") }; + + coll_bool = collectionManager.get_collection("coll_bool").get(); + if(coll_bool == nullptr) { + coll_bool = collectionManager.create_collection("coll_bool", 1, fields, "rating").get(); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + coll_bool->add(json_line); + } + + infile.close(); + + // Plain search with no filters - results should be sorted correctly + query_fields = {"title"}; + std::vector facets; + nlohmann::json results = coll_bool->search("the", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(5, results["hits"].size()); + + std::vector ids = {"1", "3", "4", "9", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // Searching on a bool field + results = coll_bool->search("the", query_fields, "popular:true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); + + ids = {"1", "3", "4"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // alternative `:=` syntax + results = coll_bool->search("the", query_fields, "popular:=true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); + + results = coll_bool->search("the", query_fields, "popular:false", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(2, results["hits"].size()); + + results = coll_bool->search("the", query_fields, "popular:= false", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(2, results["hits"].size()); + + ids = {"9", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // searching against a bool array field + + // should be able to filter with an array of boolean values + Option res_op = coll_bool->search("the", query_fields, "bool_array:[true, false]", facets, + sort_fields, {0}, 10, 1, FREQUENCY, {false}); + ASSERT_TRUE(res_op.ok()); + results = res_op.get(); + + ASSERT_EQ(5, results["hits"].size()); + + results = coll_bool->search("the", query_fields, "bool_array: true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(4, results["hits"].size()); + ids = {"1", "4", "9", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // should be able to search using array with a single element boolean value + + results = coll_bool->search("the", query_fields, "bool_array:[true]", facets, + sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + + ASSERT_EQ(4, results["hits"].size()); + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // not equals on bool field + + results = coll_bool->search("the", query_fields, "popular:!= true", facets, + sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("9", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][1]["document"]["id"].get()); + + // not equals on bool array field + results = coll_bool->search("the", query_fields, "bool_array:!= [true]", facets, + sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("3", results["hits"][0]["document"]["id"].get()); + + // empty filter value + res_op = coll_bool->search("the", query_fields, "bool_array:=", facets, + sort_fields, {0}, 10, 1, FREQUENCY, {false}); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error with filter field `bool_array`: Filter value cannot be empty.", res_op.error()); + + collectionManager.drop_collection("coll_bool"); +} diff --git a/test/collection_test.cpp b/test/collection_test.cpp index dab4a2f9..77d6dd1f 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -1448,115 +1448,6 @@ TEST_F(CollectionTest, ImportDocuments) { collectionManager.drop_collection("coll_mul_fields"); } -TEST_F(CollectionTest, QueryBoolFields) { - Collection *coll_bool; - - std::ifstream infile(std::string(ROOT_DIR)+"test/bool_documents.jsonl"); - std::vector fields = { - field("popular", field_types::BOOL, false), - field("title", field_types::STRING, false), - field("rating", field_types::FLOAT, false), - field("bool_array", field_types::BOOL_ARRAY, false), - }; - - std::vector sort_fields = { sort_by("popular", "DESC"), sort_by("rating", "DESC") }; - - coll_bool = collectionManager.get_collection("coll_bool").get(); - if(coll_bool == nullptr) { - coll_bool = collectionManager.create_collection("coll_bool", 4, fields, "rating").get(); - } - - std::string json_line; - - while (std::getline(infile, json_line)) { - coll_bool->add(json_line); - } - - infile.close(); - - // Plain search with no filters - results should be sorted correctly - query_fields = {"title"}; - std::vector facets; - nlohmann::json results = coll_bool->search("the", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(5, results["hits"].size()); - - std::vector ids = {"1", "3", "4", "9", "2"}; - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - - // Searching on a bool field - results = coll_bool->search("the", query_fields, "popular:true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(3, results["hits"].size()); - - ids = {"1", "3", "4"}; - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - - // alternative `:=` syntax - results = coll_bool->search("the", query_fields, "popular:=true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(3, results["hits"].size()); - - results = coll_bool->search("the", query_fields, "popular:false", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(2, results["hits"].size()); - - ids = {"9", "2"}; - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - - // searching against a bool array field - - // should be able to filter with an array of boolean values - Option res_op = coll_bool->search("the", query_fields, "bool_array:[true, false]", facets, - sort_fields, {0}, 10, 1, FREQUENCY, {false}); - ASSERT_TRUE(res_op.ok()); - results = res_op.get(); - - ASSERT_EQ(5, results["hits"].size()); - - results = coll_bool->search("the", query_fields, "bool_array: true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(4, results["hits"].size()); - ids = {"1", "4", "9", "2"}; - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - - // should be able to search using array with a single element boolean value - - auto res = coll_bool->search("the", query_fields, "bool_array:[true]", facets, - sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - - results = coll_bool->search("the", query_fields, "bool_array: true", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(4, results["hits"].size()); - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - - collectionManager.drop_collection("coll_bool"); -} - TEST_F(CollectionTest, SearchingWithMissingFields) { // return error without crashing when searching for fields that do not conform to the schema Collection *coll_array_fields;