use filter_curated_hits to filter pinned_ids (#1711)

This commit is contained in:
Krunal Gandhi 2024-05-08 06:04:23 +05:30 committed by GitHub
parent d69c6a09c7
commit 39cb7ba9ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 11 deletions

View File

@ -563,7 +563,7 @@ public:
const size_t max_extra_prefix = INT16_MAX,
const size_t max_extra_suffix = INT16_MAX,
const size_t facet_query_num_typos = 2,
const size_t filter_curated_hits_option = 2,
const bool filter_curated_hits_option = false,
const bool prioritize_token_position = false,
const std::string& vector_query_str = "",
const bool enable_highlight_v1 = true,

View File

@ -1723,7 +1723,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t max_extra_prefix,
const size_t max_extra_suffix,
const size_t facet_query_num_typos,
const size_t filter_curated_hits_option,
const bool filter_curated_hits_option,
const bool prioritize_token_position,
const std::string& vector_query_str,
const bool enable_highlight_v1,
@ -2222,7 +2222,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
nlohmann::json override_metadata;
std::vector<const override_t*> filter_overrides;
bool filter_curated_hits = false;
std::string curated_sort_by;
std::set<std::string> override_tag_set;
@ -2232,14 +2231,13 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
override_tag_set.insert(tag);
}
bool filter_curated_hits_overrides = false;
curate_results(query, filter_query, enable_overrides, pre_segmented_query, override_tag_set,
pinned_hits, hidden_hits, included_ids, excluded_ids, filter_overrides, filter_curated_hits,
pinned_hits, hidden_hits, included_ids, excluded_ids, filter_overrides, filter_curated_hits_overrides,
curated_sort_by, override_metadata);
if(filter_curated_hits_option == 0 || filter_curated_hits_option == 1) {
// When query param has explicit value set, override level configuration takes lower precedence.
filter_curated_hits = bool(filter_curated_hits_option);
}
bool filter_curated_hits = filter_curated_hits_option || filter_curated_hits_overrides;
/*for(auto& kv: included_ids) {
LOG(INFO) << "key: " << kv.first;

View File

@ -1616,7 +1616,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
bool synonym_prefix = false;
size_t synonym_num_typos = 0;
size_t filter_curated_hits_option = 2;
bool filter_curated_hits_option = false;
std::string highlight_fields;
bool exhaustive_search = false;
size_t search_cutoff_ms = 30 * 1000;
@ -1669,7 +1669,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{MAX_EXTRA_SUFFIX, &max_extra_suffix},
{MAX_CANDIDATES, &max_candidates},
{FACET_QUERY_NUM_TYPOS, &facet_query_num_typos},
{FILTER_CURATED_HITS, &filter_curated_hits_option},
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
@ -1710,6 +1709,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{SYNONYM_PREFIX, &synonym_prefix},
{ENABLE_LAZY_FILTER, &enable_lazy_filter},
{ENABLE_TYPOS_FOR_ALPHA_NUMERICAL_TOKENS, &enable_typos_for_alpha_numerical_tokens},
{FILTER_CURATED_HITS, &filter_curated_hits_option},
};
std::unordered_map<std::string, std::vector<std::string>*> str_list_values = {

View File

@ -410,7 +410,7 @@ TEST_F(CollectionOverrideTest, IncludeHitsFilterOverrides) {
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2, 0).get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ(1, results["hits"].size());
}
@ -4223,3 +4223,101 @@ TEST_F(CollectionOverrideTest, RetrieveOverideByID) {
auto op = coll2->get_override("override1");
ASSERT_TRUE(op.ok());
}
TEST_F(CollectionOverrideTest, FilterPinnedHits) {
std::vector<field> fields = {field("title", field_types::STRING, false),
field("points", field_types::INT32, false)};
Collection* coll3 = collectionManager.get_collection("coll3").get();
if(coll3 == nullptr) {
coll3 = collectionManager.create_collection("coll3", 1, fields, "points").get();
}
nlohmann::json doc;
doc["title"] = "Snapdragon 7 gen 2023";
doc["points"] = 100;
ASSERT_TRUE(coll3->add(doc.dump()).ok());
doc["title"] = "Snapdragon 732G 2023";
doc["points"] = 91;
ASSERT_TRUE(coll3->add(doc.dump()).ok());
doc["title"] = "Snapdragon 4 gen 2023";
doc["points"] = 65;
ASSERT_TRUE(coll3->add(doc.dump()).ok());
doc["title"] = "Mediatek Dimensity 720G 2022";
doc["points"] = 87;
ASSERT_TRUE(coll3->add(doc.dump()).ok());
doc["title"] = "Mediatek Dimensity 470G 2023";
doc["points"] = 63;
ASSERT_TRUE(coll3->add(doc.dump()).ok());
auto pinned_hits = "3:1, 4:2";
bool filter_curated_hits = false;
auto results = coll3->search("2023", {"title"}, "title: snapdragon", {}, {},
{0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10,
"", 30, 5, "",
10, pinned_hits, {}, {}, 3,
"<mark>", "</mark>", {}, UINT_MAX,
true, false, true, "",
false, 6000 * 1000, 4, 7,
fallback, 4, {off}, INT16_MAX,
INT16_MAX, 2, filter_curated_hits ).get();
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("4", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][3]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][4]["document"]["id"].get<std::string>());
filter_curated_hits = true;
results = coll3->search("2023", {"title"}, "title: snapdragon", {}, {},
{0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10,
"", 30, 5, "",
10, pinned_hits, {}, {}, 3,
"<mark>", "</mark>", {}, UINT_MAX,
true, false, true, "",
false, 6000 * 1000, 4, 7,
fallback, 4, {off}, INT16_MAX,
INT16_MAX, 2, filter_curated_hits).get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][2]["document"]["id"].get<std::string>());
//partial filter out ids, remaining will take higher precedence than their assignment
results = coll3->search("snapdragon", {"title"}, "title: 2023", {}, {},
{0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10,
"", 30, 5, "",
10, pinned_hits, {}, {}, 3,
"<mark>", "</mark>", {}, UINT_MAX,
true, false, true, "",
false, 6000 * 1000, 4, 7,
fallback, 4, {off}, INT16_MAX,
INT16_MAX, 2, filter_curated_hits).get();
ASSERT_EQ(4, results["hits"].size());
ASSERT_EQ("4", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][3]["document"]["id"].get<std::string>());
}