Add option to filter curated hits.

This commit is contained in:
Kishore Nallan 2022-03-11 20:52:22 +05:30
parent 671be25190
commit 9197627e81
6 changed files with 129 additions and 49 deletions

View File

@ -217,7 +217,7 @@ private:
void curate_results(string& actual_query, bool enable_overrides, bool already_segmented,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
std::vector<uint32_t>& excluded_ids, std::vector<const override_t*>& filter_overrides) const;
Option<bool> check_and_update_schema(nlohmann::json& document, const DIRTY_VALUES& dirty_values);
@ -407,7 +407,8 @@ public:
const std::vector<infix_t>& infixes = {off},
const size_t max_extra_prefix = INT16_MAX,
const size_t max_extra_suffix = INT16_MAX,
const size_t facet_query_num_typos = 2) const;
const size_t facet_query_num_typos = 2,
const bool filter_curated_hits = false) const;
Option<bool> get_filter_ids(const std::string & simple_filter_query,
std::vector<std::pair<size_t, uint32_t*>>& index_ids);

View File

@ -274,7 +274,7 @@ struct search_args {
std::vector<search_field_t> search_fields;
std::vector<filter> filters;
std::vector<facet>& facets;
std::map<size_t, std::map<size_t, uint32_t>> included_ids;
std::vector<std::pair<uint32_t, uint32_t>>& included_ids;
std::vector<uint32_t> excluded_ids;
std::vector<sort_by> sort_fields_std;
facet_query_t facet_query;
@ -302,6 +302,7 @@ struct search_args {
const size_t max_extra_prefix;
const size_t max_extra_suffix;
const size_t facet_query_num_typos;
const bool filter_curated_hits;
spp::sparse_hash_set<uint64_t> groups_processed;
std::vector<std::vector<art_leaf*>> searched_queries;
@ -312,7 +313,7 @@ struct search_args {
search_args(std::vector<query_tokens_t> field_query_tokens, std::vector<search_field_t> search_fields,
std::vector<filter> filters, std::vector<facet>& facets,
std::map<size_t, std::map<size_t, uint32_t>> included_ids, std::vector<uint32_t> excluded_ids,
std::vector<std::pair<uint32_t, uint32_t>>& included_ids, std::vector<uint32_t> excluded_ids,
std::vector<sort_by> sort_fields_std, facet_query_t facet_query, const std::vector<uint32_t>& num_typos,
size_t max_facet_values, size_t max_hits, size_t per_page, size_t page, token_ordering token_order,
const std::vector<bool>& prefixes, size_t drop_tokens_threshold, size_t typo_tokens_threshold,
@ -320,7 +321,8 @@ struct search_args {
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
size_t concurrency, const std::vector<const override_t*>& dynamic_overrides, size_t search_cutoff_ms,
size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector<infix_t>& infixes,
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos) :
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits) :
field_query_tokens(field_query_tokens),
search_fields(search_fields), filters(filters), facets(facets),
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
@ -333,7 +335,7 @@ struct search_args {
filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms),
min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates),
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix),
facet_query_num_typos(facet_query_num_typos) {
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits) {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
topster = new Topster(topster_size, group_limit);
@ -681,7 +683,7 @@ public:
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
const size_t per_page,
@ -694,7 +696,8 @@ public:
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
size_t max_candidates, const std::vector<infix_t>& infixes, const size_t max_extra_prefix,
const size_t max_extra_suffix, const size_t facet_query_num_typos) const;
const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits) const;
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document, const bool is_update);

View File

@ -398,7 +398,7 @@ void Collection::prune_document(nlohmann::json &document, const spp::sparse_hash
void Collection::curate_results(string& actual_query, bool enable_overrides, bool already_segmented,
const std::map<size_t, std::vector<std::string>>& pinned_hits,
const std::vector<std::string>& hidden_hits,
std::map<size_t, std::vector<uint32_t>>& include_ids,
std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
std::vector<uint32_t>& excluded_ids,
std::vector<const override_t*>& filter_overrides) const {
@ -452,7 +452,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
uint32_t seq_id = seq_id_op.get();
bool excluded = (excluded_set.count(seq_id) != 0);
if(!excluded) {
include_ids[hit.position].push_back(seq_id);
included_ids.emplace_back(seq_id, hit.position);
}
}
@ -480,7 +480,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
uint32_t seq_id = seq_id_op.get();
bool excluded = (excluded_set.count(seq_id) != 0);
if(!excluded) {
include_ids[pos].push_back(seq_id);
included_ids.emplace_back(seq_id, pos);
}
}
}
@ -698,7 +698,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
const std::vector<infix_t>& infixes,
const size_t max_extra_prefix,
const size_t max_extra_suffix,
const size_t facet_query_num_typos) const {
const size_t facet_query_num_typos,
const bool filter_curated_hits) const {
std::shared_lock lock(mutex);
@ -919,7 +920,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
spp::sparse_hash_set<uint64_t> groups_processed; // used to calculate total_found for grouped query
std::vector<uint32_t> excluded_ids;
std::map<size_t, std::vector<uint32_t>> include_ids; // position => list of IDs
std::vector<std::pair<uint32_t, uint32_t>> included_ids; // ID -> position
std::map<size_t, std::vector<std::string>> pinned_hits;
Option<bool> pinned_hits_op = parse_pinned_hits(pinned_hits_str, pinned_hits);
@ -934,9 +935,9 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
std::vector<const override_t*> filter_overrides;
std::string query = raw_query;
curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits,
include_ids, excluded_ids, filter_overrides);
included_ids, excluded_ids, filter_overrides);
/*for(auto& kv: include_ids) {
/*for(auto& kv: included_ids) {
LOG(INFO) << "key: " << kv.first;
for(auto val: kv.second) {
LOG(INFO) << val;
@ -949,8 +950,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
LOG(INFO) << id;
}
LOG(INFO) << "include_ids size: " << include_ids.size();
for(auto& group: include_ids) {
LOG(INFO) << "included_ids size: " << included_ids.size();
for(auto& group: included_ids) {
for(uint32_t& seq_id: group.second) {
LOG(INFO) << "seq_id: " << seq_id;
}
@ -959,19 +960,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
}
*/
std::map<size_t, std::map<size_t, uint32_t>> included_ids;
for(const auto& pos_ids: include_ids) {
size_t outer_pos = pos_ids.first;
size_t ids_per_pos = std::max(size_t(1), group_limit);
for(size_t inner_pos = 0; inner_pos < std::min(ids_per_pos, pos_ids.second.size()); inner_pos++) {
auto seq_id = pos_ids.second[inner_pos];
included_ids[outer_pos][inner_pos] = seq_id;
//LOG(INFO) << "Adding seq_id " << seq_id << " to index_id " << index_id;
}
}
//LOG(INFO) << "Num indices used for querying: " << indices.size();
std::vector<query_tokens_t> field_query_tokens;
std::vector<std::string> q_tokens; // used for auxillary highlighting
@ -1033,7 +1021,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
exhaustive_search, 4, filter_overrides,
search_stop_millis,
min_len_1typo, min_len_2typo, max_candidates, infixes,
max_extra_prefix, max_extra_suffix, facet_query_num_typos);
max_extra_prefix, max_extra_suffix, facet_query_num_typos,
filter_curated_hits);
index->run_search(search_params);

View File

@ -613,6 +613,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *PINNED_HITS = "pinned_hits";
const char *HIDDEN_HITS = "hidden_hits";
const char *ENABLE_OVERRIDES = "enable_overrides";
const char *FILTER_CURATED_HITS = "filter_curated_hits";
const char *MAX_CANDIDATES = "max_candidates";
const char *INFIX = "infix";
@ -703,6 +705,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
bool prioritize_exact_match = true;
bool pre_segmented_query = false;
bool enable_overrides = true;
bool filter_curated_hits = false;
std::string highlight_fields;
bool exhaustive_search = false;
size_t search_stop_millis;
@ -749,6 +752,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{EXHAUSTIVE_SEARCH, &exhaustive_search},
{SPLIT_JOIN_TOKENS, &split_join_tokens},
{ENABLE_OVERRIDES, &enable_overrides},
{FILTER_CURATED_HITS, &filter_curated_hits},
};
std::unordered_map<std::string, std::vector<std::string>*> str_list_values = {
@ -892,7 +896,9 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
max_candidates,
infixes,
max_extra_prefix,
max_extra_suffix
max_extra_suffix,
facet_query_num_typos,
filter_curated_hits
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -1653,7 +1653,8 @@ void Index::run_search(search_args* search_params) {
search_params->infixes,
search_params->max_extra_prefix,
search_params->max_extra_suffix,
search_params->facet_query_num_typos);
search_params->facet_query_num_typos,
search_params->filter_curated_hits);
}
void Index::collate_included_ids(const std::vector<std::string>& q_included_tokens,
@ -1666,7 +1667,7 @@ void Index::collate_included_ids(const std::vector<std::string>& q_included_toke
return;
}
// calculate match_score and add to topster independently
// created searched queries so that curated results can be highlighted
std::vector<art_leaf *> override_query;
@ -2084,7 +2085,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name
void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
const size_t per_page,
@ -2099,7 +2100,8 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
size_t max_candidates, const std::vector<infix_t>& infixes, const size_t max_extra_prefix,
const size_t max_extra_suffix, const size_t facet_query_num_typos) const {
const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits) const {
search_begin = std::chrono::high_resolution_clock::now();
search_stop_ms = search_cutoff_ms;
@ -2112,26 +2114,68 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
std::shared_lock lock(mutex);
// we will be removing all curated IDs from organic result ids before running topster
std::set<uint32_t> curated_ids;
std::vector<uint32_t> included_ids;
process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters);
do_filtering(filter_ids, filter_ids_length, filters, true);
for(const auto& outer_pos_ids: included_ids_map) {
for(const auto& inner_pos_seq_id: outer_pos_ids.second) {
curated_ids.insert(inner_pos_seq_id.second);
included_ids.push_back(inner_pos_seq_id.second);
std::vector<uint32_t> included_ids_vec;
for(const auto& seq_id_pos: included_ids) {
included_ids_vec.push_back(seq_id_pos.first);
}
std::sort(included_ids_vec.begin(), included_ids_vec.end());
std::map<size_t, std::map<size_t, uint32_t>> included_ids_map; // outer pos => inner pos => list of IDs
// if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition
std::set<uint32_t> included_ids_set;
if(filter_ids_length != 0 && filter_curated_hits) {
uint32_t* included_ids_arr = nullptr;
size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids,
filter_ids_length, &included_ids_arr);
included_ids_vec.clear();
for(size_t i = 0; i < included_ids_len; i++) {
included_ids_set.insert(included_ids_arr[i]);
included_ids_vec.push_back(included_ids_arr[i]);
}
delete [] included_ids_arr;
} else {
included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end());
}
std::map<size_t, std::vector<uint32_t>> included_ids_grouped;
for(const auto& seq_id_pos: included_ids) {
if(included_ids_set.count(seq_id_pos.first) == 0) {
continue;
}
included_ids_grouped[seq_id_pos.second].push_back(seq_id_pos.first);
}
for(const auto& pos_ids: included_ids_grouped) {
size_t outer_pos = pos_ids.first;
size_t ids_per_pos = std::max(size_t(1), group_limit);
for(size_t inner_pos = 0; inner_pos < std::min(ids_per_pos, pos_ids.second.size()); inner_pos++) {
auto seq_id = pos_ids.second[inner_pos];
included_ids_map[outer_pos][inner_pos] = seq_id;
}
}
std::set<uint32_t> curated_ids;
curated_ids.insert(excluded_ids.begin(), excluded_ids.end());
for(const auto& outer_pos_inner_pos_ids: included_ids_map) {
for(const auto& inner_pos_ids: outer_pos_inner_pos_ids.second) {
curated_ids.insert(inner_pos_ids.second);
}
}
std::vector<uint32_t> curated_ids_sorted(curated_ids.begin(), curated_ids.end());
std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end());
process_filter_overrides(filter_overrides, field_query_tokens, token_order, filters);
do_filtering(filter_ids, filter_ids_length, filters, true);
// Order of `fields` are used to sort results
//auto begin = std::chrono::high_resolution_clock::now();
uint32_t* all_result_ids = nullptr;
@ -2305,8 +2349,8 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
std::vector<facet_info_t> facet_infos(facets.size());
compute_facet_infos(facets, facet_query, facet_query_num_typos,
&included_ids[0], included_ids.size(), group_by_fields, facet_infos);
do_facets(facets, facet_query, facet_infos, group_limit, group_by_fields, &included_ids[0], included_ids.size());
&included_ids_vec[0], included_ids_vec.size(), group_by_fields, facet_infos);
do_facets(facets, facet_query, facet_infos, group_limit, group_by_fields, &included_ids_vec[0], included_ids_vec.size());
all_result_ids_len += curated_topster->size;

View File

@ -425,8 +425,45 @@ TEST_F(CollectionOverrideTest, IncludeExcludeHitsQuery) {
ASSERT_STREQ("16", results["hits"][3]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("6", results["hits"][4]["document"]["id"].get<std::string>().c_str());
// pinning + filtering
results = coll_mul_fields->search("of", {"title"}, "points:>58", {}, {}, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
pinned_hits, {}).get();
ASSERT_EQ(5, results["found"].get<size_t>());
ASSERT_STREQ("13", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("11", results["hits"][2]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("12", results["hits"][3]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("5", results["hits"][4]["document"]["id"].get<std::string>().c_str());
// pinning + filtering with filter_curated_hits: true
pinned_hits = "14:1,4:2";
results = coll_mul_fields->search("of", {"title"}, "points:>58", {}, {}, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, pinned_hits, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, true,
4, {off}, 32767, 32767, 2, true).get();
ASSERT_EQ(4, results["found"].get<size_t>());
ASSERT_STREQ("14", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("11", results["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("12", results["hits"][2]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("5", results["hits"][3]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ("The Silence <mark>of</mark> the Lambs", results["hits"][1]["highlights"][0]["snippet"].get<std::string>());
ASSERT_EQ("Confessions <mark>of</mark> a Shopaholic", results["hits"][2]["highlights"][0]["snippet"].get<std::string>());
ASSERT_EQ("Percy Jackson: Sea <mark>of</mark> Monsters", results["hits"][3]["highlights"][0]["snippet"].get<std::string>());
// both pinning and hiding
pinned_hits = "13:1,4:2";
std::string hidden_hits="11,16";
results = coll_mul_fields->search("the", {"title"}, "", {"starring"}, {}, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,