Ensure that synonyms are ranked equally.

This commit is contained in:
Kishore Nallan 2022-04-12 08:04:43 +05:30
parent f92b8f59bb
commit 3f9544535c
3 changed files with 115 additions and 53 deletions

View File

@ -706,6 +706,7 @@ public:
const size_t group_limit, const std::vector<std::string>& group_by_fields,
const bool prioritize_exact_match,
const bool single_exact_query_token,
size_t num_query_tokens,
int syn_orig_num_tokens,
const std::vector<posting_list_t::iterator_t>& posting_lists) const;
@ -892,7 +893,8 @@ public:
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
Topster* actual_topster,
std::vector<query_tokens_t>& field_query_tokens,
std::vector<std::vector<token_t>>& q_pos_synonyms,
int syn_orig_num_tokens,
spp::sparse_hash_set<uint64_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
@ -901,8 +903,7 @@ public:
const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
const std::vector<size_t>& geopoint_indices,
tsl::htrie_map<char, token_leaf>& qtoken_set,
std::vector<std::vector<token_t>>& all_queries) const;
tsl::htrie_map<char, token_leaf>& qtoken_set) const;
void do_phrase_search(const size_t num_search_fields, const std::vector<search_field_t>& search_fields,
std::vector<query_tokens_t>& field_query_tokens,

View File

@ -2261,12 +2261,42 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
// FIXME: needed?
std::set<uint64> query_hashes;
// resolve synonyms so that we can compute `syn_orig_num_tokens`
std::vector<std::vector<token_t>> all_queries = {field_query_tokens[0].q_include_tokens};
std::vector<std::vector<token_t>> q_pos_synonyms;
std::vector<std::string> q_include_tokens;
int syn_orig_num_tokens = -1;
for(size_t j = 0; j < field_query_tokens[0].q_include_tokens.size(); j++) {
q_include_tokens.push_back(field_query_tokens[0].q_include_tokens[j].value);
}
synonym_index->synonym_reduction(q_include_tokens, field_query_tokens[0].q_synonyms);
if(!field_query_tokens[0].q_synonyms.empty()) {
syn_orig_num_tokens = field_query_tokens[0].q_include_tokens.size();
}
for(const auto& q_syn_vec: field_query_tokens[0].q_synonyms) {
std::vector<token_t> q_pos_syn;
for(size_t j=0; j < q_syn_vec.size(); j++) {
bool is_prefix = (j == q_syn_vec.size()-1);
q_pos_syn.emplace_back(j, q_syn_vec[j], is_prefix, q_syn_vec[j].size(), 0);
}
q_pos_synonyms.push_back(q_pos_syn);
all_queries.push_back(q_pos_syn);
if(q_syn_vec.size() > syn_orig_num_tokens) {
syn_orig_num_tokens = q_syn_vec.size();
}
}
fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, excluded_result_ids,
excluded_result_ids_size, filter_ids, filter_ids_length, curated_ids_sorted,
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search,
max_candidates, min_len_1typo, min_len_2typo, -1, sort_order,
max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order,
field_values, geopoint_indices);
// try split/joining tokens if no results are found
@ -2274,12 +2304,12 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
std::vector<std::vector<std::string>> space_resolved_queries;
for(size_t i = 0; i < num_search_fields; i++) {
std::vector<std::string> q_include_tokens;
std::vector<std::string> orig_q_include_tokens;
for(auto& q_include_token: field_query_tokens[i].q_include_tokens) {
q_include_tokens.push_back(q_include_token.value);
orig_q_include_tokens.push_back(q_include_token.value);
}
resolve_space_as_typos(q_include_tokens, the_fields[i].name,space_resolved_queries);
resolve_space_as_typos(orig_q_include_tokens, the_fields[i].name,space_resolved_queries);
if(!space_resolved_queries.empty()) {
break;
@ -2302,28 +2332,26 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search,
max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices);
max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices);
}
}
// do synonym based searches
std::vector<std::vector<token_t>> all_queries = {field_query_tokens[0].q_include_tokens};
do_synonym_search(the_fields, filters, included_ids_map, sort_fields_std, curated_topster, token_order,
0, group_limit, group_by_fields, prioritize_exact_match, exhaustive_search, concurrency,
min_len_1typo, min_len_2typo, max_candidates, curated_ids, curated_ids_sorted,
excluded_result_ids, excluded_result_ids_size, topster, field_query_tokens,
excluded_result_ids, excluded_result_ids_size, topster, q_pos_synonyms, syn_orig_num_tokens,
groups_processed, searched_queries, all_result_ids, all_result_ids_len,
filter_ids, filter_ids_length, query_hashes,
sort_order, field_values, geopoint_indices,
qtoken_set, all_queries);
qtoken_set);
// gather up both original query and synonym queries and do drop tokens
if(all_result_ids_len < drop_tokens_threshold) {
// gather up both original query and synonym queries and do drop tokens
for(size_t qi = 0; qi < all_queries.size(); qi++) {
auto& orig_tokens = all_queries[qi];
size_t num_tokens_dropped = 0;
int syn_orig_num_token = (qi == 0) ? -1 : all_queries[0].size();
while(exhaustive_search || all_result_ids_len < drop_tokens_threshold) {
// When atleast two tokens from the query are available we can drop one
@ -2360,7 +2388,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
query_hashes, token_order, prefixes, typo_tokens_threshold,
exhaustive_search, max_candidates, min_len_1typo,
min_len_2typo, syn_orig_num_token, sort_order, field_values, geopoint_indices);
min_len_2typo, -1, sort_order, field_values, geopoint_indices);
} else {
break;
@ -2828,7 +2856,8 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
score_results2(sort_fields, searched_queries.size(), field_is_array,
total_cost, field_match_score,
seq_id, sort_order, group_limit, group_by_fields,
prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, token_postings);
prioritize_exact_match, single_exact_query_token,
query_tokens.size(), syn_orig_num_tokens, token_postings);
if(field_match_score > max_field_match_score) {
max_field_match_score = field_match_score;
@ -2861,7 +2890,10 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
(int64_t(the_fields[max_field_match_index].weight) << 0);
/*LOG(INFO) << "seq_id: " << seq_id << ", query_tokens.size(): " << query_tokens.size()
<< ", syn_orig_num_tokens: " << syn_orig_num_tokens
<< ", max_field_match_score: " << max_field_match_score
<< ", max_field_match_index: " << max_field_match_index
<< ", field_weight: " << the_fields[max_field_match_index].weight
<< ", aggregated_score: " << aggregated_score;*/
KV kv(0, searched_queries.size(), 0, seq_id, distinct_id, match_score_index, scores);
@ -3121,7 +3153,8 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
Topster* actual_topster,
std::vector<query_tokens_t>& field_query_tokens,
std::vector<std::vector<token_t>>& q_pos_synonyms,
int syn_orig_num_tokens,
spp::sparse_hash_set<uint64_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
@ -3130,28 +3163,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
const std::vector<size_t>& geopoint_indices,
tsl::htrie_map<char, token_leaf>& qtoken_set,
std::vector<std::vector<token_t>>& all_queries) const {
std::vector<std::string> q_include_tokens;
for(size_t j = 0; j < field_query_tokens[0].q_include_tokens.size(); j++) {
q_include_tokens.push_back(field_query_tokens[0].q_include_tokens[j].value);
}
synonym_index->synonym_reduction(q_include_tokens, field_query_tokens[0].q_synonyms);
std::vector<std::vector<token_t>> q_pos_synonyms;
for(const auto& q_syn_vec: field_query_tokens[0].q_synonyms) {
std::vector<token_t> q_pos_syn;
for(size_t j=0; j < q_syn_vec.size(); j++) {
bool is_prefix = (j == q_syn_vec.size()-1);
q_pos_syn.emplace_back(j, q_syn_vec[j], is_prefix, q_syn_vec[j].size(), 0);
}
q_pos_synonyms.push_back(q_pos_syn);
all_queries.push_back(q_pos_syn);
}
int syn_orig_num_tokens = field_query_tokens[0].q_include_tokens.size();
bool syn_wildcard_filter_init_done = false;
tsl::htrie_map<char, token_leaf>& qtoken_set) const {
for(const auto& syn_tokens: q_pos_synonyms) {
query_hashes.clear();
@ -3161,7 +3173,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
query_hashes, token_order, {0}, typo_tokens_threshold,
exhaustive_search, max_candidates, min_len_1typo,
min_len_2typo, q_include_tokens.size(), sort_order, field_values, geopoint_indices);
min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices);
}
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
@ -3225,7 +3237,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
int64_t match_score = 0;
score_results2(sort_fields, searched_queries.size(), field_is_array,
0, match_score, seq_id, sort_order, group_limit, group_by_fields,
false, false, -1, {});
false, false, 1, -1, {});
int64_t scores[3] = {0};
int64_t match_score_index = 0;
@ -3548,7 +3560,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
score_results2(sort_fields, (uint16_t) searched_queries.size(), false, 0,
match_score, seq_id, sort_order, group_limit, group_by_fields, false,
false, -1, plists);
false, 1, -1, plists);
int64_t scores[3] = {0};
int64_t match_score_index = 0;
@ -3850,6 +3862,7 @@ int64_t Index::score_results2(const std::vector<sort_by> & sort_fields, const ui
const size_t group_limit, const std::vector<std::string>& group_by_fields,
const bool prioritize_exact_match,
const bool single_exact_query_token,
size_t num_query_tokens,
int syn_orig_num_tokens,
const std::vector<posting_list_t::iterator_t>& posting_lists) const {
@ -3861,8 +3874,8 @@ int64_t Index::score_results2(const std::vector<sort_by> & sort_fields, const ui
prioritize_exact_match && single_exact_query_token &&
posting_list_t::is_single_token_verbatim_match(posting_lists[0], field_is_array)
);
size_t words_present = (syn_orig_num_tokens == -1) ? 1 : syn_orig_num_tokens;
size_t distance = (syn_orig_num_tokens == -1) ? 0 : syn_orig_num_tokens-1;
size_t words_present = (num_query_tokens == 1 && syn_orig_num_tokens != -1) ? syn_orig_num_tokens : 1;
size_t distance = (num_query_tokens == 1 && syn_orig_num_tokens != -1) ? syn_orig_num_tokens-1 : 0;
Match single_token_match = Match(words_present, distance, is_verbatim_match);
match_score = single_token_match.get_match_score(total_cost, words_present);
} else {
@ -3884,7 +3897,7 @@ int64_t Index::score_results2(const std::vector<sort_by> & sort_fields, const ui
auto proximity = ((this_match_score >> 8) & 0xFF);
auto verbatim = (this_match_score & 0xFF);
if(syn_orig_num_tokens != -1) {
if(syn_orig_num_tokens != -1 && num_query_tokens == posting_lists.size()) {
unique_words = syn_orig_num_tokens;
this_words_present = syn_orig_num_tokens;
proximity = 100 - (syn_orig_num_tokens - 1);
@ -4179,8 +4192,6 @@ inline uint32_t Index::next_suggestion2(const std::vector<tok_candidates>& token
q = ldiv(q.quot, token_candidates_vec[i].candidates.size());
const auto& candidate = token_candidates_vec[i].candidates[q.rem];
bool exact_match = token_candidates_vec[i].cost == 0 && token_size == candidate.size();
// we assume that toke was found via prefix search if candidate is longer than token's typo tolerance
bool is_prefix_searched = token_candidates_vec[i].prefix_search &&
(candidate.size() > (token_size + token_candidates_vec[i].cost));

View File

@ -375,11 +375,8 @@ TEST_F(CollectionSynonymsTest, SynonymQueryVariantWithDropTokens) {
auto res = coll1->search("us sneakers", {"category", "location"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(3, res["hits"].size());
// NOTE: "1" is ranked above "0" because synonym matches uses the root query's number of tokens for counting
// This means that "united states" == "us" so both records have 2 tokens matched, so tie breaking happens on points.
ASSERT_EQ("1", res["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("0", res["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("2", res["hits"][2]["document"]["id"].get<std::string>());
collectionManager.drop_collection("coll1");
@ -418,7 +415,7 @@ TEST_F(CollectionSynonymsTest, SynonymsTextMatchSameAsRootQuery) {
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
auto res = coll1->search("ceo", {"name", "title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 10).get();
auto res = coll1->search("ceo", {"name", "title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(2, res["hits"].size());
ASSERT_EQ("1", res["hits"][0]["document"]["id"].get<std::string>());
@ -535,6 +532,59 @@ TEST_F(CollectionSynonymsTest, ExactMatchRankedSameAsSynonymMatch) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSynonymsTest, ExactMatchVsSynonymMatchCrossFields) {
Collection *coll1;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("description", field_types::STRING, false),
field("points", field_types::INT32, false),};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
std::vector<std::vector<std::string>> records = {
{"Head of Marketing", "The Chief Marketing Officer", "100"},
{"VP of Sales", "Preparing marketing and sales materials.", "120"},
};
for(size_t i=0; i<records.size(); i++) {
nlohmann::json doc;
doc["id"] = std::to_string(i);
doc["title"] = records[i][0];
doc["description"] = records[i][1];
doc["points"] = std::stoi(records[i][2]);
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
nlohmann::json syn_json = {
{"id", "syn-1"},
{"synonyms", {"cmo", "Chief Marketing Officer", "VP of Marketing"}}
};
synonym_t synonym;
auto syn_op = synonym_t::parse(syn_json, synonym);
ASSERT_TRUE(syn_op.ok());
coll1->add_synonym(synonym);
auto res = coll1->search("cmo", {"title", "description"}, "", {}, {},
{0}, 10, 1, FREQUENCY, {false}, 0).get();
LOG(INFO) << res;
ASSERT_EQ(2, res["hits"].size());
ASSERT_EQ(2, res["found"].get<uint32_t>());
ASSERT_EQ("0", res["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSynonymsTest, SynonymFieldOrdering) {
// Synonym match on a field earlier in the fields list should rank above exact match of another field
Collection *coll1;