diff --git a/include/index.h b/include/index.h index 26e94e84..f3733db1 100644 --- a/include/index.h +++ b/include/index.h @@ -110,6 +110,7 @@ struct override_t { bool stop_processing = true; std::string sort_by; + std::string replace_query; override_t() = default; @@ -128,9 +129,10 @@ struct override_t { if(override_json.count("includes") == 0 && override_json.count("excludes") == 0 && override_json.count("filter_by") == 0 && override_json.count("sort_by") == 0 && - override_json.count("remove_matched_tokens") == 0) { + override_json.count("remove_matched_tokens") == 0 && + override_json.count("replace_query") == 0) { return Option(400, "Must contain one of: `includes`, `excludes`, " - "`filter_by`, `sort_by`, `remove_matched_tokens`."); + "`filter_by`, `sort_by`, `remove_matched_tokens`, `replace_query`."); } if(override_json.count("includes") != 0) { @@ -242,6 +244,13 @@ struct override_t { override.sort_by = override_json["sort_by"].get(); } + if (override_json.count("replace_query") != 0) { + if(override_json.count("remove_matched_tokens") != 0) { + return Option(400, "Only one of `replace_query` or `remove_matched_tokens` can be specified."); + } + override.replace_query = override_json["replace_query"].get(); + } + if(override_json.count("remove_matched_tokens") != 0) { override.remove_matched_tokens = override_json["remove_matched_tokens"].get(); } else { @@ -308,6 +317,10 @@ struct override_t { override["sort_by"] = sort_by; } + if(!replace_query.empty()) { + override["replace_query"] = replace_query; + } + override["remove_matched_tokens"] = remove_matched_tokens; override["filter_curated_hits"] = filter_curated_hits; override["stop_processing"] = stop_processing; diff --git a/src/collection.cpp b/src/collection.cpp index 00de1d2d..c57707e4 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -481,7 +481,9 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo } } - if(override.remove_matched_tokens && override.filter_by.empty()) { + if(!override.replace_query.empty()) { + actual_query = override.replace_query; + } else if(override.remove_matched_tokens && override.filter_by.empty()) { // don't prematurely remove tokens from query because dynamic filtering will require them StringUtils::replace_all(query, override.rule.query, ""); StringUtils::trim(query); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 27100308..5834ccea 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -781,6 +781,66 @@ TEST_F(CollectionOverrideTest, IncludeOverrideWithFilterBy) { ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); } +TEST_F(CollectionOverrideTest, ReplaceQuery) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["points"] = 30; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Fast Shoes"; + doc2["points"] = 50; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Comfortable Socks"; + doc3["points"] = 1; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json = R"({ + "id": "rule-1", + "rule": { + "query": "boots", + "match": "exact" + }, + "replace_query": "shoes" + })"_json; + + override_t override_rule; + auto op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + auto results = coll1->search("boots", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + + // don't allow both remove_matched_tokens and replace_query + override_json["remove_matched_tokens"] = true; + op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_FALSE(op.ok()); + ASSERT_EQ("Only one of `replace_query` or `remove_matched_tokens` can be specified.", op.error()); +} + TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) { auto pinned_hits = "13:1,4:2";