From 77af30ef936089d28b96bdc83bd46a00afe8ebd3 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 29 Sep 2022 15:55:11 +0530 Subject: [PATCH] Handle prefix expansion for the same field. --- include/index.h | 27 ++- src/index.cpp | 269 ++++++++++++++----------- test/collection_specific_more_test.cpp | 38 ++++ 3 files changed, 211 insertions(+), 123 deletions(-) diff --git a/include/index.h b/include/index.h index bc4f909f..e0fdc8b2 100644 --- a/include/index.h +++ b/include/index.h @@ -354,6 +354,13 @@ private: uint32_t& token_bits, uint64& qhash); + static bool is_valid_token_prefix(const std::vector& the_fields, size_t field_id, + const unsigned char* token_c_str, size_t token_len, + const std::vector& num_typos, const std::vector& prefixes, + size_t token_num_typos, bool token_prefix, + const spp::sparse_hash_map& search_index, + const std::vector& prev_token_doc_ids); + void log_leaves(int cost, const std::string &token, const std::vector &leaves) const; void do_facets(std::vector & facets, facet_query_t & facet_query, @@ -797,16 +804,16 @@ public: std::array*, 3>& field_values, const std::vector& geopoint_indices) const; - void find_across_fields(const std::vector& query_tokens, - const size_t num_query_tokens, - const std::vector& num_typos, - const std::vector& prefixes, - const std::vector& the_fields, - const size_t num_search_fields, - const uint32_t* filter_ids, uint32_t filter_ids_length, - const uint32_t* exclude_token_ids, - size_t exclude_token_ids_size, - std::vector& id_buff) const; + void find_across_fields(const token_t& previous_token, + const std::vector& num_typos, + const std::vector& prefixes, + const std::vector& the_fields, + const size_t num_search_fields, + const uint32_t* filter_ids, uint32_t filter_ids_length, + const uint32_t* exclude_token_ids, + size_t exclude_token_ids_size, + std::vector& prev_token_doc_ids, + std::vector& top_prefix_field_ids) const; void search_across_fields(const std::vector& query_tokens, const std::vector& num_typos, diff --git a/src/index.cpp b/src/index.cpp index ab9dbae1..854d5856 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1265,6 +1265,38 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { } } +bool Index::is_valid_token_prefix(const std::vector& the_fields, size_t field_id, + const unsigned char* token_c_str, size_t token_len, + const std::vector& num_typos, const std::vector& prefixes, + size_t token_num_typos, bool token_prefix, + const spp::sparse_hash_map& search_index, + const std::vector& prev_token_doc_ids) { + + const std::string& field_name = the_fields[field_id].name; + const uint32_t field_num_typos = (field_id < num_typos.size()) ? num_typos[field_id] : num_typos[0]; + const bool field_prefix = (field_id < prefixes.size()) ? prefixes[field_id] : prefixes[0]; + + if (token_num_typos > field_num_typos) { + // since the token can come from any field, we still have to respect per-field num_typos + return false; + } + + if (token_prefix && !field_prefix) { + // even though this token is an outcome of prefix search, we can't use it for this field, since + // this field has prefix search disabled. + return false; + } + + art_tree* tree = search_index.at(field_name); + art_leaf* leaf = static_cast(art_search(tree, token_c_str, token_len)); + + if (!leaf) { + return false; + } + + return posting_t::contains_atleast_one(leaf->values, &prev_token_doc_ids[0], prev_token_doc_ids.size()); +} + void Index::search_all_candidates(const size_t num_search_fields, const std::vector& the_fields, const uint32_t* filter_ids, size_t filter_ids_length, @@ -1301,18 +1333,20 @@ void Index::search_all_candidates(const size_t num_search_fields, std::set trimmed_candidates; - if(token_candidates_vec.size() > 1 && token_candidates_vec.back().candidates.size() > max_candidates) { - std::vector temp_ids; + if(token_candidates_vec.size() >= 2 && token_candidates_vec.back().candidates.size() > max_candidates) { + std::vector prev_token_doc_ids; // documents that contain the previous token across fields + std::vector top_prefix_field_ids; // fields which contained the token the most across documents - find_across_fields(query_tokens, query_tokens.size()-1, num_typos, prefixes, the_fields, num_search_fields, + find_across_fields(query_tokens[query_tokens.size()-2], num_typos, prefixes, the_fields, num_search_fields, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size, - temp_ids); + prev_token_doc_ids, top_prefix_field_ids); - //LOG(INFO) << "temp_ids found: " << temp_ids.size(); + //LOG(INFO) << "prev_token_doc_ids found: " << prev_token_doc_ids.size(); + + std::unordered_set processed_field_ids; for(auto& token_str: token_candidates_vec.back().candidates) { //LOG(INFO) << "Prefix token: " << token_str; - const bool prefix_search = query_tokens.back().is_prefix_searched; const uint32_t token_num_typos = query_tokens.back().num_typos; const bool token_prefix = (token_str.size() > token_candidates_vec.back().token.value.size()); @@ -1321,46 +1355,51 @@ void Index::search_all_candidates(const size_t num_search_fields, const size_t token_len = token_str.size() + 1; for(size_t i = 0; i < num_search_fields; i++) { - const std::string& field_name = the_fields[i].name; - const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0]; - const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0]; - - if (token_num_typos > field_num_typos) { - // since the token can come from any field, we still have to respect per-field num_typos - continue; - } - - if (token_prefix && !field_prefix) { - // even though this token is an outcome of prefix search, we can't use it for this field, since - // this field has prefix search disabled. - continue; - } - - art_tree* tree = search_index.at(field_name); - art_leaf* leaf = static_cast(art_search(tree, token_c_str, token_len)); - - if (!leaf) { - continue; - } - - bool found_atleast_one = posting_t::contains_atleast_one(leaf->values, &temp_ids[0], - temp_ids.size()); - if(!found_atleast_one) { + if(!is_valid_token_prefix(the_fields, i, token_c_str, token_len, num_typos, prefixes, + token_num_typos, token_prefix, search_index, prev_token_doc_ids)) { continue; } trimmed_candidates.insert(token_str); - //LOG(INFO) << "Pushing back found token_str: " << token_str; - /*LOG(INFO) << "max_candidates: " << max_candidates << ", trimmed_candidates.size(): " - << trimmed_candidates.size();*/ - + processed_field_ids.insert(i); if(trimmed_candidates.size() == max_candidates) { - goto outer_loop; + goto outer1; } } } - outer_loop: + outer1: + + for(auto& token_str: token_candidates_vec.back().candidates) { + //LOG(INFO) << "Prefix token: " << token_str; + const bool prefix_search = query_tokens.back().is_prefix_searched; + const uint32_t token_num_typos = query_tokens.back().num_typos; + const bool token_prefix = (token_str.size() > token_candidates_vec.back().token.value.size()); + auto token_c_str = (const unsigned char*) token_str.c_str(); + const size_t token_len = token_str.size() + 1; + + size_t top_fields_processed = 0; + for(auto field_id: top_prefix_field_ids) { + if(processed_field_ids.count(field_id) != 0) { + continue; + } + + if(!is_valid_token_prefix(the_fields, field_id, token_c_str, token_len, num_typos, prefixes, + token_num_typos, token_prefix, search_index, prev_token_doc_ids)) { + continue; + } + + top_fields_processed++; + trimmed_candidates.insert(token_str); + + if(top_fields_processed == 3) { + // limit to only 3 unprocessed top fields + goto outer2; + } + } + } + + outer2: if(trimmed_candidates.empty()) { return ; @@ -2987,7 +3026,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, // NOTE: `query_tokens` preserve original tokens, while `search_tokens` could be a result of dropped tokens // To prevent us from doing ART search repeatedly as we iterate through possible corrections - spp::sparse_hash_map> token_cost_cache; + spp::sparse_hash_map> token_cost_cache; std::vector> token_to_costs; @@ -3041,10 +3080,10 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const std::string& token = query_tokens[token_index].value; const std::string token_cost_hash = token + std::to_string(costs[token_index]); - std::vector leaves; + std::vector leaf_tokens; if(token_cost_cache.count(token_cost_hash) != 0) { - leaves = token_cost_cache[token_cost_hash]; + leaf_tokens = token_cost_cache[token_cost_hash]; } else { //auto begin = std::chrono::high_resolution_clock::now(); @@ -3071,41 +3110,38 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, continue; } - size_t max_words = 100000; + //LOG(INFO) << "Searching for field: " << the_field.name << ", found token:" << token; + + std::vector field_leaves; + int max_words = 100000; art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_words, token_order, prefix_search, - filter_ids, filter_ids_length, leaves, unique_tokens); + filter_ids, filter_ids_length, field_leaves, unique_tokens); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); LOG(INFO) << "Time taken for fuzzy search: " << timeMillis << "ms";*/ - if(leaves.empty()) { + if(field_leaves.empty()) { // look at the next field continue; } - token_cost_cache.emplace(token_cost_hash, leaves); - for(auto leaf: leaves) { + for(size_t i = 0; i < field_leaves.size(); i++) { + auto leaf = field_leaves[i]; std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); - unique_tokens.emplace(tok); + if(unique_tokens.count(tok) == 0) { + unique_tokens.emplace(tok); + leaf_tokens.push_back(tok); + } } + + token_cost_cache.emplace(token_cost_hash, leaf_tokens); } } - if(!leaves.empty()) { + if(!leaf_tokens.empty()) { //log_leaves(costs[token_index], token, leaves); - std::vector leaf_tokens; - std::unordered_set leaf_token_set; - - for(auto leaf: leaves) { - std::string ltok(reinterpret_cast(leaf->key), leaf->key_len - 1); - if(leaf_token_set.count(ltok) == 0) { - leaf_tokens.push_back(ltok); - leaf_token_set.insert(ltok); - } - } - token_candidates_vec.push_back(tok_candidates{query_tokens[token_index], costs[token_index], query_tokens[token_index].is_prefix_searched, leaf_tokens}); } else { @@ -3167,15 +3203,15 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } } -void Index::find_across_fields(const std::vector& query_tokens, - const size_t num_query_tokens, +void Index::find_across_fields(const token_t& previous_token, const std::vector& num_typos, const std::vector& prefixes, const std::vector& the_fields, const size_t num_search_fields, const uint32_t* filter_ids, uint32_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - std::vector& id_buff) const { + std::vector& prev_token_doc_ids, + std::vector& top_prefix_field_ids) const { // one iterator for each token, each underlying iterator contains results of token across multiple fields std::vector token_its; @@ -3185,68 +3221,75 @@ void Index::find_across_fields(const std::vector& query_tokens, result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length); - // for each token, find the posting lists across all query_by fields - for(size_t ti = 0; ti < num_query_tokens; ti++) { - const bool prefix_search = query_tokens[ti].is_prefix_searched; - const uint32_t token_num_typos = query_tokens[ti].num_typos; - const bool token_prefix = query_tokens[ti].is_prefix_searched; + const bool prefix_search = previous_token.is_prefix_searched; + const uint32_t token_num_typos = previous_token.num_typos; + const bool token_prefix = previous_token.is_prefix_searched; - auto& token_str = query_tokens[ti].value; - auto token_c_str = (const unsigned char*) token_str.c_str(); - const size_t token_len = token_str.size() + 1; - std::vector its; + auto& token_str = previous_token.value; + auto token_c_str = (const unsigned char*) token_str.c_str(); + const size_t token_len = token_str.size() + 1; + std::vector its; - for(size_t i = 0; i < num_search_fields; i++) { - const std::string& field_name = the_fields[i].name; - const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0]; - const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0]; + std::vector> field_id_doc_counts; - if(token_num_typos > field_num_typos) { - // since the token can come from any field, we still have to respect per-field num_typos - continue; - } + for(size_t i = 0; i < num_search_fields; i++) { + const std::string& field_name = the_fields[i].name; + const uint32_t field_num_typos = (i < num_typos.size()) ? num_typos[i] : num_typos[0]; + const bool field_prefix = (i < prefixes.size()) ? prefixes[i] : prefixes[0]; - if(token_prefix && !field_prefix) { - // even though this token is an outcome of prefix search, we can't use it for this field, since - // this field has prefix search disabled. - continue; - } - - art_tree* tree = search_index.at(field_name); - art_leaf* leaf = static_cast(art_search(tree, token_c_str, token_len)); - - if(!leaf) { - continue; - } - - /*LOG(INFO) << "Token: " << token_str << ", field_name: " << field_name - << ", num_ids: " << posting_t::num_ids(leaf->values);*/ - - if(IS_COMPACT_POSTING(leaf->values)) { - auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values); - posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list(); - expanded_plists.push_back(full_posting_list); - its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied - } else { - posting_list_t* full_posting_list = (posting_list_t*)(leaf->values); - its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied - } - } - - if(its.empty()) { - // this token does not have any match across *any* field: probably a typo - LOG(INFO) << "No matching field found for token: " << token_str; + if(token_num_typos > field_num_typos) { + // since the token can come from any field, we still have to respect per-field num_typos continue; } - or_iterator_t token_fields(its); - token_its.push_back(std::move(token_fields)); + if(token_prefix && !field_prefix) { + // even though this token is an outcome of prefix search, we can't use it for this field, since + // this field has prefix search disabled. + continue; + } + + art_tree* tree = search_index.at(field_name); + art_leaf* leaf = static_cast(art_search(tree, token_c_str, token_len)); + + if(!leaf) { + continue; + } + + /*LOG(INFO) << "Token: " << token_str << ", field_name: " << field_name + << ", num_ids: " << posting_t::num_ids(leaf->values);*/ + + if(IS_COMPACT_POSTING(leaf->values)) { + auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values); + posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list(); + expanded_plists.push_back(full_posting_list); + its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied + } else { + posting_list_t* full_posting_list = (posting_list_t*)(leaf->values); + its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied + } + + field_id_doc_counts.emplace_back(i, posting_t::num_ids(leaf->values)); } + if(its.empty()) { + // this token does not have any match across *any* field: probably a typo + LOG(INFO) << "No matching field found for token: " << token_str; + return; + } + + std::sort(field_id_doc_counts.begin(), field_id_doc_counts.end(), [](const auto& p1, const auto& p2) { + return p1.second > p2.second; + }); + + for(auto& field_id_doc_count: field_id_doc_counts) { + top_prefix_field_ids.push_back(field_id_doc_count.first); + } + + or_iterator_t token_fields(its); + token_its.push_back(std::move(token_fields)); + or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector& its) { - // Convert [token -> fields] orientation to [field -> tokens] orientation - //LOG(INFO) << "seq_id: " << seq_id; - id_buff.push_back(seq_id); + prev_token_doc_ids.push_back(seq_id); }); for(posting_list_t* plist: expanded_plists) { diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index e233b632..4ed3caf6 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -114,6 +114,44 @@ TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnSingleField) { ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); } +TEST_F(CollectionSpecificMoreTest, PrefixExpansionOnMultiField) { + Collection *coll1; + std::vector fields = {field("location", field_types::STRING, false), + field("name", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector names = { + "John Stewart", "John Smith", "John Scott", "John Stone", "John Romero", "John Oliver", "John Adams" + }; + + std::vector locations = { + "Switzerland", "Seoul", "Sydney", "Surat", "Stockholm", "Salem", "Sevilla" + }; + + for(size_t i = 0; i < names.size(); i++) { + nlohmann::json doc; + doc["location"] = locations[i]; + doc["name"] = names[i]; + doc["points"] = i; + coll1->add(doc.dump()); + } + + auto results = coll1->search("john s", {"location", "name"}, "", {}, {}, {0}, 100, 1, MAX_SCORE, {true}).get(); + + // tokens are ordered by max_score, but prefix continuation on the same field should be prioritized + ASSERT_EQ(7, results["hits"].size()); + ASSERT_EQ("3", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("2", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][2]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][3]["document"]["id"].get()); + ASSERT_EQ("6", results["hits"][4]["document"]["id"].get()); +} + TEST_F(CollectionSpecificMoreTest, ArrayElementMatchShouldBeMoreImportantThanTotalMatch) { std::vector fields = {field("title", field_types::STRING, false), field("author", field_types::STRING, false),