More exhaustive multi-field ranking.

This commit is contained in:
kishorenc 2020-12-20 18:34:41 +05:30
parent 90ac320c93
commit 8f818f7fcb
7 changed files with 194 additions and 31 deletions

View File

@ -183,6 +183,7 @@ private:
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
uint32_t** all_result_ids,
size_t & all_result_ids_len,
size_t& field_num_results,
const size_t typo_tokens_threshold);
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
@ -242,6 +243,8 @@ public:
std::vector<std::vector<KV*>> & override_result_kvs,
const size_t typo_tokens_threshold);
static void concat_topster_ids(Topster* topster, spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids);
Option<uint32_t> do_filtering(uint32_t** filter_ids_out, const std::vector<filter> & filters);
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document);

View File

@ -45,6 +45,17 @@ struct Match {
}
// Explicit construction of match score
static inline uint64_t get_match_score(const uint32_t words_present, const uint32_t total_cost, const uint8_t distance,
const uint8_t field_id) {
uint64_t match_score = ((int64_t) (words_present) << 24) |
((int64_t) (255 - total_cost) << 16) |
((int64_t) (distance) << 8) |
((int64_t) (field_id));
return match_score;
}
// Construct a single match score from individual components (for multi-field sort)
inline uint64_t get_match_score(const uint32_t total_cost, const uint8_t field_id) const {
uint64_t match_score = ((int64_t) (words_present) << 24) |

View File

@ -88,9 +88,9 @@ struct Topster {
}
bool add(KV* kv) {
//LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->match_score;
/*for(auto kv: kv_map) {
LOG(INFO) << "kv key: " << kv.first << " => " << kv.second->match_score;
/*LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->scores[kvs[0]->match_score_index];
for(auto& mkv: kv_map) {
LOG(INFO) << "kv key: " << mkv.first << " => " << mkv.second->scores[mkv.second->match_score_index];
}*/
bool less_than_min_heap = (size >= MAX_SIZE) && is_smaller(kv, kvs[0]);

View File

@ -1290,13 +1290,13 @@ void Collection::highlight_result(const field &search_field,
const std::string& highlight_end_tag,
highlight_t & highlight) {
std::vector<uint32_t*> leaf_to_indices;
std::vector<art_leaf*> query_suggestion;
if(searched_queries.size() <= field_order_kv->query_index) {
return ;
}
std::vector<uint32_t*> leaf_to_indices;
std::vector<art_leaf*> query_suggestion;
for (const art_leaf *token_leaf : searched_queries[field_order_kv->query_index]) {
// Must search for the token string fresh on that field for the given document since `token_leaf`
// is from the best matched field and need not be present in other fields of a document.

View File

@ -828,6 +828,7 @@ void Index::search_candidates(const uint8_t & field_id,
std::vector<std::vector<art_leaf*>> & searched_queries, Topster* topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
uint32_t** all_result_ids, size_t & all_result_ids_len,
size_t& field_num_results,
const size_t typo_tokens_threshold) {
const long long combination_limit = 10;
@ -901,7 +902,6 @@ void Index::search_candidates(const uint8_t & field_id,
log_query << query_suggestion[i]->key << " ";
}
if(filter_ids != nullptr) {
// intersect once again with filter ids
uint32_t* filtered_result_ids = nullptr;
@ -918,6 +918,8 @@ void Index::search_candidates(const uint8_t & field_id,
score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion,
groups_processed, filtered_result_ids, filtered_results_size);
field_num_results += filtered_results_size;
delete[] filtered_result_ids;
delete[] result_ids;
} else {
@ -934,13 +936,15 @@ void Index::search_candidates(const uint8_t & field_id,
LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size;
}*/
field_num_results += result_size;
delete[] result_ids;
}
searched_queries.push_back(actual_query_suggestion);
//LOG(INFO) << "all_result_ids_len: " << all_result_ids_len << ", typo_tokens_threshold: " << typo_tokens_threshold;
if(all_result_ids_len >= typo_tokens_threshold) {
//LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold;
if(field_num_results >= typo_tokens_threshold) {
break;
}
}
@ -1275,6 +1279,23 @@ void Index::collate_included_ids(const std::vector<std::string>& q_included_toke
searched_queries.push_back(override_query);
}
void Index::concat_topster_ids(Topster* topster, spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) {
if(topster->distinct) {
for(auto &group_topster_entry: topster->group_kv_map) {
Topster* group_topster = group_topster_entry.second;
for(const auto& map_kv: group_topster->kv_map) {
topster_ids[map_kv.first].push_back(map_kv.second);
}
}
} else {
for(const auto& map_kv: topster->kv_map) {
//LOG(INFO) << "map_kv.second.key: " << map_kv.second->key;
//LOG(INFO) << "map_kv.first: " << map_kv.first;
topster_ids[map_kv.first].push_back(map_kv.second);
}
}
}
void Index::search(Option<uint32_t> & outcome,
const std::vector<std::string>& q_include_tokens,
const std::vector<std::string>& q_exclude_tokens,
@ -1413,6 +1434,9 @@ void Index::search(Option<uint32_t> & outcome,
all_result_ids = filter_ids;
filter_ids = nullptr;
} else {
spp::sparse_hash_map<uint64_t, std::vector<KV*>> topster_ids;
std::vector<Topster*> ftopsters;
// non-wildcard
for(size_t i = 0; i < num_search_fields; i++) {
// proceed to query search only when no filters are provided or when filtering produces results
@ -1425,10 +1449,15 @@ void Index::search(Option<uint32_t> & outcome,
size_t num_tokens_dropped = 0;
//LOG(INFO) << "searching field! " << field;
Topster* ftopster = new Topster(topster->MAX_SIZE, topster->distinct);
ftopsters.push_back(ftopster);
// Don't waste additional cycles for single field searches
Topster* actual_topster = (num_search_fields == 1) ? topster : ftopster;
search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped,
field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len,
num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold);
// do synonym based searches
@ -1439,16 +1468,88 @@ void Index::search(Option<uint32_t> & outcome,
// for synonym we use a smaller field id than for original tokens
search_field(field_id-1, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped,
field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len,
num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold);
}
concat_topster_ids(ftopster, topster_ids);
collate_included_ids(q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries);
//LOG(INFO) << "topster_ids.size: " << topster_ids.size();
}
}
for(const auto& key_kvs: topster_ids) {
// first calculate existing aggregate scores across best matching fields
spp::sparse_hash_map<uint8_t, KV*> existing_field_kvs;
const auto& kvs = key_kvs.second;
const uint64_t seq_id = key_kvs.first;
//LOG(INFO) << "DOC ID: " << seq_id;
/*if(seq_id == 12 || seq_id == 15) {
LOG(INFO) << "here";
}*/
for(const auto kv: kvs) {
existing_field_kvs.emplace(kv->field_id, kv);
}
for(size_t i = 0; i < num_search_fields && num_search_fields > 1; i++) {
const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - (2*i)); // Order of `fields` used to sort results
if(field_id == kvs[0]->field_id) {
continue;
}
if(existing_field_kvs.count(field_id) != 0) {
// for existing field, we will simply sum field-wise match scores
kvs[0]->scores[kvs[0]->match_score_index] +=
existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index];
continue;
}
const std::string & field = search_fields[i];
// compute approximate match score for this field from actual query
size_t words_present = 0;
for(size_t token_index=0; token_index < q_include_tokens.size(); token_index++) {
const auto& token = q_include_tokens[token_index];
std::vector<art_leaf*> leaves;
const bool prefix_search = prefix && (token_index == q_include_tokens.size()-1);
const size_t token_len = prefix_search ? (int) token.length() : (int) token.length() + 1;
art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len,
0, 0, 1, token_order, prefix_search, leaves);
if(!leaves.empty() && leaves[0]->values->ids.contains(seq_id)) {
words_present++;
}
/*if(!leaves.empty()) {
LOG(INFO) << "tok: " << leaves[0]->key;
}*/
}
if(words_present != 0) {
uint64_t match_score = Match::get_match_score(words_present, 0, 100, field_id);
kvs[0]->scores[kvs[0]->match_score_index] += match_score;
}
}
//LOG(INFO) << "kvs[0].key: " << kvs[0]->key;
topster->add(kvs[0]);
}
for(Topster* ftopster: ftopsters) {
delete ftopster;
}
}
//LOG(INFO) << "topster size: " << topster->size;
delete [] exclude_token_ids;
do_facets(facets, facet_query, all_result_ids, all_result_ids_len);
@ -1496,6 +1597,9 @@ void Index::search_field(const uint8_t & field_id,
const size_t max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos;
// tracks the number of results found for the current field
size_t field_num_results = 0;
// To prevent us from doing ART search repeatedly as we iterate through possible corrections
spp::sparse_hash_map<std::string, std::vector<art_leaf*>> token_cost_cache;
@ -1565,19 +1669,14 @@ void Index::search_field(const uint8_t & field_id,
//log_leaves(costs[token_index], token, leaves);
token_candidates_vec.push_back(token_candidates{token, costs[token_index], leaves});
} else {
// No result at `cost = costs[token_index]`. Remove costs until `cost` for token and re-do combinations
// No result at `cost = costs[token_index]`. Remove `cost` for token and re-do combinations
auto it = std::find(token_to_costs[token_index].begin(), token_to_costs[token_index].end(), costs[token_index]);
if(it != token_to_costs[token_index].end()) {
token_to_costs[token_index].erase(it);
// when no more costs are left for this token and `drop_tokens_threshold` is breached
if(token_to_costs[token_index].empty() && all_result_ids_len >= drop_tokens_threshold) {
n = combination_limit; // to break outer loop
break;
}
// otherwise, we try to drop the token and search with remaining tokens
// when no more costs are left for this token
if(token_to_costs[token_index].empty()) {
// we can try to drop the token and search with remaining tokens
token_to_costs.erase(token_to_costs.begin()+token_index);
search_tokens.erase(search_tokens.begin()+token_index);
query_tokens.erase(query_tokens.begin()+token_index);
@ -1585,32 +1684,35 @@ void Index::search_field(const uint8_t & field_id,
}
}
// To continue outerloop on new cost combination
// Continue outerloop on new cost combination
n = -1;
N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product);
break;
goto resume_typo_loop;
}
token_index++;
}
if(!token_candidates_vec.empty() && token_candidates_vec.size() == search_tokens.size()) {
// If all tokens were found, go ahead and search for candidates with what we have so far
if(!token_candidates_vec.empty()) {
// If atleast one token is found, go ahead and search for candidates
search_candidates(field_id, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size,
curated_ids, sort_fields, token_candidates_vec, searched_queries, topster,
groups_processed, all_result_ids, all_result_ids_len, typo_tokens_threshold);
groups_processed, all_result_ids, all_result_ids_len, field_num_results,
typo_tokens_threshold);
}
if (all_result_ids_len >= typo_tokens_threshold) {
// If we don't find enough results, we continue outerloop (looking at tokens with greater typo cost)
break;
resume_typo_loop:
if(field_num_results >= drop_tokens_threshold || field_num_results >= typo_tokens_threshold) {
// if either threshold is breached, we are done
return ;
}
n++;
}
// When there are not enough overall results and atleast one token has results
if(all_result_ids_len < drop_tokens_threshold && !query_tokens.empty() && num_tokens_dropped < query_tokens.size()) {
// When atleast one token from the query is available
if(!query_tokens.empty() && num_tokens_dropped < query_tokens.size()) {
// Drop tokens from right until (len/2 + 1), and then from left until (len/2 + 1)
std::vector<std::string> truncated_tokens;

View File

@ -307,8 +307,8 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) {
ASSERT_STREQ("country", results["grouped_hits"][2]["group_key"][0].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][2]["hits"].size());
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
collectionManager.drop_collection("coll1");
}

View File

@ -2659,6 +2659,53 @@ TEST_F(CollectionTest, MultiFieldRelevance) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionTest, MultiFieldMatchRanking) {
Collection *coll1;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("artist", field_types::STRING, false),
field("points", field_types::INT32, false),};
coll1 = collectionManager.get_collection("coll1");
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
std::vector<std::vector<std::string>> records = {
{"Style", "Taylor Swift"},
{"Blank Space", "Taylor Swift"},
{"Balance Overkill", "Taylor Swift"},
{"Cardigan", "Taylor Swift"},
{"Invisible String", "Taylor Swift"},
{"The Last Great American Dynasty", "Taylor Swift"},
{"Mirrorball", "Taylor Swift"},
{"Peace", "Taylor Swift"},
{"Betty", "Taylor Swift"},
{"Mad Woman", "Taylor Swift"},
};
for(size_t i=0; i<records.size(); i++) {
nlohmann::json doc;
doc["id"] = std::to_string(i);
doc["title"] = records[i][0];
doc["artist"] = records[i][1];
doc["points"] = i;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto results = coll1->search("taylor swift style",
{"artist", "title"}, "", {}, {}, 0, 3, 1, FREQUENCY, true, 5).get();
LOG(INFO) << results;
ASSERT_EQ(10, results["found"].get<size_t>());
ASSERT_EQ(3, results["hits"].size());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionTest, HighlightWithAccentedCharacters) {
Collection *coll1;