mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 21:52:23 +08:00
Extract out cross-field aggregation and scoring.
This commit is contained in:
parent
87e8e7b0ce
commit
d218db6487
@ -857,6 +857,11 @@ public:
|
||||
const std::vector<search_field_t>& the_fields, size_t& all_result_ids_len,
|
||||
uint32_t*& all_result_ids,
|
||||
spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) const;
|
||||
|
||||
void aggregate_and_score_fields(const std::vector<query_tokens_t>& field_query_tokens,
|
||||
const std::vector<search_field_t>& the_fields, Topster* topster,
|
||||
const size_t num_search_fields,
|
||||
spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) const;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
387
src/index.cpp
387
src/index.cpp
@ -1153,12 +1153,12 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
|
||||
//LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold;
|
||||
//LOG(INFO) << "n: " << n;
|
||||
|
||||
std::stringstream fullq;
|
||||
/*std::stringstream fullq;
|
||||
for(const auto& qleaf : actual_query_suggestion) {
|
||||
std::string qtok(reinterpret_cast<char*>(qleaf->key),qleaf->key_len - 1);
|
||||
fullq << qtok << " ";
|
||||
}
|
||||
LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost;
|
||||
LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost;*/
|
||||
|
||||
// Prepare excluded document IDs that we can later remove from the result set
|
||||
uint32_t* excluded_result_ids = nullptr;
|
||||
@ -2211,194 +2211,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
|
||||
total_q_tokens += phrase.size();
|
||||
}*/
|
||||
|
||||
for(auto& seq_id_kvs: topster_ids) {
|
||||
const uint64_t seq_id = seq_id_kvs.first;
|
||||
auto& kvs = seq_id_kvs.second; // each `kv` can be from a different field
|
||||
|
||||
std::sort(kvs.begin(), kvs.end(), Topster::is_greater);
|
||||
|
||||
// kvs[0] will store query indices of the kv group (across fields)
|
||||
kvs[0]->query_indices = new uint64_t[kvs.size() + 1];
|
||||
kvs[0]->query_indices[0] = kvs.size();
|
||||
|
||||
//LOG(INFO) << "DOC ID: " << seq_id << ", score: " << kvs[0]->scores[kvs[0]->match_score_index];
|
||||
|
||||
// to calculate existing aggregate scores across best matching fields
|
||||
spp::sparse_hash_map<uint8_t, KV*> existing_field_kvs;
|
||||
for(size_t kv_i = 0; kv_i < kvs.size(); kv_i++) {
|
||||
existing_field_kvs.emplace(kvs[kv_i]->field_id, kvs[kv_i]);
|
||||
kvs[0]->query_indices[kv_i+1] = kvs[kv_i]->query_index;
|
||||
/*LOG(INFO) << "kv_i: " << kv_i << ", kvs[kv_i]->query_index: " << kvs[kv_i]->query_index << ", "
|
||||
<< "searched_query: " << searched_queries[kvs[kv_i]->query_index][0];*/
|
||||
}
|
||||
|
||||
uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set
|
||||
uint64_t total_typos = 0, total_distances = 0, min_typos = 1000;
|
||||
|
||||
uint64_t verbatim_match_fields = 0; // field value *exactly* same as query tokens
|
||||
uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens
|
||||
uint64_t max_weighted_tokens_match = 0; // weighted max number of tokens matched in a field
|
||||
uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones)
|
||||
|
||||
//LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits);
|
||||
|
||||
for(size_t i = 0; i < num_search_fields; i++) {
|
||||
const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i);
|
||||
const size_t priority = the_fields[i].priority;
|
||||
const size_t weight = the_fields[i].weight;
|
||||
|
||||
//LOG(INFO) << "--- field index: " << i << ", priority: " << priority;
|
||||
|
||||
if(existing_field_kvs.count(field_id) != 0) {
|
||||
// for existing field, we will simply sum field-wise weighted scores
|
||||
token_bits |= existing_field_kvs[field_id]->token_bits;
|
||||
//LOG(INFO) << "existing_field_kvs.count pop count: " << __builtin_popcount(token_bits);
|
||||
|
||||
int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index];
|
||||
|
||||
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
|
||||
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
|
||||
total_typos += (field_typos + 1) * priority;
|
||||
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
|
||||
|
||||
int64_t exact_match_score = match_score & 0xFF;
|
||||
verbatim_match_fields += (weight * exact_match_score);
|
||||
|
||||
uint64_t unique_tokens_found =
|
||||
int64_t(__builtin_popcount(existing_field_kvs[field_id]->token_bits)) - 1;
|
||||
|
||||
if(field_typos == 0 && unique_tokens_found == field_query_tokens[i].q_include_tokens.size()) {
|
||||
exact_match_fields += weight;
|
||||
}
|
||||
|
||||
auto weighted_tokens_match = (tokens_found * weight);
|
||||
if(weighted_tokens_match > max_weighted_tokens_match) {
|
||||
max_weighted_tokens_match = weighted_tokens_match;
|
||||
}
|
||||
|
||||
if(field_typos < min_typos) {
|
||||
min_typos = field_typos;
|
||||
}
|
||||
|
||||
total_token_matches += (weight * tokens_found);
|
||||
|
||||
/*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF))
|
||||
<< ", weighted typos: " << std::max<uint64_t>((255 - ((match_score >> 8) & 0xFF)), 1) * priority
|
||||
<< ", total dist: " << (((match_score & 0xFF)))
|
||||
<< ", weighted dist: " << std::max<uint64_t>((100 - (match_score & 0xFF)), 1) * priority;*/
|
||||
continue;
|
||||
}
|
||||
|
||||
// compute approximate match score for this field from actual query
|
||||
const std::string& field = the_fields[i].name;
|
||||
size_t words_present = 0;
|
||||
|
||||
// FIXME: must consider phrase tokens also
|
||||
for(size_t token_index=0; token_index < field_query_tokens[i].q_include_tokens.size(); token_index++) {
|
||||
const auto& token = field_query_tokens[i].q_include_tokens[token_index];
|
||||
const art_leaf* leaf = (art_leaf *) art_search(search_index.at(field), (const unsigned char*) token.c_str(),
|
||||
token.length()+1);
|
||||
|
||||
if(!leaf) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(!posting_t::contains(leaf->values, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
token_bits |= 1UL << token_index; // sets nth bit
|
||||
//LOG(INFO) << "token_index: " << token_index << ", pop count: " << __builtin_popcount(token_bits);
|
||||
|
||||
words_present += 1;
|
||||
|
||||
/*if(!leaves.empty()) {
|
||||
LOG(INFO) << "tok: " << leaves[0]->key;
|
||||
}*/
|
||||
}
|
||||
|
||||
if(words_present != 0) {
|
||||
uint64_t match_score = Match::get_match_score(words_present, 0, 0);
|
||||
|
||||
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
|
||||
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
|
||||
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
|
||||
total_typos += (field_typos + 1) * priority;
|
||||
|
||||
if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) {
|
||||
exact_match_fields += weight;
|
||||
// not possible to calculate verbatim_match_fields accurately here, so we won't
|
||||
}
|
||||
|
||||
auto weighted_tokens_match = (tokens_found * weight);
|
||||
|
||||
if(weighted_tokens_match > max_weighted_tokens_match) {
|
||||
max_weighted_tokens_match = weighted_tokens_match;
|
||||
}
|
||||
|
||||
if(field_typos < min_typos) {
|
||||
min_typos = field_typos;
|
||||
}
|
||||
|
||||
total_token_matches += (weight * tokens_found);
|
||||
//LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF);
|
||||
}
|
||||
}
|
||||
|
||||
// num tokens present across fields including those containing typos
|
||||
int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1;
|
||||
|
||||
// verbtaim match should not consider dropped-token cases
|
||||
if(uniq_tokens_found != field_query_tokens[0].q_include_tokens.size()) {
|
||||
// also check for synonyms
|
||||
bool found_verbatim_syn = false;
|
||||
for(const auto& synonym: field_query_tokens[0].q_synonyms) {
|
||||
if(uniq_tokens_found == synonym.size()) {
|
||||
found_verbatim_syn = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(!found_verbatim_syn) {
|
||||
verbatim_match_fields = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// protect most significant byte from overflow, since topster uses int64_t
|
||||
verbatim_match_fields = std::min<uint64_t>(INT8_MAX, verbatim_match_fields);
|
||||
|
||||
exact_match_fields += verbatim_match_fields;
|
||||
exact_match_fields = std::min<uint64_t>(255, exact_match_fields);
|
||||
max_weighted_tokens_match = std::min<uint64_t>(255, max_weighted_tokens_match);
|
||||
total_typos = std::min<uint64_t>(255, total_typos);
|
||||
total_distances = std::min<uint64_t>(100, total_distances);
|
||||
|
||||
uint64_t aggregated_score = (
|
||||
(exact_match_fields << 48) | // number of fields that contain *all tokens* in the query
|
||||
(max_weighted_tokens_match << 40) | // weighted max number of tokens matched in a field
|
||||
(uniq_tokens_found << 32) | // number of unique tokens found across fields including typos
|
||||
((255 - min_typos) << 24) | // minimum typo cost across all fields
|
||||
(total_token_matches << 16) | // total matches across fields including typos
|
||||
((255 - total_typos) << 8) | // total typos across fields (weighted)
|
||||
((100 - total_distances) << 0) // total distances across fields (weighted)
|
||||
);
|
||||
|
||||
//LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score;
|
||||
|
||||
/*LOG(INFO) << "seq id: " << seq_id
|
||||
<< ", verbatim_match_fields: " << verbatim_match_fields
|
||||
<< ", exact_match_fields: " << exact_match_fields
|
||||
<< ", max_weighted_tokens_match: " << max_weighted_tokens_match
|
||||
<< ", uniq_tokens_found: " << uniq_tokens_found
|
||||
<< ", min typo score: " << (255 - min_typos)
|
||||
<< ", total_token_matches: " << total_token_matches
|
||||
<< ", typo score: " << (255 - total_typos)
|
||||
<< ", distance score: " << (100 - total_distances)
|
||||
<< ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/
|
||||
|
||||
kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score;
|
||||
topster->add(kvs[0]);
|
||||
}
|
||||
aggregate_and_score_fields(field_query_tokens, the_fields, topster, num_search_fields, topster_ids);
|
||||
|
||||
/*auto timeMillis0 = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::high_resolution_clock::now() - begin0).count();
|
||||
@ -2521,6 +2334,200 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
|
||||
//LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms";
|
||||
}
|
||||
|
||||
void Index::aggregate_and_score_fields(const std::vector<query_tokens_t>& field_query_tokens,
|
||||
const std::vector<search_field_t>& the_fields, Topster* topster,
|
||||
const size_t num_search_fields,
|
||||
spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) const {
|
||||
for(auto& seq_id_kvs: topster_ids) {
|
||||
const uint64_t seq_id = seq_id_kvs.first;
|
||||
auto& kvs = seq_id_kvs.second; // each `kv` can be from a different field
|
||||
|
||||
std::sort(kvs.begin(), kvs.end(), Topster::is_greater);
|
||||
|
||||
// kvs[0] will store query indices of the kv group (across fields)
|
||||
kvs[0]->query_indices = new uint64_t[kvs.size() + 1];
|
||||
kvs[0]->query_indices[0] = kvs.size();
|
||||
|
||||
//LOG(INFO) << "DOC ID: " << seq_id << ", score: " << kvs[0]->scores[kvs[0]->match_score_index];
|
||||
|
||||
// to calculate existing aggregate scores across best matching fields
|
||||
spp::sparse_hash_map<uint8_t, KV*> existing_field_kvs;
|
||||
for(size_t kv_i = 0; kv_i < kvs.size(); kv_i++) {
|
||||
existing_field_kvs.emplace(kvs[kv_i]->field_id, kvs[kv_i]);
|
||||
kvs[0]->query_indices[kv_i+1] = kvs[kv_i]->query_index;
|
||||
/*LOG(INFO) << "kv_i: " << kv_i << ", kvs[kv_i]->query_index: " << kvs[kv_i]->query_index << ", "
|
||||
<< "searched_query: " << searched_queries[kvs[kv_i]->query_index][0];*/
|
||||
}
|
||||
|
||||
uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set
|
||||
uint64_t total_typos = 0, total_distances = 0, min_typos = 1000;
|
||||
|
||||
uint64_t verbatim_match_fields = 0; // field value *exactly* same as query tokens
|
||||
uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens
|
||||
uint64_t max_weighted_tokens_match = 0; // weighted max number of tokens matched in a field
|
||||
uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones)
|
||||
|
||||
//LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits);
|
||||
|
||||
for(size_t i = 0; i < num_search_fields; i++) {
|
||||
const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i);
|
||||
const size_t priority = the_fields[i].priority;
|
||||
const size_t weight = the_fields[i].weight;
|
||||
|
||||
//LOG(INFO) << "--- field index: " << i << ", priority: " << priority;
|
||||
|
||||
if(existing_field_kvs.count(field_id) != 0) {
|
||||
// for existing field, we will simply sum field-wise weighted scores
|
||||
token_bits |= existing_field_kvs[field_id]->token_bits;
|
||||
//LOG(INFO) << "existing_field_kvs.count pop count: " << __builtin_popcount(token_bits);
|
||||
|
||||
int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index];
|
||||
|
||||
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
|
||||
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
|
||||
total_typos += (field_typos + 1) * priority;
|
||||
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
|
||||
|
||||
int64_t exact_match_score = match_score & 0xFF;
|
||||
verbatim_match_fields += (weight * exact_match_score);
|
||||
|
||||
uint64_t unique_tokens_found =
|
||||
int64_t(__builtin_popcount(existing_field_kvs[field_id]->token_bits)) - 1;
|
||||
|
||||
if(field_typos == 0 && unique_tokens_found == field_query_tokens[i].q_include_tokens.size()) {
|
||||
exact_match_fields += weight;
|
||||
}
|
||||
|
||||
auto weighted_tokens_match = (tokens_found * weight);
|
||||
if(weighted_tokens_match > max_weighted_tokens_match) {
|
||||
max_weighted_tokens_match = weighted_tokens_match;
|
||||
}
|
||||
|
||||
if(field_typos < min_typos) {
|
||||
min_typos = field_typos;
|
||||
}
|
||||
|
||||
total_token_matches += (weight * tokens_found);
|
||||
|
||||
/*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF))
|
||||
<< ", weighted typos: " << std::max<uint64_t>((255 - ((match_score >> 8) & 0xFF)), 1) * priority
|
||||
<< ", total dist: " << (((match_score & 0xFF)))
|
||||
<< ", weighted dist: " << std::max<uint64_t>((100 - (match_score & 0xFF)), 1) * priority;*/
|
||||
continue;
|
||||
}
|
||||
|
||||
// compute approximate match score for this field from actual query
|
||||
const std::string& field = the_fields[i].name;
|
||||
size_t words_present = 0;
|
||||
|
||||
// FIXME: must consider phrase tokens also
|
||||
for(size_t token_index=0; token_index < field_query_tokens[i].q_include_tokens.size(); token_index++) {
|
||||
const auto& token = field_query_tokens[i].q_include_tokens[token_index];
|
||||
const art_leaf* leaf = (art_leaf *) art_search(search_index.at(field), (const unsigned char*) token.c_str(),
|
||||
token.length()+1);
|
||||
|
||||
if(!leaf) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(!posting_t::contains(leaf->values, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
token_bits |= 1UL << token_index; // sets nth bit
|
||||
//LOG(INFO) << "token_index: " << token_index << ", pop count: " << __builtin_popcount(token_bits);
|
||||
|
||||
words_present += 1;
|
||||
|
||||
/*if(!leaves.empty()) {
|
||||
LOG(INFO) << "tok: " << leaves[0]->key;
|
||||
}*/
|
||||
}
|
||||
|
||||
if(words_present != 0) {
|
||||
uint64_t match_score = Match::get_match_score(words_present, 0, 0);
|
||||
|
||||
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
|
||||
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
|
||||
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
|
||||
total_typos += (field_typos + 1) * priority;
|
||||
|
||||
if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) {
|
||||
exact_match_fields += weight;
|
||||
// not possible to calculate verbatim_match_fields accurately here, so we won't
|
||||
}
|
||||
|
||||
auto weighted_tokens_match = (tokens_found * weight);
|
||||
|
||||
if(weighted_tokens_match > max_weighted_tokens_match) {
|
||||
max_weighted_tokens_match = weighted_tokens_match;
|
||||
}
|
||||
|
||||
if(field_typos < min_typos) {
|
||||
min_typos = field_typos;
|
||||
}
|
||||
|
||||
total_token_matches += (weight * tokens_found);
|
||||
//LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF);
|
||||
}
|
||||
}
|
||||
|
||||
// num tokens present across fields including those containing typos
|
||||
int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1;
|
||||
|
||||
// verbtaim match should not consider dropped-token cases
|
||||
if(uniq_tokens_found != field_query_tokens[0].q_include_tokens.size()) {
|
||||
// also check for synonyms
|
||||
bool found_verbatim_syn = false;
|
||||
for(const auto& synonym: field_query_tokens[0].q_synonyms) {
|
||||
if(uniq_tokens_found == synonym.size()) {
|
||||
found_verbatim_syn = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(!found_verbatim_syn) {
|
||||
verbatim_match_fields = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// protect most significant byte from overflow, since topster uses int64_t
|
||||
verbatim_match_fields = std::min<uint64_t>(INT8_MAX, verbatim_match_fields);
|
||||
|
||||
exact_match_fields += verbatim_match_fields;
|
||||
exact_match_fields = std::min<uint64_t>(255, exact_match_fields);
|
||||
max_weighted_tokens_match = std::min<uint64_t>(255, max_weighted_tokens_match);
|
||||
total_typos = std::min<uint64_t>(255, total_typos);
|
||||
total_distances = std::min<uint64_t>(100, total_distances);
|
||||
|
||||
uint64_t aggregated_score = (
|
||||
(exact_match_fields << 48) | // number of fields that contain *all tokens* in the query
|
||||
(max_weighted_tokens_match << 40) | // weighted max number of tokens matched in a field
|
||||
(uniq_tokens_found << 32) | // number of unique tokens found across fields including typos
|
||||
((255 - min_typos) << 24) | // minimum typo cost across all fields
|
||||
(total_token_matches << 16) | // total matches across fields including typos
|
||||
((255 - total_typos) << 8) | // total typos across fields (weighted)
|
||||
((100 - total_distances) << 0) // total distances across fields (weighted)
|
||||
);
|
||||
|
||||
//LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score;
|
||||
|
||||
/*LOG(INFO) << "seq id: " << seq_id
|
||||
<< ", verbatim_match_fields: " << verbatim_match_fields
|
||||
<< ", exact_match_fields: " << exact_match_fields
|
||||
<< ", max_weighted_tokens_match: " << max_weighted_tokens_match
|
||||
<< ", uniq_tokens_found: " << uniq_tokens_found
|
||||
<< ", min typo score: " << (255 - min_typos)
|
||||
<< ", total_token_matches: " << total_token_matches
|
||||
<< ", typo score: " << (255 - total_typos)
|
||||
<< ", distance score: " << (100 - total_distances)
|
||||
<< ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/
|
||||
|
||||
kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score;
|
||||
topster->add(kvs[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void Index::search_fields(const std::vector<filter>& filters,
|
||||
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
|
||||
const std::vector<sort_by>& sort_fields_std, const std::vector<uint32_t>& num_typos,
|
||||
|
Loading…
x
Reference in New Issue
Block a user