From 3f9544535cee8669b02b2ac54d3b0aedb6f2ca88 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 12 Apr 2022 08:04:43 +0530 Subject: [PATCH] Ensure that synonyms are ranked equally. --- include/index.h | 7 ++- src/index.cpp | 99 +++++++++++++++++-------------- test/collection_synonyms_test.cpp | 62 +++++++++++++++++-- 3 files changed, 115 insertions(+), 53 deletions(-) diff --git a/include/index.h b/include/index.h index 53f0b02a..4b8a9a31 100644 --- a/include/index.h +++ b/include/index.h @@ -706,6 +706,7 @@ public: const size_t group_limit, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool single_exact_query_token, + size_t num_query_tokens, int syn_orig_num_tokens, const std::vector& posting_lists) const; @@ -892,7 +893,8 @@ public: const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, Topster* actual_topster, - std::vector& field_query_tokens, + std::vector>& q_pos_synonyms, + int syn_orig_num_tokens, spp::sparse_hash_set& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, @@ -901,8 +903,7 @@ public: const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices, - tsl::htrie_map& qtoken_set, - std::vector>& all_queries) const; + tsl::htrie_map& qtoken_set) const; void do_phrase_search(const size_t num_search_fields, const std::vector& search_fields, std::vector& field_query_tokens, diff --git a/src/index.cpp b/src/index.cpp index 32cdb5fd..ac2d2ae9 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2261,12 +2261,42 @@ void Index::search(std::vector& field_query_tokens, const std::v // FIXME: needed? std::set query_hashes; + // resolve synonyms so that we can compute `syn_orig_num_tokens` + std::vector> all_queries = {field_query_tokens[0].q_include_tokens}; + std::vector> q_pos_synonyms; + std::vector q_include_tokens; + int syn_orig_num_tokens = -1; + + for(size_t j = 0; j < field_query_tokens[0].q_include_tokens.size(); j++) { + q_include_tokens.push_back(field_query_tokens[0].q_include_tokens[j].value); + } + synonym_index->synonym_reduction(q_include_tokens, field_query_tokens[0].q_synonyms); + + if(!field_query_tokens[0].q_synonyms.empty()) { + syn_orig_num_tokens = field_query_tokens[0].q_include_tokens.size(); + } + + for(const auto& q_syn_vec: field_query_tokens[0].q_synonyms) { + std::vector q_pos_syn; + for(size_t j=0; j < q_syn_vec.size(); j++) { + bool is_prefix = (j == q_syn_vec.size()-1); + q_pos_syn.emplace_back(j, q_syn_vec[j], is_prefix, q_syn_vec[j].size(), 0); + } + + q_pos_synonyms.push_back(q_pos_syn); + all_queries.push_back(q_pos_syn); + + if(q_syn_vec.size() > syn_orig_num_tokens) { + syn_orig_num_tokens = q_syn_vec.size(); + } + } + fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, excluded_result_ids, excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, - max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, + max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); // try split/joining tokens if no results are found @@ -2274,12 +2304,12 @@ void Index::search(std::vector& field_query_tokens, const std::v std::vector> space_resolved_queries; for(size_t i = 0; i < num_search_fields; i++) { - std::vector q_include_tokens; + std::vector orig_q_include_tokens; for(auto& q_include_token: field_query_tokens[i].q_include_tokens) { - q_include_tokens.push_back(q_include_token.value); + orig_q_include_tokens.push_back(q_include_token.value); } - resolve_space_as_typos(q_include_tokens, the_fields[i].name,space_resolved_queries); + resolve_space_as_typos(orig_q_include_tokens, the_fields[i].name,space_resolved_queries); if(!space_resolved_queries.empty()) { break; @@ -2302,28 +2332,26 @@ void Index::search(std::vector& field_query_tokens, const std::v sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, - max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); + max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); } } // do synonym based searches - std::vector> all_queries = {field_query_tokens[0].q_include_tokens}; - do_synonym_search(the_fields, filters, included_ids_map, sort_fields_std, curated_topster, token_order, 0, group_limit, group_by_fields, prioritize_exact_match, exhaustive_search, concurrency, min_len_1typo, min_len_2typo, max_candidates, curated_ids, curated_ids_sorted, - excluded_result_ids, excluded_result_ids_size, topster, field_query_tokens, + excluded_result_ids, excluded_result_ids_size, topster, q_pos_synonyms, syn_orig_num_tokens, groups_processed, searched_queries, all_result_ids, all_result_ids_len, filter_ids, filter_ids_length, query_hashes, sort_order, field_values, geopoint_indices, - qtoken_set, all_queries); + qtoken_set); + + // gather up both original query and synonym queries and do drop tokens if(all_result_ids_len < drop_tokens_threshold) { - // gather up both original query and synonym queries and do drop tokens for(size_t qi = 0; qi < all_queries.size(); qi++) { auto& orig_tokens = all_queries[qi]; size_t num_tokens_dropped = 0; - int syn_orig_num_token = (qi == 0) ? -1 : all_queries[0].size(); while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) { // When atleast two tokens from the query are available we can drop one @@ -2360,7 +2388,7 @@ void Index::search(std::vector& field_query_tokens, const std::v all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, - min_len_2typo, syn_orig_num_token, sort_order, field_values, geopoint_indices); + min_len_2typo, -1, sort_order, field_values, geopoint_indices); } else { break; @@ -2828,7 +2856,8 @@ void Index::search_across_fields(const std::vector& query_tokens, score_results2(sort_fields, searched_queries.size(), field_is_array, total_cost, field_match_score, seq_id, sort_order, group_limit, group_by_fields, - prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, token_postings); + prioritize_exact_match, single_exact_query_token, + query_tokens.size(), syn_orig_num_tokens, token_postings); if(field_match_score > max_field_match_score) { max_field_match_score = field_match_score; @@ -2861,7 +2890,10 @@ void Index::search_across_fields(const std::vector& query_tokens, (int64_t(the_fields[max_field_match_index].weight) << 0); /*LOG(INFO) << "seq_id: " << seq_id << ", query_tokens.size(): " << query_tokens.size() + << ", syn_orig_num_tokens: " << syn_orig_num_tokens << ", max_field_match_score: " << max_field_match_score + << ", max_field_match_index: " << max_field_match_index + << ", field_weight: " << the_fields[max_field_match_index].weight << ", aggregated_score: " << aggregated_score;*/ KV kv(0, searched_queries.size(), 0, seq_id, distinct_id, match_score_index, scores); @@ -3121,7 +3153,8 @@ void Index::do_synonym_search(const std::vector& the_fields, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, Topster* actual_topster, - std::vector& field_query_tokens, + std::vector>& q_pos_synonyms, + int syn_orig_num_tokens, spp::sparse_hash_set& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, @@ -3130,28 +3163,7 @@ void Index::do_synonym_search(const std::vector& the_fields, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices, - tsl::htrie_map& qtoken_set, - std::vector>& all_queries) const { - - std::vector q_include_tokens; - for(size_t j = 0; j < field_query_tokens[0].q_include_tokens.size(); j++) { - q_include_tokens.push_back(field_query_tokens[0].q_include_tokens[j].value); - } - synonym_index->synonym_reduction(q_include_tokens, field_query_tokens[0].q_synonyms); - - std::vector> q_pos_synonyms; - for(const auto& q_syn_vec: field_query_tokens[0].q_synonyms) { - std::vector q_pos_syn; - for(size_t j=0; j < q_syn_vec.size(); j++) { - bool is_prefix = (j == q_syn_vec.size()-1); - q_pos_syn.emplace_back(j, q_syn_vec[j], is_prefix, q_syn_vec[j].size(), 0); - } - q_pos_synonyms.push_back(q_pos_syn); - all_queries.push_back(q_pos_syn); - } - - int syn_orig_num_tokens = field_query_tokens[0].q_include_tokens.size(); - bool syn_wildcard_filter_init_done = false; + tsl::htrie_map& qtoken_set) const { for(const auto& syn_tokens: q_pos_synonyms) { query_hashes.clear(); @@ -3161,7 +3173,7 @@ void Index::do_synonym_search(const std::vector& the_fields, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, query_hashes, token_order, {0}, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, - min_len_2typo, q_include_tokens.size(), sort_order, field_values, geopoint_indices); + min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -3225,7 +3237,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& filters, score_results2(sort_fields, (uint16_t) searched_queries.size(), false, 0, match_score, seq_id, sort_order, group_limit, group_by_fields, false, - false, -1, plists); + false, 1, -1, plists); int64_t scores[3] = {0}; int64_t match_score_index = 0; @@ -3850,6 +3862,7 @@ int64_t Index::score_results2(const std::vector & sort_fields, const ui const size_t group_limit, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool single_exact_query_token, + size_t num_query_tokens, int syn_orig_num_tokens, const std::vector& posting_lists) const { @@ -3861,8 +3874,8 @@ int64_t Index::score_results2(const std::vector & sort_fields, const ui prioritize_exact_match && single_exact_query_token && posting_list_t::is_single_token_verbatim_match(posting_lists[0], field_is_array) ); - size_t words_present = (syn_orig_num_tokens == -1) ? 1 : syn_orig_num_tokens; - size_t distance = (syn_orig_num_tokens == -1) ? 0 : syn_orig_num_tokens-1; + size_t words_present = (num_query_tokens == 1 && syn_orig_num_tokens != -1) ? syn_orig_num_tokens : 1; + size_t distance = (num_query_tokens == 1 && syn_orig_num_tokens != -1) ? syn_orig_num_tokens-1 : 0; Match single_token_match = Match(words_present, distance, is_verbatim_match); match_score = single_token_match.get_match_score(total_cost, words_present); } else { @@ -3884,7 +3897,7 @@ int64_t Index::score_results2(const std::vector & sort_fields, const ui auto proximity = ((this_match_score >> 8) & 0xFF); auto verbatim = (this_match_score & 0xFF); - if(syn_orig_num_tokens != -1) { + if(syn_orig_num_tokens != -1 && num_query_tokens == posting_lists.size()) { unique_words = syn_orig_num_tokens; this_words_present = syn_orig_num_tokens; proximity = 100 - (syn_orig_num_tokens - 1); @@ -4179,8 +4192,6 @@ inline uint32_t Index::next_suggestion2(const std::vector& token q = ldiv(q.quot, token_candidates_vec[i].candidates.size()); const auto& candidate = token_candidates_vec[i].candidates[q.rem]; - bool exact_match = token_candidates_vec[i].cost == 0 && token_size == candidate.size(); - // we assume that toke was found via prefix search if candidate is longer than token's typo tolerance bool is_prefix_searched = token_candidates_vec[i].prefix_search && (candidate.size() > (token_size + token_candidates_vec[i].cost)); diff --git a/test/collection_synonyms_test.cpp b/test/collection_synonyms_test.cpp index 91426240..b0057b7e 100644 --- a/test/collection_synonyms_test.cpp +++ b/test/collection_synonyms_test.cpp @@ -375,11 +375,8 @@ TEST_F(CollectionSynonymsTest, SynonymQueryVariantWithDropTokens) { auto res = coll1->search("us sneakers", {"category", "location"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); ASSERT_EQ(3, res["hits"].size()); - // NOTE: "1" is ranked above "0" because synonym matches uses the root query's number of tokens for counting - // This means that "united states" == "us" so both records have 2 tokens matched, so tie breaking happens on points. - - ASSERT_EQ("1", res["hits"][0]["document"]["id"].get()); - ASSERT_EQ("0", res["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", res["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", res["hits"][1]["document"]["id"].get()); ASSERT_EQ("2", res["hits"][2]["document"]["id"].get()); collectionManager.drop_collection("coll1"); @@ -418,7 +415,7 @@ TEST_F(CollectionSynonymsTest, SynonymsTextMatchSameAsRootQuery) { ASSERT_TRUE(coll1->add(doc1.dump()).ok()); ASSERT_TRUE(coll1->add(doc2.dump()).ok()); - auto res = coll1->search("ceo", {"name", "title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get(); + auto res = coll1->search("ceo", {"name", "title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 0).get(); ASSERT_EQ(2, res["hits"].size()); ASSERT_EQ("1", res["hits"][0]["document"]["id"].get()); @@ -535,6 +532,59 @@ TEST_F(CollectionSynonymsTest, ExactMatchRankedSameAsSynonymMatch) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionSynonymsTest, ExactMatchVsSynonymMatchCrossFields) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("description", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Head of Marketing", "The Chief Marketing Officer", "100"}, + {"VP of Sales", "Preparing marketing and sales materials.", "120"}, + }; + + for(size_t i=0; iadd(doc.dump()).ok()); + } + + nlohmann::json syn_json = { + {"id", "syn-1"}, + {"synonyms", {"cmo", "Chief Marketing Officer", "VP of Marketing"}} + }; + + synonym_t synonym; + auto syn_op = synonym_t::parse(syn_json, synonym); + ASSERT_TRUE(syn_op.ok()); + + coll1->add_synonym(synonym); + + auto res = coll1->search("cmo", {"title", "description"}, "", {}, {}, + {0}, 10, 1, FREQUENCY, {false}, 0).get(); + + LOG(INFO) << res; + + ASSERT_EQ(2, res["hits"].size()); + ASSERT_EQ(2, res["found"].get()); + + ASSERT_EQ("0", res["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", res["hits"][1]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionSynonymsTest, SynonymFieldOrdering) { // Synonym match on a field earlier in the fields list should rank above exact match of another field Collection *coll1;