From 529bb55c5c1080eb2d3fe918de7d43b0cc65d470 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 13 May 2021 14:36:57 +0530 Subject: [PATCH] Make exact match behavior configurable. --- include/collection.h | 3 ++- include/index.h | 25 ++++++++++++++---------- include/match_score.h | 35 +++++++++++++++++++-------------- src/collection.cpp | 7 ++++--- src/collection_manager.cpp | 11 ++++++++++- src/index.cpp | 39 +++++++++++++++++++++++-------------- test/collection_test.cpp | 15 ++++++++++++-- test/match_score_test.cpp | 40 +++++++++++++++++++++++++++++++++++--- 8 files changed, 125 insertions(+), 50 deletions(-) diff --git a/include/collection.h b/include/collection.h index b12e7ebc..d56bf68e 100644 --- a/include/collection.h +++ b/include/collection.h @@ -530,7 +530,8 @@ public: const std::string& highlight_start_tag="", const std::string& highlight_end_tag="", std::vector query_by_weights={}, - size_t limit_hits=UINT32_MAX) const; + size_t limit_hits=UINT32_MAX, + bool prioritize_exact_match=true) const; Option get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids); diff --git a/include/index.h b/include/index.h index 300414a4..25b69a6a 100644 --- a/include/index.h +++ b/include/index.h @@ -63,6 +63,7 @@ struct search_args { std::vector group_by_fields; size_t group_limit; std::string default_sorting_field; + bool prioritize_exact_match; size_t all_result_ids_len; spp::sparse_hash_set groups_processed; std::vector> searched_queries; @@ -82,7 +83,8 @@ struct search_args { size_t max_hits, size_t per_page, size_t page, token_ordering token_order, bool prefix, size_t drop_tokens_threshold, size_t typo_tokens_threshold, const std::vector& group_by_fields, size_t group_limit, - const std::string& default_sorting_field): + const std::string& default_sorting_field, + bool prioritize_exact_match): field_query_tokens(field_query_tokens), search_fields(search_fields), filters(filters), facets(facets), included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std), @@ -90,7 +92,7 @@ struct search_args { page(page), token_order(token_order), prefix(prefix), drop_tokens_threshold(drop_tokens_threshold), typo_tokens_threshold(typo_tokens_threshold), group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field), - all_result_ids_len(0) { + prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0) { const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory topster = new Topster(topster_size, group_limit); @@ -211,6 +213,7 @@ private: size_t& field_num_results, const size_t group_limit, const std::vector& group_by_fields, + bool prioritize_exact_match, const token_ordering token_order = FREQUENCY, const bool prefix = false, const size_t drop_tokens_threshold = Index::DROP_TOKENS_THRESHOLD, const size_t typo_tokens_threshold = Index::TYPO_TOKENS_THRESHOLD) const; @@ -227,7 +230,8 @@ private: size_t& field_num_results, const size_t typo_tokens_threshold, const size_t group_limit, const std::vector& group_by_fields, - const std::vector& query_tokens) const; + const std::vector& query_tokens, + bool prioritize_exact_match) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; @@ -302,12 +306,12 @@ public: static void concat_topster_ids(Topster* topster, spp::sparse_hash_map>& topster_ids); - void score_results(const std::vector & sort_fields, const uint16_t & query_index, const uint8_t & field_id, - const uint32_t total_cost, Topster* topster, const std::vector & query_suggestion, - spp::sparse_hash_set& groups_processed, - const uint32_t *result_ids, const size_t result_size, - const size_t group_limit, const std::vector& group_by_fields, - uint32_t token_bits, const std::vector& query_tokens) const; + void score_results(const std::vector &sort_fields, const uint16_t &query_index, const uint8_t &field_id, + const uint32_t total_cost, Topster *topster, const std::vector &query_suggestion, + spp::sparse_hash_set &groups_processed, const uint32_t *result_ids, + const size_t result_size, const size_t group_limit, + const std::vector &group_by_fields, uint32_t token_bits, + const std::vector &query_tokens, bool prioritize_exact_match) const; static int64_t get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field); @@ -353,7 +357,8 @@ public: const size_t typo_tokens_threshold, const size_t group_limit, const std::vector& group_by_fields, - const std::string& default_sorting_field) const; + const std::string& default_sorting_field, + bool prioritize_exact_match) const; Option remove(const uint32_t seq_id, const nlohmann::json & document, const bool is_update); diff --git a/include/match_score.h b/include/match_score.h index b89523cc..a07d9704 100644 --- a/include/match_score.h +++ b/include/match_score.h @@ -123,7 +123,8 @@ struct Match { Until queue size is 1. */ - Match(uint32_t doc_id, const std::vector& token_offsets, bool populate_window=true) { + Match(uint32_t doc_id, const std::vector& token_offsets, + bool populate_window=true, bool check_exact_match=false) { // in case if number of tokens in query is greater than max window const size_t tokens_size = std::min(token_offsets.size(), WINDOW_SIZE); @@ -216,23 +217,27 @@ struct Match { offsets = best_window; } - int last_token_index = -1; - size_t total_offsets = 0; exact_match = 0; - for(const auto& token_positions: token_offsets) { - if(token_positions.last_token && !token_positions.positions.empty()) { - last_token_index = token_positions.positions.back(); - } - total_offsets += token_positions.positions.size(); - if(total_offsets > token_offsets.size()) { - break; - } - } + if(check_exact_match) { + int last_token_index = -1; + size_t total_offsets = 0; - if(last_token_index == int(token_offsets.size())-1 && - total_offsets == token_offsets.size() && distance == token_offsets.size()-1) { - exact_match = 1; + for(const auto& token_positions: token_offsets) { + if(token_positions.last_token && !token_positions.positions.empty()) { + last_token_index = token_positions.positions.back(); + } + total_offsets += token_positions.positions.size(); + if(total_offsets > token_offsets.size()) { + // if total offsets exceed query length, there cannot possibly be an exact match + return; + } + } + + if(last_token_index == int(token_offsets.size())-1 && + total_offsets == token_offsets.size() && distance == token_offsets.size()-1) { + exact_match = 1; + } } } }; diff --git a/src/collection.cpp b/src/collection.cpp index fe920256..1d7491d9 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -503,7 +503,8 @@ Option Collection::search(const std::string & query, const std:: const std::string& highlight_start_tag, const std::string& highlight_end_tag, std::vector query_by_weights, - size_t limit_hits) const { + size_t limit_hits, + bool prioritize_exact_match) const { std::shared_lock lock(mutex); @@ -857,7 +858,7 @@ Option Collection::search(const std::string & query, const std:: sort_fields_std, facet_query, num_typos, max_facet_values, max_hits, per_page, page, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold, - group_by_fields, group_limit, default_sorting_field); + group_by_fields, group_limit, default_sorting_field, prioritize_exact_match); search_args_vec.push_back(search_params); @@ -1465,7 +1466,7 @@ void Collection::highlight_result(const field &search_field, continue; } - const Match & this_match = Match(field_order_kv->key, token_positions); + const Match & this_match = Match(field_order_kv->key, token_positions, true, true); uint64_t this_match_score = this_match.get_match_score(1); match_indices.emplace_back(this_match, this_match_score, array_index); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 33cf12ec..81c88e2c 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -498,6 +498,8 @@ Option CollectionManager::do_search(std::map& re const char *HIGHLIGHT_START_TAG = "highlight_start_tag"; const char *HIGHLIGHT_END_TAG = "highlight_end_tag"; + const char *PRIORITIZE_EXACT_MATCH = "prioritize_exact_match"; + if(req_params.count(NUM_TYPOS) == 0) { req_params[NUM_TYPOS] = "2"; } @@ -583,6 +585,10 @@ Option CollectionManager::do_search(std::map& re } } + if(req_params.count(PRIORITIZE_EXACT_MATCH) == 0) { + req_params[PRIORITIZE_EXACT_MATCH] = "true"; + } + std::vector query_by_weights_str; std::vector query_by_weights; @@ -638,6 +644,8 @@ Option CollectionManager::do_search(std::map& re return Option(400,"Parameter `" + std::string(GROUP_LIMIT) + "` must be an unsigned integer."); } + bool prioritize_exact_match = (req_params[PRIORITIZE_EXACT_MATCH] == "true"); + std::string filter_str = req_params.count(FILTER) != 0 ? req_params[FILTER] : ""; std::vector search_fields; @@ -718,7 +726,8 @@ Option CollectionManager::do_search(std::map& re req_params[HIGHLIGHT_START_TAG], req_params[HIGHLIGHT_END_TAG], query_by_weights, - static_cast(std::stol(req_params[LIMIT_HITS])) + static_cast(std::stol(req_params[LIMIT_HITS])), + prioritize_exact_match ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/index.cpp b/src/index.cpp index fbc6cf84..8c77988d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -885,7 +885,8 @@ void Index::search_candidates(const uint8_t & field_id, const size_t typo_tokens_threshold, const size_t group_limit, const std::vector& group_by_fields, - const std::vector& query_tokens) const { + const std::vector& query_tokens, + bool prioritize_exact_match) const { const long long combination_limit = 10; @@ -969,9 +970,10 @@ void Index::search_candidates(const uint8_t & field_id, *all_result_ids = new_all_result_ids; // go through each matching document id and calculate match score - score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion, + score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, + query_suggestion, groups_processed, filtered_result_ids, filtered_results_size, - group_limit, group_by_fields, token_bits, query_tokens); + group_limit, group_by_fields, token_bits, query_tokens, prioritize_exact_match); field_num_results += filtered_results_size; @@ -988,8 +990,10 @@ void Index::search_candidates(const uint8_t & field_id, LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size; }*/ - score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion, - groups_processed, result_ids, result_size, group_limit, group_by_fields, token_bits, query_tokens); + score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, + query_suggestion, + groups_processed, result_ids, result_size, group_limit, group_by_fields, token_bits, + query_tokens, prioritize_exact_match); field_num_results += result_size; @@ -1382,7 +1386,8 @@ void Index::run_search(search_args* search_params) { search_params->raw_result_kvs, search_params->override_result_kvs, search_params->typo_tokens_threshold, search_params->group_limit, search_params->group_by_fields, - search_params->default_sorting_field); + search_params->default_sorting_field, + search_params->prioritize_exact_match); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -1473,7 +1478,8 @@ void Index::search(const std::vector& field_query_tokens, const size_t typo_tokens_threshold, const size_t group_limit, const std::vector& group_by_fields, - const std::string& default_sorting_field) const { + const std::string& default_sorting_field, + bool prioritize_exact_match) const { std::shared_lock lock(mutex); @@ -1575,7 +1581,8 @@ void Index::search(const std::vector& field_query_tokens, uint32_t token_bits = 255; score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {}, - groups_processed, filter_ids, filter_ids_length, group_limit, group_by_fields, token_bits, {}); + groups_processed, filter_ids, filter_ids_length, group_limit, group_by_fields, token_bits, {}, + prioritize_exact_match); collate_included_ids(field_query_tokens[0].q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries); all_result_ids_len = filter_ids_length; @@ -1625,7 +1632,7 @@ void Index::search(const std::vector& field_query_tokens, search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, field_name, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, - field_num_results, group_limit, group_by_fields, token_order, prefix, + field_num_results, group_limit, group_by_fields, prioritize_exact_match, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold); // do synonym based searches @@ -1637,7 +1644,7 @@ void Index::search(const std::vector& field_query_tokens, search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, field_name, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std, num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, - field_num_results, group_limit, group_by_fields, token_order, prefix, + field_num_results, group_limit, group_by_fields, prioritize_exact_match, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold); } @@ -1812,7 +1819,8 @@ void Index::search_field(const uint8_t & field_id, Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, size_t & all_result_ids_len, size_t& field_num_results, const size_t group_limit, const std::vector& group_by_fields, - const token_ordering token_order, const bool prefix, + bool prioritize_exact_match, + const token_ordering token_order, const bool prefix, const size_t drop_tokens_threshold, const size_t typo_tokens_threshold) const { size_t max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos; @@ -1920,7 +1928,7 @@ void Index::search_field(const uint8_t & field_id, search_candidates(field_id, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size, curated_ids, sort_fields, token_candidates_vec, searched_queries, topster, groups_processed, all_result_ids, all_result_ids_len, field_num_results, - typo_tokens_threshold, group_limit, group_by_fields, query_tokens); + typo_tokens_threshold, group_limit, group_by_fields, query_tokens, prioritize_exact_match); } resume_typo_loop: @@ -1958,7 +1966,7 @@ void Index::search_field(const uint8_t & field_id, return search_field(field_id, query_tokens, truncated_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped, field, filter_ids, filter_ids_length, curated_ids,facets, sort_fields, num_typos,searched_queries, topster, groups_processed, all_result_ids, - all_result_ids_len, field_num_results, group_limit, group_by_fields, + all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, token_order, prefix); } } @@ -1991,7 +1999,8 @@ void Index::score_results(const std::vector & sort_fields, const uint16 const uint32_t *result_ids, const size_t result_size, const size_t group_limit, const std::vector& group_by_fields, uint32_t token_bits, - const std::vector& query_tokens) const { + const std::vector& query_tokens, + bool prioritize_exact_match) const { int sort_order[3]; // 1 or -1 based on DESC or ASC respectively spp::sparse_hash_map* field_values[3]; @@ -2074,7 +2083,7 @@ void Index::score_results(const std::vector & sort_fields, const uint16 if (token_positions.empty()) { continue; } - const Match &match = Match(seq_id, token_positions, false); + const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match); uint64_t this_match_score = match.get_match_score(total_cost); match_score += this_match_score; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index f826ba55..c66944cc 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -3126,6 +3126,19 @@ TEST_F(CollectionTest, MultiFieldRelevance6) { ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); + // when exact matches are disabled + results = coll1->search("taylor swift", + {"title", "artist"}, "", {}, {}, 2, 10, 1, FREQUENCY, + true, 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 40, {}, {}, {}, 0, + "", "", {1, 1}, 100, false).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + collectionManager.drop_collection("coll1"); } @@ -3170,8 +3183,6 @@ TEST_F(CollectionTest, ExactMatch) { results = coll1->search("alpha", {"title"}, "", {}, {}, 2, 10, 1, FREQUENCY, true, 10).get(); - LOG(INFO) << results; - ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); diff --git a/test/match_score_test.cpp b/test/match_score_test.cpp index 41459a87..9aeb3619 100644 --- a/test/match_score_test.cpp +++ b/test/match_score_test.cpp @@ -47,12 +47,13 @@ TEST(MatchTest, MatchScoreV2) { token_offsets.clear(); token_offsets.push_back(token_positions_t{false, {38, 50, 170, 187, 195, 222}}); - token_offsets.push_back(token_positions_t{false, {39, 140, 171, 189, 223}}); + token_offsets.push_back(token_positions_t{true, {39, 140, 171, 189, 223}}); token_offsets.push_back(token_positions_t{false, {169, 180}}); - match = Match(100, token_offsets, true); + match = Match(100, token_offsets, true, true); ASSERT_EQ(3, match.words_present); ASSERT_EQ(2, match.distance); + ASSERT_EQ(0, match.exact_match); expected_offsets = {170, 171, 169}; for(size_t i=0; i