From 9197627e8149ae53d4f87382b88e73dc173e85e0 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 11 Mar 2022 20:52:22 +0530 Subject: [PATCH] Add option to filter curated hits. --- include/collection.h | 5 +- include/index.h | 15 +++--- src/collection.cpp | 35 +++++--------- src/collection_manager.cpp | 8 +++- src/index.cpp | 78 ++++++++++++++++++++++++------- test/collection_override_test.cpp | 37 +++++++++++++++ 6 files changed, 129 insertions(+), 49 deletions(-) diff --git a/include/collection.h b/include/collection.h index 7eaa0d4e..e54eed0c 100644 --- a/include/collection.h +++ b/include/collection.h @@ -217,7 +217,7 @@ private: void curate_results(string& actual_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, const std::vector& hidden_hits, - std::map>& include_ids, + std::vector>& included_ids, std::vector& excluded_ids, std::vector& filter_overrides) const; Option check_and_update_schema(nlohmann::json& document, const DIRTY_VALUES& dirty_values); @@ -407,7 +407,8 @@ public: const std::vector& infixes = {off}, const size_t max_extra_prefix = INT16_MAX, const size_t max_extra_suffix = INT16_MAX, - const size_t facet_query_num_typos = 2) const; + const size_t facet_query_num_typos = 2, + const bool filter_curated_hits = false) const; Option get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids); diff --git a/include/index.h b/include/index.h index d3360553..497bd5af 100644 --- a/include/index.h +++ b/include/index.h @@ -274,7 +274,7 @@ struct search_args { std::vector search_fields; std::vector filters; std::vector& facets; - std::map> included_ids; + std::vector>& included_ids; std::vector excluded_ids; std::vector sort_fields_std; facet_query_t facet_query; @@ -302,6 +302,7 @@ struct search_args { const size_t max_extra_prefix; const size_t max_extra_suffix; const size_t facet_query_num_typos; + const bool filter_curated_hits; spp::sparse_hash_set groups_processed; std::vector> searched_queries; @@ -312,7 +313,7 @@ struct search_args { search_args(std::vector field_query_tokens, std::vector search_fields, std::vector filters, std::vector& facets, - std::map> included_ids, std::vector excluded_ids, + std::vector>& included_ids, std::vector excluded_ids, std::vector sort_fields_std, facet_query_t facet_query, const std::vector& num_typos, size_t max_facet_values, size_t max_hits, size_t per_page, size_t page, token_ordering token_order, const std::vector& prefixes, size_t drop_tokens_threshold, size_t typo_tokens_threshold, @@ -320,7 +321,8 @@ struct search_args { const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search, size_t concurrency, const std::vector& dynamic_overrides, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, - const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos) : + const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, + const bool filter_curated_hits) : field_query_tokens(field_query_tokens), search_fields(search_fields), filters(filters), facets(facets), included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), @@ -333,7 +335,7 @@ struct search_args { filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms), min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates), infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix), - facet_query_num_typos(facet_query_num_typos) { + facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits) { const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory topster = new Topster(topster_size, group_limit); @@ -681,7 +683,7 @@ public: void search(std::vector& field_query_tokens, const std::vector& the_fields, std::vector& filters, std::vector& facets, facet_query_t& facet_query, - const std::map>& included_ids_map, + const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, const size_t per_page, @@ -694,7 +696,8 @@ public: const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search, size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, - const size_t max_extra_suffix, const size_t facet_query_num_typos) const; + const size_t max_extra_suffix, const size_t facet_query_num_typos, + const bool filter_curated_hits) const; Option remove(const uint32_t seq_id, const nlohmann::json & document, const bool is_update); diff --git a/src/collection.cpp b/src/collection.cpp index 23b67eca..7cd2faf4 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -398,7 +398,7 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash void Collection::curate_results(string& actual_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, const std::vector& hidden_hits, - std::map>& include_ids, + std::vector>& included_ids, std::vector& excluded_ids, std::vector& filter_overrides) const { @@ -452,7 +452,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo uint32_t seq_id = seq_id_op.get(); bool excluded = (excluded_set.count(seq_id) != 0); if(!excluded) { - include_ids[hit.position].push_back(seq_id); + included_ids.emplace_back(seq_id, hit.position); } } @@ -480,7 +480,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo uint32_t seq_id = seq_id_op.get(); bool excluded = (excluded_set.count(seq_id) != 0); if(!excluded) { - include_ids[pos].push_back(seq_id); + included_ids.emplace_back(seq_id, pos); } } } @@ -698,7 +698,8 @@ Option Collection::search(const std::string & raw_query, const s const std::vector& infixes, const size_t max_extra_prefix, const size_t max_extra_suffix, - const size_t facet_query_num_typos) const { + const size_t facet_query_num_typos, + const bool filter_curated_hits) const { std::shared_lock lock(mutex); @@ -919,7 +920,7 @@ Option Collection::search(const std::string & raw_query, const s spp::sparse_hash_set groups_processed; // used to calculate total_found for grouped query std::vector excluded_ids; - std::map> include_ids; // position => list of IDs + std::vector> included_ids; // ID -> position std::map> pinned_hits; Option pinned_hits_op = parse_pinned_hits(pinned_hits_str, pinned_hits); @@ -934,9 +935,9 @@ Option Collection::search(const std::string & raw_query, const s std::vector filter_overrides; std::string query = raw_query; curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, - include_ids, excluded_ids, filter_overrides); + included_ids, excluded_ids, filter_overrides); - /*for(auto& kv: include_ids) { + /*for(auto& kv: included_ids) { LOG(INFO) << "key: " << kv.first; for(auto val: kv.second) { LOG(INFO) << val; @@ -949,8 +950,8 @@ Option Collection::search(const std::string & raw_query, const s LOG(INFO) << id; } - LOG(INFO) << "include_ids size: " << include_ids.size(); - for(auto& group: include_ids) { + LOG(INFO) << "included_ids size: " << included_ids.size(); + for(auto& group: included_ids) { for(uint32_t& seq_id: group.second) { LOG(INFO) << "seq_id: " << seq_id; } @@ -959,19 +960,6 @@ Option Collection::search(const std::string & raw_query, const s } */ - std::map> included_ids; - - for(const auto& pos_ids: include_ids) { - size_t outer_pos = pos_ids.first; - size_t ids_per_pos = std::max(size_t(1), group_limit); - - for(size_t inner_pos = 0; inner_pos < std::min(ids_per_pos, pos_ids.second.size()); inner_pos++) { - auto seq_id = pos_ids.second[inner_pos]; - included_ids[outer_pos][inner_pos] = seq_id; - //LOG(INFO) << "Adding seq_id " << seq_id << " to index_id " << index_id; - } - } - //LOG(INFO) << "Num indices used for querying: " << indices.size(); std::vector field_query_tokens; std::vector q_tokens; // used for auxillary highlighting @@ -1033,7 +1021,8 @@ Option Collection::search(const std::string & raw_query, const s exhaustive_search, 4, filter_overrides, search_stop_millis, min_len_1typo, min_len_2typo, max_candidates, infixes, - max_extra_prefix, max_extra_suffix, facet_query_num_typos); + max_extra_prefix, max_extra_suffix, facet_query_num_typos, + filter_curated_hits); index->run_search(search_params); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 03601e2e..0eeacbb3 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -613,6 +613,8 @@ Option CollectionManager::do_search(std::map& re const char *PINNED_HITS = "pinned_hits"; const char *HIDDEN_HITS = "hidden_hits"; const char *ENABLE_OVERRIDES = "enable_overrides"; + const char *FILTER_CURATED_HITS = "filter_curated_hits"; + const char *MAX_CANDIDATES = "max_candidates"; const char *INFIX = "infix"; @@ -703,6 +705,7 @@ Option CollectionManager::do_search(std::map& re bool prioritize_exact_match = true; bool pre_segmented_query = false; bool enable_overrides = true; + bool filter_curated_hits = false; std::string highlight_fields; bool exhaustive_search = false; size_t search_stop_millis; @@ -749,6 +752,7 @@ Option CollectionManager::do_search(std::map& re {EXHAUSTIVE_SEARCH, &exhaustive_search}, {SPLIT_JOIN_TOKENS, &split_join_tokens}, {ENABLE_OVERRIDES, &enable_overrides}, + {FILTER_CURATED_HITS, &filter_curated_hits}, }; std::unordered_map*> str_list_values = { @@ -892,7 +896,9 @@ Option CollectionManager::do_search(std::map& re max_candidates, infixes, max_extra_prefix, - max_extra_suffix + max_extra_suffix, + facet_query_num_typos, + filter_curated_hits ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/index.cpp b/src/index.cpp index 3dd10e15..8365dc73 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1653,7 +1653,8 @@ void Index::run_search(search_args* search_params) { search_params->infixes, search_params->max_extra_prefix, search_params->max_extra_suffix, - search_params->facet_query_num_typos); + search_params->facet_query_num_typos, + search_params->filter_curated_hits); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -1666,7 +1667,7 @@ void Index::collate_included_ids(const std::vector& q_included_toke return; } - // calculate match_score and add to topster independently + // created searched queries so that curated results can be highlighted std::vector override_query; @@ -2084,7 +2085,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name void Index::search(std::vector& field_query_tokens, const std::vector& the_fields, std::vector& filters, std::vector& facets, facet_query_t& facet_query, - const std::map>& included_ids_map, + const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, const size_t per_page, @@ -2099,7 +2100,8 @@ void Index::search(std::vector& field_query_tokens, const std::v const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search, size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, - const size_t max_extra_suffix, const size_t facet_query_num_typos) const { + const size_t max_extra_suffix, const size_t facet_query_num_typos, + const bool filter_curated_hits) const { search_begin = std::chrono::high_resolution_clock::now(); search_stop_ms = search_cutoff_ms; @@ -2112,26 +2114,68 @@ void Index::search(std::vector& field_query_tokens, const std::v std::shared_lock lock(mutex); - // we will be removing all curated IDs from organic result ids before running topster - std::set curated_ids; - std::vector included_ids; + process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters); + do_filtering(filter_ids, filter_ids_length, filters, true); - for(const auto& outer_pos_ids: included_ids_map) { - for(const auto& inner_pos_seq_id: outer_pos_ids.second) { - curated_ids.insert(inner_pos_seq_id.second); - included_ids.push_back(inner_pos_seq_id.second); + std::vector included_ids_vec; + for(const auto& seq_id_pos: included_ids) { + included_ids_vec.push_back(seq_id_pos.first); + } + std::sort(included_ids_vec.begin(), included_ids_vec.end()); + + std::map> included_ids_map; // outer pos => inner pos => list of IDs + + // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition + std::set included_ids_set; + + if(filter_ids_length != 0 && filter_curated_hits) { + uint32_t* included_ids_arr = nullptr; + size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids, + filter_ids_length, &included_ids_arr); + + included_ids_vec.clear(); + + for(size_t i = 0; i < included_ids_len; i++) { + included_ids_set.insert(included_ids_arr[i]); + included_ids_vec.push_back(included_ids_arr[i]); + } + + delete [] included_ids_arr; + } else { + included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end()); + } + + std::map> included_ids_grouped; + + for(const auto& seq_id_pos: included_ids) { + if(included_ids_set.count(seq_id_pos.first) == 0) { + continue; + } + included_ids_grouped[seq_id_pos.second].push_back(seq_id_pos.first); + } + + for(const auto& pos_ids: included_ids_grouped) { + size_t outer_pos = pos_ids.first; + size_t ids_per_pos = std::max(size_t(1), group_limit); + + for(size_t inner_pos = 0; inner_pos < std::min(ids_per_pos, pos_ids.second.size()); inner_pos++) { + auto seq_id = pos_ids.second[inner_pos]; + included_ids_map[outer_pos][inner_pos] = seq_id; } } + std::set curated_ids; curated_ids.insert(excluded_ids.begin(), excluded_ids.end()); + for(const auto& outer_pos_inner_pos_ids: included_ids_map) { + for(const auto& inner_pos_ids: outer_pos_inner_pos_ids.second) { + curated_ids.insert(inner_pos_ids.second); + } + } + std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); - process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters); - - do_filtering(filter_ids, filter_ids_length, filters, true); - // Order of `fields` are used to sort results //auto begin = std::chrono::high_resolution_clock::now(); uint32_t* all_result_ids = nullptr; @@ -2305,8 +2349,8 @@ void Index::search(std::vector& field_query_tokens, const std::v std::vector facet_infos(facets.size()); compute_facet_infos(facets, facet_query, facet_query_num_typos, - &included_ids[0], included_ids.size(), group_by_fields, facet_infos); - do_facets(facets, facet_query, facet_infos, group_limit, group_by_fields, &included_ids[0], included_ids.size()); + &included_ids_vec[0], included_ids_vec.size(), group_by_fields, facet_infos); + do_facets(facets, facet_query, facet_infos, group_limit, group_by_fields, &included_ids_vec[0], included_ids_vec.size()); all_result_ids_len += curated_topster->size; diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index debb0a91..d2d5f2a3 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -425,8 +425,45 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { ASSERT_STREQ("16", results["hits"][3]["document"]["id"].get().c_str()); ASSERT_STREQ("6", results["hits"][4]["document"]["id"].get().c_str()); + // pinning + filtering + results = coll_mul_fields->search("of", {"title"}, "points:>58", {}, {}, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + pinned_hits, {}).get(); + + ASSERT_EQ(5, results["found"].get()); + ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_STREQ("12", results["hits"][3]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][4]["document"]["id"].get().c_str()); + + // pinning + filtering with filter_curated_hits: true + pinned_hits = "14:1,4:2"; + + results = coll_mul_fields->search("of", {"title"}, "points:>58", {}, {}, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, pinned_hits, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true, + 4, {off}, 32767, 32767, 2, true).get(); + + ASSERT_EQ(4, results["found"].get()); + ASSERT_STREQ("14", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("11", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("12", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][3]["document"]["id"].get().c_str()); + + ASSERT_EQ("The Silence of the Lambs", results["hits"][1]["highlights"][0]["snippet"].get()); + ASSERT_EQ("Confessions of a Shopaholic", results["hits"][2]["highlights"][0]["snippet"].get()); + ASSERT_EQ("Percy Jackson: Sea of Monsters", results["hits"][3]["highlights"][0]["snippet"].get()); + // both pinning and hiding + pinned_hits = "13:1,4:2"; std::string hidden_hits="11,16"; results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, {0}, 50, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD,