Synonym matches should share parent match score.

This commit is contained in:
Kishore Nallan 2022-03-19 14:05:21 +05:30
parent 59a55d0195
commit 5f244bc588
5 changed files with 149 additions and 44 deletions

View File

@ -528,6 +528,7 @@ private:
size_t drop_tokens_threshold,
size_t typo_tokens_threshold,
bool exhaustive_search,
int syn_orig_num_tokens,
size_t min_len_1typo,
size_t min_len_2typo,
size_t max_candidates) const;
@ -548,6 +549,7 @@ private:
const std::vector<token_t>& query_tokens,
bool prioritize_exact_match,
bool exhaustive_search,
int syn_orig_num_tokens,
size_t concurrency,
std::set<uint64>& query_hashes,
std::vector<uint32_t>& id_buff) const;
@ -656,6 +658,7 @@ public:
const std::vector<std::string> &group_by_fields, uint32_t token_bits,
bool prioritize_exact_match,
bool single_exact_query_token,
int syn_orig_num_tokens,
const std::vector<posting_list_t::iterator_t>& posting_lists) const;
static int64_t get_points_from_doc(const nlohmann::json &document, const std::string & default_sorting_field);

View File

@ -6,8 +6,6 @@
#include <match_score.h>
#include <string_utils.h>
#include <art.h>
#include <thread>
#include <future>
#include <rocksdb/write_batch.h>
#include <system_metrics.h>
#include <tokenizer.h>

View File

@ -1114,6 +1114,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
const std::vector<token_t>& query_tokens,
bool prioritize_exact_match,
const bool exhaustive_search,
int syn_orig_num_tokens,
const size_t concurrency,
std::set<uint64>& query_hashes,
std::vector<uint32_t>& id_buff) const {
@ -1208,7 +1209,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
total_cost, topsters[index], query_suggestion, groups_processed_vec[index],
seq_id, sort_order, field_values, geopoint_indices,
group_limit, group_by_fields, token_bits,
prioritize_exact_match, single_exact_query_token, its);
prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, its);
result_id_vecs[index].push_back(seq_id);
}, concurrency);
@ -1952,7 +1953,7 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string&
search_field(0, window_tokens, search_tokens, nullptr, 0, num_toks_dropped, field_it->second, field_name,
nullptr, 0, {}, {}, 2, searched_queries, topster, groups_processed,
&result_ids, result_ids_len, field_num_results, 0, group_by_fields,
false, 4, query_hashes, token_order, false, 0, 1, false, 3, 7, 4);
false, 4, query_hashes, token_order, false, 0, 1, false, -1, 3, 7, 4);
delete [] result_ids;
@ -2536,7 +2537,7 @@ void Index::search_fields(const std::vector<filter>& filters,
field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency,
query_hashes, token_order, field_prefix,
drop_tokens_threshold, typo_tokens_threshold, exhaustive_search,
drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, -1,
min_len_1typo, min_len_2typo, max_candidates);
do_infix_search(sort_fields_std, searched_queries, group_limit, group_by_fields,
@ -2568,8 +2569,7 @@ void Index::search_fields(const std::vector<filter>& filters,
field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency,
query_hashes, token_order, field_prefix,
0, 0, exhaustive_search,
min_len_1typo, min_len_2typo, max_candidates);
0, 0, exhaustive_search, -1, min_len_1typo, min_len_2typo, max_candidates);
}
}
@ -2588,12 +2588,14 @@ void Index::search_fields(const std::vector<filter>& filters,
}
// do synonym based searches
// since typos are disabled, we will use drop_tokens_threshold for typo_tokens_threshold as well
// otherwise, we can't support dropping of tokens here.
do_synonym_search(filters, included_ids_map, sort_fields_std, curated_topster, token_order,
drop_tokens_threshold, typo_tokens_threshold, group_limit,
drop_tokens_threshold, drop_tokens_threshold, group_limit,
group_by_fields, prioritize_exact_match, exhaustive_search, concurrency,
min_len_1typo, min_len_2typo,
max_candidates, curated_ids, curated_ids_sorted, exclude_token_ids,
exclude_token_ids_size, i, actual_filter_ids_length, field_num_typos, field_prefix,
exclude_token_ids_size, i, actual_filter_ids_length, 0, field_prefix,
field_id, field_name, field_it,
query_tokens, search_tokens, num_tokens_dropped, actual_topster, field_num_results,
field_query_tokens, all_result_ids_len, groups_processed, searched_queries,
@ -2703,6 +2705,7 @@ void Index::do_synonym_search(const std::vector<filter>& filters,
q_pos_synonyms.emplace_back(q_pos_syn);
}
int syn_orig_num_tokens = field_query_tokens[i].q_include_tokens.size();
bool syn_wildcard_filter_init_done = false;
for(const auto& syn_tokens: q_pos_synonyms) {
@ -2734,7 +2737,7 @@ void Index::do_synonym_search(const std::vector<filter>& filters,
field_num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
field_num_results, group_limit, group_by_fields, prioritize_exact_match, concurrency,
query_hashes, token_order, field_prefix,
drop_tokens_threshold, typo_tokens_threshold, exhaustive_search,
drop_tokens_threshold, typo_tokens_threshold, exhaustive_search, syn_orig_num_tokens,
min_len_1typo, min_len_2typo, max_candidates);
}
}
@ -2792,7 +2795,7 @@ void Index::do_infix_search(const std::vector<sort_by>& sort_fields_std,
score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 2,
actual_topster, {}, groups_processed, seq_id, sort_order, field_values,
geopoint_indices, group_limit, group_by_fields, token_bits,
false, false, {});
false, false, -1, {});
}
uint32_t* new_all_result_ids = nullptr;
@ -2928,7 +2931,7 @@ void Index::compute_facet_infos(const std::vector<facet>& facets, facet_query_t&
facet_field, facet_field.faceted_name(),
all_result_ids, all_result_ids_len, {}, {}, facet_query_num_typos, searched_queries, topster, groups_processed,
&field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields,
false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, 3, 1000, 4);
false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, -1, 3, 1000, 4);
//LOG(INFO) << "searched_queries.size: " << searched_queries.size();
@ -3100,7 +3103,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, false, 0,
topsters[thread_id], {}, tgroups_processed[thread_id], seq_id, sort_order, field_values,
geopoint_indices, group_limit, group_by_fields, token_bits,
false, false, plists);
false, false, -1, plists);
if(check_for_circuit_break && ((i + 1) % (1 << 15)) == 0) {
// check only once every 2^15 docs to reduce overhead
@ -3201,6 +3204,7 @@ void Index::search_field(const uint8_t & field_id,
const size_t drop_tokens_threshold,
const size_t typo_tokens_threshold,
const bool exhaustive_search,
int syn_orig_num_tokens,
size_t min_len_1typo,
size_t min_len_2typo,
const size_t max_candidates) const {
@ -3339,7 +3343,8 @@ void Index::search_field(const uint8_t & field_id,
curated_ids, sort_fields, token_candidates_vec, searched_queries, topster,
groups_processed, all_result_ids, all_result_ids_len, field_num_results,
typo_tokens_threshold, group_limit, group_by_fields, query_tokens,
prioritize_exact_match, combination_limit, concurrency, query_hashes, id_buff);
prioritize_exact_match, exhaustive_search, syn_orig_num_tokens,
concurrency, query_hashes, id_buff);
if(id_buff.size() > 1) {
std::sort(id_buff.begin(), id_buff.end());
@ -3395,7 +3400,7 @@ void Index::search_field(const uint8_t & field_id,
all_result_ids_len, field_num_results, group_limit, group_by_fields,
prioritize_exact_match, concurrency, query_hashes,
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold,
exhaustive_search, min_len_1typo, min_len_2typo, max_candidates);
exhaustive_search, syn_orig_num_tokens, min_len_1typo, min_len_2typo, max_candidates);
}
}
@ -3429,16 +3434,17 @@ void Index::log_leaves(const int cost, const std::string &token, const std::vect
void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16_t & query_index,
const uint8_t & field_id, const bool field_is_array, const uint32_t total_cost,
Topster* topster /**/,
Topster* topster,
const std::vector<art_leaf *> &query_suggestion,
spp::sparse_hash_set<uint64_t>& groups_processed /**/,
spp::sparse_hash_set<uint64_t>& groups_processed,
const uint32_t seq_id, const int sort_order[3],
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values /**/,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
const uint32_t token_bits,
const bool prioritize_exact_match,
const bool single_exact_query_token,
int syn_orig_num_tokens,
const std::vector<posting_list_t::iterator_t>& posting_lists) const {
int64_t geopoint_distances[3];
@ -3503,17 +3509,20 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
prioritize_exact_match && single_exact_query_token &&
posting_list_t::is_single_token_verbatim_match(posting_lists[0], field_is_array)
);
Match single_token_match = Match(1, 0, is_verbatim_match);
match_score = single_token_match.get_match_score(total_cost, 1);
size_t words_present = (syn_orig_num_tokens == -1) ? 1 : syn_orig_num_tokens;
size_t distance = (syn_orig_num_tokens == -1) ? 0 : syn_orig_num_tokens-1;
Match single_token_match = Match(words_present, distance, is_verbatim_match);
match_score = single_token_match.get_match_score(total_cost, words_present);
} else {
uint64_t total_tokens_found = 0, total_num_typos = 0, total_distance = 0, total_verbatim = 0;
std::unordered_map<size_t, std::vector<token_positions_t>> array_token_positions;
posting_list_t::get_offsets(posting_lists, array_token_positions);
// NOTE: tokens found returned by matcher is only within the best matched window, so we have to still consider
// unique tokens found if they are spread across the text.
uint32_t unique_tokens_found = __builtin_popcount(token_bits);
if(syn_orig_num_tokens != -1) {
unique_tokens_found = syn_orig_num_tokens;
}
for (const auto& kv: array_token_positions) {
const std::vector<token_positions_t>& token_positions = kv.second;
@ -3524,10 +3533,27 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match);
uint64_t this_match_score = match.get_match_score(total_cost, unique_tokens_found);
total_tokens_found += ((this_match_score >> 24) & 0xFF);
total_num_typos += 255 - ((this_match_score >> 16) & 0xFF);
total_distance += 100 - ((this_match_score >> 8) & 0xFF);
total_verbatim += (this_match_score & 0xFF);
auto this_words_present = ((this_match_score >> 24) & 0xFF);
auto typo_score = ((this_match_score >> 16) & 0xFF);
auto proximity = ((this_match_score >> 8) & 0xFF);
auto verbatim = (this_match_score & 0xFF);
if(syn_orig_num_tokens != -1) {
this_words_present = syn_orig_num_tokens;
proximity = 100 - (syn_orig_num_tokens - 1);
}
uint64_t mod_match_score = (
(int64_t(unique_tokens_found) << 32) |
(int64_t(this_words_present) << 24) |
(int64_t(typo_score) << 16) |
(int64_t(proximity) << 8) |
(int64_t(verbatim) << 0)
);
if(mod_match_score > match_score) {
match_score = mod_match_score;
}
/*std::ostringstream os;
os << name << ", total_cost: " << (255 - total_cost)
@ -3537,21 +3563,6 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
<< ", seq_id: " << seq_id << std::endl;
LOG(INFO) << os.str();*/
}
match_score = (
(uint64_t(unique_tokens_found) << 32) |
(uint64_t(total_tokens_found) << 24) |
(uint64_t(255 - total_num_typos) << 16) |
(uint64_t(100 - total_distance) << 8) |
(uint64_t(total_verbatim) << 0)
);
/*LOG(INFO) << "Match score: " << match_score << ", for seq_id: " << seq_id
<< " - total_tokens_found: " << total_tokens_found
<< " - total_num_typos: " << total_num_typos
<< " - total_distance: " << total_distance
<< " - total_verbatim: " << total_verbatim
<< " - total_cost: " << total_cost;*/
}
const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional)

View File

@ -2535,3 +2535,36 @@ TEST_F(CollectionSpecificTest, PhraseSearchOnLongText) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSpecificTest, RepeatedTokensInArray) {
// should have same text match score
std::vector<field> fields = {field("tags", field_types::STRING_ARRAY, false),};
Collection* coll1 = collectionManager.create_collection("coll1", 1, fields).get();
nlohmann::json doc1;
doc1["id"] = "0";
doc1["tags"] = {"Harry Mark"};
nlohmann::json doc2;
doc2["id"] = "1";
doc2["tags"] = {"Harry is random", "Harry Simpson"};
nlohmann::json doc3;
doc3["id"] = "2";
doc3["tags"] = {"Harry is Harry"};
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
ASSERT_TRUE(coll1->add(doc3.dump()).ok());
auto results = coll1->search("harry", {"tags"},
"", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ(results["hits"][0]["text_match"].get<size_t>(), results["hits"][1]["text_match"].get<size_t>());
ASSERT_EQ(results["hits"][1]["text_match"].get<size_t>(), results["hits"][2]["text_match"].get<size_t>());
collectionManager.drop_collection("coll1");
}

View File

@ -426,8 +426,8 @@ TEST_F(CollectionSynonymsTest, MultiWaySynonym) {
ASSERT_EQ(2, res["hits"].size());
ASSERT_EQ(2, res["found"].get<uint32_t>());
ASSERT_STREQ("<mark>Samuel</mark> <mark>L.</mark> <mark>Jackson</mark>", res["hits"][0]["highlights"][0]["snippet"].get<std::string>().c_str());
ASSERT_STREQ("<mark>Samuel</mark> <mark>L.</mark> <mark>Jackson</mark>", res["hits"][1]["highlights"][0]["snippet"].get<std::string>().c_str());
ASSERT_STREQ("<mark>Samuel</mark> L. <mark>Jackson</mark>", res["hits"][0]["highlights"][0]["snippet"].get<std::string>().c_str());
ASSERT_STREQ("<mark>Samuel</mark> L. <mark>Jackson</mark>", res["hits"][1]["highlights"][0]["snippet"].get<std::string>().c_str());
// for now we don't support synonyms on ANY prefix
@ -624,4 +624,64 @@ TEST_F(CollectionSynonymsTest, SynonymSingleTokenExactMatch) {
ASSERT_STREQ("2", res["hits"][0]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}
}
TEST_F(CollectionSynonymsTest, SynonymExpansionAndCompressionRanking) {
Collection *coll1;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("points", field_types::INT32, false),};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
std::vector<std::vector<std::string>> records = {
{"Smashed Lemon", "100"},
{"Lulu Lemon", "100"},
{"Lululemon", "200"},
};
for(size_t i=0; i<records.size(); i++) {
nlohmann::json doc;
doc["id"] = std::to_string(i);
doc["title"] = records[i][0];
doc["points"] = std::stoi(records[i][1]);
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
synonym_t synonym1{"syn-1", {"lululemon"}, {{"lulu", "lemon"}}};
coll1->add_synonym(synonym1);
auto res = coll1->search("lululemon", {"title"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(2, res["hits"].size());
ASSERT_EQ(2, res["found"].get<uint32_t>());
// Even thought "lulu lemon" has two token synonym match, it should have same text match score as "lululemon"
// and hence must be tied and then ranked on "points"
ASSERT_EQ("2", res["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ(res["hits"][0]["text_match"].get<size_t>(), res["hits"][1]["text_match"].get<size_t>());
// now with compression synonym
synonym1.root = {"lulu", "lemon"};
synonym1.synonyms = {{"lululemon"}};
coll1->add_synonym(synonym1);
res = coll1->search("lulu lemon", {"title"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(2, res["hits"].size());
ASSERT_EQ(2, res["found"].get<uint32_t>());
// Even thought "lululemon" has single token synonym match, it should have same text match score as "lulu lemon"
// and hence must be tied and then ranked on "points"
ASSERT_EQ("2", res["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", res["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ(res["hits"][0]["text_match"].get<size_t>(), res["hits"][1]["text_match"].get<size_t>());
collectionManager.drop_collection("coll1");
}