Handle prefix expansion for the same field.

This commit is contained in:
Kishore Nallan 2022-09-29 15:55:11 +05:30
parent e3a0eb60b0
commit 77af30ef93
3 changed files with 211 additions and 123 deletions

View File

@ -354,6 +354,13 @@ private:
uint32_t& token_bits,
uint64& qhash);
static bool is_valid_token_prefix(const std::vector<search_field_t>& the_fields, size_t field_id,
const unsigned char* token_c_str, size_t token_len,
const std::vector<uint32_t>& num_typos, const std::vector<bool>& prefixes,
size_t token_num_typos, bool token_prefix,
const spp::sparse_hash_map<std::string, art_tree*>& search_index,
const std::vector<uint32_t>& prev_token_doc_ids);
void log_leaves(int cost, const std::string &token, const std::vector<art_leaf *> &leaves) const;
void do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
@ -797,16 +804,16 @@ public:
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
const std::vector<size_t>& geopoint_indices) const;
void find_across_fields(const std::vector<token_t>& query_tokens,
const size_t num_query_tokens,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
const uint32_t* filter_ids, uint32_t filter_ids_length,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
std::vector<uint32_t>& id_buff) const;
void find_across_fields(const token_t& previous_token,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
const uint32_t* filter_ids, uint32_t filter_ids_length,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
std::vector<uint32_t>& prev_token_doc_ids,
std::vector<size_t>& top_prefix_field_ids) const;
void search_across_fields(const std::vector<token_t>& query_tokens,
const std::vector<uint32_t>& num_typos,

View File

@ -1265,6 +1265,38 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) {
}
}
bool Index::is_valid_token_prefix(const std::vector<search_field_t>& the_fields, size_t field_id,
const unsigned char* token_c_str, size_t token_len,
const std::vector<uint32_t>& num_typos, const std::vector<bool>& prefixes,
size_t token_num_typos, bool token_prefix,
const spp::sparse_hash_map<std::string, art_tree*>& search_index,
const std::vector<uint32_t>& prev_token_doc_ids) {
const std::string& field_name = the_fields[field_id].name;
const uint32_t field_num_typos = (field_id < num_typos.size()) ? num_typos[field_id] : num_typos[0];
const bool field_prefix = (field_id < prefixes.size()) ? prefixes[field_id] : prefixes[0];
if (token_num_typos > field_num_typos) {
// since the token can come from any field, we still have to respect per-field num_typos
return false;
}
if (token_prefix && !field_prefix) {
// even though this token is an outcome of prefix search, we can't use it for this field, since
// this field has prefix search disabled.
return false;
}
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token_len));
if (!leaf) {
return false;
}
return posting_t::contains_atleast_one(leaf->values, &prev_token_doc_ids[0], prev_token_doc_ids.size());
}
void Index::search_all_candidates(const size_t num_search_fields,
const std::vector<search_field_t>& the_fields,
const uint32_t* filter_ids, size_t filter_ids_length,
@ -1301,18 +1333,20 @@ void Index::search_all_candidates(const size_t num_search_fields,
std::set<std::string> trimmed_candidates;
if(token_candidates_vec.size() > 1 && token_candidates_vec.back().candidates.size() > max_candidates) {
std::vector<uint32_t> temp_ids;
if(token_candidates_vec.size() >= 2 && token_candidates_vec.back().candidates.size() > max_candidates) {
std::vector<uint32_t> prev_token_doc_ids; // documents that contain the previous token across fields
std::vector<size_t> top_prefix_field_ids; // fields which contained the token the most across documents
find_across_fields(query_tokens, query_tokens.size()-1, num_typos, prefixes, the_fields, num_search_fields,
find_across_fields(query_tokens[query_tokens.size()-2], num_typos, prefixes, the_fields, num_search_fields,
filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size,
temp_ids);
prev_token_doc_ids, top_prefix_field_ids);
//LOG(INFO) << "temp_ids found: " << temp_ids.size();
//LOG(INFO) << "prev_token_doc_ids found: " << prev_token_doc_ids.size();
std::unordered_set<size_t> processed_field_ids;
for(auto& token_str: token_candidates_vec.back().candidates) {
//LOG(INFO) << "Prefix token: " << token_str;
const bool prefix_search = query_tokens.back().is_prefix_searched;
const uint32_t token_num_typos = query_tokens.back().num_typos;
const bool token_prefix = (token_str.size() > token_candidates_vec.back().token.value.size());
@ -1321,46 +1355,51 @@ void Index::search_all_candidates(const size_t num_search_fields,
const size_t token_len = token_str.size() + 1;
for(size_t i = 0; i < num_search_fields; i++) {
const std::string& field_name = the_fields[i].name;
const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0];
const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0];
if (token_num_typos > field_num_typos) {
// since the token can come from any field, we still have to respect per-field num_typos
continue;
}
if (token_prefix && !field_prefix) {
// even though this token is an outcome of prefix search, we can't use it for this field, since
// this field has prefix search disabled.
continue;
}
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token_len));
if (!leaf) {
continue;
}
bool found_atleast_one = posting_t::contains_atleast_one(leaf->values, &temp_ids[0],
temp_ids.size());
if(!found_atleast_one) {
if(!is_valid_token_prefix(the_fields, i, token_c_str, token_len, num_typos, prefixes,
token_num_typos, token_prefix, search_index, prev_token_doc_ids)) {
continue;
}
trimmed_candidates.insert(token_str);
//LOG(INFO) << "Pushing back found token_str: " << token_str;
/*LOG(INFO) << "max_candidates: " << max_candidates << ", trimmed_candidates.size(): "
<< trimmed_candidates.size();*/
processed_field_ids.insert(i);
if(trimmed_candidates.size() == max_candidates) {
goto outer_loop;
goto outer1;
}
}
}
outer_loop:
outer1:
for(auto& token_str: token_candidates_vec.back().candidates) {
//LOG(INFO) << "Prefix token: " << token_str;
const bool prefix_search = query_tokens.back().is_prefix_searched;
const uint32_t token_num_typos = query_tokens.back().num_typos;
const bool token_prefix = (token_str.size() > token_candidates_vec.back().token.value.size());
auto token_c_str = (const unsigned char*) token_str.c_str();
const size_t token_len = token_str.size() + 1;
size_t top_fields_processed = 0;
for(auto field_id: top_prefix_field_ids) {
if(processed_field_ids.count(field_id) != 0) {
continue;
}
if(!is_valid_token_prefix(the_fields, field_id, token_c_str, token_len, num_typos, prefixes,
token_num_typos, token_prefix, search_index, prev_token_doc_ids)) {
continue;
}
top_fields_processed++;
trimmed_candidates.insert(token_str);
if(top_fields_processed == 3) {
// limit to only 3 unprocessed top fields
goto outer2;
}
}
}
outer2:
if(trimmed_candidates.empty()) {
return ;
@ -2987,7 +3026,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
// NOTE: `query_tokens` preserve original tokens, while `search_tokens` could be a result of dropped tokens
// To prevent us from doing ART search repeatedly as we iterate through possible corrections
spp::sparse_hash_map<std::string, std::vector<art_leaf*>> token_cost_cache;
spp::sparse_hash_map<std::string, std::vector<std::string>> token_cost_cache;
std::vector<std::vector<int>> token_to_costs;
@ -3041,10 +3080,10 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const std::string& token = query_tokens[token_index].value;
const std::string token_cost_hash = token + std::to_string(costs[token_index]);
std::vector<art_leaf*> leaves;
std::vector<std::string> leaf_tokens;
if(token_cost_cache.count(token_cost_hash) != 0) {
leaves = token_cost_cache[token_cost_hash];
leaf_tokens = token_cost_cache[token_cost_hash];
} else {
//auto begin = std::chrono::high_resolution_clock::now();
@ -3071,41 +3110,38 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
continue;
}
size_t max_words = 100000;
//LOG(INFO) << "Searching for field: " << the_field.name << ", found token:" << token;
std::vector<art_leaf*> field_leaves;
int max_words = 100000;
art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len,
costs[token_index], costs[token_index], max_words, token_order, prefix_search,
filter_ids, filter_ids_length, leaves, unique_tokens);
filter_ids, filter_ids_length, field_leaves, unique_tokens);
/*auto timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin).count();
LOG(INFO) << "Time taken for fuzzy search: " << timeMillis << "ms";*/
if(leaves.empty()) {
if(field_leaves.empty()) {
// look at the next field
continue;
}
token_cost_cache.emplace(token_cost_hash, leaves);
for(auto leaf: leaves) {
for(size_t i = 0; i < field_leaves.size(); i++) {
auto leaf = field_leaves[i];
std::string tok(reinterpret_cast<char*>(leaf->key), leaf->key_len - 1);
unique_tokens.emplace(tok);
if(unique_tokens.count(tok) == 0) {
unique_tokens.emplace(tok);
leaf_tokens.push_back(tok);
}
}
token_cost_cache.emplace(token_cost_hash, leaf_tokens);
}
}
if(!leaves.empty()) {
if(!leaf_tokens.empty()) {
//log_leaves(costs[token_index], token, leaves);
std::vector<std::string> leaf_tokens;
std::unordered_set<std::string> leaf_token_set;
for(auto leaf: leaves) {
std::string ltok(reinterpret_cast<char*>(leaf->key), leaf->key_len - 1);
if(leaf_token_set.count(ltok) == 0) {
leaf_tokens.push_back(ltok);
leaf_token_set.insert(ltok);
}
}
token_candidates_vec.push_back(tok_candidates{query_tokens[token_index], costs[token_index],
query_tokens[token_index].is_prefix_searched, leaf_tokens});
} else {
@ -3167,15 +3203,15 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
}
}
void Index::find_across_fields(const std::vector<token_t>& query_tokens,
const size_t num_query_tokens,
void Index::find_across_fields(const token_t& previous_token,
const std::vector<uint32_t>& num_typos,
const std::vector<bool>& prefixes,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
const uint32_t* filter_ids, uint32_t filter_ids_length,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
std::vector<uint32_t>& id_buff) const {
std::vector<uint32_t>& prev_token_doc_ids,
std::vector<size_t>& top_prefix_field_ids) const {
// one iterator for each token, each underlying iterator contains results of token across multiple fields
std::vector<or_iterator_t> token_its;
@ -3185,68 +3221,75 @@ void Index::find_across_fields(const std::vector<token_t>& query_tokens,
result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length);
// for each token, find the posting lists across all query_by fields
for(size_t ti = 0; ti < num_query_tokens; ti++) {
const bool prefix_search = query_tokens[ti].is_prefix_searched;
const uint32_t token_num_typos = query_tokens[ti].num_typos;
const bool token_prefix = query_tokens[ti].is_prefix_searched;
const bool prefix_search = previous_token.is_prefix_searched;
const uint32_t token_num_typos = previous_token.num_typos;
const bool token_prefix = previous_token.is_prefix_searched;
auto& token_str = query_tokens[ti].value;
auto token_c_str = (const unsigned char*) token_str.c_str();
const size_t token_len = token_str.size() + 1;
std::vector<posting_list_t::iterator_t> its;
auto& token_str = previous_token.value;
auto token_c_str = (const unsigned char*) token_str.c_str();
const size_t token_len = token_str.size() + 1;
std::vector<posting_list_t::iterator_t> its;
for(size_t i = 0; i < num_search_fields; i++) {
const std::string& field_name = the_fields[i].name;
const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0];
const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0];
std::vector<std::pair<size_t, size_t>> field_id_doc_counts;
if(token_num_typos > field_num_typos) {
// since the token can come from any field, we still have to respect per-field num_typos
continue;
}
for(size_t i = 0; i < num_search_fields; i++) {
const std::string& field_name = the_fields[i].name;
const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0];
const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0];
if(token_prefix && !field_prefix) {
// even though this token is an outcome of prefix search, we can't use it for this field, since
// this field has prefix search disabled.
continue;
}
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token_len));
if(!leaf) {
continue;
}
/*LOG(INFO) << "Token: " << token_str << ", field_name: " << field_name
<< ", num_ids: " << posting_t::num_ids(leaf->values);*/
if(IS_COMPACT_POSTING(leaf->values)) {
auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values);
posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list();
expanded_plists.push_back(full_posting_list);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
} else {
posting_list_t* full_posting_list = (posting_list_t*)(leaf->values);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
}
}
if(its.empty()) {
// this token does not have any match across *any* field: probably a typo
LOG(INFO) << "No matching field found for token: " << token_str;
if(token_num_typos > field_num_typos) {
// since the token can come from any field, we still have to respect per-field num_typos
continue;
}
or_iterator_t token_fields(its);
token_its.push_back(std::move(token_fields));
if(token_prefix && !field_prefix) {
// even though this token is an outcome of prefix search, we can't use it for this field, since
// this field has prefix search disabled.
continue;
}
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token_len));
if(!leaf) {
continue;
}
/*LOG(INFO) << "Token: " << token_str << ", field_name: " << field_name
<< ", num_ids: " << posting_t::num_ids(leaf->values);*/
if(IS_COMPACT_POSTING(leaf->values)) {
auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values);
posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list();
expanded_plists.push_back(full_posting_list);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
} else {
posting_list_t* full_posting_list = (posting_list_t*)(leaf->values);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
}
field_id_doc_counts.emplace_back(i, posting_t::num_ids(leaf->values));
}
if(its.empty()) {
// this token does not have any match across *any* field: probably a typo
LOG(INFO) << "No matching field found for token: " << token_str;
return;
}
std::sort(field_id_doc_counts.begin(), field_id_doc_counts.end(), [](const auto& p1, const auto& p2) {
return p1.second > p2.second;
});
for(auto& field_id_doc_count: field_id_doc_counts) {
top_prefix_field_ids.push_back(field_id_doc_count.first);
}
or_iterator_t token_fields(its);
token_its.push_back(std::move(token_fields));
or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector<or_iterator_t>& its) {
// Convert [token -> fields] orientation to [field -> tokens] orientation
//LOG(INFO) << "seq_id: " << seq_id;
id_buff.push_back(seq_id);
prev_token_doc_ids.push_back(seq_id);
});
for(posting_list_t* plist: expanded_plists) {

View File

@ -114,6 +114,44 @@ TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnSingleField) {
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
}
TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnMultiField) {
Collection *coll1;
std::vector<field> fields = {field("location", field_types::STRING, false),
field("name", field_types::STRING, false),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
std::vector<std::string> names = {
"John Stewart", "John Smith", "John Scott", "John Stone", "John Romero", "John Oliver", "John Adams"
};
std::vector<std::string> locations = {
"Switzerland", "Seoul", "Sydney", "Surat", "Stockholm", "Salem", "Sevilla"
};
for(size_t i = 0; i < names.size(); i++) {
nlohmann::json doc;
doc["location"] = locations[i];
doc["name"] = names[i];
doc["points"] = i;
coll1->add(doc.dump());
}
auto results = coll1->search("john s", {"location", "name"}, "", {}, {}, {0}, 100, 1, MAX_SCORE, {true}).get();
// tokens are ordered by max_score, but prefix continuation on the same field should be prioritized
ASSERT_EQ(7, results["hits"].size());
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][3]["document"]["id"].get<std::string>());
ASSERT_EQ("6", results["hits"][4]["document"]["id"].get<std::string>());
}
TEST_F(CollectionSpecificMoreTest, ArrayElementMatchShouldBeMoreImportantThanTotalMatch) {
std::vector<field> fields = {field("title", field_types::STRING, false),
field("author", field_types::STRING, false),