diff --git a/include/collection.h b/include/collection.h index f04a8034..cadfe47e 100644 --- a/include/collection.h +++ b/include/collection.h @@ -563,7 +563,7 @@ public: const size_t max_extra_prefix = INT16_MAX, const size_t max_extra_suffix = INT16_MAX, const size_t facet_query_num_typos = 2, - const size_t filter_curated_hits_option = 2, + const bool filter_curated_hits_option = false, const bool prioritize_token_position = false, const std::string& vector_query_str = "", const bool enable_highlight_v1 = true, diff --git a/src/collection.cpp b/src/collection.cpp index 6b288948..f2d61c13 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1723,7 +1723,7 @@ Option Collection::search(std::string raw_query, const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, - const size_t filter_curated_hits_option, + const bool filter_curated_hits_option, const bool prioritize_token_position, const std::string& vector_query_str, const bool enable_highlight_v1, @@ -2222,7 +2222,6 @@ Option Collection::search(std::string raw_query, nlohmann::json override_metadata; std::vector filter_overrides; - bool filter_curated_hits = false; std::string curated_sort_by; std::set override_tag_set; @@ -2232,14 +2231,13 @@ Option Collection::search(std::string raw_query, override_tag_set.insert(tag); } + bool filter_curated_hits_overrides = false; + curate_results(query, filter_query, enable_overrides, pre_segmented_query, override_tag_set, - pinned_hits, hidden_hits, included_ids, excluded_ids, filter_overrides, filter_curated_hits, + pinned_hits, hidden_hits, included_ids, excluded_ids, filter_overrides, filter_curated_hits_overrides, curated_sort_by, override_metadata); - if(filter_curated_hits_option == 0 || filter_curated_hits_option == 1) { - // When query param has explicit value set, override level configuration takes lower precedence. - filter_curated_hits = bool(filter_curated_hits_option); - } + bool filter_curated_hits = filter_curated_hits_option || filter_curated_hits_overrides; /*for(auto& kv: included_ids) { LOG(INFO) << "key: " << kv.first; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 40d64a10..4c5af903 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1616,7 +1616,7 @@ Option CollectionManager::do_search(std::map& re bool synonym_prefix = false; size_t synonym_num_typos = 0; - size_t filter_curated_hits_option = 2; + bool filter_curated_hits_option = false; std::string highlight_fields; bool exhaustive_search = false; size_t search_cutoff_ms = 30 * 1000; @@ -1669,7 +1669,6 @@ Option CollectionManager::do_search(std::map& re {MAX_EXTRA_SUFFIX, &max_extra_suffix}, {MAX_CANDIDATES, &max_candidates}, {FACET_QUERY_NUM_TYPOS, &facet_query_num_typos}, - {FILTER_CURATED_HITS, &filter_curated_hits_option}, {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, @@ -1710,6 +1709,7 @@ Option CollectionManager::do_search(std::map& re {SYNONYM_PREFIX, &synonym_prefix}, {ENABLE_LAZY_FILTER, &enable_lazy_filter}, {ENABLE_TYPOS_FOR_ALPHA_NUMERICAL_TOKENS, &enable_typos_for_alpha_numerical_tokens}, + {FILTER_CURATED_HITS, &filter_curated_hits_option}, }; std::unordered_map*> str_list_values = { diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index a3cd3fb4..d696273f 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -410,7 +410,7 @@ TEST_F(CollectionOverrideTest, IncludeHitsFilterOverrides) { "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 32767, 32767, 2, 0).get(); - ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ(1, results["hits"].size()); } @@ -4223,3 +4223,101 @@ TEST_F(CollectionOverrideTest, RetrieveOverideByID) { auto op = coll2->get_override("override1"); ASSERT_TRUE(op.ok()); } + + +TEST_F(CollectionOverrideTest, FilterPinnedHits) { + std::vector fields = {field("title", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + Collection* coll3 = collectionManager.get_collection("coll3").get(); + if(coll3 == nullptr) { + coll3 = collectionManager.create_collection("coll3", 1, fields, "points").get(); + } + + nlohmann::json doc; + + doc["title"] = "Snapdragon 7 gen 2023"; + doc["points"] = 100; + ASSERT_TRUE(coll3->add(doc.dump()).ok()); + + doc["title"] = "Snapdragon 732G 2023"; + doc["points"] = 91; + ASSERT_TRUE(coll3->add(doc.dump()).ok()); + + doc["title"] = "Snapdragon 4 gen 2023"; + doc["points"] = 65; + ASSERT_TRUE(coll3->add(doc.dump()).ok()); + + doc["title"] = "Mediatek Dimensity 720G 2022"; + doc["points"] = 87; + ASSERT_TRUE(coll3->add(doc.dump()).ok()); + + doc["title"] = "Mediatek Dimensity 470G 2023"; + doc["points"] = 63; + ASSERT_TRUE(coll3->add(doc.dump()).ok()); + + auto pinned_hits = "3:1, 4:2"; + + bool filter_curated_hits = false; + auto results = coll3->search("2023", {"title"}, "title: snapdragon", {}, {}, + {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, + "", 30, 5, "", + 10, pinned_hits, {}, {}, 3, + "", "", {}, UINT_MAX, + true, false, true, "", + false, 6000 * 1000, 4, 7, + fallback, 4, {off}, INT16_MAX, + INT16_MAX, 2, filter_curated_hits ).get(); + + + ASSERT_EQ(5, results["hits"].size()); + ASSERT_EQ("3", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("4", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][3]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][4]["document"]["id"].get()); + + filter_curated_hits = true; + results = coll3->search("2023", {"title"}, "title: snapdragon", {}, {}, + {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, + "", 30, 5, "", + 10, pinned_hits, {}, {}, 3, + "", "", {}, UINT_MAX, + true, false, true, "", + false, 6000 * 1000, 4, 7, + fallback, 4, {off}, INT16_MAX, + INT16_MAX, 2, filter_curated_hits).get(); + + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][2]["document"]["id"].get()); + + //partial filter out ids, remaining will take higher precedence than their assignment + results = coll3->search("snapdragon", {"title"}, "title: 2023", {}, {}, + {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, + "", 30, 5, "", + 10, pinned_hits, {}, {}, 3, + "", "", {}, UINT_MAX, + true, false, true, "", + false, 6000 * 1000, 4, 7, + fallback, 4, {off}, INT16_MAX, + INT16_MAX, 2, filter_curated_hits).get(); + + + ASSERT_EQ(4, results["hits"].size()); + ASSERT_EQ("4", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][2]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][3]["document"]["id"].get()); +} \ No newline at end of file