diff --git a/src/art.cpp b/src/art.cpp index 0b47efa2..5ba4aea0 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1531,7 +1531,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, std::sort(results.begin(), results.end(), compare_art_leaf_score); } - if(exact_leaf) { + if(exact_leaf && min_cost == 0) { results.insert(results.begin(), exact_leaf); } diff --git a/src/index.cpp b/src/index.cpp index c4bdcb04..26a64855 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1170,22 +1170,33 @@ void Index::search_all_candidates(const size_t num_search_fields, // escape hatch to prevent too much looping but subject to being overriden explicitly via `max_candidates` long long combination_limit = std::max(Index::COMBINATION_MIN_LIMIT, max_candidates); + /*if(!token_candidates_vec.empty()) { + LOG(INFO) << "Prefix candidates size: " << token_candidates_vec.back().candidates.size(); + LOG(INFO) << "max_candidates: " << max_candidates; + LOG(INFO) << "combination_limit: " << combination_limit; + LOG(INFO) << "token_candidates_vec.size(): " << token_candidates_vec.size(); + }*/ + + std::unordered_set trimmed_candidates; + if(token_candidates_vec.size() > 1 && token_candidates_vec.back().candidates.size() > max_candidates) { - std::vector trimmed_candidates; std::vector temp_ids; find_across_fields(query_tokens, query_tokens.size()-1, num_typos, prefixes, the_fields, num_search_fields, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size, temp_ids); + //LOG(INFO) << "temp_ids found: " << temp_ids.size(); + for(auto& token_str: token_candidates_vec.back().candidates) { + //LOG(INFO) << "Prefix token: " << token_str; + const bool prefix_search = query_tokens.back().is_prefix_searched; const uint32_t token_num_typos = query_tokens.back().num_typos; const bool token_prefix = query_tokens.back().is_prefix_searched; auto token_c_str = (const unsigned char*) token_str.c_str(); const size_t token_len = token_str.size() + 1; - std::vector its; for(size_t i = 0; i < num_search_fields; i++) { const std::string& field_name = the_fields[i].name; @@ -1216,18 +1227,31 @@ void Index::search_all_candidates(const size_t num_search_fields, continue; } - trimmed_candidates.push_back(token_str); + trimmed_candidates.insert(token_str); + //LOG(INFO) << "Pushing back found token_str: " << token_str; + /*LOG(INFO) << "max_candidates: " << max_candidates << ", trimmed_candidates.size(): " + << trimmed_candidates.size();*/ + if(trimmed_candidates.size() == max_candidates) { - break; + goto outer_loop; } } } + outer_loop: + if(trimmed_candidates.empty()) { return ; } - token_candidates_vec.back().candidates = std::move(trimmed_candidates); + /*LOG(INFO) << "Final trimmed_candidates.size: " << trimmed_candidates.size(); + for(const auto& trimmed_candidate: trimmed_candidates) { + LOG(INFO) << "trimmed_candidate: " << trimmed_candidate; + }*/ + + token_candidates_vec.back().candidates.clear(); + token_candidates_vec.back().candidates.assign(trimmed_candidates.begin(), trimmed_candidates.end()); + } for(long long n = 0; n < N && n < combination_limit; ++n) { @@ -1238,8 +1262,22 @@ void Index::search_all_candidates(const size_t num_search_fields, uint64 qhash; uint32_t total_cost = next_suggestion2(token_candidates_vec, n, query_suggestion, qhash); + /*LOG(INFO) << "n: " << n; + std::stringstream fullq; + for(const auto& qtok : query_suggestion) { + fullq << qtok.value << " "; + } + LOG(INFO) << "query: " << fullq.str() << ", total_cost: " << total_cost + << ", all_result_ids_len: " << all_result_ids_len << ", bufsiz: " << id_buff.size();*/ + + if(query_hashes.find(qhash) != query_hashes.end()) { + // skip this query since it has already been processed before + //LOG(INFO) << "Skipping qhash " << qhash; + continue; + } + //LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold; - //LOG(INFO) << "n: " << n; + search_across_fields(query_suggestion, num_typos, prefixes, the_fields, num_search_fields, sort_fields, topster,groups_processed, searched_queries, qtoken_set, group_limit, group_by_fields, prioritize_exact_match, @@ -1248,20 +1286,7 @@ void Index::search_all_candidates(const size_t num_search_fields, sort_order, field_values, geopoint_indices, id_buff, all_result_ids, all_result_ids_len); - if(query_hashes.find(qhash) != query_hashes.end()) { - // skip this query since it has already been processed before - continue; - } - query_hashes.insert(qhash); - - /*std::stringstream fullq; - for(const auto& qtok : query_suggestion) { - fullq << qtok.value << " "; - } - - LOG(INFO) << "query: " << fullq.str() << ", total_cost: " << total_cost - << ", all_result_ids_len: " << all_result_ids_len << ", bufsiz: " << id_buff.size();*/ } } @@ -2712,6 +2737,8 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, while(n < N && n < combination_limit) { RETURN_CIRCUIT_BREAKER + //LOG(INFO) << "fuzzy_search_fields, n: " << n; + // Outerloop generates combinations of [cost to max_cost] for each token // For e.g. for a 3-token query: [0, 0, 0], [0, 0, 1], [0, 1, 1] etc. std::vector costs(token_to_costs.size()); @@ -2784,10 +2811,16 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, if(!leaves.empty()) { //log_leaves(costs[token_index], token, leaves); std::vector leaf_tokens; + std::unordered_set leaf_token_set; + for(auto leaf: leaves) { std::string ltok(reinterpret_cast(leaf->key), leaf->key_len - 1); - leaf_tokens.push_back(ltok); + if(leaf_token_set.count(ltok) == 0) { + leaf_tokens.push_back(ltok); + leaf_token_set.insert(ltok); + } } + token_candidates_vec.push_back(tok_candidates{query_tokens[token_index], costs[token_index], query_tokens[token_index].is_prefix_searched, leaf_tokens}); } else { diff --git a/test/art_test.cpp b/test/art_test.cpp index c7550bf8..4f30b042 100644 --- a/test/art_test.cpp +++ b/test/art_test.cpp @@ -683,6 +683,10 @@ TEST(ArtTest, test_art_fuzzy_search_prefix_token_ordering) { std::string third_key(reinterpret_cast(leaves[2]->key), leaves[2]->key_len - 1); ASSERT_EQ("elephant", third_key); + leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "enter", 5, 1, 1, 3, MAX_SCORE, true, nullptr, 0, leaves); + ASSERT_TRUE(leaves.empty()); + res = art_tree_destroy(&t); ASSERT_TRUE(res == 0); } diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index c75859fc..44f08453 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -55,3 +55,33 @@ TEST_F(CollectionSpecificMoreTest, MaxCandidatesShouldBeRespected) { ASSERT_EQ(200, results["found"].get()); collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionSpecificMoreTest, PrefixExpansionWhenExactMatchExists) { + std::vector fields = {field("title", field_types::STRING, false), + field("author", field_types::STRING, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields).get(); + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["title"] = "The Little Prince [by] Antoine de Saint Exupéry : teacher guide"; + doc1["author"] = "Barbara Valdez"; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["title"] = "Little Prince"; + doc2["author"] = "Antoine de Saint-Exupery"; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + + auto results = coll1->search("little prince antoine saint", {"title", "author"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 1, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 5, {}, {}, {}, 0, + "", "", {}, 1000, true).get(); + + ASSERT_EQ(2, results["hits"].size()); + collectionManager.drop_collection("coll1"); +}