mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 05:08:43 +08:00
Ensure that synonyms are ranked equally.
This commit is contained in:
parent
f92b8f59bb
commit
3f9544535c
@ -706,6 +706,7 @@ public:
|
||||
const size_t group_limit, const std::vector<std::string>& group_by_fields,
|
||||
const bool prioritize_exact_match,
|
||||
const bool single_exact_query_token,
|
||||
size_t num_query_tokens,
|
||||
int syn_orig_num_tokens,
|
||||
const std::vector<posting_list_t::iterator_t>& posting_lists) const;
|
||||
|
||||
@ -892,7 +893,8 @@ public:
|
||||
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
|
||||
size_t exclude_token_ids_size,
|
||||
Topster* actual_topster,
|
||||
std::vector<query_tokens_t>& field_query_tokens,
|
||||
std::vector<std::vector<token_t>>& q_pos_synonyms,
|
||||
int syn_orig_num_tokens,
|
||||
spp::sparse_hash_set<uint64_t>& groups_processed,
|
||||
std::vector<std::vector<art_leaf*>>& searched_queries,
|
||||
uint32_t*& all_result_ids, size_t& all_result_ids_len,
|
||||
@ -901,8 +903,7 @@ public:
|
||||
const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
std::vector<std::vector<token_t>>& all_queries) const;
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set) const;
|
||||
|
||||
void do_phrase_search(const size_t num_search_fields, const std::vector<search_field_t>& search_fields,
|
||||
std::vector<query_tokens_t>& field_query_tokens,
|
||||
|
@ -2261,12 +2261,42 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
// FIXME: needed?
|
||||
std::set<uint64> query_hashes;
|
||||
|
||||
// resolve synonyms so that we can compute `syn_orig_num_tokens`
|
||||
std::vector<std::vector<token_t>> all_queries = {field_query_tokens[0].q_include_tokens};
|
||||
std::vector<std::vector<token_t>> q_pos_synonyms;
|
||||
std::vector<std::string> q_include_tokens;
|
||||
int syn_orig_num_tokens = -1;
|
||||
|
||||
for(size_t j = 0; j < field_query_tokens[0].q_include_tokens.size(); j++) {
|
||||
q_include_tokens.push_back(field_query_tokens[0].q_include_tokens[j].value);
|
||||
}
|
||||
synonym_index->synonym_reduction(q_include_tokens, field_query_tokens[0].q_synonyms);
|
||||
|
||||
if(!field_query_tokens[0].q_synonyms.empty()) {
|
||||
syn_orig_num_tokens = field_query_tokens[0].q_include_tokens.size();
|
||||
}
|
||||
|
||||
for(const auto& q_syn_vec: field_query_tokens[0].q_synonyms) {
|
||||
std::vector<token_t> q_pos_syn;
|
||||
for(size_t j=0; j < q_syn_vec.size(); j++) {
|
||||
bool is_prefix = (j == q_syn_vec.size()-1);
|
||||
q_pos_syn.emplace_back(j, q_syn_vec[j], is_prefix, q_syn_vec[j].size(), 0);
|
||||
}
|
||||
|
||||
q_pos_synonyms.push_back(q_pos_syn);
|
||||
all_queries.push_back(q_pos_syn);
|
||||
|
||||
if(q_syn_vec.size() > syn_orig_num_tokens) {
|
||||
syn_orig_num_tokens = q_syn_vec.size();
|
||||
}
|
||||
}
|
||||
|
||||
fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, excluded_result_ids,
|
||||
excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted,
|
||||
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search,
|
||||
max_candidates, min_len_1typo, min_len_2typo, -1, sort_order,
|
||||
max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order,
|
||||
field_values, geopoint_indices);
|
||||
|
||||
// try split/joining tokens if no results are found
|
||||
@ -2274,12 +2304,12 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
std::vector<std::vector<std::string>> space_resolved_queries;
|
||||
|
||||
for(size_t i = 0; i < num_search_fields; i++) {
|
||||
std::vector<std::string> q_include_tokens;
|
||||
std::vector<std::string> orig_q_include_tokens;
|
||||
for(auto& q_include_token: field_query_tokens[i].q_include_tokens) {
|
||||
q_include_tokens.push_back(q_include_token.value);
|
||||
orig_q_include_tokens.push_back(q_include_token.value);
|
||||
}
|
||||
|
||||
resolve_space_as_typos(q_include_tokens, the_fields[i].name,space_resolved_queries);
|
||||
resolve_space_as_typos(orig_q_include_tokens, the_fields[i].name,space_resolved_queries);
|
||||
|
||||
if(!space_resolved_queries.empty()) {
|
||||
break;
|
||||
@ -2302,28 +2332,26 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search,
|
||||
max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices);
|
||||
max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices);
|
||||
}
|
||||
}
|
||||
|
||||
// do synonym based searches
|
||||
std::vector<std::vector<token_t>> all_queries = {field_query_tokens[0].q_include_tokens};
|
||||
|
||||
do_synonym_search(the_fields, filters, included_ids_map, sort_fields_std, curated_topster, token_order,
|
||||
0, group_limit, group_by_fields, prioritize_exact_match, exhaustive_search, concurrency,
|
||||
min_len_1typo, min_len_2typo, max_candidates, curated_ids, curated_ids_sorted,
|
||||
excluded_result_ids, excluded_result_ids_size, topster, field_query_tokens,
|
||||
excluded_result_ids, excluded_result_ids_size, topster, q_pos_synonyms, syn_orig_num_tokens,
|
||||
groups_processed, searched_queries, all_result_ids, all_result_ids_len,
|
||||
filter_ids, filter_ids_length, query_hashes,
|
||||
sort_order, field_values, geopoint_indices,
|
||||
qtoken_set, all_queries);
|
||||
qtoken_set);
|
||||
|
||||
// gather up both original query and synonym queries and do drop tokens
|
||||
|
||||
if(all_result_ids_len < drop_tokens_threshold) {
|
||||
// gather up both original query and synonym queries and do drop tokens
|
||||
for(size_t qi = 0; qi < all_queries.size(); qi++) {
|
||||
auto& orig_tokens = all_queries[qi];
|
||||
size_t num_tokens_dropped = 0;
|
||||
int syn_orig_num_token = (qi == 0) ? -1 : all_queries[0].size();
|
||||
|
||||
while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) {
|
||||
// When atleast two tokens from the query are available we can drop one
|
||||
@ -2360,7 +2388,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
query_hashes, token_order, prefixes, typo_tokens_threshold,
|
||||
exhaustive_search, max_candidates, min_len_1typo,
|
||||
min_len_2typo, syn_orig_num_token, sort_order, field_values, geopoint_indices);
|
||||
min_len_2typo, -1, sort_order, field_values, geopoint_indices);
|
||||
|
||||
} else {
|
||||
break;
|
||||
@ -2828,7 +2856,8 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
score_results2(sort_fields, searched_queries.size(), field_is_array,
|
||||
total_cost, field_match_score,
|
||||
seq_id, sort_order, group_limit, group_by_fields,
|
||||
prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, token_postings);
|
||||
prioritize_exact_match, single_exact_query_token,
|
||||
query_tokens.size(), syn_orig_num_tokens, token_postings);
|
||||
|
||||
if(field_match_score > max_field_match_score) {
|
||||
max_field_match_score = field_match_score;
|
||||
@ -2861,7 +2890,10 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
(int64_t(the_fields[max_field_match_index].weight) << 0);
|
||||
|
||||
/*LOG(INFO) << "seq_id: " << seq_id << ", query_tokens.size(): " << query_tokens.size()
|
||||
<< ", syn_orig_num_tokens: " << syn_orig_num_tokens
|
||||
<< ", max_field_match_score: " << max_field_match_score
|
||||
<< ", max_field_match_index: " << max_field_match_index
|
||||
<< ", field_weight: " << the_fields[max_field_match_index].weight
|
||||
<< ", aggregated_score: " << aggregated_score;*/
|
||||
|
||||
KV kv(0, searched_queries.size(), 0, seq_id, distinct_id, match_score_index, scores);
|
||||
@ -3121,7 +3153,8 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
|
||||
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
|
||||
size_t exclude_token_ids_size,
|
||||
Topster* actual_topster,
|
||||
std::vector<query_tokens_t>& field_query_tokens,
|
||||
std::vector<std::vector<token_t>>& q_pos_synonyms,
|
||||
int syn_orig_num_tokens,
|
||||
spp::sparse_hash_set<uint64_t>& groups_processed,
|
||||
std::vector<std::vector<art_leaf*>>& searched_queries,
|
||||
uint32_t*& all_result_ids, size_t& all_result_ids_len,
|
||||
@ -3130,28 +3163,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
|
||||
const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
std::vector<std::vector<token_t>>& all_queries) const {
|
||||
|
||||
std::vector<std::string> q_include_tokens;
|
||||
for(size_t j = 0; j < field_query_tokens[0].q_include_tokens.size(); j++) {
|
||||
q_include_tokens.push_back(field_query_tokens[0].q_include_tokens[j].value);
|
||||
}
|
||||
synonym_index->synonym_reduction(q_include_tokens, field_query_tokens[0].q_synonyms);
|
||||
|
||||
std::vector<std::vector<token_t>> q_pos_synonyms;
|
||||
for(const auto& q_syn_vec: field_query_tokens[0].q_synonyms) {
|
||||
std::vector<token_t> q_pos_syn;
|
||||
for(size_t j=0; j < q_syn_vec.size(); j++) {
|
||||
bool is_prefix = (j == q_syn_vec.size()-1);
|
||||
q_pos_syn.emplace_back(j, q_syn_vec[j], is_prefix, q_syn_vec[j].size(), 0);
|
||||
}
|
||||
q_pos_synonyms.push_back(q_pos_syn);
|
||||
all_queries.push_back(q_pos_syn);
|
||||
}
|
||||
|
||||
int syn_orig_num_tokens = field_query_tokens[0].q_include_tokens.size();
|
||||
bool syn_wildcard_filter_init_done = false;
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set) const {
|
||||
|
||||
for(const auto& syn_tokens: q_pos_synonyms) {
|
||||
query_hashes.clear();
|
||||
@ -3161,7 +3173,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
|
||||
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
|
||||
query_hashes, token_order, {0}, typo_tokens_threshold,
|
||||
exhaustive_search, max_candidates, min_len_1typo,
|
||||
min_len_2typo, q_include_tokens.size(), sort_order, field_values, geopoint_indices);
|
||||
min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices);
|
||||
}
|
||||
|
||||
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
|
||||
@ -3225,7 +3237,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
|
||||
int64_t match_score = 0;
|
||||
score_results2(sort_fields, searched_queries.size(), field_is_array,
|
||||
0, match_score, seq_id, sort_order, group_limit, group_by_fields,
|
||||
false, false, -1, {});
|
||||
false, false, 1, -1, {});
|
||||
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score_index = 0;
|
||||
@ -3548,7 +3560,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
|
||||
|
||||
score_results2(sort_fields, (uint16_t) searched_queries.size(), false, 0,
|
||||
match_score, seq_id, sort_order, group_limit, group_by_fields, false,
|
||||
false, -1, plists);
|
||||
false, 1, -1, plists);
|
||||
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score_index = 0;
|
||||
@ -3850,6 +3862,7 @@ int64_t Index::score_results2(const std::vector<sort_by> & sort_fields, const ui
|
||||
const size_t group_limit, const std::vector<std::string>& group_by_fields,
|
||||
const bool prioritize_exact_match,
|
||||
const bool single_exact_query_token,
|
||||
size_t num_query_tokens,
|
||||
int syn_orig_num_tokens,
|
||||
const std::vector<posting_list_t::iterator_t>& posting_lists) const {
|
||||
|
||||
@ -3861,8 +3874,8 @@ int64_t Index::score_results2(const std::vector<sort_by> & sort_fields, const ui
|
||||
prioritize_exact_match && single_exact_query_token &&
|
||||
posting_list_t::is_single_token_verbatim_match(posting_lists[0], field_is_array)
|
||||
);
|
||||
size_t words_present = (syn_orig_num_tokens == -1) ? 1 : syn_orig_num_tokens;
|
||||
size_t distance = (syn_orig_num_tokens == -1) ? 0 : syn_orig_num_tokens-1;
|
||||
size_t words_present = (num_query_tokens == 1 && syn_orig_num_tokens != -1) ? syn_orig_num_tokens : 1;
|
||||
size_t distance = (num_query_tokens == 1 && syn_orig_num_tokens != -1) ? syn_orig_num_tokens-1 : 0;
|
||||
Match single_token_match = Match(words_present, distance, is_verbatim_match);
|
||||
match_score = single_token_match.get_match_score(total_cost, words_present);
|
||||
} else {
|
||||
@ -3884,7 +3897,7 @@ int64_t Index::score_results2(const std::vector<sort_by> & sort_fields, const ui
|
||||
auto proximity = ((this_match_score >> 8) & 0xFF);
|
||||
auto verbatim = (this_match_score & 0xFF);
|
||||
|
||||
if(syn_orig_num_tokens != -1) {
|
||||
if(syn_orig_num_tokens != -1 && num_query_tokens == posting_lists.size()) {
|
||||
unique_words = syn_orig_num_tokens;
|
||||
this_words_present = syn_orig_num_tokens;
|
||||
proximity = 100 - (syn_orig_num_tokens - 1);
|
||||
@ -4179,8 +4192,6 @@ inline uint32_t Index::next_suggestion2(const std::vector<tok_candidates>& token
|
||||
q = ldiv(q.quot, token_candidates_vec[i].candidates.size());
|
||||
const auto& candidate = token_candidates_vec[i].candidates[q.rem];
|
||||
|
||||
bool exact_match = token_candidates_vec[i].cost == 0 && token_size == candidate.size();
|
||||
|
||||
// we assume that toke was found via prefix search if candidate is longer than token's typo tolerance
|
||||
bool is_prefix_searched = token_candidates_vec[i].prefix_search &&
|
||||
(candidate.size() > (token_size + token_candidates_vec[i].cost));
|
||||
|
@ -375,11 +375,8 @@ TEST_F(CollectionSynonymsTest, SynonymQueryVariantWithDropTokens) {
|
||||
auto res = coll1->search("us sneakers", {"category", "location"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
ASSERT_EQ(3, res["hits"].size());
|
||||
|
||||
// NOTE: "1" is ranked above "0" because synonym matches uses the root query's number of tokens for counting
|
||||
// This means that "united states" == "us" so both records have 2 tokens matched, so tie breaking happens on points.
|
||||
|
||||
ASSERT_EQ("1", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("0", res["hits"][1]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("2", res["hits"][2]["document"]["id"].get<std::string>());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
@ -418,7 +415,7 @@ TEST_F(CollectionSynonymsTest, SynonymsTextMatchSameAsRootQuery) {
|
||||
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
|
||||
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
|
||||
|
||||
auto res = coll1->search("ceo", {"name", "title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
auto res = coll1->search("ceo", {"name", "title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 0).get();
|
||||
ASSERT_EQ(2, res["hits"].size());
|
||||
|
||||
ASSERT_EQ("1", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
@ -535,6 +532,59 @@ TEST_F(CollectionSynonymsTest, ExactMatchRankedSameAsSynonymMatch) {
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSynonymsTest, ExactMatchVsSynonymMatchCrossFields) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("description", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> records = {
|
||||
{"Head of Marketing", "The Chief Marketing Officer", "100"},
|
||||
{"VP of Sales", "Preparing marketing and sales materials.", "120"},
|
||||
};
|
||||
|
||||
for(size_t i=0; i<records.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
|
||||
doc["id"] = std::to_string(i);
|
||||
doc["title"] = records[i][0];
|
||||
doc["description"] = records[i][1];
|
||||
doc["points"] = std::stoi(records[i][2]);
|
||||
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
nlohmann::json syn_json = {
|
||||
{"id", "syn-1"},
|
||||
{"synonyms", {"cmo", "Chief Marketing Officer", "VP of Marketing"}}
|
||||
};
|
||||
|
||||
synonym_t synonym;
|
||||
auto syn_op = synonym_t::parse(syn_json, synonym);
|
||||
ASSERT_TRUE(syn_op.ok());
|
||||
|
||||
coll1->add_synonym(synonym);
|
||||
|
||||
auto res = coll1->search("cmo", {"title", "description"}, "", {}, {},
|
||||
{0}, 10, 1, FREQUENCY, {false}, 0).get();
|
||||
|
||||
LOG(INFO) << res;
|
||||
|
||||
ASSERT_EQ(2, res["hits"].size());
|
||||
ASSERT_EQ(2, res["found"].get<uint32_t>());
|
||||
|
||||
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSynonymsTest, SynonymFieldOrdering) {
|
||||
// Synonym match on a field earlier in the fields list should rank above exact match of another field
|
||||
Collection *coll1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user