diff --git a/include/collection.h b/include/collection.h index ff3c991b..52ebb431 100644 --- a/include/collection.h +++ b/include/collection.h @@ -150,7 +150,7 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); - void curate_results(string& actual_query, bool enable_overrides, bool already_segmented, + void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, const std::vector& hidden_hits, std::vector>& included_ids, diff --git a/include/override.h b/include/override.h index 1003d865..344f32ae 100644 --- a/include/override.h +++ b/include/override.h @@ -11,6 +11,7 @@ struct override_t { std::string query; std::string match; bool dynamic_query = false; + std::string filter_by; }; struct add_hit_t { diff --git a/src/collection.cpp b/src/collection.cpp index d5392014..b5aa5afc 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -405,7 +405,8 @@ size_t Collection::batch_index_in_memory(std::vector& index_record return num_indexed; } -void Collection::curate_results(string& actual_query, bool enable_overrides, bool already_segmented, +void Collection::curate_results(string& actual_query, const string& filter_query, + bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, const std::vector& hidden_hits, std::vector>& included_ids, @@ -455,6 +456,10 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo (override.rule.match == override_t::MATCH_CONTAINS && StringUtils::contains_word(query, override.rule.query))) { + if(!override.rule.filter_by.empty() && override.rule.filter_by != filter_query) { + continue; + } + // have to ensure that dropped hits take precedence over added hits for(const auto & hit: override.drop_hits) { Option seq_id_op = doc_id_to_seq_id(hit.doc_id); @@ -1099,7 +1104,7 @@ Option Collection::search(const std::string & raw_query, std::string query = raw_query; bool filter_curated_hits = false; std::string curated_sort_by; - curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, + curate_results(query, simple_filter_query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, included_ids, excluded_ids, filter_overrides, filter_curated_hits, curated_sort_by); if(filter_curated_hits_option == 0 || filter_curated_hits_option == 1) { diff --git a/src/override.cpp b/src/override.cpp index f85bb8e3..1249f4a9 100644 --- a/src/override.cpp +++ b/src/override.cpp @@ -106,6 +106,14 @@ Option override_t::parse(const nlohmann::json& override_json, const std::s override.rule.query = override_json["rule"]["query"].get(); override.rule.match = override_json["rule"]["match"].get(); + if(override_json["rule"].count("filter_by") != 0) { + if(!override_json["rule"]["filter_by"].is_string()) { + return Option(400, "Override `rule.filter_by` must be a string."); + } + + override.rule.filter_by = override_json["rule"]["filter_by"].get(); + } + if (override_json.count("includes") != 0) { for(const auto & include: override_json["includes"]) { add_hit_t add_hit; @@ -187,6 +195,9 @@ nlohmann::json override_t::to_json() const { override["id"] = id; override["rule"]["query"] = rule.query; override["rule"]["match"] = rule.match; + if(!rule.filter_by.empty()) { + override["rule"]["filter_by"] = rule.filter_by; + } override["includes"] = nlohmann::json::array(); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index ea40ccb5..5f73d6a6 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -918,6 +918,68 @@ TEST_F(CollectionOverrideTest, WindowForRule) { ASSERT_EQ(1, results["hits"].size()); } +TEST_F(CollectionOverrideTest, FilterRule) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["points"] = 30; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Fast Shoes"; + doc2["points"] = 50; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Comfortable Socks"; + doc3["points"] = 1; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json = R"({ + "id": "rule-1", + "rule": { + "query": "*", + "match": "exact", + "filter_by": "points: 50" + }, + "includes": [{ + "id": "0", + "position": 1 + }] + })"_json; + + override_t override_rule; + auto op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + auto results = coll1->search("*", {}, "points: 50", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + + // check to_json + nlohmann::json override_json_ser = override_rule.to_json(); + ASSERT_EQ("points: 50", override_json_ser["rule"]["filter_by"]); +} + TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) { auto pinned_hits = "13:1,4:2";