Extract out cross-field aggregation and scoring.

This commit is contained in:
Kishore Nallan 2022-03-01 16:57:17 +05:30
parent 87e8e7b0ce
commit d218db6487
2 changed files with 202 additions and 190 deletions

View File

@ -857,6 +857,11 @@ public:
const std::vector<search_field_t>& the_fields, size_t& all_result_ids_len,
uint32_t*& all_result_ids,
spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) const;
void aggregate_and_score_fields(const std::vector<query_tokens_t>& field_query_tokens,
const std::vector<search_field_t>& the_fields, Topster* topster,
const size_t num_search_fields,
spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) const;
};
template<class T>

View File

@ -1153,12 +1153,12 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
//LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold;
//LOG(INFO) << "n: " << n;
std::stringstream fullq;
/*std::stringstream fullq;
for(const auto& qleaf : actual_query_suggestion) {
std::string qtok(reinterpret_cast<char*>(qleaf->key),qleaf->key_len - 1);
fullq << qtok << " ";
}
LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost;
LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost;*/
// Prepare excluded document IDs that we can later remove from the result set
uint32_t* excluded_result_ids = nullptr;
@ -2211,194 +2211,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
total_q_tokens += phrase.size();
}*/
for(auto& seq_id_kvs: topster_ids) {
const uint64_t seq_id = seq_id_kvs.first;
auto& kvs = seq_id_kvs.second; // each `kv` can be from a different field
std::sort(kvs.begin(), kvs.end(), Topster::is_greater);
// kvs[0] will store query indices of the kv group (across fields)
kvs[0]->query_indices = new uint64_t[kvs.size() + 1];
kvs[0]->query_indices[0] = kvs.size();
//LOG(INFO) << "DOC ID: " << seq_id << ", score: " << kvs[0]->scores[kvs[0]->match_score_index];
// to calculate existing aggregate scores across best matching fields
spp::sparse_hash_map<uint8_t, KV*> existing_field_kvs;
for(size_t kv_i = 0; kv_i < kvs.size(); kv_i++) {
existing_field_kvs.emplace(kvs[kv_i]->field_id, kvs[kv_i]);
kvs[0]->query_indices[kv_i+1] = kvs[kv_i]->query_index;
/*LOG(INFO) << "kv_i: " << kv_i << ", kvs[kv_i]->query_index: " << kvs[kv_i]->query_index << ", "
<< "searched_query: " << searched_queries[kvs[kv_i]->query_index][0];*/
}
uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set
uint64_t total_typos = 0, total_distances = 0, min_typos = 1000;
uint64_t verbatim_match_fields = 0; // field value *exactly* same as query tokens
uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens
uint64_t max_weighted_tokens_match = 0; // weighted max number of tokens matched in a field
uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones)
//LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits);
for(size_t i = 0; i < num_search_fields; i++) {
const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i);
const size_t priority = the_fields[i].priority;
const size_t weight = the_fields[i].weight;
//LOG(INFO) << "--- field index: " << i << ", priority: " << priority;
if(existing_field_kvs.count(field_id) != 0) {
// for existing field, we will simply sum field-wise weighted scores
token_bits |= existing_field_kvs[field_id]->token_bits;
//LOG(INFO) << "existing_field_kvs.count pop count: " << __builtin_popcount(token_bits);
int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index];
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
total_typos += (field_typos + 1) * priority;
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
int64_t exact_match_score = match_score & 0xFF;
verbatim_match_fields += (weight * exact_match_score);
uint64_t unique_tokens_found =
int64_t(__builtin_popcount(existing_field_kvs[field_id]->token_bits)) - 1;
if(field_typos == 0 && unique_tokens_found == field_query_tokens[i].q_include_tokens.size()) {
exact_match_fields += weight;
}
auto weighted_tokens_match = (tokens_found * weight);
if(weighted_tokens_match > max_weighted_tokens_match) {
max_weighted_tokens_match = weighted_tokens_match;
}
if(field_typos < min_typos) {
min_typos = field_typos;
}
total_token_matches += (weight * tokens_found);
/*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF))
<< ", weighted typos: " << std::max<uint64_t>((255 - ((match_score >> 8) & 0xFF)), 1) * priority
<< ", total dist: " << (((match_score & 0xFF)))
<< ", weighted dist: " << std::max<uint64_t>((100 - (match_score & 0xFF)), 1) * priority;*/
continue;
}
// compute approximate match score for this field from actual query
const std::string& field = the_fields[i].name;
size_t words_present = 0;
// FIXME: must consider phrase tokens also
for(size_t token_index=0; token_index < field_query_tokens[i].q_include_tokens.size(); token_index++) {
const auto& token = field_query_tokens[i].q_include_tokens[token_index];
const art_leaf* leaf = (art_leaf *) art_search(search_index.at(field), (const unsigned char*) token.c_str(),
token.length()+1);
if(!leaf) {
continue;
}
if(!posting_t::contains(leaf->values, seq_id)) {
continue;
}
token_bits |= 1UL << token_index; // sets nth bit
//LOG(INFO) << "token_index: " << token_index << ", pop count: " << __builtin_popcount(token_bits);
words_present += 1;
/*if(!leaves.empty()) {
LOG(INFO) << "tok: " << leaves[0]->key;
}*/
}
if(words_present != 0) {
uint64_t match_score = Match::get_match_score(words_present, 0, 0);
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
total_typos += (field_typos + 1) * priority;
if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) {
exact_match_fields += weight;
// not possible to calculate verbatim_match_fields accurately here, so we won't
}
auto weighted_tokens_match = (tokens_found * weight);
if(weighted_tokens_match > max_weighted_tokens_match) {
max_weighted_tokens_match = weighted_tokens_match;
}
if(field_typos < min_typos) {
min_typos = field_typos;
}
total_token_matches += (weight * tokens_found);
//LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF);
}
}
// num tokens present across fields including those containing typos
int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1;
// verbtaim match should not consider dropped-token cases
if(uniq_tokens_found != field_query_tokens[0].q_include_tokens.size()) {
// also check for synonyms
bool found_verbatim_syn = false;
for(const auto& synonym: field_query_tokens[0].q_synonyms) {
if(uniq_tokens_found == synonym.size()) {
found_verbatim_syn = true;
break;
}
}
if(!found_verbatim_syn) {
verbatim_match_fields = 0;
}
}
// protect most significant byte from overflow, since topster uses int64_t
verbatim_match_fields = std::min<uint64_t>(INT8_MAX, verbatim_match_fields);
exact_match_fields += verbatim_match_fields;
exact_match_fields = std::min<uint64_t>(255, exact_match_fields);
max_weighted_tokens_match = std::min<uint64_t>(255, max_weighted_tokens_match);
total_typos = std::min<uint64_t>(255, total_typos);
total_distances = std::min<uint64_t>(100, total_distances);
uint64_t aggregated_score = (
(exact_match_fields << 48) | // number of fields that contain *all tokens* in the query
(max_weighted_tokens_match << 40) | // weighted max number of tokens matched in a field
(uniq_tokens_found << 32) | // number of unique tokens found across fields including typos
((255 - min_typos) << 24) | // minimum typo cost across all fields
(total_token_matches << 16) | // total matches across fields including typos
((255 - total_typos) << 8) | // total typos across fields (weighted)
((100 - total_distances) << 0) // total distances across fields (weighted)
);
//LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score;
/*LOG(INFO) << "seq id: " << seq_id
<< ", verbatim_match_fields: " << verbatim_match_fields
<< ", exact_match_fields: " << exact_match_fields
<< ", max_weighted_tokens_match: " << max_weighted_tokens_match
<< ", uniq_tokens_found: " << uniq_tokens_found
<< ", min typo score: " << (255 - min_typos)
<< ", total_token_matches: " << total_token_matches
<< ", typo score: " << (255 - total_typos)
<< ", distance score: " << (100 - total_distances)
<< ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/
kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score;
topster->add(kvs[0]);
}
aggregate_and_score_fields(field_query_tokens, the_fields, topster, num_search_fields, topster_ids);
/*auto timeMillis0 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin0).count();
@ -2521,6 +2334,200 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens,
//LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms";
}
void Index::aggregate_and_score_fields(const std::vector<query_tokens_t>& field_query_tokens,
const std::vector<search_field_t>& the_fields, Topster* topster,
const size_t num_search_fields,
spp::sparse_hash_map<uint64_t, std::vector<KV*>>& topster_ids) const {
for(auto& seq_id_kvs: topster_ids) {
const uint64_t seq_id = seq_id_kvs.first;
auto& kvs = seq_id_kvs.second; // each `kv` can be from a different field
std::sort(kvs.begin(), kvs.end(), Topster::is_greater);
// kvs[0] will store query indices of the kv group (across fields)
kvs[0]->query_indices = new uint64_t[kvs.size() + 1];
kvs[0]->query_indices[0] = kvs.size();
//LOG(INFO) << "DOC ID: " << seq_id << ", score: " << kvs[0]->scores[kvs[0]->match_score_index];
// to calculate existing aggregate scores across best matching fields
spp::sparse_hash_map<uint8_t, KV*> existing_field_kvs;
for(size_t kv_i = 0; kv_i < kvs.size(); kv_i++) {
existing_field_kvs.emplace(kvs[kv_i]->field_id, kvs[kv_i]);
kvs[0]->query_indices[kv_i+1] = kvs[kv_i]->query_index;
/*LOG(INFO) << "kv_i: " << kv_i << ", kvs[kv_i]->query_index: " << kvs[kv_i]->query_index << ", "
<< "searched_query: " << searched_queries[kvs[kv_i]->query_index][0];*/
}
uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set
uint64_t total_typos = 0, total_distances = 0, min_typos = 1000;
uint64_t verbatim_match_fields = 0; // field value *exactly* same as query tokens
uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens
uint64_t max_weighted_tokens_match = 0; // weighted max number of tokens matched in a field
uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones)
//LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits);
for(size_t i = 0; i < num_search_fields; i++) {
const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i);
const size_t priority = the_fields[i].priority;
const size_t weight = the_fields[i].weight;
//LOG(INFO) << "--- field index: " << i << ", priority: " << priority;
if(existing_field_kvs.count(field_id) != 0) {
// for existing field, we will simply sum field-wise weighted scores
token_bits |= existing_field_kvs[field_id]->token_bits;
//LOG(INFO) << "existing_field_kvs.count pop count: " << __builtin_popcount(token_bits);
int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index];
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
total_typos += (field_typos + 1) * priority;
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
int64_t exact_match_score = match_score & 0xFF;
verbatim_match_fields += (weight * exact_match_score);
uint64_t unique_tokens_found =
int64_t(__builtin_popcount(existing_field_kvs[field_id]->token_bits)) - 1;
if(field_typos == 0 && unique_tokens_found == field_query_tokens[i].q_include_tokens.size()) {
exact_match_fields += weight;
}
auto weighted_tokens_match = (tokens_found * weight);
if(weighted_tokens_match > max_weighted_tokens_match) {
max_weighted_tokens_match = weighted_tokens_match;
}
if(field_typos < min_typos) {
min_typos = field_typos;
}
total_token_matches += (weight * tokens_found);
/*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF))
<< ", weighted typos: " << std::max<uint64_t>((255 - ((match_score >> 8) & 0xFF)), 1) * priority
<< ", total dist: " << (((match_score & 0xFF)))
<< ", weighted dist: " << std::max<uint64_t>((100 - (match_score & 0xFF)), 1) * priority;*/
continue;
}
// compute approximate match score for this field from actual query
const std::string& field = the_fields[i].name;
size_t words_present = 0;
// FIXME: must consider phrase tokens also
for(size_t token_index=0; token_index < field_query_tokens[i].q_include_tokens.size(); token_index++) {
const auto& token = field_query_tokens[i].q_include_tokens[token_index];
const art_leaf* leaf = (art_leaf *) art_search(search_index.at(field), (const unsigned char*) token.c_str(),
token.length()+1);
if(!leaf) {
continue;
}
if(!posting_t::contains(leaf->values, seq_id)) {
continue;
}
token_bits |= 1UL << token_index; // sets nth bit
//LOG(INFO) << "token_index: " << token_index << ", pop count: " << __builtin_popcount(token_bits);
words_present += 1;
/*if(!leaves.empty()) {
LOG(INFO) << "tok: " << leaves[0]->key;
}*/
}
if(words_present != 0) {
uint64_t match_score = Match::get_match_score(words_present, 0, 0);
uint64_t tokens_found = ((match_score >> 24) & 0xFF);
uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF);
total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority;
total_typos += (field_typos + 1) * priority;
if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) {
exact_match_fields += weight;
// not possible to calculate verbatim_match_fields accurately here, so we won't
}
auto weighted_tokens_match = (tokens_found * weight);
if(weighted_tokens_match > max_weighted_tokens_match) {
max_weighted_tokens_match = weighted_tokens_match;
}
if(field_typos < min_typos) {
min_typos = field_typos;
}
total_token_matches += (weight * tokens_found);
//LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF);
}
}
// num tokens present across fields including those containing typos
int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1;
// verbtaim match should not consider dropped-token cases
if(uniq_tokens_found != field_query_tokens[0].q_include_tokens.size()) {
// also check for synonyms
bool found_verbatim_syn = false;
for(const auto& synonym: field_query_tokens[0].q_synonyms) {
if(uniq_tokens_found == synonym.size()) {
found_verbatim_syn = true;
break;
}
}
if(!found_verbatim_syn) {
verbatim_match_fields = 0;
}
}
// protect most significant byte from overflow, since topster uses int64_t
verbatim_match_fields = std::min<uint64_t>(INT8_MAX, verbatim_match_fields);
exact_match_fields += verbatim_match_fields;
exact_match_fields = std::min<uint64_t>(255, exact_match_fields);
max_weighted_tokens_match = std::min<uint64_t>(255, max_weighted_tokens_match);
total_typos = std::min<uint64_t>(255, total_typos);
total_distances = std::min<uint64_t>(100, total_distances);
uint64_t aggregated_score = (
(exact_match_fields << 48) | // number of fields that contain *all tokens* in the query
(max_weighted_tokens_match << 40) | // weighted max number of tokens matched in a field
(uniq_tokens_found << 32) | // number of unique tokens found across fields including typos
((255 - min_typos) << 24) | // minimum typo cost across all fields
(total_token_matches << 16) | // total matches across fields including typos
((255 - total_typos) << 8) | // total typos across fields (weighted)
((100 - total_distances) << 0) // total distances across fields (weighted)
);
//LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score;
/*LOG(INFO) << "seq id: " << seq_id
<< ", verbatim_match_fields: " << verbatim_match_fields
<< ", exact_match_fields: " << exact_match_fields
<< ", max_weighted_tokens_match: " << max_weighted_tokens_match
<< ", uniq_tokens_found: " << uniq_tokens_found
<< ", min typo score: " << (255 - min_typos)
<< ", total_token_matches: " << total_token_matches
<< ", typo score: " << (255 - total_typos)
<< ", distance score: " << (100 - total_distances)
<< ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/
kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score;
topster->add(kvs[0]);
}
}
void Index::search_fields(const std::vector<filter>& filters,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
const std::vector<sort_by>& sort_fields_std, const std::vector<uint32_t>& num_typos,