Make exact match behavior configurable.

This commit is contained in:
Kishore Nallan 2021-05-13 14:36:57 +05:30
parent 09e2e62312
commit 529bb55c5c
8 changed files with 125 additions and 50 deletions

View File

@ -530,7 +530,8 @@ public:
const std::string& highlight_start_tag="<mark>",
const std::string& highlight_end_tag="</mark>",
std::vector<size_t> query_by_weights={},
size_t limit_hits=UINT32_MAX) const;
size_t limit_hits=UINT32_MAX,
bool prioritize_exact_match=true) const;
Option<bool> get_filter_ids(const std::string & simple_filter_query,
std::vector<std::pair<size_t, uint32_t*>>& index_ids);

View File

@ -63,6 +63,7 @@ struct search_args {
std::vector<std::string> group_by_fields;
size_t group_limit;
std::string default_sorting_field;
bool prioritize_exact_match;
size_t all_result_ids_len;
spp::sparse_hash_set<uint64_t> groups_processed;
std::vector<std::vector<art_leaf*>> searched_queries;
@ -82,7 +83,8 @@ struct search_args {
size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix,
size_t drop_tokens_threshold, size_t typo_tokens_threshold,
const std::vector<std::string>& group_by_fields, size_t group_limit,
const std::string& default_sorting_field):
const std::string& default_sorting_field,
bool prioritize_exact_match):
field_query_tokens(field_query_tokens),
search_fields(search_fields), filters(filters), facets(facets),
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
@ -90,7 +92,7 @@ struct search_args {
page(page), token_order(token_order), prefix(prefix),
drop_tokens_threshold(drop_tokens_threshold), typo_tokens_threshold(typo_tokens_threshold),
group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field),
all_result_ids_len(0) {
prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0) {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
topster = new Topster(topster_size, group_limit);
@ -211,6 +213,7 @@ private:
size_t& field_num_results,
const size_t group_limit,
const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,
const token_ordering token_order = FREQUENCY, const bool prefix = false,
const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD,
const size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD) const;
@ -227,7 +230,8 @@ private:
size_t& field_num_results,
const size_t typo_tokens_threshold,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
const std::vector<token_t>& query_tokens) const;
const std::vector<token_t>& query_tokens,
bool prioritize_exact_match) const;
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;
@ -302,12 +306,12 @@ public:
static void concat_topster_ids(Topster* topster, spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids);
void score_results(const std::vector<sort_by> & sort_fields, const uint16_t & query_index, const uint8_t & field_id,
const uint32_t total_cost, Topster* topster, const std::vector<art_leaf *> & query_suggestion,
spp::sparse_hash_set<uint64_t>& groups_processed,
const uint32_t *result_ids, const size_t result_size,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
uint32_t token_bits, const std::vector<token_t>& query_tokens) const;
void score_results(const std::vector<sort_by> &sort_fields, const uint16_t &query_index, const uint8_t &field_id,
const uint32_t total_cost, Topster *topster, const std::vector<art_leaf *> &query_suggestion,
spp::sparse_hash_set<uint64_t> &groups_processed, const uint32_t *result_ids,
const size_t result_size, const size_t group_limit,
const std::vector<std::string> &group_by_fields, uint32_t token_bits,
const std::vector<token_t> &query_tokens, bool prioritize_exact_match) const;
static int64_t get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field);
@ -353,7 +357,8 @@ public:
const size_t typo_tokens_threshold,
const size_t group_limit,
const std::vector<std::string>& group_by_fields,
const std::string& default_sorting_field) const;
const std::string& default_sorting_field,
bool prioritize_exact_match) const;
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document, const bool is_update);

View File

@ -123,7 +123,8 @@ struct Match {
Until queue size is 1.
*/
Match(uint32_t doc_id, const std::vector<token_positions_t>& token_offsets, bool populate_window=true) {
Match(uint32_t doc_id, const std::vector<token_positions_t>& token_offsets,
bool populate_window=true, bool check_exact_match=false) {
// in case if number of tokens in query is greater than max window
const size_t tokens_size = std::min(token_offsets.size(), WINDOW_SIZE);
@ -216,23 +217,27 @@ struct Match {
offsets = best_window;
}
int last_token_index = -1;
size_t total_offsets = 0;
exact_match = 0;
for(const auto& token_positions: token_offsets) {
if(token_positions.last_token && !token_positions.positions.empty()) {
last_token_index = token_positions.positions.back();
}
total_offsets += token_positions.positions.size();
if(total_offsets > token_offsets.size()) {
break;
}
}
if(check_exact_match) {
int last_token_index = -1;
size_t total_offsets = 0;
if(last_token_index == int(token_offsets.size())-1 &&
total_offsets == token_offsets.size() && distance == token_offsets.size()-1) {
exact_match = 1;
for(const auto& token_positions: token_offsets) {
if(token_positions.last_token && !token_positions.positions.empty()) {
last_token_index = token_positions.positions.back();
}
total_offsets += token_positions.positions.size();
if(total_offsets > token_offsets.size()) {
// if total offsets exceed query length, there cannot possibly be an exact match
return;
}
}
if(last_token_index == int(token_offsets.size())-1 &&
total_offsets == token_offsets.size() && distance == token_offsets.size()-1) {
exact_match = 1;
}
}
}
};

View File

@ -503,7 +503,8 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
const std::string& highlight_start_tag,
const std::string& highlight_end_tag,
std::vector<size_t> query_by_weights,
size_t limit_hits) const {
size_t limit_hits,
bool prioritize_exact_match) const {
std::shared_lock lock(mutex);
@ -857,7 +858,7 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
sort_fields_std, facet_query, num_typos, max_facet_values, max_hits,
per_page, page, token_order, prefix,
drop_tokens_threshold, typo_tokens_threshold,
group_by_fields, group_limit, default_sorting_field);
group_by_fields, group_limit, default_sorting_field, prioritize_exact_match);
search_args_vec.push_back(search_params);
@ -1465,7 +1466,7 @@ void Collection::highlight_result(const field &search_field,
continue;
}
const Match & this_match = Match(field_order_kv->key, token_positions);
const Match & this_match = Match(field_order_kv->key, token_positions, true, true);
uint64_t this_match_score = this_match.get_match_score(1);
match_indices.emplace_back(this_match, this_match_score, array_index);

View File

@ -498,6 +498,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *HIGHLIGHT_START_TAG = "highlight_start_tag";
const char *HIGHLIGHT_END_TAG = "highlight_end_tag";
const char *PRIORITIZE_EXACT_MATCH = "prioritize_exact_match";
if(req_params.count(NUM_TYPOS) == 0) {
req_params[NUM_TYPOS] = "2";
}
@ -583,6 +585,10 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
}
}
if(req_params.count(PRIORITIZE_EXACT_MATCH) == 0) {
req_params[PRIORITIZE_EXACT_MATCH] = "true";
}
std::vector<std::string> query_by_weights_str;
std::vector<size_t> query_by_weights;
@ -638,6 +644,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
return Option<bool>(400,"Parameter `" + std::string(GROUP_LIMIT) + "` must be an unsigned integer.");
}
bool prioritize_exact_match = (req_params[PRIORITIZE_EXACT_MATCH] == "true");
std::string filter_str = req_params.count(FILTER) != 0 ? req_params[FILTER] : "";
std::vector<std::string> search_fields;
@ -718,7 +726,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
req_params[HIGHLIGHT_START_TAG],
req_params[HIGHLIGHT_END_TAG],
query_by_weights,
static_cast<size_t>(std::stol(req_params[LIMIT_HITS]))
static_cast<size_t>(std::stol(req_params[LIMIT_HITS])),
prioritize_exact_match
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@ -885,7 +885,8 @@ void Index::search_candidates(const uint8_t & field_id,
const size_t typo_tokens_threshold,
const size_t group_limit,
const std::vector<std::string>& group_by_fields,
const std::vector<token_t>& query_tokens) const {
const std::vector<token_t>& query_tokens,
bool prioritize_exact_match) const {
const long long combination_limit = 10;
@ -969,9 +970,10 @@ void Index::search_candidates(const uint8_t & field_id,
*all_result_ids = new_all_result_ids;
// go through each matching document id and calculate match score
score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion,
score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster,
query_suggestion,
groups_processed, filtered_result_ids, filtered_results_size,
group_limit, group_by_fields, token_bits, query_tokens);
group_limit, group_by_fields, token_bits, query_tokens, prioritize_exact_match);
field_num_results += filtered_results_size;
@ -988,8 +990,10 @@ void Index::search_candidates(const uint8_t & field_id,
LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size;
}*/
score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion,
groups_processed, result_ids, result_size, group_limit, group_by_fields, token_bits, query_tokens);
score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster,
query_suggestion,
groups_processed, result_ids, result_size, group_limit, group_by_fields, token_bits,
query_tokens, prioritize_exact_match);
field_num_results += result_size;
@ -1382,7 +1386,8 @@ void Index::run_search(search_args* search_params) {
search_params->raw_result_kvs, search_params->override_result_kvs,
search_params->typo_tokens_threshold,
search_params->group_limit, search_params->group_by_fields,
search_params->default_sorting_field);
search_params->default_sorting_field,
search_params->prioritize_exact_match);
}
void Index::collate_included_ids(const std::vector<std::string>& q_included_tokens,
@ -1473,7 +1478,8 @@ void Index::search(const std::vector<query_tokens_t>& field_query_tokens,
const size_t typo_tokens_threshold,
const size_t group_limit,
const std::vector<std::string>& group_by_fields,
const std::string& default_sorting_field) const {
const std::string& default_sorting_field,
bool prioritize_exact_match) const {
std::shared_lock lock(mutex);
@ -1575,7 +1581,8 @@ void Index::search(const std::vector<query_tokens_t>& field_query_tokens,
uint32_t token_bits = 255;
score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {},
groups_processed, filter_ids, filter_ids_length, group_limit, group_by_fields, token_bits, {});
groups_processed, filter_ids, filter_ids_length, group_limit, group_by_fields, token_bits, {},
prioritize_exact_match);
collate_included_ids(field_query_tokens[0].q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries);
all_result_ids_len = filter_ids_length;
@ -1625,7 +1632,7 @@ void Index::search(const std::vector<query_tokens_t>& field_query_tokens,
search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped,
field_name, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
field_num_results, group_limit, group_by_fields, token_order, prefix,
field_num_results, group_limit, group_by_fields, prioritize_exact_match, token_order, prefix,
drop_tokens_threshold, typo_tokens_threshold);
// do synonym based searches
@ -1637,7 +1644,7 @@ void Index::search(const std::vector<query_tokens_t>& field_query_tokens,
search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped,
field_name, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
field_num_results, group_limit, group_by_fields, token_order, prefix,
field_num_results, group_limit, group_by_fields, prioritize_exact_match, token_order, prefix,
drop_tokens_threshold, typo_tokens_threshold);
}
@ -1812,7 +1819,8 @@ void Index::search_field(const uint8_t & field_id,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
uint32_t** all_result_ids, size_t & all_result_ids_len, size_t& field_num_results,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
const token_ordering token_order, const bool prefix,
bool prioritize_exact_match,
const token_ordering token_order, const bool prefix,
const size_t drop_tokens_threshold, const size_t typo_tokens_threshold) const {
size_t max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos;
@ -1920,7 +1928,7 @@ void Index::search_field(const uint8_t & field_id,
search_candidates(field_id, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size,
curated_ids, sort_fields, token_candidates_vec, searched_queries, topster,
groups_processed, all_result_ids, all_result_ids_len, field_num_results,
typo_tokens_threshold, group_limit, group_by_fields, query_tokens);
typo_tokens_threshold, group_limit, group_by_fields, query_tokens, prioritize_exact_match);
}
resume_typo_loop:
@ -1958,7 +1966,7 @@ void Index::search_field(const uint8_t & field_id,
return search_field(field_id, query_tokens, truncated_tokens, exclude_token_ids, exclude_token_ids_size,
num_tokens_dropped, field, filter_ids, filter_ids_length, curated_ids,facets,
sort_fields, num_typos,searched_queries, topster, groups_processed, all_result_ids,
all_result_ids_len, field_num_results, group_limit, group_by_fields,
all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match,
token_order, prefix);
}
}
@ -1991,7 +1999,8 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
const uint32_t *result_ids, const size_t result_size,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
uint32_t token_bits,
const std::vector<token_t>& query_tokens) const {
const std::vector<token_t>& query_tokens,
bool prioritize_exact_match) const {
int sort_order[3]; // 1 or -1 based on DESC or ASC respectively
spp::sparse_hash_map<uint32_t, int64_t>* field_values[3];
@ -2074,7 +2083,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
if (token_positions.empty()) {
continue;
}
const Match &match = Match(seq_id, token_positions, false);
const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match);
uint64_t this_match_score = match.get_match_score(total_cost);
match_score += this_match_score;

View File

@ -3126,6 +3126,19 @@ TEST_F(CollectionTest, MultiFieldRelevance6) {
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get<std::string>().c_str());
// when exact matches are disabled
results = coll1->search("taylor swift",
{"title", "artist"}, "", {}, {}, 2, 10, 1, FREQUENCY,
true, 10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 40, {}, {}, {}, 0,
"<mark>", "</mark>", {1, 1}, 100, false).get();
ASSERT_EQ(2, results["found"].get<size_t>());
ASSERT_EQ(2, results["hits"].size());
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}
@ -3170,8 +3183,6 @@ TEST_F(CollectionTest, ExactMatch) {
results = coll1->search("alpha", {"title"}, "", {}, {}, 2, 10, 1, FREQUENCY, true, 10).get();
LOG(INFO) << results;
ASSERT_EQ(3, results["found"].get<size_t>());
ASSERT_EQ(3, results["hits"].size());

View File

@ -47,12 +47,13 @@ TEST(MatchTest, MatchScoreV2) {
token_offsets.clear();
token_offsets.push_back(token_positions_t{false, {38, 50, 170, 187, 195, 222}});
token_offsets.push_back(token_positions_t{false, {39, 140, 171, 189, 223}});
token_offsets.push_back(token_positions_t{true, {39, 140, 171, 189, 223}});
token_offsets.push_back(token_positions_t{false, {169, 180}});
match = Match(100, token_offsets, true);
match = Match(100, token_offsets, true, true);
ASSERT_EQ(3, match.words_present);
ASSERT_EQ(2, match.distance);
ASSERT_EQ(0, match.exact_match);
expected_offsets = {170, 171, 169};
for(size_t i=0; i<token_offsets.size(); i++) {
@ -62,11 +63,12 @@ TEST(MatchTest, MatchScoreV2) {
token_offsets.clear();
token_offsets.push_back(token_positions_t{false, {38, 50, 187, 195, 201}});
token_offsets.push_back(token_positions_t{false, {120, 167, 171, 223}});
token_offsets.push_back(token_positions_t{false, {240, 250}});
token_offsets.push_back(token_positions_t{true, {240, 250}});
match = Match(100, token_offsets, true);
ASSERT_EQ(1, match.words_present);
ASSERT_EQ(0, match.distance);
ASSERT_EQ(0, match.exact_match);
expected_offsets = {38, MAX_DISPLACEMENT, MAX_DISPLACEMENT};
for(size_t i=0; i<token_offsets.size(); i++) {
@ -78,7 +80,39 @@ TEST(MatchTest, MatchScoreV2) {
ASSERT_EQ(1, match.words_present);
ASSERT_EQ(0, match.distance);
ASSERT_EQ(0, match.offsets.size());
ASSERT_EQ(0, match.exact_match);
// exact match
token_offsets.clear();
token_offsets.push_back(token_positions_t{false, {0}});
token_offsets.push_back(token_positions_t{true, {2}});
token_offsets.push_back(token_positions_t{false, {1}});
match = Match(100, token_offsets, true, true);
ASSERT_EQ(3, match.words_present);
ASSERT_EQ(2, match.distance);
ASSERT_EQ(1, match.exact_match);
match = Match(100, token_offsets, true, false);
ASSERT_EQ(3, match.words_present);
ASSERT_EQ(2, match.distance);
ASSERT_EQ(0, match.exact_match);
token_offsets.clear();
token_offsets.push_back(token_positions_t{false, {1}});
token_offsets.push_back(token_positions_t{false, {2}});
token_offsets.push_back(token_positions_t{true, {3}});
match = Match(100, token_offsets, true, true);
ASSERT_EQ(0, match.exact_match);
token_offsets.clear();
token_offsets.push_back(token_positions_t{false, {0}});
token_offsets.push_back(token_positions_t{false, {1}});
token_offsets.push_back(token_positions_t{false, {2}});
match = Match(100, token_offsets, true, true);
ASSERT_EQ(0, match.exact_match);
/*size_t total_distance = 0, words_present = 0, offset_sum = 0;
auto begin = std::chrono::high_resolution_clock::now();