mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 12:42:50 +08:00
Consider dropped token in text match calculation.
This commit is contained in:
parent
0a0a2ed272
commit
48119f76eb
@ -418,6 +418,7 @@ private:
|
||||
std::vector<tok_candidates>& token_candidates_vec,
|
||||
std::vector<std::vector<art_leaf*>>& searched_queries,
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
const std::vector<token_t>& dropped_tokens,
|
||||
Topster* topster,
|
||||
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
|
||||
uint32_t*& all_result_ids, size_t& all_result_ids_len,
|
||||
@ -835,8 +836,8 @@ public:
|
||||
|
||||
void fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
|
||||
const std::vector<token_t>& query_tokens,
|
||||
const std::vector<token_t>& dropped_tokens,
|
||||
const text_match_type_t match_type,
|
||||
const bool dropped_tokens,
|
||||
const uint32_t* exclude_token_ids,
|
||||
size_t exclude_token_ids_size,
|
||||
const uint32_t* filter_ids, size_t filter_ids_length,
|
||||
@ -884,6 +885,7 @@ public:
|
||||
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
|
||||
std::vector<std::vector<art_leaf*>>& searched_queries,
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
const std::vector<token_t>& dropped_tokens,
|
||||
const size_t group_limit,
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
bool prioritize_exact_match,
|
||||
|
102
src/index.cpp
102
src/index.cpp
@ -1269,6 +1269,7 @@ void Index::search_all_candidates(const size_t num_search_fields,
|
||||
std::vector<tok_candidates>& token_candidates_vec,
|
||||
std::vector<std::vector<art_leaf*>>& searched_queries,
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
const std::vector<token_t>& dropped_tokens,
|
||||
Topster* topster,
|
||||
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
|
||||
uint32_t*& all_result_ids, size_t& all_result_ids_len,
|
||||
@ -1328,7 +1329,7 @@ void Index::search_all_candidates(const size_t num_search_fields,
|
||||
|
||||
search_across_fields(query_suggestion, num_typos, prefixes, the_fields, num_search_fields, match_type,
|
||||
sort_fields, topster,groups_processed,
|
||||
searched_queries, qtoken_set, group_limit, group_by_fields,
|
||||
searched_queries, qtoken_set, dropped_tokens, group_limit, group_by_fields,
|
||||
prioritize_exact_match, prioritize_token_position,
|
||||
filter_ids, filter_ids_length, total_cost, syn_orig_num_tokens,
|
||||
exclude_token_ids, exclude_token_ids_size,
|
||||
@ -2960,7 +2961,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
}
|
||||
}
|
||||
|
||||
fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, match_type, false, excluded_result_ids,
|
||||
fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, {}, match_type, excluded_result_ids,
|
||||
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted,
|
||||
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
@ -2997,7 +2998,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
space_resolved_queries[0][j].size(), 0);
|
||||
}
|
||||
|
||||
fuzzy_search_fields(the_fields, resolved_tokens, match_type, false, excluded_result_ids,
|
||||
fuzzy_search_fields(the_fields, resolved_tokens, {}, match_type, excluded_result_ids,
|
||||
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted,
|
||||
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
@ -3028,6 +3029,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) {
|
||||
// When atleast two tokens from the query are available we can drop one
|
||||
std::vector<token_t> truncated_tokens;
|
||||
std::vector<token_t> dropped_tokens;
|
||||
|
||||
if(orig_tokens.size() > 1 && num_tokens_dropped < 2*(orig_tokens.size()-1)) {
|
||||
bool prefix_search = false;
|
||||
@ -3035,15 +3037,23 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
if (num_tokens_dropped < orig_tokens.size() - 1) {
|
||||
// drop from right
|
||||
size_t truncated_len = orig_tokens.size() - num_tokens_dropped - 1;
|
||||
for (size_t i = 0; i < truncated_len; i++) {
|
||||
truncated_tokens.emplace_back(orig_tokens[i]);
|
||||
for (size_t i = 0; i < orig_tokens.size(); i++) {
|
||||
if(i < truncated_len) {
|
||||
truncated_tokens.emplace_back(orig_tokens[i]);
|
||||
} else {
|
||||
dropped_tokens.emplace_back(orig_tokens[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// drop from left
|
||||
prefix_search = true;
|
||||
size_t start_index = (num_tokens_dropped + 1) - orig_tokens.size() + 1;
|
||||
for(size_t i = start_index; i < orig_tokens.size(); i++) {
|
||||
truncated_tokens.emplace_back(orig_tokens[i]);
|
||||
for(size_t i = 0; i < orig_tokens.size(); i++) {
|
||||
if(i >= start_index) {
|
||||
truncated_tokens.emplace_back(orig_tokens[i]);
|
||||
} else {
|
||||
dropped_tokens.emplace_back(orig_tokens[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3054,11 +3064,14 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
drop_token_prefixes.push_back(p && prefix_search);
|
||||
}
|
||||
|
||||
fuzzy_search_fields(the_fields, truncated_tokens, match_type, true, excluded_result_ids,
|
||||
excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted,
|
||||
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold,
|
||||
fuzzy_search_fields(the_fields, truncated_tokens, dropped_tokens, match_type,
|
||||
excluded_result_ids, excluded_result_ids_size,
|
||||
filter_result.docs, filter_result.count,
|
||||
curated_ids_sorted, sort_fields_std, num_typos, searched_queries,
|
||||
qtoken_set, topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields,
|
||||
prioritize_exact_match, prioritize_token_position, query_hashes,
|
||||
token_order, prefixes, typo_tokens_threshold,
|
||||
exhaustive_search, max_candidates, min_len_1typo,
|
||||
min_len_2typo, -1, sort_order, field_values, geopoint_indices);
|
||||
|
||||
@ -3391,8 +3404,8 @@ void Index::process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>
|
||||
|
||||
void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
|
||||
const std::vector<token_t>& query_tokens,
|
||||
const std::vector<token_t>& dropped_tokens,
|
||||
const text_match_type_t match_type,
|
||||
const bool dropped_tokens,
|
||||
const uint32_t* exclude_token_ids,
|
||||
size_t exclude_token_ids_size,
|
||||
const uint32_t* filter_ids, size_t filter_ids_length,
|
||||
@ -3488,7 +3501,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
|
||||
// we will first attempt to match the prefix with the most "popular" fields of the preceding token.
|
||||
// Tokens matched from popular fields will also be searched across other query fields.
|
||||
// Only when we find *no results* for such an expansion, we will attempt cross field matching.
|
||||
bool last_token = query_tokens.size() > 1 && !dropped_tokens &&
|
||||
bool last_token = query_tokens.size() > 1 && dropped_tokens.empty() &&
|
||||
(token_index == (query_tokens.size() - 1));
|
||||
|
||||
std::vector<size_t> query_field_ids(num_search_fields);
|
||||
@ -3646,7 +3659,8 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
|
||||
std::vector<uint32_t> id_buff;
|
||||
search_all_candidates(num_search_fields, match_type, the_fields, filter_ids, filter_ids_length,
|
||||
exclude_token_ids, exclude_token_ids_size,
|
||||
sort_fields, token_candidates_vec, searched_queries, qtoken_set, topster,
|
||||
sort_fields, token_candidates_vec, searched_queries, qtoken_set,
|
||||
dropped_tokens, topster,
|
||||
groups_processed, all_result_ids, all_result_ids_len,
|
||||
typo_tokens_threshold, group_limit, group_by_fields, query_tokens,
|
||||
num_typos, prefixes, prioritize_exact_match, prioritize_token_position,
|
||||
@ -3800,6 +3814,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
|
||||
std::vector<std::vector<art_leaf*>>& searched_queries,
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
const std::vector<token_t>& dropped_tokens,
|
||||
const size_t group_limit,
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
const bool prioritize_exact_match,
|
||||
@ -3815,6 +3830,49 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
|
||||
std::vector<art_leaf*> query_suggestion;
|
||||
|
||||
// one or_iterator for each token (across multiple fields)
|
||||
std::vector<or_iterator_t> dropped_token_its;
|
||||
|
||||
// used to track plists that must be destructed once done
|
||||
std::vector<posting_list_t*> expanded_dropped_plists;
|
||||
|
||||
for(auto& dropped_token: dropped_tokens) {
|
||||
auto& token = dropped_token.value;
|
||||
auto token_c_str = (const unsigned char*) token.c_str();
|
||||
|
||||
// convert token from each field into an or_iterator
|
||||
std::vector<posting_list_t::iterator_t> its;
|
||||
|
||||
for(size_t i = 0; i < the_fields.size(); i++) {
|
||||
const std::string& field_name = the_fields[i].name;
|
||||
|
||||
art_tree* tree = search_index.at(field_name);
|
||||
art_leaf* leaf = static_cast<art_leaf*>(art_search(tree, token_c_str, token.size()+1));
|
||||
|
||||
if(!leaf) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Token: " << token << ", field_name: " << field_name
|
||||
<< ", num_ids: " << posting_t::num_ids(leaf->values);
|
||||
|
||||
if(IS_COMPACT_POSTING(leaf->values)) {
|
||||
auto compact_posting_list = COMPACT_POSTING_PTR(leaf->values);
|
||||
posting_list_t* full_posting_list = compact_posting_list->to_full_posting_list();
|
||||
expanded_dropped_plists.push_back(full_posting_list);
|
||||
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
|
||||
} else {
|
||||
posting_list_t* full_posting_list = (posting_list_t*)(leaf->values);
|
||||
its.push_back(full_posting_list->new_iterator(nullptr, nullptr, i)); // moved, not copied
|
||||
}
|
||||
}
|
||||
|
||||
or_iterator_t token_fields(its);
|
||||
dropped_token_its.push_back(std::move(token_fields));
|
||||
}
|
||||
|
||||
|
||||
|
||||
// one iterator for each token, each underlying iterator contains results of token across multiple fields
|
||||
std::vector<or_iterator_t> token_its;
|
||||
|
||||
@ -3957,6 +4015,14 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
best_field_match_score, scores, match_score_index);
|
||||
|
||||
size_t query_len = query_tokens.size();
|
||||
|
||||
// check if seq_id exists in any of the dropped_token iters and increment matching fields accordingly
|
||||
for(auto& dropped_token_it: dropped_token_its) {
|
||||
if(dropped_token_it.skip_to(seq_id) && dropped_token_it.id() == seq_id) {
|
||||
query_len++;
|
||||
}
|
||||
}
|
||||
|
||||
if(syn_orig_num_tokens != -1) {
|
||||
query_len = syn_orig_num_tokens;
|
||||
}
|
||||
@ -4034,6 +4100,10 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
for(posting_list_t* plist: expanded_plists) {
|
||||
delete plist;
|
||||
}
|
||||
|
||||
for(posting_list_t* plist: expanded_dropped_plists) {
|
||||
delete plist;
|
||||
}
|
||||
}
|
||||
|
||||
void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
@ -4460,7 +4530,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
|
||||
|
||||
for (const auto& syn_tokens : q_pos_synonyms) {
|
||||
query_hashes.clear();
|
||||
fuzzy_search_fields(the_fields, syn_tokens, match_type, false, exclude_token_ids,
|
||||
fuzzy_search_fields(the_fields, syn_tokens, {}, match_type, exclude_token_ids,
|
||||
exclude_token_ids_size, filter_ids, filter_ids_length, curated_ids_sorted,
|
||||
sort_fields_std, {0}, searched_queries, qtoken_set, actual_topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
|
@ -331,8 +331,8 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) {
|
||||
ASSERT_EQ(2, results["grouped_hits"][2]["found"].get<int32_t>());
|
||||
ASSERT_STREQ("country", results["grouped_hits"][2]["group_key"][0].get<std::string>().c_str());
|
||||
ASSERT_EQ(2, results["grouped_hits"][2]["hits"].size());
|
||||
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
@ -1813,11 +1813,10 @@ TEST_F(CollectionOverrideTest, DynamicFilteringMultiplePlaceholders) {
|
||||
auto results = coll1->search("Nike Air Jordan light yellow shoes", {"name", "category", "brand"}, "",
|
||||
{}, sort_fields, {2, 2, 2}, 10, 1, FREQUENCY, {false}, 10).get();
|
||||
|
||||
// not happy with this order (0,2,1 is better)
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("2", results["hits"][2]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", results["hits"][2]["document"]["id"].get<std::string>());
|
||||
|
||||
// query with tokens at the start that preceding the placeholders in the rule
|
||||
results = coll1->search("New Nike Air Jordan yellow shoes", {"name", "category", "brand"}, "",
|
||||
@ -1997,9 +1996,9 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithNumericalFilter) {
|
||||
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", results["hits"][2]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("2", results["hits"][3]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("2", results["hits"][1]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("0", results["hits"][2]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", results["hits"][3]["document"]["id"].get<std::string>());
|
||||
|
||||
results = coll1->search("adidas", {"name", "category", "brand"}, "",
|
||||
{}, sort_fields, {2, 2, 2}, 10, 1, FREQUENCY, {false}, 10).get();
|
||||
|
@ -1594,6 +1594,39 @@ TEST_F(CollectionSpecificMoreTest, ValidateQueryById) {
|
||||
ASSERT_EQ("Cannot use `id` as a query by field.", res_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSpecificMoreTest, ConsiderDroppedTokensDuringTextMatchScoring) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "brand", "type": "string"}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
Collection *coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
nlohmann::json doc;
|
||||
doc["id"] = "0";
|
||||
doc["brand"] = "Neutrogena";
|
||||
doc["name"] = "Neutrogena Ultra Sheer Oil-Free Face Serum With Vitamin E + SPF 60";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
|
||||
doc["id"] = "1";
|
||||
doc["brand"] = "Neutrogena";
|
||||
doc["name"] = "Neutrogena Ultra Sheer Liquid Sunscreen SPF 70";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
|
||||
auto res = coll1->search("Neutrogena Ultra Sheer Moisturizing Face Serum", {"brand", "name"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 5,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 20, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {3, 2}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 0, 0, 0, 2, false, "", true, 0, max_weight).get();
|
||||
|
||||
ASSERT_EQ(2, res["hits"].size());
|
||||
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSpecificMoreTest, NonNestedFieldNameWithDot) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "coll1",
|
||||
|
@ -3810,7 +3810,7 @@ TEST_F(CollectionTest, MultiFieldMatchRankingOnArray) {
|
||||
}
|
||||
|
||||
auto results = coll1->search("golang vue",
|
||||
{"strong_skills", "skills"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 5).get();
|
||||
{"strong_skills", "skills"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {true}, 1).get();
|
||||
|
||||
ASSERT_EQ(2, results["found"].get<size_t>());
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
Loading…
x
Reference in New Issue
Block a user