Cover missing curated hits with hits below.

This commit is contained in:
Kishore Nallan 2022-04-15 15:38:05 +05:30
parent 21c31de3b8
commit 93c31be88f
3 changed files with 175 additions and 55 deletions

View File

@ -992,6 +992,14 @@ public:
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index) const;
void
process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids,
const size_t group_limit, const bool filter_curated_hits, const uint32_t* filter_ids,
uint32_t filter_ids_length, std::set<uint32_t>& curated_ids,
std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
std::vector<uint32_t>& included_ids_vec) const;
};
template<class T>

View File

@ -2163,61 +2163,11 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
return ;
}
std::vector<uint32_t> included_ids_vec;
for(const auto& seq_id_pos: included_ids) {
included_ids_vec.push_back(seq_id_pos.first);
}
std::sort(included_ids_vec.begin(), included_ids_vec.end());
std::map<size_t, std::map<size_t, uint32_t>> included_ids_map; // outer pos => inner pos => list of IDs
// if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition
std::set<uint32_t> included_ids_set;
if(filter_ids_length != 0 && filter_curated_hits) {
uint32_t* included_ids_arr = nullptr;
size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids,
filter_ids_length, &included_ids_arr);
included_ids_vec.clear();
for(size_t i = 0; i < included_ids_len; i++) {
included_ids_set.insert(included_ids_arr[i]);
included_ids_vec.push_back(included_ids_arr[i]);
}
delete [] included_ids_arr;
} else {
included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end());
}
std::map<size_t, std::vector<uint32_t>> included_ids_grouped;
for(const auto& seq_id_pos: included_ids) {
if(included_ids_set.count(seq_id_pos.first) == 0) {
continue;
}
included_ids_grouped[seq_id_pos.second].push_back(seq_id_pos.first);
}
for(const auto& pos_ids: included_ids_grouped) {
size_t outer_pos = pos_ids.first;
size_t ids_per_pos = std::max(size_t(1), group_limit);
for(size_t inner_pos = 0; inner_pos < std::min(ids_per_pos, pos_ids.second.size()); inner_pos++) {
auto seq_id = pos_ids.second[inner_pos];
included_ids_map[outer_pos][inner_pos] = seq_id;
}
}
std::set<uint32_t> curated_ids;
curated_ids.insert(excluded_ids.begin(), excluded_ids.end());
for(const auto& outer_pos_inner_pos_ids: included_ids_map) {
for(const auto& inner_pos_ids: outer_pos_inner_pos_ids.second) {
curated_ids.insert(inner_pos_ids.second);
}
}
std::map<size_t, std::map<size_t, uint32_t>> included_ids_map; // outer pos => inner pos => list of IDs
std::vector<uint32_t> included_ids_vec;
process_curated_ids(included_ids, excluded_ids, group_limit, filter_curated_hits,
filter_ids, filter_ids_length, curated_ids, included_ids_map, included_ids_vec);
std::vector<uint32_t> curated_ids_sorted(curated_ids.begin(), curated_ids.end());
std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end());
@ -2540,6 +2490,85 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
//LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms";
}
void Index::process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const size_t group_limit,
const bool filter_curated_hits, const uint32_t* filter_ids, uint32_t filter_ids_length,
std::set<uint32_t>& curated_ids,
std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
std::vector<uint32_t>& included_ids_vec) const {
for(const auto& seq_id_pos: included_ids) {
included_ids_vec.push_back(seq_id_pos.first);
}
std::sort(included_ids_vec.begin(), included_ids_vec.end());
// if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition
std::set<uint32_t> included_ids_set;
if(filter_ids_length != 0 && filter_curated_hits) {
uint32_t* included_ids_arr = nullptr;
size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids,
filter_ids_length, &included_ids_arr);
included_ids_vec.clear();
for(size_t i = 0; i < included_ids_len; i++) {
included_ids_set.insert(included_ids_arr[i]);
included_ids_vec.push_back(included_ids_arr[i]);
}
delete [] included_ids_arr;
} else {
included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end());
}
std::map<size_t, std::vector<uint32_t>> included_ids_grouped; // pos -> seq_ids
std::vector<uint32_t> all_positions;
for(const auto& seq_id_pos: included_ids) {
all_positions.push_back(seq_id_pos.second);
if(included_ids_set.count(seq_id_pos.first) == 0) {
continue;
}
included_ids_grouped[seq_id_pos.second].push_back(seq_id_pos.first);
}
for(const auto& pos_ids: included_ids_grouped) {
size_t outer_pos = pos_ids.first;
size_t ids_per_pos = std::max(size_t(1), group_limit);
auto num_inner_ids = std::min(ids_per_pos, pos_ids.second.size());
for(size_t inner_pos = 0; inner_pos < num_inner_ids; inner_pos++) {
auto seq_id = pos_ids.second[inner_pos];
included_ids_map[outer_pos][inner_pos] = seq_id;
curated_ids.insert(seq_id);
}
}
curated_ids.insert(excluded_ids.begin(), excluded_ids.end());
if(all_positions.size() > included_ids_map.size()) {
// Some curated IDs may have been removed via filtering or simply don't exist.
// We have to shift lower placed hits upwards to fill those positions.
std::sort(all_positions.begin(), all_positions.end());
all_positions.erase(unique(all_positions.begin(), all_positions.end()), all_positions.end());
size_t pos_count = 0;
std::map<size_t, std::map<size_t, uint32_t>> new_included_ids_map;
auto included_id_it = included_ids_map.begin();
auto all_pos_it = all_positions.begin();
while(included_id_it != included_ids_map.end()) {
new_included_ids_map[*all_pos_it] = included_id_it->second;
all_pos_it++;
included_id_it++;
}
included_ids_map = new_included_ids_map;
}
}
void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const std::vector<token_t>& query_tokens,
const uint32_t* exclude_token_ids,

View File

@ -518,7 +518,90 @@ TEST_F(CollectionOverrideTest, ExcludeIncludeFacetFilterQuery) {
coll_mul_fields->remove_override("include-rule");
}
TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
TEST_F(CollectionOverrideTest, FilterCuratedHitsSlideToCoverMissingSlots) {
// when some of the curated hits are filtered away, lower ranked hits must be pulled up
nlohmann::json override_json_include = {
{"id", "include-rule"},
{
"rule", {
{"query", "scott"},
{"match", override_t::MATCH_EXACT}
}
}
};
// first 2 hits won't match the filter, 3rd position should float up to position 1
override_json_include["includes"] = nlohmann::json::array();
override_json_include["includes"][0] = nlohmann::json::object();
override_json_include["includes"][0]["id"] = "7";
override_json_include["includes"][0]["position"] = 1;
override_json_include["includes"][1] = nlohmann::json::object();
override_json_include["includes"][1]["id"] = "17";
override_json_include["includes"][1]["position"] = 2;
override_json_include["includes"][2] = nlohmann::json::object();
override_json_include["includes"][2]["id"] = "10";
override_json_include["includes"][2]["position"] = 3;
override_json_include["filter_curated_hits"] = true;
override_t override_include;
override_t::parse(override_json_include, "", override_include);
coll_mul_fields->add_override(override_include);
auto results = coll_mul_fields->search("scott", {"starring"}, "points:>55", {}, {}, {0}, 10, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "").get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("10", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("11", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("12", results["hits"][2]["document"]["id"].get<std::string>());
// another curation where there is an ID missing in the middle
override_json_include = {
{"id", "include-rule"},
{
"rule", {
{"query", "glenn"},
{"match", override_t::MATCH_EXACT}
}
}
};
// middle hit ("10") will not satisfy filter, so "11" will move to position 2
override_json_include["includes"] = nlohmann::json::array();
override_json_include["includes"][0] = nlohmann::json::object();
override_json_include["includes"][0]["id"] = "9";
override_json_include["includes"][0]["position"] = 1;
override_json_include["includes"][1] = nlohmann::json::object();
override_json_include["includes"][1]["id"] = "10";
override_json_include["includes"][1]["position"] = 2;
override_json_include["includes"][2] = nlohmann::json::object();
override_json_include["includes"][2]["id"] = "11";
override_json_include["includes"][2]["position"] = 3;
override_json_include["filter_curated_hits"] = true;
override_t override_include2;
override_t::parse(override_json_include, "", override_include2);
coll_mul_fields->add_override(override_include2);
results = coll_mul_fields->search("glenn", {"starring"}, "points:[43,86]", {}, {}, {0}, 10, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "").get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("9", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("11", results["hits"][1]["document"]["id"].get<std::string>());
}
TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) {
auto pinned_hits = "13:1,4:2";
// basic pinning