From 0b904a3d6befb39026f47b90a800b4bbc7a3ba04 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Mon, 24 Apr 2023 17:23:27 +0530 Subject: [PATCH] Members of a curated group should not repeat. --- include/index.h | 13 ++- src/index.cpp | 69 ++++++++++--- test/collection_override_test.cpp | 166 ++++++++++++++++++++++++++++++ 3 files changed, 231 insertions(+), 17 deletions(-) diff --git a/include/index.h b/include/index.h index 53c54995..8a622960 100644 --- a/include/index.h +++ b/include/index.h @@ -414,6 +414,7 @@ private: const std::vector& the_fields, const uint32_t* filter_ids, size_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, std::vector& token_candidates_vec, std::vector>& searched_queries, @@ -737,7 +738,7 @@ public: std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, - size_t exclude_token_ids_size, + size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, uint32_t filter_ids_length, const size_t concurrency, const int* sort_order, @@ -784,6 +785,7 @@ public: std::array*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed) const; @@ -802,6 +804,7 @@ public: size_t min_len_2typo, const size_t max_candidates, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + const std::unordered_set& excluded_group_ids, Topster* actual_topster, std::vector>& q_pos_synonyms, int syn_orig_num_tokens, @@ -829,6 +832,7 @@ public: spp::sparse_hash_map& groups_processed, const std::set& curated_ids, const uint32_t* excluded_result_ids, size_t excluded_result_ids_size, + const std::unordered_set& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, bool is_wildcard_query, @@ -842,6 +846,7 @@ public: size_t exclude_token_ids_size, const uint32_t* filter_ids, size_t filter_ids_length, const std::vector& curated_ids, + const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, const std::vector& num_typos, std::vector>& searched_queries, @@ -895,6 +900,7 @@ public: const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + const std::unordered_set& excluded_group_ids, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices, @@ -940,11 +946,12 @@ public: void process_curated_ids(const std::vector>& included_ids, - const std::vector& excluded_ids, + const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, const bool filter_curated_hits, const uint32_t* filter_ids, uint32_t filter_ids_length, std::set& curated_ids, std::map>& included_ids_map, - std::vector& included_ids_vec) const; + std::vector& included_ids_vec, + std::unordered_set& excluded_group_ids) const; }; template diff --git a/src/index.cpp b/src/index.cpp index e6305315..46595c85 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1265,6 +1265,7 @@ void Index::search_all_candidates(const size_t num_search_fields, const std::vector& the_fields, const uint32_t* filter_ids, size_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, std::vector& token_candidates_vec, std::vector>& searched_queries, @@ -1332,7 +1333,7 @@ void Index::search_all_candidates(const size_t num_search_fields, searched_queries, qtoken_set, dropped_tokens, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, filter_ids, filter_ids_length, total_cost, syn_orig_num_tokens, - exclude_token_ids, exclude_token_ids_size, + exclude_token_ids, exclude_token_ids_size, excluded_group_ids, sort_order, field_values, geopoint_indices, id_buff, all_result_ids, all_result_ids_len); @@ -2726,8 +2727,11 @@ Option Index::search(std::vector& field_query_tokens, cons std::set curated_ids; std::map> included_ids_map; // outer pos => inner pos => list of IDs std::vector included_ids_vec; - process_curated_ids(included_ids, excluded_ids, group_limit, filter_curated_hits, - filter_result.docs, filter_result.count, curated_ids, included_ids_map, included_ids_vec); + std::unordered_set excluded_group_ids; + + process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits, + filter_result.docs, filter_result.count, curated_ids, included_ids_map, + included_ids_vec, excluded_group_ids); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2764,7 +2768,7 @@ Option Index::search(std::vector& field_query_tokens, cons sort_fields_std, searched_queries, group_limit, group_by_fields, topster, sort_order, field_values, geopoint_indices, curated_ids_sorted, all_result_ids, all_result_ids_len, groups_processed, curated_ids, - excluded_result_ids, excluded_result_ids_size, curated_topster, + excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, included_ids_map, is_wildcard_query, filter_result.docs, filter_result.count); if (filter_result.count == 0) { @@ -2788,6 +2792,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint64_t distinct_id = seq_id; if (group_limit != 0) { distinct_id = get_distinct_id(group_by_fields, seq_id); + if(excluded_group_ids.count(distinct_id) != 0) { + continue; + } } int64_t scores[3] = {0}; @@ -2884,6 +2891,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint64_t distinct_id = seq_id; if (group_limit != 0) { distinct_id = get_distinct_id(group_by_fields, seq_id); + if(excluded_group_ids.count(distinct_id) != 0) { + continue; + } } auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : @@ -2918,7 +2928,7 @@ Option Index::search(std::vector& field_query_tokens, cons search_wildcard(filter_tree_root, included_ids_map, sort_fields_std, topster, curated_topster, groups_processed, searched_queries, group_limit, group_by_fields, curated_ids, curated_ids_sorted, - excluded_result_ids, excluded_result_ids_size, + excluded_result_ids, excluded_result_ids_size, excluded_group_ids, all_result_ids, all_result_ids_len, filter_result.docs, filter_result.count, concurrency, sort_order, field_values, geopoint_indices); @@ -2963,6 +2973,7 @@ Option Index::search(std::vector& field_query_tokens, cons fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, {}, match_type, excluded_result_ids, excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, + excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, @@ -3000,6 +3011,7 @@ Option Index::search(std::vector& field_query_tokens, cons fuzzy_search_fields(the_fields, resolved_tokens, {}, match_type, excluded_result_ids, excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, + excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, @@ -3013,7 +3025,8 @@ Option Index::search(std::vector& field_query_tokens, cons 0, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, exhaustive_search, concurrency, prefixes, min_len_1typo, min_len_2typo, max_candidates, curated_ids, curated_ids_sorted, - excluded_result_ids, excluded_result_ids_size, topster, q_pos_synonyms, syn_orig_num_tokens, + excluded_result_ids, excluded_result_ids_size, excluded_group_ids, + topster, q_pos_synonyms, syn_orig_num_tokens, groups_processed, searched_queries, all_result_ids, all_result_ids_len, filter_result.docs, filter_result.count, query_hashes, sort_order, field_values, geopoint_indices, @@ -3067,7 +3080,7 @@ Option Index::search(std::vector& field_query_tokens, cons fuzzy_search_fields(the_fields, truncated_tokens, dropped_tokens, match_type, excluded_result_ids, excluded_result_ids_size, filter_result.docs, filter_result.count, - curated_ids_sorted, sort_fields_std, num_typos, searched_queries, + curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, @@ -3088,7 +3101,7 @@ Option Index::search(std::vector& field_query_tokens, cons field_query_tokens[0].q_include_tokens, topster, filter_result.docs, filter_result.count, sort_order, field_values, geopoint_indices, - curated_ids_sorted, all_result_ids, all_result_ids_len, groups_processed); + curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed); if(!vector_query.field_name.empty()) { // check at least one of sort fields is text match @@ -3324,15 +3337,26 @@ Option Index::search(std::vector& field_query_tokens, cons } void Index::process_curated_ids(const std::vector>& included_ids, - const std::vector& excluded_ids, const size_t group_limit, + const std::vector& excluded_ids, + const std::vector& group_by_fields, const size_t group_limit, const bool filter_curated_hits, const uint32_t* filter_ids, uint32_t filter_ids_length, std::set& curated_ids, std::map>& included_ids_map, - std::vector& included_ids_vec) const { + std::vector& included_ids_vec, + std::unordered_set& excluded_group_ids) const { for(const auto& seq_id_pos: included_ids) { included_ids_vec.push_back(seq_id_pos.first); } + + if(group_limit != 0) { + // if one `id` of a group is present in curated hits, we have to exclude that entire group from results + for(auto seq_id: included_ids_vec) { + uint64_t distinct_id = get_distinct_id(group_by_fields, seq_id); + excluded_group_ids.emplace(distinct_id); + } + } + std::sort(included_ids_vec.begin(), included_ids_vec.end()); // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition @@ -3410,6 +3434,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, size_t exclude_token_ids_size, const uint32_t* filter_ids, size_t filter_ids_length, const std::vector& curated_ids, + const std::unordered_set& excluded_group_ids, const std::vector & sort_fields, const std::vector& num_typos, std::vector> & searched_queries, @@ -3658,7 +3683,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, if(token_candidates_vec.size() == query_tokens.size()) { std::vector id_buff; search_all_candidates(num_search_fields, match_type, the_fields, filter_ids, filter_ids_length, - exclude_token_ids, exclude_token_ids_size, + exclude_token_ids, exclude_token_ids_size, excluded_group_ids, sort_fields, token_candidates_vec, searched_queries, qtoken_set, dropped_tokens, topster, groups_processed, all_result_ids, all_result_ids_len, @@ -3822,6 +3847,7 @@ void Index::search_across_fields(const std::vector& query_tokens, const uint32_t* filter_ids, uint32_t filter_ids_length, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + const std::unordered_set& excluded_group_ids, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices, @@ -4006,6 +4032,9 @@ void Index::search_across_fields(const std::vector& query_tokens, uint64_t distinct_id = seq_id; if(group_limit != 0) { distinct_id = get_distinct_id(group_by_fields, seq_id); + if(excluded_group_ids.count(distinct_id) != 0) { + return; + } } int64_t scores[3] = {0}; @@ -4342,6 +4371,7 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector& groups_processed, const std::set& curated_ids, const uint32_t* excluded_result_ids, size_t excluded_result_ids_size, + const std::unordered_set& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, bool is_wildcard_query, @@ -4484,6 +4514,9 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector& the_fields, size_t min_len_2typo, const size_t max_candidates, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + const std::unordered_set& excluded_group_ids, Topster* actual_topster, std::vector>& q_pos_synonyms, int syn_orig_num_tokens, @@ -4531,7 +4565,7 @@ void Index::do_synonym_search(const std::vector& the_fields, for (const auto& syn_tokens : q_pos_synonyms) { query_hashes.clear(); fuzzy_search_fields(the_fields, syn_tokens, {}, match_type, exclude_token_ids, - exclude_token_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, + exclude_token_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, excluded_group_ids, sort_fields_std, {0}, searched_queries, qtoken_set, actual_topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, @@ -4554,6 +4588,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed) const { @@ -4612,6 +4647,9 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, - size_t exclude_token_ids_size, + size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, uint32_t filter_ids_length, const size_t concurrency, const int* sort_order, @@ -4921,7 +4959,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, thread_pool->enqueue([this, &parent_search_begin, &parent_search_stop_ms, &parent_search_cutoff, thread_id, &sort_fields, &searched_queries, - &group_limit, &group_by_fields, &topsters, &tgroups_processed, + &group_limit, &group_by_fields, &topsters, &tgroups_processed, &excluded_group_ids, &sort_order, field_values, &geopoint_indices, &plists, check_for_circuit_break, batch_result_ids, batch_res_len, @@ -4949,6 +4987,9 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, uint64_t distinct_id = seq_id; if(group_limit != 0) { distinct_id = get_distinct_id(group_by_fields, seq_id); + if(excluded_group_ids.count(distinct_id) != 0) { + continue; + } } KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index f3ca16a1..b602c777 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -1118,6 +1118,105 @@ TEST_F(CollectionOverrideTest, FilterRule) { ASSERT_EQ(0, override_json_ser["rule"].count("match")); } +TEST_F(CollectionOverrideTest, CurationGroupingNonCuratedHitsShouldNotAppearOutside) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("group_id", field_types::STRING, true),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 3, fields).get(); + } + + nlohmann::json doc; + doc["id"] = "1"; + doc["title"] = "The Harry Potter 1"; + doc["group_id"] = "hp"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["id"] = "2"; + doc["title"] = "The Harry Potter 2"; + doc["group_id"] = "hp"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["id"] = "3"; + doc["title"] = "Lord of the Rings"; + doc["group_id"] = "lotr"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + nlohmann::json override_json = R"({ + "id": "rule-1", + "rule": { + "query": "*", + "match": "exact" + }, + "includes": [{ + "id": "2", + "position": 1 + }] + })"_json; + + override_t override_rule; + auto op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + override_json = R"({ + "id": "rule-2", + "rule": { + "query": "the", + "match": "exact" + }, + "includes": [{ + "id": "2", + "position": 1 + }] + })"_json; + + override_t override_rule2; + op = override_t::parse(override_json, "rule-2", override_rule2); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule2); + + auto results = coll1->search("*", {"title"}, "", {}, {}, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + "", {}, {"group_id"}, 2).get(); + + // when only one of the 2 records belonging to a record is used for curation, the other record + // should not appear back + + ASSERT_EQ(2, results["found"].get()); + + ASSERT_EQ(1, results["grouped_hits"][0]["hits"].size()); + ASSERT_EQ(1, results["grouped_hits"][1]["hits"].size()); + + ASSERT_EQ("2", results["grouped_hits"][0]["hits"][0]["document"]["id"].get()); + ASSERT_EQ("3", results["grouped_hits"][1]["hits"][0]["document"]["id"].get()); + + // same for keyword search + results = coll1->search("the", {"title"}, "", {}, {}, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + "", {}, {"group_id"}, 2).get(); + + // when only one of the 2 records belonging to a record is used for curation, the other record + // should not appear back + + ASSERT_EQ(2, results["found"].get()); + + ASSERT_EQ(1, results["grouped_hits"][0]["hits"].size()); + ASSERT_EQ(1, results["grouped_hits"][1]["hits"].size()); + + ASSERT_EQ("2", results["grouped_hits"][0]["hits"][0]["document"]["id"].get()); + ASSERT_EQ("3", results["grouped_hits"][1]["hits"][0]["document"]["id"].get()); +} + TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) { auto pinned_hits = "13:1,4:2"; @@ -1461,6 +1560,73 @@ TEST_F(CollectionOverrideTest, PinnedHitsGrouping) { ASSERT_STREQ("16", results["grouped_hits"][4]["hits"][0]["document"]["id"].get().c_str()); } +TEST_F(CollectionOverrideTest, PinnedHitsGroupingNonPinnedHitsShouldNotAppearOutside) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("group_id", field_types::STRING, true),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 3, fields).get(); + } + + nlohmann::json doc; + doc["id"] = "1"; + doc["title"] = "The Harry Potter 1"; + doc["group_id"] = "hp"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["id"] = "2"; + doc["title"] = "The Harry Potter 2"; + doc["group_id"] = "hp"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + doc["id"] = "3"; + doc["title"] = "Lord of the Rings"; + doc["group_id"] = "lotr"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto pinned_hits = "2:1"; + + auto results = coll1->search("*", {"title"}, "", {}, {}, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + pinned_hits, {}, {"group_id"}, 2).get(); + + // when only one of the 2 records belonging to a record is used for curation, the other record + // should not appear back + + ASSERT_EQ(2, results["found"].get()); + + ASSERT_EQ(1, results["grouped_hits"][0]["hits"].size()); + ASSERT_EQ(1, results["grouped_hits"][1]["hits"].size()); + + ASSERT_EQ("2", results["grouped_hits"][0]["hits"][0]["document"]["id"].get()); + ASSERT_EQ("3", results["grouped_hits"][1]["hits"][0]["document"]["id"].get()); + + // same for keyword search + results = coll1->search("the", {"title"}, "", {}, {}, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + pinned_hits, {}, {"group_id"}, 2).get(); + + // when only one of the 2 records belonging to a record is used for curation, the other record + // should not appear back + + ASSERT_EQ(2, results["found"].get()); + + ASSERT_EQ(1, results["grouped_hits"][0]["hits"].size()); + ASSERT_EQ(1, results["grouped_hits"][1]["hits"].size()); + + ASSERT_EQ("2", results["grouped_hits"][0]["hits"][0]["document"]["id"].get()); + ASSERT_EQ("3", results["grouped_hits"][1]["hits"][0]["document"]["id"].get()); +} + TEST_F(CollectionOverrideTest, PinnedHitsWithWildCardQuery) { Collection *coll1;