Fix proper prefix expansion.

This commit is contained in:
Kishore Nallan 2022-05-19 16:12:33 +05:30
parent c481920979
commit d935cb4041
5 changed files with 192 additions and 39 deletions

View File

@ -501,8 +501,6 @@ private:
static inline uint32_t next_suggestion2(const std::vector<tok_candidates>& token_candidates_vec,
long long int n,
std::vector<token_t>& query_suggestion,
int syn_orig_num_tokens,
uint32_t& token_bits,
uint64& qhash);
static inline uint32_t next_suggestion(const std::vector<token_candidates> &token_candidates_vec,
@ -680,6 +678,7 @@ public:
enum {COMBINATION_MAX_LIMIT = 10000};
enum {COMBINATION_MIN_LIMIT = 10};
enum {MAX_CANDIDATES_DEFAULT = 4};
// If the number of results found is less than this threshold, Typesense will attempt to drop the tokens
// in the query that have the least individual hits one by one until enough results are found.
@ -939,6 +938,17 @@ public:
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
const std::vector<size_t>& geopoint_indices) const;
void find_across_fields(const std::vector<token_t>& query_tokens,
const size_t num_query_tokens,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
const uint32_t* filter_ids, uint32_t filter_ids_length,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
std::vector<uint32_t>& id_buff) const;
void search_across_fields(const std::vector<token_t>& query_tokens,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,

View File

@ -866,7 +866,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
}
if(!max_candidates) {
max_candidates = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::COMBINATION_MIN_LIMIT;
max_candidates = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::MAX_CANDIDATES_DEFAULT;
}
Option<nlohmann::json> result_op = collection->search(raw_query, search_fields, simple_filter_query, facet_fields,

View File

@ -1167,26 +1167,76 @@ void Index::search_all_candidates(const size_t num_search_fields,
auto product = []( long long a, tok_candidates & b ) { return a*b.candidates.size(); };
long long int N = std::accumulate(token_candidates_vec.begin(), token_candidates_vec.end(), 1LL, product);
// escape hatch to prevent too much looping but subject to being overriden explicitly via `max_candidates`
long long combination_limit = std::max<size_t>(Index::COMBINATION_MIN_LIMIT, max_candidates);
if(token_candidates_vec.size() > 1 && token_candidates_vec.back().candidates.size() > max_candidates) {
std::vector<std::string> trimmed_candidates;
std::vector<uint32_t> temp_ids;
find_across_fields(query_tokens, query_tokens.size()-1, num_typos, prefixes, the_fields, num_search_fields,
filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size,
temp_ids);
for(auto& token_str: token_candidates_vec.back().candidates) {
const bool prefix_search = query_tokens.back().is_prefix_searched;
const uint32_t token_num_typos = query_tokens.back().num_typos;
const bool token_prefix = query_tokens.back().is_prefix_searched;
auto token_c_str = (const unsigned char*) token_str.c_str();
const size_t token_len = token_str.size() + 1;
std::vector<posting_list_t::iterator_t> its;
for(size_t i = 0; i < num_search_fields; i++) {
const std::string& field_name = the_fields[i].name;
const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0];
const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0];
if (token_num_typos > field_num_typos) {
// since the token can come from any field, we still have to respect per-field num_typos
continue;
}
if (token_prefix && !field_prefix) {
// even though this token is an outcome of prefix search, we can't use it for this field, since
// this field has prefix search disabled.
continue;
}
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token_len));
if (!leaf) {
continue;
}
bool found_atleast_one = posting_t::contains_atleast_one(leaf->values, &temp_ids[0],
temp_ids.size());
if(!found_atleast_one) {
continue;
}
trimmed_candidates.push_back(token_str);
if(trimmed_candidates.size() == max_candidates) {
break;
}
}
}
if(trimmed_candidates.empty()) {
return ;
}
token_candidates_vec.back().candidates = std::move(trimmed_candidates);
}
for(long long n = 0; n < N && n < combination_limit; ++n) {
RETURN_CIRCUIT_BREAKER
// every element in `query_suggestion` contains a token and its associated hits
std::vector<token_t> query_suggestion(token_candidates_vec.size());
uint64 qhash;
uint32_t token_bits = 0;
uint32_t total_cost = next_suggestion2(token_candidates_vec, n, query_suggestion, syn_orig_num_tokens,
token_bits, qhash);
if(query_hashes.find(qhash) != query_hashes.end()) {
// skip this query since it has already been processed before
continue;
}
query_hashes.insert(qhash);
uint32_t total_cost = next_suggestion2(token_candidates_vec, n, query_suggestion, qhash);
//LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold;
//LOG(INFO) << "n: " << n;
@ -1198,12 +1248,20 @@ void Index::search_all_candidates(const size_t num_search_fields,
sort_order, field_values, geopoint_indices,
id_buff, all_result_ids, all_result_ids_len);
if(query_hashes.find(qhash) != query_hashes.end()) {
// skip this query since it has already been processed before
continue;
}
query_hashes.insert(qhash);
/*std::stringstream fullq;
for(const auto& qtok : query_suggestion) {
fullq << qtok.value << " ";
}
LOG(INFO) << "query: " << fullq.str() << ", total_cost: " << total_cost << ", num: " << all_result_ids_len;*/
LOG(INFO) << "query: " << fullq.str() << ", total_cost: " << total_cost
<< ", all_result_ids_len: " << all_result_ids_len << ", bufsiz: " << id_buff.size();*/
}
}
@ -1238,9 +1296,10 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
populate_sort_mapping(sort_order, geopoint_indices, sort_fields, field_values);
// escape hatch to prevent too much looping
size_t combination_limit = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::COMBINATION_MIN_LIMIT;
for(long long n=0; n<N && n<combination_limit; ++n) {
for (long long n = 0; n < N && n < combination_limit; ++n) {
RETURN_CIRCUIT_BREAKER
// every element in `query_suggestion` contains a token and its associated hits
@ -2648,7 +2707,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
long long n = 0;
long long int N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product);
const long long combination_limit = std::max<size_t>(Index::COMBINATION_MIN_LIMIT, max_candidates);
const long long combination_limit = exhaustive_search ? Index::COMBINATION_MAX_LIMIT : Index::COMBINATION_MIN_LIMIT;
while(n < N && n < combination_limit) {
RETURN_CIRCUIT_BREAKER
@ -2699,9 +2758,10 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
continue;
}
size_t max_words = (num_search_fields == 1 && prefix_search) ? max_candidates : 100000;
// need less candidates for filtered searches since we already only pick tokens with results
art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len,
costs[token_index], costs[token_index], max_candidates, token_order, prefix_search,
costs[token_index], costs[token_index], max_words, token_order, prefix_search,
filter_ids, filter_ids_length, leaves, unique_tokens);
/*auto timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
@ -2718,10 +2778,6 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
std::string tok(reinterpret_cast<char*>(leaf->key), leaf->key_len - 1);
unique_tokens.emplace(tok);
}
if(unique_tokens.size() > max_candidates) {
break;
}
}
}
@ -2792,6 +2848,93 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
}
}
void Index::find_across_fields(const std::vector<token_t>& query_tokens,
const size_t num_query_tokens,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
const uint32_t* filter_ids, uint32_t filter_ids_length,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
std::vector<uint32_t>& id_buff) const {
// one iterator for each token, each underlying iterator contains results of token across multiple fields
std::vector<or_iterator_t> token_its;
// used to track plists that must be destructed once done
std::vector<posting_list_t*> expanded_plists;
result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length);
// for each token, find the posting lists across all query_by fields
for(size_t ti = 0; ti < num_query_tokens; ti++) {
const bool prefix_search = query_tokens[ti].is_prefix_searched;
const uint32_t token_num_typos = query_tokens[ti].num_typos;
const bool token_prefix = query_tokens[ti].is_prefix_searched;
auto& token_str = query_tokens[ti].value;
auto token_c_str = (const unsigned char*) token_str.c_str();
const size_t token_len = token_str.size() + 1;
std::vector<posting_list_t::iterator_t> its;
for(size_t i = 0; i < num_search_fields; i++) {
const std::string& field_name = the_fields[i].name;
const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0];
const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0];
if(token_num_typos > field_num_typos) {
// since the token can come from any field, we still have to respect per-field num_typos
continue;
}
if(token_prefix && !field_prefix) {
// even though this token is an outcome of prefix search, we can't use it for this field, since
// this field has prefix search disabled.
continue;
}
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token_len));
if(!leaf) {
continue;
}
/*LOG(INFO) << "Token: " << token_str << ", field_name: " << field_name
<< ", num_ids: " << posting_t::num_ids(leaf->values);*/
if(IS_COMPACT_POSTING(leaf->values)) {
auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values);
posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list();
expanded_plists.push_back(full_posting_list);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
} else {
posting_list_t* full_posting_list = (posting_list_t*)(leaf->values);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
}
}
if(its.empty()) {
// this token does not have any match across *any* field: probably a typo
LOG(INFO) << "No matching field found for token: " << token_str;
continue;
}
or_iterator_t token_fields(its);
token_its.push_back(std::move(token_fields));
}
or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector<or_iterator_t>& its) {
// Convert [token -> fields] orientation to [field -> tokens] orientation
//LOG(INFO) << "seq_id: " << seq_id;
id_buff.push_back(seq_id);
});
for(posting_list_t* plist: expanded_plists) {
delete plist;
}
}
void Index::search_across_fields(const std::vector<token_t>& query_tokens,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,
@ -2886,8 +3029,8 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
std::vector<uint32_t> result_ids;
or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector<or_iterator_t>& its) {
// Convert [token -> fields] orientation to [field -> tokens] orientation
//LOG(INFO) << "seq_id: " << seq_id;
// Convert [token -> fields] orientation to [field -> tokens] orientation
std::vector<std::vector<posting_list_t::iterator_t>> field_to_tokens(num_search_fields);
for(size_t ti = 0; ti < its.size(); ti++) {
@ -4256,15 +4399,13 @@ uint64_t Index::get_distinct_id(const std::vector<std::string>& group_by_fields,
inline uint32_t Index::next_suggestion2(const std::vector<tok_candidates>& token_candidates_vec,
long long int n,
std::vector<token_t>& query_suggestion,
int syn_orig_num_tokens,
uint32_t& token_bits,
uint64& qhash) {
uint32_t total_cost = 0;
qhash = 1;
// generate the next combination from `token_leaves` and store it in `query_suggestion`
ldiv_t q { n, 0 };
for(long long i = 0 ; i < (long long) token_candidates_vec.size(); i++) {
for(size_t i = 0 ; i < token_candidates_vec.size(); i++) {
size_t token_size = token_candidates_vec[i].token.value.size();
q = ldiv(q.quot, token_candidates_vec[i].candidates.size());
const auto& candidate = token_candidates_vec[i].candidates[q.rem];
@ -4276,8 +4417,6 @@ inline uint32_t Index::next_suggestion2(const std::vector<tok_candidates>& token
size_t actual_cost = (2 * token_candidates_vec[i].cost) + uint32_t(is_prefix_searched);
total_cost += actual_cost;
token_bits |= 1UL << token_candidates_vec[i].token.position; // sets n-th bit
query_suggestion[i] = token_t(i, candidate, is_prefix_searched, token_size, token_candidates_vec[i].cost);
uint64_t this_hash = StringUtils::hash_wy(query_suggestion[i].value.c_str(), query_suggestion[i].value.size());
@ -4288,13 +4427,6 @@ inline uint32_t Index::next_suggestion2(const std::vector<tok_candidates>& token
LOG(INFO) << ".";*/
}
if(syn_orig_num_tokens != -1) {
token_bits = 0;
for(size_t i = 0; i < size_t(syn_orig_num_tokens); i++) {
token_bits |= 1UL << i;
}
}
return total_cost;
}

View File

@ -221,6 +221,7 @@ TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField)
}
}
// max candidates as default 4
auto results = coll1->search("e", {"title"}, "", {}, {}, {0}, 100, 1, NOT_SET, {true}).get();
// [11 + 10 + 9 + 8] + 7 + 6 + 5 + 4 + 3 + 2
@ -234,6 +235,16 @@ TEST_F(CollectionSortingTest, FrequencyOrderedTokensWithoutDefaultSortingField)
}
}
// 2 candidates
results = coll1->search("e", {"title"}, "", {}, {}, {0}, 100, 1, NOT_SET, {true},
0, spp::sparse_hash_set<std::string>(), spp::sparse_hash_set<std::string>(),
10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
false, 2).get();
// [11 + 10] + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2
ASSERT_EQ(21, results["found"].get<size_t>());
ASSERT_FALSE(found_end);
}

View File

@ -1040,7 +1040,7 @@ TEST_F(CollectionTest, KeywordQueryReturnsResultsBasedOnPerPageParam) {
FREQUENCY, {true}, 1000, empty, empty, 10).get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ(6, results["found"].get<int>());
ASSERT_EQ(7, results["found"].get<int>());
// cannot fetch more than in-built limit of 250
auto res_op = coll_mul_fields->search("w", query_fields, "", facets, sort_fields, {0}, 251, 1,
@ -1062,13 +1062,13 @@ TEST_F(CollectionTest, KeywordQueryReturnsResultsBasedOnPerPageParam) {
FREQUENCY, {true}, 1000, empty, empty, 10).get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ(6, results["found"].get<int>());
ASSERT_EQ(7, results["found"].get<int>());
results = coll_mul_fields->search("w", query_fields, "", facets, sort_fields, {0}, 3, 2,
FREQUENCY, {true}, 1000, empty, empty, 10).get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ(6, results["found"].get<int>());
ASSERT_EQ(7, results["found"].get<int>());
collectionManager.drop_collection("coll_mul_fields");
}