diff --git a/include/index.h b/include/index.h index d0f012a7..32c8b5ad 100644 --- a/include/index.h +++ b/include/index.h @@ -528,6 +528,7 @@ private: size_t drop_tokens_threshold, size_t typo_tokens_threshold, bool exhaustive_search, + int syn_orig_num_tokens, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates) const; @@ -548,6 +549,7 @@ private: const std::vector& query_tokens, bool prioritize_exact_match, bool exhaustive_search, + int syn_orig_num_tokens, size_t concurrency, std::set& query_hashes, std::vector& id_buff) const; @@ -656,6 +658,7 @@ public: const std::vector &group_by_fields, uint32_t token_bits, bool prioritize_exact_match, bool single_exact_query_token, + int syn_orig_num_tokens, const std::vector& posting_lists) const; static int64_t get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field); diff --git a/src/collection.cpp b/src/collection.cpp index 51e08cdc..3887a562 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/src/index.cpp b/src/index.cpp index 1a441d39..0c666115 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1114,6 +1114,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, const std::vector& query_tokens, bool prioritize_exact_match, const bool exhaustive_search, + int syn_orig_num_tokens, const size_t concurrency, std::set& query_hashes, std::vector& id_buff) const { @@ -1208,7 +1209,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, total_cost, topsters[index], query_suggestion, groups_processed_vec[index], seq_id, sort_order, field_values, geopoint_indices, group_limit, group_by_fields, token_bits, - prioritize_exact_match, single_exact_query_token, its); + prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, its); result_id_vecs[index].push_back(seq_id); }, concurrency); @@ -1952,7 +1953,7 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string& search_field(0, window_tokens, search_tokens, nullptr, 0, num_toks_dropped, field_it->second, field_name, nullptr, 0, {}, {}, 2, searched_queries, topster, groups_processed, &result_ids, result_ids_len, field_num_results, 0, group_by_fields, - false, 4, query_hashes, token_order, false, 0, 1, false, 3, 7, 4); + false, 4, query_hashes, token_order, false, 0, 1, false, -1, 3, 7, 4); delete [] result_ids; @@ -2536,7 +2537,7 @@ void Index::search_fields(const std::vector& filters, field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, token_order, field_prefix, - drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, + drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, -1, min_len_1typo, min_len_2typo, max_candidates); do_infix_search(sort_fields_std, searched_queries, group_limit, group_by_fields, @@ -2568,8 +2569,7 @@ void Index::search_fields(const std::vector& filters, field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, token_order, field_prefix, - 0, 0, exhaustive_search, - min_len_1typo, min_len_2typo, max_candidates); + 0, 0, exhaustive_search, -1, min_len_1typo, min_len_2typo, max_candidates); } } @@ -2588,12 +2588,14 @@ void Index::search_fields(const std::vector& filters, } // do synonym based searches + // since typos are disabled, we will use drop_tokens_threshold for typo_tokens_threshold as well + // otherwise, we can't support dropping of tokens here. do_synonym_search(filters, included_ids_map, sort_fields_std, curated_topster, token_order, - drop_tokens_threshold, typo_tokens_threshold, group_limit, + drop_tokens_threshold, drop_tokens_threshold, group_limit, group_by_fields, prioritize_exact_match, exhaustive_search, concurrency, min_len_1typo, min_len_2typo, max_candidates, curated_ids, curated_ids_sorted, exclude_token_ids, - exclude_token_ids_size, i, actual_filter_ids_length, field_num_typos, field_prefix, + exclude_token_ids_size, i, actual_filter_ids_length, 0, field_prefix, field_id, field_name, field_it, query_tokens, search_tokens, num_tokens_dropped, actual_topster, field_num_results, field_query_tokens, all_result_ids_len, groups_processed, searched_queries, @@ -2703,6 +2705,7 @@ void Index::do_synonym_search(const std::vector& filters, q_pos_synonyms.emplace_back(q_pos_syn); } + int syn_orig_num_tokens = field_query_tokens[i].q_include_tokens.size(); bool syn_wildcard_filter_init_done = false; for(const auto& syn_tokens: q_pos_synonyms) { @@ -2734,7 +2737,7 @@ void Index::do_synonym_search(const std::vector& filters, field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, token_order, field_prefix, - drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, + drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, syn_orig_num_tokens, min_len_1typo, min_len_2typo, max_candidates); } } @@ -2792,7 +2795,7 @@ void Index::do_infix_search(const std::vector& sort_fields_std, score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 2, actual_topster, {}, groups_processed, seq_id, sort_order, field_values, geopoint_indices, group_limit, group_by_fields, token_bits, - false, false, {}); + false, false, -1, {}); } uint32_t* new_all_result_ids = nullptr; @@ -2928,7 +2931,7 @@ void Index::compute_facet_infos(const std::vector& facets, facet_query_t& facet_field, facet_field.faceted_name(), all_result_ids, all_result_ids_len, {}, {}, facet_query_num_typos, searched_queries, topster, groups_processed, &field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields, - false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, 3, 1000, 4); + false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, -1, 3, 1000, 4); //LOG(INFO) << "searched_queries.size: " << searched_queries.size(); @@ -3100,7 +3103,7 @@ void Index::search_wildcard(const std::vector& filters, score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 0, topsters[thread_id], {}, tgroups_processed[thread_id], seq_id, sort_order, field_values, geopoint_indices, group_limit, group_by_fields, token_bits, - false, false, plists); + false, false, -1, plists); if(check_for_circuit_break && ((i + 1) % (1 << 15)) == 0) { // check only once every 2^15 docs to reduce overhead @@ -3201,6 +3204,7 @@ void Index::search_field(const uint8_t & field_id, const size_t drop_tokens_threshold, const size_t typo_tokens_threshold, const bool exhaustive_search, + int syn_orig_num_tokens, size_t min_len_1typo, size_t min_len_2typo, const size_t max_candidates) const { @@ -3339,7 +3343,8 @@ void Index::search_field(const uint8_t & field_id, curated_ids, sort_fields, token_candidates_vec, searched_queries, topster, groups_processed, all_result_ids, all_result_ids_len, field_num_results, typo_tokens_threshold, group_limit, group_by_fields, query_tokens, - prioritize_exact_match, combination_limit, concurrency, query_hashes, id_buff); + prioritize_exact_match, exhaustive_search, syn_orig_num_tokens, + concurrency, query_hashes, id_buff); if(id_buff.size() > 1) { std::sort(id_buff.begin(), id_buff.end()); @@ -3395,7 +3400,7 @@ void Index::search_field(const uint8_t & field_id, all_result_ids_len, field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency, query_hashes, token_order, prefix, drop_tokens_threshold, typo_tokens_threshold, - exhaustive_search, min_len_1typo, min_len_2typo, max_candidates); + exhaustive_search, syn_orig_num_tokens, min_len_1typo, min_len_2typo, max_candidates); } } @@ -3429,16 +3434,17 @@ void Index::log_leaves(const int cost, const std::string &token, const std::vect void Index::score_results(const std::vector & sort_fields, const uint16_t & query_index, const uint8_t & field_id, const bool field_is_array, const uint32_t total_cost, - Topster* topster /**/, + Topster* topster, const std::vector &query_suggestion, - spp::sparse_hash_set& groups_processed /**/, + spp::sparse_hash_set& groups_processed, const uint32_t seq_id, const int sort_order[3], - std::array*, 3> field_values /**/, + std::array*, 3> field_values, const std::vector& geopoint_indices, const size_t group_limit, const std::vector& group_by_fields, const uint32_t token_bits, const bool prioritize_exact_match, const bool single_exact_query_token, + int syn_orig_num_tokens, const std::vector& posting_lists) const { int64_t geopoint_distances[3]; @@ -3503,17 +3509,20 @@ void Index::score_results(const std::vector & sort_fields, const uint16 prioritize_exact_match && single_exact_query_token && posting_list_t::is_single_token_verbatim_match(posting_lists[0], field_is_array) ); - Match single_token_match = Match(1, 0, is_verbatim_match); - match_score = single_token_match.get_match_score(total_cost, 1); + size_t words_present = (syn_orig_num_tokens == -1) ? 1 : syn_orig_num_tokens; + size_t distance = (syn_orig_num_tokens == -1) ? 0 : syn_orig_num_tokens-1; + Match single_token_match = Match(words_present, distance, is_verbatim_match); + match_score = single_token_match.get_match_score(total_cost, words_present); } else { - uint64_t total_tokens_found = 0, total_num_typos = 0, total_distance = 0, total_verbatim = 0; - std::unordered_map> array_token_positions; posting_list_t::get_offsets(posting_lists, array_token_positions); // NOTE: tokens found returned by matcher is only within the best matched window, so we have to still consider // unique tokens found if they are spread across the text. uint32_t unique_tokens_found = __builtin_popcount(token_bits); + if(syn_orig_num_tokens != -1) { + unique_tokens_found = syn_orig_num_tokens; + } for (const auto& kv: array_token_positions) { const std::vector& token_positions = kv.second; @@ -3524,10 +3533,27 @@ void Index::score_results(const std::vector & sort_fields, const uint16 const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match); uint64_t this_match_score = match.get_match_score(total_cost, unique_tokens_found); - total_tokens_found += ((this_match_score >> 24) & 0xFF); - total_num_typos += 255 - ((this_match_score >> 16) & 0xFF); - total_distance += 100 - ((this_match_score >> 8) & 0xFF); - total_verbatim += (this_match_score & 0xFF); + auto this_words_present = ((this_match_score >> 24) & 0xFF); + auto typo_score = ((this_match_score >> 16) & 0xFF); + auto proximity = ((this_match_score >> 8) & 0xFF); + auto verbatim = (this_match_score & 0xFF); + + if(syn_orig_num_tokens != -1) { + this_words_present = syn_orig_num_tokens; + proximity = 100 - (syn_orig_num_tokens - 1); + } + + uint64_t mod_match_score = ( + (int64_t(unique_tokens_found) << 32) | + (int64_t(this_words_present) << 24) | + (int64_t(typo_score) << 16) | + (int64_t(proximity) << 8) | + (int64_t(verbatim) << 0) + ); + + if(mod_match_score > match_score) { + match_score = mod_match_score; + } /*std::ostringstream os; os << name << ", total_cost: " << (255 - total_cost) @@ -3537,21 +3563,6 @@ void Index::score_results(const std::vector & sort_fields, const uint16 << ", seq_id: " << seq_id << std::endl; LOG(INFO) << os.str();*/ } - - match_score = ( - (uint64_t(unique_tokens_found) << 32) | - (uint64_t(total_tokens_found) << 24) | - (uint64_t(255 - total_num_typos) << 16) | - (uint64_t(100 - total_distance) << 8) | - (uint64_t(total_verbatim) << 0) - ); - - /*LOG(INFO) << "Match score: " << match_score << ", for seq_id: " << seq_id - << " - total_tokens_found: " << total_tokens_found - << " - total_num_typos: " << total_num_typos - << " - total_distance: " << total_distance - << " - total_verbatim: " << total_verbatim - << " - total_cost: " << total_cost;*/ } const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional) diff --git a/test/collection_specific_test.cpp b/test/collection_specific_test.cpp index 125bb949..5ab05b28 100644 --- a/test/collection_specific_test.cpp +++ b/test/collection_specific_test.cpp @@ -2535,3 +2535,36 @@ TEST_F(CollectionSpecificTest, PhraseSearchOnLongText) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionSpecificTest, RepeatedTokensInArray) { + // should have same text match score + std::vector fields = {field("tags", field_types::STRING_ARRAY, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields).get(); + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["tags"] = {"Harry Mark"}; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["tags"] = {"Harry is random", "Harry Simpson"}; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["tags"] = {"Harry is Harry"}; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + auto results = coll1->search("harry", {"tags"}, + "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ(results["hits"][0]["text_match"].get(), results["hits"][1]["text_match"].get()); + ASSERT_EQ(results["hits"][1]["text_match"].get(), results["hits"][2]["text_match"].get()); + + collectionManager.drop_collection("coll1"); +} + diff --git a/test/collection_synonyms_test.cpp b/test/collection_synonyms_test.cpp index 727af751..4506892c 100644 --- a/test/collection_synonyms_test.cpp +++ b/test/collection_synonyms_test.cpp @@ -426,8 +426,8 @@ TEST_F(CollectionSynonymsTest, MultiWaySynonym) { ASSERT_EQ(2, res["hits"].size()); ASSERT_EQ(2, res["found"].get()); - ASSERT_STREQ("Samuel L. Jackson", res["hits"][0]["highlights"][0]["snippet"].get().c_str()); - ASSERT_STREQ("Samuel L. Jackson", res["hits"][1]["highlights"][0]["snippet"].get().c_str()); + ASSERT_STREQ("Samuel L. Jackson", res["hits"][0]["highlights"][0]["snippet"].get().c_str()); + ASSERT_STREQ("Samuel L. Jackson", res["hits"][1]["highlights"][0]["snippet"].get().c_str()); // for now we don't support synonyms on ANY prefix @@ -624,4 +624,64 @@ TEST_F(CollectionSynonymsTest, SynonymSingleTokenExactMatch) { ASSERT_STREQ("2", res["hits"][0]["document"]["id"].get().c_str()); collectionManager.drop_collection("coll1"); -} \ No newline at end of file +} + +TEST_F(CollectionSynonymsTest, SynonymExpansionAndCompressionRanking) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Smashed Lemon", "100"}, + {"Lulu Lemon", "100"}, + {"Lululemon", "200"}, + }; + + for(size_t i=0; iadd(doc.dump()).ok()); + } + + synonym_t synonym1{"syn-1", {"lululemon"}, {{"lulu", "lemon"}}}; + coll1->add_synonym(synonym1); + + auto res = coll1->search("lululemon", {"title"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + ASSERT_EQ(2, res["hits"].size()); + ASSERT_EQ(2, res["found"].get()); + + // Even thought "lulu lemon" has two token synonym match, it should have same text match score as "lululemon" + // and hence must be tied and then ranked on "points" + ASSERT_EQ("2", res["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", res["hits"][1]["document"]["id"].get()); + + ASSERT_EQ(res["hits"][0]["text_match"].get(), res["hits"][1]["text_match"].get()); + + // now with compression synonym + synonym1.root = {"lulu", "lemon"}; + synonym1.synonyms = {{"lululemon"}}; + coll1->add_synonym(synonym1); + + res = coll1->search("lulu lemon", {"title"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + ASSERT_EQ(2, res["hits"].size()); + ASSERT_EQ(2, res["found"].get()); + + // Even thought "lululemon" has single token synonym match, it should have same text match score as "lulu lemon" + // and hence must be tied and then ranked on "points" + ASSERT_EQ("2", res["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", res["hits"][1]["document"]["id"].get()); + + ASSERT_EQ(res["hits"][0]["text_match"].get(), res["hits"][1]["text_match"].get()); + + collectionManager.drop_collection("coll1"); +}