diff --git a/include/index.h b/include/index.h index 7eb5f5cd..672b4575 100644 --- a/include/index.h +++ b/include/index.h @@ -108,6 +108,8 @@ struct override_t { bool remove_matched_tokens = false; bool filter_curated_hits = false; + bool stop_processing = true; + std::string sort_by; override_t() = default; @@ -199,6 +201,12 @@ struct override_t { } } + if(override_json.count("stop_processing") != 0) { + if (!override_json["stop_processing"].is_boolean()) { + return Option(400, "The `stop_processing` must be a boolean."); + } + } + if(!id.empty()) { override.id = id; } else if(override_json.count("id") != 0) { @@ -245,6 +253,10 @@ struct override_t { override.filter_curated_hits = override_json["filter_curated_hits"].get(); } + if(override_json.count("stop_processing") != 0) { + override.stop_processing = override_json["stop_processing"].get(); + } + // we have to also detect if it is a dynamic query rule size_t i = 0; while(i < override.rule.query.size()) { @@ -299,6 +311,7 @@ struct override_t { override["remove_matched_tokens"] = remove_matched_tokens; override["filter_curated_hits"] = filter_curated_hits; + override["stop_processing"] = stop_processing; return override; } diff --git a/src/index.cpp b/src/index.cpp index 814d6cd5..838c82ff 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2035,15 +2035,16 @@ void Index::process_filter_overrides(const std::vector& filte bool resolved_override = static_filter_query_eval(override, query_tokens, filters); if(resolved_override) { - if(!override->remove_matched_tokens) { - return ; + if(override->remove_matched_tokens) { + std::vector rule_tokens; + Tokenizer(override->rule.query, true).tokenize(rule_tokens); + std::set rule_token_set(rule_tokens.begin(), rule_tokens.end()); + remove_matched_tokens(query_tokens, rule_token_set); } - std::vector rule_tokens; - Tokenizer(override->rule.query, true).tokenize(rule_tokens); - std::set rule_token_set(rule_tokens.begin(), rule_tokens.end()); - remove_matched_tokens(query_tokens, rule_token_set); - return ; + if(override->stop_processing) { + return; + } } } else { // need to extract placeholder field names from the search query, filter on them and rewrite query @@ -2072,7 +2073,9 @@ void Index::process_filter_overrides(const std::vector& filte } } - return ; + if(override->stop_processing) { + return; + } } } } diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index f318dde6..654e9111 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -1809,6 +1809,214 @@ TEST_F(CollectionOverrideTest, StaticFiltering) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionOverrideTest, StaticFilteringMultipleRuleMatch) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("tags", field_types::STRING_ARRAY, true), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["tags"] = {"twitter"}; + doc1["points"] = 3; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Track Shoes"; + doc2["tags"] = {"starred"}; + doc2["points"] = 5; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Track Shoes"; + doc3["tags"] = {"twitter", "starred"}; + doc3["points"] = 10; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_filter1_json = { + {"id", "static-filter-1"}, + { + "rule", { + {"query", "twitter"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", true}, + {"stop_processing", false}, + {"filter_by", "tags: twitter"} + }; + + override_t override_filter1; + auto op = override_t::parse(override_filter1_json, "static-filter-1", override_filter1); + ASSERT_TRUE(op.ok()); + + coll1->add_override(override_filter1); + + nlohmann::json override_filter2_json = { + {"id", "static-filter-2"}, + { + "rule", { + {"query", "starred"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", true}, + {"stop_processing", false}, + {"filter_by", "tags: starred"} + }; + + override_t override_filter2; + op = override_t::parse(override_filter2_json, "static-filter-2", override_filter2); + ASSERT_TRUE(op.ok()); + + coll1->add_override(override_filter2); + + auto results = coll1->search("starred twitter", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); + + // when stop_processing is enabled (default is true) + override_filter1_json.erase("stop_processing"); + override_filter2_json.erase("stop_processing"); + + override_t override_filter1_reset; + op = override_t::parse(override_filter1_json, "static-filter-1", override_filter1_reset); + ASSERT_TRUE(op.ok()); + override_t override_filter2_reset; + op = override_t::parse(override_filter2_json, "static-filter-2", override_filter2_reset); + ASSERT_TRUE(op.ok()); + + coll1->add_override(override_filter1_reset); + coll1->add_override(override_filter2_reset); + + results = coll1->search("starred twitter", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(0, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionOverrideTest, DynamicFilteringMultipleRuleMatch) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("brand", field_types::STRING, false), + field("tags", field_types::STRING_ARRAY, true), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["brand"] = "Nike"; + doc1["tags"] = {"twitter"}; + doc1["points"] = 3; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Track Shoes"; + doc2["brand"] = "Adidas"; + doc2["tags"] = {"starred"}; + doc2["points"] = 5; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Track Shoes"; + doc3["brand"] = "Nike"; + doc3["tags"] = {"twitter", "starred"}; + doc3["points"] = 10; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_filter1_json = { + {"id", "dynamic-filter-1"}, + { + "rule", { + {"query", "{brand}"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", true}, + {"stop_processing", false}, + {"filter_by", "tags: twitter"} + }; + + override_t override_filter1; + auto op = override_t::parse(override_filter1_json, "dynamic-filter-1", override_filter1); + ASSERT_TRUE(op.ok()); + + coll1->add_override(override_filter1); + + nlohmann::json override_filter2_json = { + {"id", "dynamic-filter-2"}, + { + "rule", { + {"query", "{tags}"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", true}, + {"stop_processing", false}, + {"filter_by", "tags: starred"} + }; + + override_t override_filter2; + op = override_t::parse(override_filter2_json, "dynamic-filter-2", override_filter2); + ASSERT_TRUE(op.ok()); + + coll1->add_override(override_filter2); + + auto results = coll1->search("starred nike", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); + + // when stop_processing is enabled (default is true) + override_filter1_json.erase("stop_processing"); + override_filter2_json.erase("stop_processing"); + + override_t override_filter1_reset; + op = override_t::parse(override_filter1_json, "dynamic-filter-1", override_filter1_reset); + ASSERT_TRUE(op.ok()); + override_t override_filter2_reset; + op = override_t::parse(override_filter2_json, "dynamic-filter-2", override_filter2_reset); + ASSERT_TRUE(op.ok()); + + coll1->add_override(override_filter1_reset); + coll1->add_override(override_filter2_reset); + + results = coll1->search("starred nike", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(0, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionOverrideTest, SynonymsAppliedToOverridenQuery) { Collection *coll1;