Consider dropped token in text match calculation.

This commit is contained in:
Kishore Nallan 2023-04-10 12:19:21 +05:30
parent 0a0a2ed272
commit 48119f76eb
6 changed files with 130 additions and 26 deletions

View File

@ -418,6 +418,7 @@ private:
std::vector<tok_candidates>& token_candidates_vec,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
const std::vector<token_t>& dropped_tokens,
Topster* topster,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
@ -835,8 +836,8 @@ public:
void fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const std::vector<token_t>& query_tokens,
const std::vector<token_t>& dropped_tokens,
const text_match_type_t match_type,
const bool dropped_tokens,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
const uint32_t* filter_ids, size_t filter_ids_length,
@ -884,6 +885,7 @@ public:
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
const std::vector<token_t>& dropped_tokens,
const size_t group_limit,
const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,

View File

@ -1269,6 +1269,7 @@ void Index::search_all_candidates(const size_t num_search_fields,
std::vector<tok_candidates>& token_candidates_vec,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
const std::vector<token_t>& dropped_tokens,
Topster* topster,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
@ -1328,7 +1329,7 @@ void Index::search_all_candidates(const size_t num_search_fields,
search_across_fields(query_suggestion, num_typos, prefixes, the_fields, num_search_fields, match_type,
sort_fields, topster,groups_processed,
searched_queries, qtoken_set, group_limit, group_by_fields,
searched_queries, qtoken_set, dropped_tokens, group_limit, group_by_fields,
prioritize_exact_match, prioritize_token_position,
filter_ids, filter_ids_length, total_cost, syn_orig_num_tokens,
exclude_token_ids, exclude_token_ids_size,
@ -2960,7 +2961,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
}
fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, match_type, false, excluded_result_ids,
fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, {}, match_type, excluded_result_ids,
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted,
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
@ -2997,7 +2998,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
space_resolved_queries[0][j].size(), 0);
}
fuzzy_search_fields(the_fields, resolved_tokens, match_type, false, excluded_result_ids,
fuzzy_search_fields(the_fields, resolved_tokens, {}, match_type, excluded_result_ids,
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted,
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
@ -3028,6 +3029,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) {
// When atleast two tokens from the query are available we can drop one
std::vector<token_t> truncated_tokens;
std::vector<token_t> dropped_tokens;
if(orig_tokens.size() > 1 && num_tokens_dropped < 2*(orig_tokens.size()-1)) {
bool prefix_search = false;
@ -3035,15 +3037,23 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
if (num_tokens_dropped < orig_tokens.size() - 1) {
// drop from right
size_t truncated_len = orig_tokens.size() - num_tokens_dropped - 1;
for (size_t i = 0; i < truncated_len; i++) {
truncated_tokens.emplace_back(orig_tokens[i]);
for (size_t i = 0; i < orig_tokens.size(); i++) {
if(i < truncated_len) {
truncated_tokens.emplace_back(orig_tokens[i]);
} else {
dropped_tokens.emplace_back(orig_tokens[i]);
}
}
} else {
// drop from left
prefix_search = true;
size_t start_index = (num_tokens_dropped + 1) - orig_tokens.size() + 1;
for(size_t i = start_index; i < orig_tokens.size(); i++) {
truncated_tokens.emplace_back(orig_tokens[i]);
for(size_t i = 0; i < orig_tokens.size(); i++) {
if(i >= start_index) {
truncated_tokens.emplace_back(orig_tokens[i]);
} else {
dropped_tokens.emplace_back(orig_tokens[i]);
}
}
}
@ -3054,11 +3064,14 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
drop_token_prefixes.push_back(p && prefix_search);
}
fuzzy_search_fields(the_fields, truncated_tokens, match_type, true, excluded_result_ids,
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted,
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold,
fuzzy_search_fields(the_fields, truncated_tokens, dropped_tokens, match_type,
excluded_result_ids, excluded_result_ids_size,
filter_result.docs, filter_result.count,
curated_ids_sorted, sort_fields_std, num_typos, searched_queries,
qtoken_set, topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields,
prioritize_exact_match, prioritize_token_position, query_hashes,
token_order, prefixes, typo_tokens_threshold,
exhaustive_search, max_candidates, min_len_1typo,
min_len_2typo, -1, sort_order, field_values, geopoint_indices);
@ -3391,8 +3404,8 @@ void Index::process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>
void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const std::vector<token_t>& query_tokens,
const std::vector<token_t>& dropped_tokens,
const text_match_type_t match_type,
const bool dropped_tokens,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
const uint32_t* filter_ids, size_t filter_ids_length,
@ -3488,7 +3501,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
// we will first attempt to match the prefix with the most "popular" fields of the preceding token.
// Tokens matched from popular fields will also be searched across other query fields.
// Only when we find *no results* for such an expansion, we will attempt cross field matching.
bool last_token = query_tokens.size() > 1 && !dropped_tokens &&
bool last_token = query_tokens.size() > 1 && dropped_tokens.empty() &&
(token_index == (query_tokens.size() - 1));
std::vector<size_t> query_field_ids(num_search_fields);
@ -3646,7 +3659,8 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
std::vector<uint32_t> id_buff;
search_all_candidates(num_search_fields, match_type, the_fields, filter_ids, filter_ids_length,
exclude_token_ids, exclude_token_ids_size,
sort_fields, token_candidates_vec, searched_queries, qtoken_set, topster,
sort_fields, token_candidates_vec, searched_queries, qtoken_set,
dropped_tokens, topster,
groups_processed, all_result_ids, all_result_ids_len,
typo_tokens_threshold, group_limit, group_by_fields, query_tokens,
num_typos, prefixes, prioritize_exact_match, prioritize_token_position,
@ -3800,6 +3814,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
const std::vector<token_t>& dropped_tokens,
const size_t group_limit,
const std::vector<std::string>& group_by_fields,
const bool prioritize_exact_match,
@ -3815,6 +3830,49 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
std::vector<art_leaf*> query_suggestion;
// one or_iterator for each token (across multiple fields)
std::vector<or_iterator_t> dropped_token_its;
// used to track plists that must be destructed once done
std::vector<posting_list_t*> expanded_dropped_plists;
for(auto& dropped_token: dropped_tokens) {
auto& token = dropped_token.value;
auto token_c_str = (const unsigned char*) token.c_str();
// convert token from each field into an or_iterator
std::vector<posting_list_t::iterator_t> its;
for(size_t i = 0; i < the_fields.size(); i++) {
const std::string& field_name = the_fields[i].name;
art_tree* tree = search_index.at(field_name);
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token.size()+1));
if(!leaf) {
continue;
}
LOG(INFO) << "Token: " << token << ", field_name: " << field_name
<< ", num_ids: " << posting_t::num_ids(leaf->values);
if(IS_COMPACT_POSTING(leaf->values)) {
auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values);
posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list();
expanded_dropped_plists.push_back(full_posting_list);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
} else {
posting_list_t* full_posting_list = (posting_list_t*)(leaf->values);
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
}
}
or_iterator_t token_fields(its);
dropped_token_its.push_back(std::move(token_fields));
}
// one iterator for each token, each underlying iterator contains results of token across multiple fields
std::vector<or_iterator_t> token_its;
@ -3957,6 +4015,14 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
best_field_match_score, scores, match_score_index);
size_t query_len = query_tokens.size();
// check if seq_id exists in any of the dropped_token iters and increment matching fields accordingly
for(auto& dropped_token_it: dropped_token_its) {
if(dropped_token_it.skip_to(seq_id) && dropped_token_it.id() == seq_id) {
query_len++;
}
}
if(syn_orig_num_tokens != -1) {
query_len = syn_orig_num_tokens;
}
@ -4034,6 +4100,10 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
for(posting_list_t* plist: expanded_plists) {
delete plist;
}
for(posting_list_t* plist: expanded_dropped_plists) {
delete plist;
}
}
void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
@ -4460,7 +4530,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
for (const auto& syn_tokens : q_pos_synonyms) {
query_hashes.clear();
fuzzy_search_fields(the_fields, syn_tokens, match_type, false, exclude_token_ids,
fuzzy_search_fields(the_fields, syn_tokens, {}, match_type, exclude_token_ids,
exclude_token_ids_size, filter_ids, filter_ids_length, curated_ids_sorted,
sort_fields_std, {0}, searched_queries, qtoken_set, actual_topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,

View File

@ -331,8 +331,8 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) {
ASSERT_EQ(2, results["grouped_hits"][2]["found"].get<int32_t>());
ASSERT_STREQ("country", results["grouped_hits"][2]["group_key"][0].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][2]["hits"].size());
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}

View File

@ -1813,11 +1813,10 @@ TEST_F(CollectionOverrideTest, DynamicFilteringMultiplePlaceholders) {
auto results = coll1->search("Nike Air Jordan light yellow shoes", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10, 1, FREQUENCY, {false}, 10).get();
// not happy with this order (0,2,1 is better)
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][2]["document"]["id"].get<std::string>());
// query with tokens at the start that preceding the placeholders in the rule
results = coll1->search("New Nike Air Jordan yellow shoes", {"name", "category", "brand"}, "",
@ -1997,9 +1996,9 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithNumericalFilter) {
ASSERT_EQ(4, results["hits"].size());
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][3]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][3]["document"]["id"].get<std::string>());
results = coll1->search("adidas", {"name", "category", "brand"}, "",
{}, sort_fields, {2, 2, 2}, 10, 1, FREQUENCY, {false}, 10).get();

View File

@ -1594,6 +1594,39 @@ TEST_F(CollectionSpecificMoreTest, ValidateQueryById) {
ASSERT_EQ("Cannot use `id` as a query by field.", res_op.error());
}
TEST_F(CollectionSpecificMoreTest, ConsiderDroppedTokensDuringTextMatchScoring) {
nlohmann::json schema = R"({
"name": "coll1",
"fields": [
{"name": "name", "type": "string"},
{"name": "brand", "type": "string"}
]
})"_json;
Collection *coll1 = collectionManager.create_collection(schema).get();
nlohmann::json doc;
doc["id"] = "0";
doc["brand"] = "Neutrogena";
doc["name"] = "Neutrogena Ultra Sheer Oil-Free Face Serum With Vitamin E + SPF 60";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
doc["id"] = "1";
doc["brand"] = "Neutrogena";
doc["name"] = "Neutrogena Ultra Sheer Liquid Sunscreen SPF 70";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
auto res = coll1->search("Neutrogena Ultra Sheer Moisturizing Face Serum", {"brand", "name"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {3, 2}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 0, 0, 0, 2, false, "", true, 0, max_weight).get();
ASSERT_EQ(2, res["hits"].size());
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
}
TEST_F(CollectionSpecificMoreTest, NonNestedFieldNameWithDot) {
nlohmann::json schema = R"({
"name": "coll1",

View File

@ -3810,7 +3810,7 @@ TEST_F(CollectionTest, MultiFieldMatchRankingOnArray) {
}
auto results = coll1->search("golang vue",
{"strong_skills", "skills"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5).get();
{"strong_skills", "skills"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 1).get();
ASSERT_EQ(2, results["found"].get<size_t>());
ASSERT_EQ(2, results["hits"].size());