From 93c31be88f56599a4dbf554f45d46038de835d1e Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 15 Apr 2022 15:38:05 +0530 Subject: [PATCH] Cover missing curated hits with hits below. --- include/index.h | 8 ++ src/index.cpp | 137 ++++++++++++++++++------------ test/collection_override_test.cpp | 85 +++++++++++++++++- 3 files changed, 175 insertions(+), 55 deletions(-) diff --git a/include/index.h b/include/index.h index 535e4caa..a001dec2 100644 --- a/include/index.h +++ b/include/index.h @@ -992,6 +992,14 @@ public: const std::vector& geopoint_indices, uint32_t seq_id, int64_t max_field_match_score, int64_t* scores, int64_t& match_score_index) const; + + void + process_curated_ids(const std::vector>& included_ids, + const std::vector& excluded_ids, + const size_t group_limit, const bool filter_curated_hits, const uint32_t* filter_ids, + uint32_t filter_ids_length, std::set& curated_ids, + std::map>& included_ids_map, + std::vector& included_ids_vec) const; }; template diff --git a/src/index.cpp b/src/index.cpp index 5ec0bd95..c57d247d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2163,61 +2163,11 @@ void Index::search(std::vector& field_query_tokens, const std::v return ; } - std::vector included_ids_vec; - for(const auto& seq_id_pos: included_ids) { - included_ids_vec.push_back(seq_id_pos.first); - } - std::sort(included_ids_vec.begin(), included_ids_vec.end()); - - std::map> included_ids_map; // outer pos => inner pos => list of IDs - - // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition - std::set included_ids_set; - - if(filter_ids_length != 0 && filter_curated_hits) { - uint32_t* included_ids_arr = nullptr; - size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids, - filter_ids_length, &included_ids_arr); - - included_ids_vec.clear(); - - for(size_t i = 0; i < included_ids_len; i++) { - included_ids_set.insert(included_ids_arr[i]); - included_ids_vec.push_back(included_ids_arr[i]); - } - - delete [] included_ids_arr; - } else { - included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end()); - } - - std::map> included_ids_grouped; - - for(const auto& seq_id_pos: included_ids) { - if(included_ids_set.count(seq_id_pos.first) == 0) { - continue; - } - included_ids_grouped[seq_id_pos.second].push_back(seq_id_pos.first); - } - - for(const auto& pos_ids: included_ids_grouped) { - size_t outer_pos = pos_ids.first; - size_t ids_per_pos = std::max(size_t(1), group_limit); - - for(size_t inner_pos = 0; inner_pos < std::min(ids_per_pos, pos_ids.second.size()); inner_pos++) { - auto seq_id = pos_ids.second[inner_pos]; - included_ids_map[outer_pos][inner_pos] = seq_id; - } - } - std::set curated_ids; - curated_ids.insert(excluded_ids.begin(), excluded_ids.end()); - - for(const auto& outer_pos_inner_pos_ids: included_ids_map) { - for(const auto& inner_pos_ids: outer_pos_inner_pos_ids.second) { - curated_ids.insert(inner_pos_ids.second); - } - } + std::map> included_ids_map; // outer pos => inner pos => list of IDs + std::vector included_ids_vec; + process_curated_ids(included_ids, excluded_ids, group_limit, filter_curated_hits, + filter_ids, filter_ids_length, curated_ids, included_ids_map, included_ids_vec); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2540,6 +2490,85 @@ void Index::search(std::vector& field_query_tokens, const std::v //LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms"; } +void Index::process_curated_ids(const std::vector>& included_ids, + const std::vector& excluded_ids, const size_t group_limit, + const bool filter_curated_hits, const uint32_t* filter_ids, uint32_t filter_ids_length, + std::set& curated_ids, + std::map>& included_ids_map, + std::vector& included_ids_vec) const { + + for(const auto& seq_id_pos: included_ids) { + included_ids_vec.push_back(seq_id_pos.first); + } + std::sort(included_ids_vec.begin(), included_ids_vec.end()); + + // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition + std::set included_ids_set; + + if(filter_ids_length != 0 && filter_curated_hits) { + uint32_t* included_ids_arr = nullptr; + size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids, + filter_ids_length, &included_ids_arr); + + included_ids_vec.clear(); + + for(size_t i = 0; i < included_ids_len; i++) { + included_ids_set.insert(included_ids_arr[i]); + included_ids_vec.push_back(included_ids_arr[i]); + } + + delete [] included_ids_arr; + } else { + included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end()); + } + + std::map> included_ids_grouped; // pos -> seq_ids + std::vector all_positions; + + for(const auto& seq_id_pos: included_ids) { + all_positions.push_back(seq_id_pos.second); + if(included_ids_set.count(seq_id_pos.first) == 0) { + continue; + } + included_ids_grouped[seq_id_pos.second].push_back(seq_id_pos.first); + } + + + for(const auto& pos_ids: included_ids_grouped) { + size_t outer_pos = pos_ids.first; + size_t ids_per_pos = std::max(size_t(1), group_limit); + auto num_inner_ids = std::min(ids_per_pos, pos_ids.second.size()); + + for(size_t inner_pos = 0; inner_pos < num_inner_ids; inner_pos++) { + auto seq_id = pos_ids.second[inner_pos]; + included_ids_map[outer_pos][inner_pos] = seq_id; + curated_ids.insert(seq_id); + } + } + + curated_ids.insert(excluded_ids.begin(), excluded_ids.end()); + + if(all_positions.size() > included_ids_map.size()) { + // Some curated IDs may have been removed via filtering or simply don't exist. + // We have to shift lower placed hits upwards to fill those positions. + std::sort(all_positions.begin(), all_positions.end()); + all_positions.erase(unique(all_positions.begin(), all_positions.end()), all_positions.end()); + + size_t pos_count = 0; + std::map> new_included_ids_map; + auto included_id_it = included_ids_map.begin(); + auto all_pos_it = all_positions.begin(); + + while(included_id_it != included_ids_map.end()) { + new_included_ids_map[*all_pos_it] = included_id_it->second; + all_pos_it++; + included_id_it++; + } + + included_ids_map = new_included_ids_map; + } +} + void Index::fuzzy_search_fields(const std::vector& the_fields, const std::vector& query_tokens, const uint32_t* exclude_token_ids, diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 44fbe967..faad8572 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -518,7 +518,90 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) { coll_mul_fields->remove_override("include-rule"); } -TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) { +TEST_F(CollectionOverrideTest, FilterCuratedHitsSlideToCoverMissingSlots) { + // when some of the curated hits are filtered away, lower ranked hits must be pulled up + nlohmann::json override_json_include = { + {"id", "include-rule"}, + { + "rule", { + {"query", "scott"}, + {"match", override_t::MATCH_EXACT} + } + } + }; + + // first 2 hits won't match the filter, 3rd position should float up to position 1 + override_json_include["includes"] = nlohmann::json::array(); + override_json_include["includes"][0] = nlohmann::json::object(); + override_json_include["includes"][0]["id"] = "7"; + override_json_include["includes"][0]["position"] = 1; + + override_json_include["includes"][1] = nlohmann::json::object(); + override_json_include["includes"][1]["id"] = "17"; + override_json_include["includes"][1]["position"] = 2; + + override_json_include["includes"][2] = nlohmann::json::object(); + override_json_include["includes"][2]["id"] = "10"; + override_json_include["includes"][2]["position"] = 3; + + override_json_include["filter_curated_hits"] = true; + + override_t override_include; + override_t::parse(override_json_include, "", override_include); + coll_mul_fields->add_override(override_include); + + auto results = coll_mul_fields->search("scott", {"starring"}, "points:>55", {}, {}, {0}, 10, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "").get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("10", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("11", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("12", results["hits"][2]["document"]["id"].get()); + + // another curation where there is an ID missing in the middle + override_json_include = { + {"id", "include-rule"}, + { + "rule", { + {"query", "glenn"}, + {"match", override_t::MATCH_EXACT} + } + } + }; + + // middle hit ("10") will not satisfy filter, so "11" will move to position 2 + override_json_include["includes"] = nlohmann::json::array(); + override_json_include["includes"][0] = nlohmann::json::object(); + override_json_include["includes"][0]["id"] = "9"; + override_json_include["includes"][0]["position"] = 1; + + override_json_include["includes"][1] = nlohmann::json::object(); + override_json_include["includes"][1]["id"] = "10"; + override_json_include["includes"][1]["position"] = 2; + + override_json_include["includes"][2] = nlohmann::json::object(); + override_json_include["includes"][2]["id"] = "11"; + override_json_include["includes"][2]["position"] = 3; + + override_json_include["filter_curated_hits"] = true; + + override_t override_include2; + override_t::parse(override_json_include, "", override_include2); + coll_mul_fields->add_override(override_include2); + + results = coll_mul_fields->search("glenn", {"starring"}, "points:[43,86]", {}, {}, {0}, 10, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "").get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("9", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("11", results["hits"][1]["document"]["id"].get()); +} + +TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) { auto pinned_hits = "13:1,4:2"; // basic pinning