From d935cb4041c86f7196a700186caab960e5d6d157 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 19 May 2022 16:12:33 +0530 Subject: [PATCH] Fix proper prefix expansion. --- include/index.h | 14 ++- src/collection_manager.cpp | 2 +- src/index.cpp | 198 +++++++++++++++++++++++++------ test/collection_sorting_test.cpp | 11 ++ test/collection_test.cpp | 6 +- 5 files changed, 192 insertions(+), 39 deletions(-) diff --git a/include/index.h b/include/index.h index eb8c0316..e339ff9b 100644 --- a/include/index.h +++ b/include/index.h @@ -501,8 +501,6 @@ private: static inline uint32_t next_suggestion2(const std::vector& token_candidates_vec, long long int n, std::vector& query_suggestion, - int syn_orig_num_tokens, - uint32_t& token_bits, uint64& qhash); static inline uint32_t next_suggestion(const std::vector &token_candidates_vec, @@ -680,6 +678,7 @@ public: enum {COMBINATION_MAX_LIMIT = 10000}; enum {COMBINATION_MIN_LIMIT = 10}; + enum {MAX_CANDIDATES_DEFAULT = 4}; // If the number of results found is less than this threshold, Typesense will attempt to drop the tokens // in the query that have the least individual hits one by one until enough results are found. @@ -939,6 +938,17 @@ public: std::array*, 3>& field_values, const std::vector& geopoint_indices) const; + void find_across_fields(const std::vector& query_tokens, + const size_t num_query_tokens, + const std::vector& num_typos, + const std::vector& prefixes, + const std::vector& the_fields, + const size_t num_search_fields, + const uint32_t* filter_ids, uint32_t filter_ids_length, + const uint32_t* exclude_token_ids, + size_t exclude_token_ids_size, + std::vector& id_buff) const; + void search_across_fields(const std::vector& query_tokens, const std::vector& num_typos, const std::vector& prefixes, diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index aa7b0ddd..2b98aa8c 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -866,7 +866,7 @@ Option CollectionManager::do_search(std::map& re } if(!max_candidates) { - max_candidates = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::COMBINATION_MIN_LIMIT; + max_candidates = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::MAX_CANDIDATES_DEFAULT; } Option result_op = collection->search(raw_query, search_fields, simple_filter_query, facet_fields, diff --git a/src/index.cpp b/src/index.cpp index 3580fe30..c4bdcb04 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1167,26 +1167,76 @@ void Index::search_all_candidates(const size_t num_search_fields, auto product = []( long long a, tok_candidates & b ) { return a*b.candidates.size(); }; long long int N = std::accumulate(token_candidates_vec.begin(), token_candidates_vec.end(), 1LL, product); + // escape hatch to prevent too much looping but subject to being overriden explicitly via `max_candidates` long long combination_limit = std::max(Index::COMBINATION_MIN_LIMIT, max_candidates); + if(token_candidates_vec.size() > 1 && token_candidates_vec.back().candidates.size() > max_candidates) { + std::vector trimmed_candidates; + std::vector temp_ids; + + find_across_fields(query_tokens, query_tokens.size()-1, num_typos, prefixes, the_fields, num_search_fields, + filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size, + temp_ids); + + for(auto& token_str: token_candidates_vec.back().candidates) { + const bool prefix_search = query_tokens.back().is_prefix_searched; + const uint32_t token_num_typos = query_tokens.back().num_typos; + const bool token_prefix = query_tokens.back().is_prefix_searched; + + auto token_c_str = (const unsigned char*) token_str.c_str(); + const size_t token_len = token_str.size() + 1; + std::vector its; + + for(size_t i = 0; i < num_search_fields; i++) { + const std::string& field_name = the_fields[i].name; + const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0]; + const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0]; + + if (token_num_typos > field_num_typos) { + // since the token can come from any field, we still have to respect per-field num_typos + continue; + } + + if (token_prefix && !field_prefix) { + // even though this token is an outcome of prefix search, we can't use it for this field, since + // this field has prefix search disabled. + continue; + } + + art_tree* tree = search_index.at(field_name); + art_leaf* leaf = static_cast(art_search(tree, token_c_str, token_len)); + + if (!leaf) { + continue; + } + + bool found_atleast_one = posting_t::contains_atleast_one(leaf->values, &temp_ids[0], + temp_ids.size()); + if(!found_atleast_one) { + continue; + } + + trimmed_candidates.push_back(token_str); + if(trimmed_candidates.size() == max_candidates) { + break; + } + } + } + + if(trimmed_candidates.empty()) { + return ; + } + + token_candidates_vec.back().candidates = std::move(trimmed_candidates); + } + for(long long n = 0; n < N && n < combination_limit; ++n) { RETURN_CIRCUIT_BREAKER - // every element in `query_suggestion` contains a token and its associated hits std::vector query_suggestion(token_candidates_vec.size()); uint64 qhash; - - uint32_t token_bits = 0; - uint32_t total_cost = next_suggestion2(token_candidates_vec, n, query_suggestion, syn_orig_num_tokens, - token_bits, qhash); - - if(query_hashes.find(qhash) != query_hashes.end()) { - // skip this query since it has already been processed before - continue; - } - - query_hashes.insert(qhash); + uint32_t total_cost = next_suggestion2(token_candidates_vec, n, query_suggestion, qhash); //LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold; //LOG(INFO) << "n: " << n; @@ -1198,12 +1248,20 @@ void Index::search_all_candidates(const size_t num_search_fields, sort_order, field_values, geopoint_indices, id_buff, all_result_ids, all_result_ids_len); + if(query_hashes.find(qhash) != query_hashes.end()) { + // skip this query since it has already been processed before + continue; + } + + query_hashes.insert(qhash); + /*std::stringstream fullq; for(const auto& qtok : query_suggestion) { fullq << qtok.value << " "; } - LOG(INFO) << "query: " << fullq.str() << ", total_cost: " << total_cost << ", num: " << all_result_ids_len;*/ + LOG(INFO) << "query: " << fullq.str() << ", total_cost: " << total_cost + << ", all_result_ids_len: " << all_result_ids_len << ", bufsiz: " << id_buff.size();*/ } } @@ -1238,9 +1296,10 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, populate_sort_mapping(sort_order, geopoint_indices, sort_fields, field_values); + // escape hatch to prevent too much looping size_t combination_limit = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::COMBINATION_MIN_LIMIT; - for(long long n=0; n& the_fields, long long n = 0; long long int N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product); - const long long combination_limit = std::max(Index::COMBINATION_MIN_LIMIT, max_candidates); + const long long combination_limit = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::COMBINATION_MIN_LIMIT; while(n < N && n < combination_limit) { RETURN_CIRCUIT_BREAKER @@ -2699,9 +2758,10 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, continue; } + size_t max_words = (num_search_fields == 1 && prefix_search) ? max_candidates : 100000; // need less candidates for filtered searches since we already only pick tokens with results art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, - costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, + costs[token_index], costs[token_index], max_words, token_order, prefix_search, filter_ids, filter_ids_length, leaves, unique_tokens); /*auto timeMillis = std::chrono::duration_cast( @@ -2718,10 +2778,6 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); unique_tokens.emplace(tok); } - - if(unique_tokens.size() > max_candidates) { - break; - } } } @@ -2792,6 +2848,93 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } } +void Index::find_across_fields(const std::vector& query_tokens, + const size_t num_query_tokens, + const std::vector& num_typos, + const std::vector& prefixes, + const std::vector& the_fields, + const size_t num_search_fields, + const uint32_t* filter_ids, uint32_t filter_ids_length, + const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, + std::vector& id_buff) const { + + // one iterator for each token, each underlying iterator contains results of token across multiple fields + std::vector token_its; + + // used to track plists that must be destructed once done + std::vector expanded_plists; + + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length); + + // for each token, find the posting lists across all query_by fields + for(size_t ti = 0; ti < num_query_tokens; ti++) { + const bool prefix_search = query_tokens[ti].is_prefix_searched; + const uint32_t token_num_typos = query_tokens[ti].num_typos; + const bool token_prefix = query_tokens[ti].is_prefix_searched; + + auto& token_str = query_tokens[ti].value; + auto token_c_str = (const unsigned char*) token_str.c_str(); + const size_t token_len = token_str.size() + 1; + std::vector its; + + for(size_t i = 0; i < num_search_fields; i++) { + const std::string& field_name = the_fields[i].name; + const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0]; + const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0]; + + if(token_num_typos > field_num_typos) { + // since the token can come from any field, we still have to respect per-field num_typos + continue; + } + + if(token_prefix && !field_prefix) { + // even though this token is an outcome of prefix search, we can't use it for this field, since + // this field has prefix search disabled. + continue; + } + + art_tree* tree = search_index.at(field_name); + art_leaf* leaf = static_cast(art_search(tree, token_c_str, token_len)); + + if(!leaf) { + continue; + } + + /*LOG(INFO) << "Token: " << token_str << ", field_name: " << field_name + << ", num_ids: " << posting_t::num_ids(leaf->values);*/ + + if(IS_COMPACT_POSTING(leaf->values)) { + auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values); + posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list(); + expanded_plists.push_back(full_posting_list); + its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied + } else { + posting_list_t* full_posting_list = (posting_list_t*)(leaf->values); + its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied + } + } + + if(its.empty()) { + // this token does not have any match across *any* field: probably a typo + LOG(INFO) << "No matching field found for token: " << token_str; + continue; + } + + or_iterator_t token_fields(its); + token_its.push_back(std::move(token_fields)); + } + + or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector& its) { + // Convert [token -> fields] orientation to [field -> tokens] orientation + //LOG(INFO) << "seq_id: " << seq_id; + id_buff.push_back(seq_id); + }); + + for(posting_list_t* plist: expanded_plists) { + delete plist; + } +} + void Index::search_across_fields(const std::vector& query_tokens, const std::vector& num_typos, const std::vector& prefixes, @@ -2886,8 +3029,8 @@ void Index::search_across_fields(const std::vector& query_tokens, std::vector result_ids; or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector& its) { - // Convert [token -> fields] orientation to [field -> tokens] orientation //LOG(INFO) << "seq_id: " << seq_id; + // Convert [token -> fields] orientation to [field -> tokens] orientation std::vector> field_to_tokens(num_search_fields); for(size_t ti = 0; ti < its.size(); ti++) { @@ -4256,15 +4399,13 @@ uint64_t Index::get_distinct_id(const std::vector& group_by_fields, inline uint32_t Index::next_suggestion2(const std::vector& token_candidates_vec, long long int n, std::vector& query_suggestion, - int syn_orig_num_tokens, - uint32_t& token_bits, uint64& qhash) { uint32_t total_cost = 0; qhash = 1; // generate the next combination from `token_leaves` and store it in `query_suggestion` ldiv_t q { n, 0 }; - for(long long i = 0 ; i < (long long) token_candidates_vec.size(); i++) { + for(size_t i = 0 ; i < token_candidates_vec.size(); i++) { size_t token_size = token_candidates_vec[i].token.value.size(); q = ldiv(q.quot, token_candidates_vec[i].candidates.size()); const auto& candidate = token_candidates_vec[i].candidates[q.rem]; @@ -4276,8 +4417,6 @@ inline uint32_t Index::next_suggestion2(const std::vector& token size_t actual_cost = (2 * token_candidates_vec[i].cost) + uint32_t(is_prefix_searched); total_cost += actual_cost; - token_bits |= 1UL << token_candidates_vec[i].token.position; // sets n-th bit - query_suggestion[i] = token_t(i, candidate, is_prefix_searched, token_size, token_candidates_vec[i].cost); uint64_t this_hash = StringUtils::hash_wy(query_suggestion[i].value.c_str(), query_suggestion[i].value.size()); @@ -4288,13 +4427,6 @@ inline uint32_t Index::next_suggestion2(const std::vector& token LOG(INFO) << ".";*/ } - if(syn_orig_num_tokens != -1) { - token_bits = 0; - for(size_t i = 0; i < size_t(syn_orig_num_tokens); i++) { - token_bits |= 1UL << i; - } - } - return total_cost; } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index bf19e5e1..1b8fa703 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -221,6 +221,7 @@ TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField) } } + // max candidates as default 4 auto results = coll1->search("e", {"title"}, "", {}, {}, {0}, 100, 1, NOT_SET, {true}).get(); // [11 + 10 + 9 + 8] + 7 + 6 + 5 + 4 + 3 + 2 @@ -234,6 +235,16 @@ TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField) } } + // 2 candidates + results = coll1->search("e", {"title"}, "", {}, {}, {0}, 100, 1, NOT_SET, {true}, + 0, spp::sparse_hash_set(), spp::sparse_hash_set(), + 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + false, 2).get(); + + // [11 + 10] + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2 + ASSERT_EQ(21, results["found"].get()); + ASSERT_FALSE(found_end); } diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 3609d5f0..ed757fbc 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -1040,7 +1040,7 @@ TEST_F(CollectionTest, KeywordQueryReturnsResultsBasedOnPerPageParam) { FREQUENCY, {true}, 1000, empty, empty, 10).get(); ASSERT_EQ(3, results["hits"].size()); - ASSERT_EQ(6, results["found"].get()); + ASSERT_EQ(7, results["found"].get()); // cannot fetch more than in-built limit of 250 auto res_op = coll_mul_fields->search("w", query_fields, "", facets, sort_fields, {0}, 251, 1, @@ -1062,13 +1062,13 @@ TEST_F(CollectionTest, KeywordQueryReturnsResultsBasedOnPerPageParam) { FREQUENCY, {true}, 1000, empty, empty, 10).get(); ASSERT_EQ(3, results["hits"].size()); - ASSERT_EQ(6, results["found"].get()); + ASSERT_EQ(7, results["found"].get()); results = coll_mul_fields->search("w", query_fields, "", facets, sort_fields, {0}, 3, 2, FREQUENCY, {true}, 1000, empty, empty, 10).get(); ASSERT_EQ(3, results["hits"].size()); - ASSERT_EQ(6, results["found"].get()); + ASSERT_EQ(7, results["found"].get()); collectionManager.drop_collection("coll_mul_fields"); }