mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 21:22:25 +08:00
More exhaustive multi-field ranking.
This commit is contained in:
parent
90ac320c93
commit
8f818f7fcb
@ -183,6 +183,7 @@ private:
|
||||
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
|
||||
uint32_t** all_result_ids,
|
||||
size_t & all_result_ids_len,
|
||||
size_t& field_num_results,
|
||||
const size_t typo_tokens_threshold);
|
||||
|
||||
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
|
||||
@ -242,6 +243,8 @@ public:
|
||||
std::vector<std::vector<KV*>> & override_result_kvs,
|
||||
const size_t typo_tokens_threshold);
|
||||
|
||||
static void concat_topster_ids(Topster* topster, spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids);
|
||||
|
||||
Option<uint32_t> do_filtering(uint32_t** filter_ids_out, const std::vector<filter> & filters);
|
||||
|
||||
Option<uint32_t> remove(const uint32_t seq_id, const nlohmann::json & document);
|
||||
|
@ -45,6 +45,17 @@ struct Match {
|
||||
|
||||
}
|
||||
|
||||
// Explicit construction of match score
|
||||
static inline uint64_t get_match_score(const uint32_t words_present, const uint32_t total_cost, const uint8_t distance,
|
||||
const uint8_t field_id) {
|
||||
|
||||
uint64_t match_score = ((int64_t) (words_present) << 24) |
|
||||
((int64_t) (255 - total_cost) << 16) |
|
||||
((int64_t) (distance) << 8) |
|
||||
((int64_t) (field_id));
|
||||
return match_score;
|
||||
}
|
||||
|
||||
// Construct a single match score from individual components (for multi-field sort)
|
||||
inline uint64_t get_match_score(const uint32_t total_cost, const uint8_t field_id) const {
|
||||
uint64_t match_score = ((int64_t) (words_present) << 24) |
|
||||
|
@ -88,9 +88,9 @@ struct Topster {
|
||||
}
|
||||
|
||||
bool add(KV* kv) {
|
||||
//LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->match_score;
|
||||
/*for(auto kv: kv_map) {
|
||||
LOG(INFO) << "kv key: " << kv.first << " => " << kv.second->match_score;
|
||||
/*LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->scores[kvs[0]->match_score_index];
|
||||
for(auto& mkv: kv_map) {
|
||||
LOG(INFO) << "kv key: " << mkv.first << " => " << mkv.second->scores[mkv.second->match_score_index];
|
||||
}*/
|
||||
|
||||
bool less_than_min_heap = (size >= MAX_SIZE) && is_smaller(kv, kvs[0]);
|
||||
|
@ -1290,13 +1290,13 @@ void Collection::highlight_result(const field &search_field,
|
||||
const std::string& highlight_end_tag,
|
||||
highlight_t & highlight) {
|
||||
|
||||
std::vector<uint32_t*> leaf_to_indices;
|
||||
std::vector<art_leaf*> query_suggestion;
|
||||
|
||||
if(searched_queries.size() <= field_order_kv->query_index) {
|
||||
return ;
|
||||
}
|
||||
|
||||
std::vector<uint32_t*> leaf_to_indices;
|
||||
std::vector<art_leaf*> query_suggestion;
|
||||
|
||||
for (const art_leaf *token_leaf : searched_queries[field_order_kv->query_index]) {
|
||||
// Must search for the token string fresh on that field for the given document since `token_leaf`
|
||||
// is from the best matched field and need not be present in other fields of a document.
|
||||
|
148
src/index.cpp
148
src/index.cpp
@ -828,6 +828,7 @@ void Index::search_candidates(const uint8_t & field_id,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries, Topster* topster,
|
||||
spp::sparse_hash_set<uint64_t>& groups_processed,
|
||||
uint32_t** all_result_ids, size_t & all_result_ids_len,
|
||||
size_t& field_num_results,
|
||||
const size_t typo_tokens_threshold) {
|
||||
const long long combination_limit = 10;
|
||||
|
||||
@ -901,7 +902,6 @@ void Index::search_candidates(const uint8_t & field_id,
|
||||
log_query << query_suggestion[i]->key << " ";
|
||||
}
|
||||
|
||||
|
||||
if(filter_ids != nullptr) {
|
||||
// intersect once again with filter ids
|
||||
uint32_t* filtered_result_ids = nullptr;
|
||||
@ -918,6 +918,8 @@ void Index::search_candidates(const uint8_t & field_id,
|
||||
score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, query_suggestion,
|
||||
groups_processed, filtered_result_ids, filtered_results_size);
|
||||
|
||||
field_num_results += filtered_results_size;
|
||||
|
||||
delete[] filtered_result_ids;
|
||||
delete[] result_ids;
|
||||
} else {
|
||||
@ -934,13 +936,15 @@ void Index::search_candidates(const uint8_t & field_id,
|
||||
LOG(INFO) << size_t(field_id) << " - " << log_query.str() << ", result_size: " << result_size;
|
||||
}*/
|
||||
|
||||
field_num_results += result_size;
|
||||
|
||||
delete[] result_ids;
|
||||
}
|
||||
|
||||
searched_queries.push_back(actual_query_suggestion);
|
||||
|
||||
//LOG(INFO) << "all_result_ids_len: " << all_result_ids_len << ", typo_tokens_threshold: " << typo_tokens_threshold;
|
||||
if(all_result_ids_len >= typo_tokens_threshold) {
|
||||
//LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold;
|
||||
if(field_num_results >= typo_tokens_threshold) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -1275,6 +1279,23 @@ void Index::collate_included_ids(const std::vector<std::string>& q_included_toke
|
||||
searched_queries.push_back(override_query);
|
||||
}
|
||||
|
||||
void Index::concat_topster_ids(Topster* topster, spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) {
|
||||
if(topster->distinct) {
|
||||
for(auto &group_topster_entry: topster->group_kv_map) {
|
||||
Topster* group_topster = group_topster_entry.second;
|
||||
for(const auto& map_kv: group_topster->kv_map) {
|
||||
topster_ids[map_kv.first].push_back(map_kv.second);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for(const auto& map_kv: topster->kv_map) {
|
||||
//LOG(INFO) << "map_kv.second.key: " << map_kv.second->key;
|
||||
//LOG(INFO) << "map_kv.first: " << map_kv.first;
|
||||
topster_ids[map_kv.first].push_back(map_kv.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Index::search(Option<uint32_t> & outcome,
|
||||
const std::vector<std::string>& q_include_tokens,
|
||||
const std::vector<std::string>& q_exclude_tokens,
|
||||
@ -1413,6 +1434,9 @@ void Index::search(Option<uint32_t> & outcome,
|
||||
all_result_ids = filter_ids;
|
||||
filter_ids = nullptr;
|
||||
} else {
|
||||
spp::sparse_hash_map<uint64_t, std::vector<KV*>> topster_ids;
|
||||
std::vector<Topster*> ftopsters;
|
||||
|
||||
// non-wildcard
|
||||
for(size_t i = 0; i < num_search_fields; i++) {
|
||||
// proceed to query search only when no filters are provided or when filtering produces results
|
||||
@ -1425,10 +1449,15 @@ void Index::search(Option<uint32_t> & outcome,
|
||||
size_t num_tokens_dropped = 0;
|
||||
|
||||
//LOG(INFO) << "searching field! " << field;
|
||||
Topster* ftopster = new Topster(topster->MAX_SIZE, topster->distinct);
|
||||
ftopsters.push_back(ftopster);
|
||||
|
||||
// Don't waste additional cycles for single field searches
|
||||
Topster* actual_topster = (num_search_fields == 1) ? topster : ftopster;
|
||||
|
||||
search_field(field_id, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped,
|
||||
field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
|
||||
num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len,
|
||||
num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
|
||||
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold);
|
||||
|
||||
// do synonym based searches
|
||||
@ -1439,16 +1468,88 @@ void Index::search(Option<uint32_t> & outcome,
|
||||
// for synonym we use a smaller field id than for original tokens
|
||||
search_field(field_id-1, query_tokens, search_tokens, exclude_token_ids, exclude_token_ids_size, num_tokens_dropped,
|
||||
field, filter_ids, filter_ids_length, curated_ids_sorted, facets, sort_fields_std,
|
||||
num_typos, searched_queries, topster, groups_processed, &all_result_ids, all_result_ids_len,
|
||||
num_typos, searched_queries, actual_topster, groups_processed, &all_result_ids, all_result_ids_len,
|
||||
token_order, prefix, drop_tokens_threshold, typo_tokens_threshold);
|
||||
}
|
||||
|
||||
concat_topster_ids(ftopster, topster_ids);
|
||||
collate_included_ids(q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries);
|
||||
//LOG(INFO) << "topster_ids.size: " << topster_ids.size();
|
||||
}
|
||||
}
|
||||
|
||||
for(const auto& key_kvs: topster_ids) {
|
||||
// first calculate existing aggregate scores across best matching fields
|
||||
spp::sparse_hash_map<uint8_t, KV*> existing_field_kvs;
|
||||
|
||||
const auto& kvs = key_kvs.second;
|
||||
const uint64_t seq_id = key_kvs.first;
|
||||
|
||||
//LOG(INFO) << "DOC ID: " << seq_id;
|
||||
|
||||
/*if(seq_id == 12 || seq_id == 15) {
|
||||
LOG(INFO) << "here";
|
||||
}*/
|
||||
|
||||
for(const auto kv: kvs) {
|
||||
existing_field_kvs.emplace(kv->field_id, kv);
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < num_search_fields && num_search_fields > 1; i++) {
|
||||
const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - (2*i)); // Order of `fields` used to sort results
|
||||
|
||||
if(field_id == kvs[0]->field_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(existing_field_kvs.count(field_id) != 0) {
|
||||
// for existing field, we will simply sum field-wise match scores
|
||||
kvs[0]->scores[kvs[0]->match_score_index] +=
|
||||
existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index];
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string & field = search_fields[i];
|
||||
|
||||
// compute approximate match score for this field from actual query
|
||||
|
||||
size_t words_present = 0;
|
||||
|
||||
for(size_t token_index=0; token_index < q_include_tokens.size(); token_index++) {
|
||||
const auto& token = q_include_tokens[token_index];
|
||||
|
||||
std::vector<art_leaf*> leaves;
|
||||
const bool prefix_search = prefix && (token_index == q_include_tokens.size()-1);
|
||||
const size_t token_len = prefix_search ? (int) token.length() : (int) token.length() + 1;
|
||||
art_fuzzy_search(search_index.at(field), (const unsigned char *) token.c_str(), token_len,
|
||||
0, 0, 1, token_order, prefix_search, leaves);
|
||||
|
||||
if(!leaves.empty() && leaves[0]->values->ids.contains(seq_id)) {
|
||||
words_present++;
|
||||
}
|
||||
|
||||
/*if(!leaves.empty()) {
|
||||
LOG(INFO) << "tok: " << leaves[0]->key;
|
||||
}*/
|
||||
}
|
||||
|
||||
if(words_present != 0) {
|
||||
uint64_t match_score = Match::get_match_score(words_present, 0, 100, field_id);
|
||||
kvs[0]->scores[kvs[0]->match_score_index] += match_score;
|
||||
}
|
||||
}
|
||||
|
||||
//LOG(INFO) << "kvs[0].key: " << kvs[0]->key;
|
||||
topster->add(kvs[0]);
|
||||
}
|
||||
|
||||
for(Topster* ftopster: ftopsters) {
|
||||
delete ftopster;
|
||||
}
|
||||
}
|
||||
|
||||
//LOG(INFO) << "topster size: " << topster->size;
|
||||
|
||||
delete [] exclude_token_ids;
|
||||
|
||||
do_facets(facets, facet_query, all_result_ids, all_result_ids_len);
|
||||
@ -1496,6 +1597,9 @@ void Index::search_field(const uint8_t & field_id,
|
||||
|
||||
const size_t max_cost = (num_typos < 0 || num_typos > 2) ? 2 : num_typos;
|
||||
|
||||
// tracks the number of results found for the current field
|
||||
size_t field_num_results = 0;
|
||||
|
||||
// To prevent us from doing ART search repeatedly as we iterate through possible corrections
|
||||
spp::sparse_hash_map<std::string, std::vector<art_leaf*>> token_cost_cache;
|
||||
|
||||
@ -1565,19 +1669,14 @@ void Index::search_field(const uint8_t & field_id,
|
||||
//log_leaves(costs[token_index], token, leaves);
|
||||
token_candidates_vec.push_back(token_candidates{token, costs[token_index], leaves});
|
||||
} else {
|
||||
// No result at `cost = costs[token_index]`. Remove costs until `cost` for token and re-do combinations
|
||||
// No result at `cost = costs[token_index]`. Remove `cost` for token and re-do combinations
|
||||
auto it = std::find(token_to_costs[token_index].begin(), token_to_costs[token_index].end(), costs[token_index]);
|
||||
if(it != token_to_costs[token_index].end()) {
|
||||
token_to_costs[token_index].erase(it);
|
||||
|
||||
// when no more costs are left for this token and `drop_tokens_threshold` is breached
|
||||
if(token_to_costs[token_index].empty() && all_result_ids_len >= drop_tokens_threshold) {
|
||||
n = combination_limit; // to break outer loop
|
||||
break;
|
||||
}
|
||||
|
||||
// otherwise, we try to drop the token and search with remaining tokens
|
||||
// when no more costs are left for this token
|
||||
if(token_to_costs[token_index].empty()) {
|
||||
// we can try to drop the token and search with remaining tokens
|
||||
token_to_costs.erase(token_to_costs.begin()+token_index);
|
||||
search_tokens.erase(search_tokens.begin()+token_index);
|
||||
query_tokens.erase(query_tokens.begin()+token_index);
|
||||
@ -1585,32 +1684,35 @@ void Index::search_field(const uint8_t & field_id,
|
||||
}
|
||||
}
|
||||
|
||||
// To continue outerloop on new cost combination
|
||||
// Continue outerloop on new cost combination
|
||||
n = -1;
|
||||
N = std::accumulate(token_to_costs.begin(), token_to_costs.end(), 1LL, product);
|
||||
break;
|
||||
goto resume_typo_loop;
|
||||
}
|
||||
|
||||
token_index++;
|
||||
}
|
||||
|
||||
if(!token_candidates_vec.empty() && token_candidates_vec.size() == search_tokens.size()) {
|
||||
// If all tokens were found, go ahead and search for candidates with what we have so far
|
||||
if(!token_candidates_vec.empty()) {
|
||||
// If atleast one token is found, go ahead and search for candidates
|
||||
search_candidates(field_id, filter_ids, filter_ids_length, exclude_token_ids, exclude_token_ids_size,
|
||||
curated_ids, sort_fields, token_candidates_vec, searched_queries, topster,
|
||||
groups_processed, all_result_ids, all_result_ids_len, typo_tokens_threshold);
|
||||
groups_processed, all_result_ids, all_result_ids_len, field_num_results,
|
||||
typo_tokens_threshold);
|
||||
}
|
||||
|
||||
if (all_result_ids_len >= typo_tokens_threshold) {
|
||||
// If we don't find enough results, we continue outerloop (looking at tokens with greater typo cost)
|
||||
break;
|
||||
resume_typo_loop:
|
||||
|
||||
if(field_num_results >= drop_tokens_threshold || field_num_results >= typo_tokens_threshold) {
|
||||
// if either threshold is breached, we are done
|
||||
return ;
|
||||
}
|
||||
|
||||
n++;
|
||||
}
|
||||
|
||||
// When there are not enough overall results and atleast one token has results
|
||||
if(all_result_ids_len < drop_tokens_threshold && !query_tokens.empty() && num_tokens_dropped < query_tokens.size()) {
|
||||
// When atleast one token from the query is available
|
||||
if(!query_tokens.empty() && num_tokens_dropped < query_tokens.size()) {
|
||||
// Drop tokens from right until (len/2 + 1), and then from left until (len/2 + 1)
|
||||
|
||||
std::vector<std::string> truncated_tokens;
|
||||
|
@ -307,8 +307,8 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) {
|
||||
|
||||
ASSERT_STREQ("country", results["grouped_hits"][2]["group_key"][0].get<std::string>().c_str());
|
||||
ASSERT_EQ(2, results["grouped_hits"][2]["hits"].size());
|
||||
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("8", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
@ -2659,6 +2659,53 @@ TEST_F(CollectionTest, MultiFieldRelevance) {
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, MultiFieldMatchRanking) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("artist", field_types::STRING, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1");
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> records = {
|
||||
{"Style", "Taylor Swift"},
|
||||
{"Blank Space", "Taylor Swift"},
|
||||
{"Balance Overkill", "Taylor Swift"},
|
||||
{"Cardigan", "Taylor Swift"},
|
||||
{"Invisible String", "Taylor Swift"},
|
||||
{"The Last Great American Dynasty", "Taylor Swift"},
|
||||
{"Mirrorball", "Taylor Swift"},
|
||||
{"Peace", "Taylor Swift"},
|
||||
{"Betty", "Taylor Swift"},
|
||||
{"Mad Woman", "Taylor Swift"},
|
||||
};
|
||||
|
||||
for(size_t i=0; i<records.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
|
||||
doc["id"] = std::to_string(i);
|
||||
doc["title"] = records[i][0];
|
||||
doc["artist"] = records[i][1];
|
||||
doc["points"] = i;
|
||||
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
auto results = coll1->search("taylor swift style",
|
||||
{"artist", "title"}, "", {}, {}, 0, 3, 1, FREQUENCY, true, 5).get();
|
||||
|
||||
LOG(INFO) << results;
|
||||
|
||||
ASSERT_EQ(10, results["found"].get<size_t>());
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, HighlightWithAccentedCharacters) {
|
||||
Collection *coll1;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user