From d218db6487de4145b8aff12ef57c23ad874160f3 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 1 Mar 2022 16:57:17 +0530 Subject: [PATCH] Extract out cross-field aggregation and scoring. --- include/index.h | 5 + src/index.cpp | 387 ++++++++++++++++++++++++------------------------ 2 files changed, 202 insertions(+), 190 deletions(-) diff --git a/include/index.h b/include/index.h index 4e25cffd..c1e2fbbd 100644 --- a/include/index.h +++ b/include/index.h @@ -857,6 +857,11 @@ public: const std::vector& the_fields, size_t& all_result_ids_len, uint32_t*& all_result_ids, spp::sparse_hash_map>& topster_ids) const; + + void aggregate_and_score_fields(const std::vector& field_query_tokens, + const std::vector& the_fields, Topster* topster, + const size_t num_search_fields, + spp::sparse_hash_map>& topster_ids) const; }; template diff --git a/src/index.cpp b/src/index.cpp index a41500b8..8f50e983 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1153,12 +1153,12 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, //LOG(INFO) << "field_num_results: " << field_num_results << ", typo_tokens_threshold: " << typo_tokens_threshold; //LOG(INFO) << "n: " << n; - std::stringstream fullq; + /*std::stringstream fullq; for(const auto& qleaf : actual_query_suggestion) { std::string qtok(reinterpret_cast(qleaf->key),qleaf->key_len - 1); fullq << qtok << " "; } - LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost; + LOG(INFO) << "field: " << size_t(field_id) << ", query: " << fullq.str() << ", total_cost: " << total_cost;*/ // Prepare excluded document IDs that we can later remove from the result set uint32_t* excluded_result_ids = nullptr; @@ -2211,194 +2211,7 @@ void Index::search(std::vector& field_query_tokens, total_q_tokens += phrase.size(); }*/ - for(auto& seq_id_kvs: topster_ids) { - const uint64_t seq_id = seq_id_kvs.first; - auto& kvs = seq_id_kvs.second; // each `kv` can be from a different field - - std::sort(kvs.begin(), kvs.end(), Topster::is_greater); - - // kvs[0] will store query indices of the kv group (across fields) - kvs[0]->query_indices = new uint64_t[kvs.size() + 1]; - kvs[0]->query_indices[0] = kvs.size(); - - //LOG(INFO) << "DOC ID: " << seq_id << ", score: " << kvs[0]->scores[kvs[0]->match_score_index]; - - // to calculate existing aggregate scores across best matching fields - spp::sparse_hash_map existing_field_kvs; - for(size_t kv_i = 0; kv_i < kvs.size(); kv_i++) { - existing_field_kvs.emplace(kvs[kv_i]->field_id, kvs[kv_i]); - kvs[0]->query_indices[kv_i+1] = kvs[kv_i]->query_index; - /*LOG(INFO) << "kv_i: " << kv_i << ", kvs[kv_i]->query_index: " << kvs[kv_i]->query_index << ", " - << "searched_query: " << searched_queries[kvs[kv_i]->query_index][0];*/ - } - - uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set - uint64_t total_typos = 0, total_distances = 0, min_typos = 1000; - - uint64_t verbatim_match_fields = 0; // field value *exactly* same as query tokens - uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens - uint64_t max_weighted_tokens_match = 0; // weighted max number of tokens matched in a field - uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones) - - //LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits); - - for(size_t i = 0; i < num_search_fields; i++) { - const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i); - const size_t priority = the_fields[i].priority; - const size_t weight = the_fields[i].weight; - - //LOG(INFO) << "--- field index: " << i << ", priority: " << priority; - - if(existing_field_kvs.count(field_id) != 0) { - // for existing field, we will simply sum field-wise weighted scores - token_bits |= existing_field_kvs[field_id]->token_bits; - //LOG(INFO) << "existing_field_kvs.count pop count: " << __builtin_popcount(token_bits); - - int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index]; - - uint64_t tokens_found = ((match_score >> 24) & 0xFF); - uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF); - total_typos += (field_typos + 1) * priority; - total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority; - - int64_t exact_match_score = match_score & 0xFF; - verbatim_match_fields += (weight * exact_match_score); - - uint64_t unique_tokens_found = - int64_t(__builtin_popcount(existing_field_kvs[field_id]->token_bits)) - 1; - - if(field_typos == 0 && unique_tokens_found == field_query_tokens[i].q_include_tokens.size()) { - exact_match_fields += weight; - } - - auto weighted_tokens_match = (tokens_found * weight); - if(weighted_tokens_match > max_weighted_tokens_match) { - max_weighted_tokens_match = weighted_tokens_match; - } - - if(field_typos < min_typos) { - min_typos = field_typos; - } - - total_token_matches += (weight * tokens_found); - - /*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF)) - << ", weighted typos: " << std::max((255 - ((match_score >> 8) & 0xFF)), 1) * priority - << ", total dist: " << (((match_score & 0xFF))) - << ", weighted dist: " << std::max((100 - (match_score & 0xFF)), 1) * priority;*/ - continue; - } - - // compute approximate match score for this field from actual query - const std::string& field = the_fields[i].name; - size_t words_present = 0; - - // FIXME: must consider phrase tokens also - for(size_t token_index=0; token_index < field_query_tokens[i].q_include_tokens.size(); token_index++) { - const auto& token = field_query_tokens[i].q_include_tokens[token_index]; - const art_leaf* leaf = (art_leaf *) art_search(search_index.at(field), (const unsigned char*) token.c_str(), - token.length()+1); - - if(!leaf) { - continue; - } - - if(!posting_t::contains(leaf->values, seq_id)) { - continue; - } - - token_bits |= 1UL << token_index; // sets nth bit - //LOG(INFO) << "token_index: " << token_index << ", pop count: " << __builtin_popcount(token_bits); - - words_present += 1; - - /*if(!leaves.empty()) { - LOG(INFO) << "tok: " << leaves[0]->key; - }*/ - } - - if(words_present != 0) { - uint64_t match_score = Match::get_match_score(words_present, 0, 0); - - uint64_t tokens_found = ((match_score >> 24) & 0xFF); - uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF); - total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority; - total_typos += (field_typos + 1) * priority; - - if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) { - exact_match_fields += weight; - // not possible to calculate verbatim_match_fields accurately here, so we won't - } - - auto weighted_tokens_match = (tokens_found * weight); - - if(weighted_tokens_match > max_weighted_tokens_match) { - max_weighted_tokens_match = weighted_tokens_match; - } - - if(field_typos < min_typos) { - min_typos = field_typos; - } - - total_token_matches += (weight * tokens_found); - //LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF); - } - } - - // num tokens present across fields including those containing typos - int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1; - - // verbtaim match should not consider dropped-token cases - if(uniq_tokens_found != field_query_tokens[0].q_include_tokens.size()) { - // also check for synonyms - bool found_verbatim_syn = false; - for(const auto& synonym: field_query_tokens[0].q_synonyms) { - if(uniq_tokens_found == synonym.size()) { - found_verbatim_syn = true; - break; - } - } - - if(!found_verbatim_syn) { - verbatim_match_fields = 0; - } - } - - // protect most significant byte from overflow, since topster uses int64_t - verbatim_match_fields = std::min(INT8_MAX, verbatim_match_fields); - - exact_match_fields += verbatim_match_fields; - exact_match_fields = std::min(255, exact_match_fields); - max_weighted_tokens_match = std::min(255, max_weighted_tokens_match); - total_typos = std::min(255, total_typos); - total_distances = std::min(100, total_distances); - - uint64_t aggregated_score = ( - (exact_match_fields << 48) | // number of fields that contain *all tokens* in the query - (max_weighted_tokens_match << 40) | // weighted max number of tokens matched in a field - (uniq_tokens_found << 32) | // number of unique tokens found across fields including typos - ((255 - min_typos) << 24) | // minimum typo cost across all fields - (total_token_matches << 16) | // total matches across fields including typos - ((255 - total_typos) << 8) | // total typos across fields (weighted) - ((100 - total_distances) << 0) // total distances across fields (weighted) - ); - - //LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score; - - /*LOG(INFO) << "seq id: " << seq_id - << ", verbatim_match_fields: " << verbatim_match_fields - << ", exact_match_fields: " << exact_match_fields - << ", max_weighted_tokens_match: " << max_weighted_tokens_match - << ", uniq_tokens_found: " << uniq_tokens_found - << ", min typo score: " << (255 - min_typos) - << ", total_token_matches: " << total_token_matches - << ", typo score: " << (255 - total_typos) - << ", distance score: " << (100 - total_distances) - << ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/ - - kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score; - topster->add(kvs[0]); - } + aggregate_and_score_fields(field_query_tokens, the_fields, topster, num_search_fields, topster_ids); /*auto timeMillis0 = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin0).count(); @@ -2521,6 +2334,200 @@ void Index::search(std::vector& field_query_tokens, //LOG(INFO) << "Time taken for result calc: " << timeMillis << "ms"; } +void Index::aggregate_and_score_fields(const std::vector& field_query_tokens, + const std::vector& the_fields, Topster* topster, + const size_t num_search_fields, + spp::sparse_hash_map>& topster_ids) const { + for(auto& seq_id_kvs: topster_ids) { + const uint64_t seq_id = seq_id_kvs.first; + auto& kvs = seq_id_kvs.second; // each `kv` can be from a different field + + std::sort(kvs.begin(), kvs.end(), Topster::is_greater); + + // kvs[0] will store query indices of the kv group (across fields) + kvs[0]->query_indices = new uint64_t[kvs.size() + 1]; + kvs[0]->query_indices[0] = kvs.size(); + + //LOG(INFO) << "DOC ID: " << seq_id << ", score: " << kvs[0]->scores[kvs[0]->match_score_index]; + + // to calculate existing aggregate scores across best matching fields + spp::sparse_hash_map existing_field_kvs; + for(size_t kv_i = 0; kv_i < kvs.size(); kv_i++) { + existing_field_kvs.emplace(kvs[kv_i]->field_id, kvs[kv_i]); + kvs[0]->query_indices[kv_i+1] = kvs[kv_i]->query_index; + /*LOG(INFO) << "kv_i: " << kv_i << ", kvs[kv_i]->query_index: " << kvs[kv_i]->query_index << ", " + << "searched_query: " << searched_queries[kvs[kv_i]->query_index][0];*/ + } + + uint32_t token_bits = (uint32_t(1) << 31); // top most bit set to guarantee atleast 1 bit set + uint64_t total_typos = 0, total_distances = 0, min_typos = 1000; + + uint64_t verbatim_match_fields = 0; // field value *exactly* same as query tokens + uint64_t exact_match_fields = 0; // number of fields that contains all of query tokens + uint64_t max_weighted_tokens_match = 0; // weighted max number of tokens matched in a field + uint64_t total_token_matches = 0; // total matches across fields (including fuzzy ones) + + //LOG(INFO) << "Init pop count: " << __builtin_popcount(token_bits); + + for(size_t i = 0; i < num_search_fields; i++) { + const auto field_id = (uint8_t)(FIELD_LIMIT_NUM - i); + const size_t priority = the_fields[i].priority; + const size_t weight = the_fields[i].weight; + + //LOG(INFO) << "--- field index: " << i << ", priority: " << priority; + + if(existing_field_kvs.count(field_id) != 0) { + // for existing field, we will simply sum field-wise weighted scores + token_bits |= existing_field_kvs[field_id]->token_bits; + //LOG(INFO) << "existing_field_kvs.count pop count: " << __builtin_popcount(token_bits); + + int64_t match_score = existing_field_kvs[field_id]->scores[existing_field_kvs[field_id]->match_score_index]; + + uint64_t tokens_found = ((match_score >> 24) & 0xFF); + uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF); + total_typos += (field_typos + 1) * priority; + total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority; + + int64_t exact_match_score = match_score & 0xFF; + verbatim_match_fields += (weight * exact_match_score); + + uint64_t unique_tokens_found = + int64_t(__builtin_popcount(existing_field_kvs[field_id]->token_bits)) - 1; + + if(field_typos == 0 && unique_tokens_found == field_query_tokens[i].q_include_tokens.size()) { + exact_match_fields += weight; + } + + auto weighted_tokens_match = (tokens_found * weight); + if(weighted_tokens_match > max_weighted_tokens_match) { + max_weighted_tokens_match = weighted_tokens_match; + } + + if(field_typos < min_typos) { + min_typos = field_typos; + } + + total_token_matches += (weight * tokens_found); + + /*LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << (255 - ((match_score >> 8) & 0xFF)) + << ", weighted typos: " << std::max((255 - ((match_score >> 8) & 0xFF)), 1) * priority + << ", total dist: " << (((match_score & 0xFF))) + << ", weighted dist: " << std::max((100 - (match_score & 0xFF)), 1) * priority;*/ + continue; + } + + // compute approximate match score for this field from actual query + const std::string& field = the_fields[i].name; + size_t words_present = 0; + + // FIXME: must consider phrase tokens also + for(size_t token_index=0; token_index < field_query_tokens[i].q_include_tokens.size(); token_index++) { + const auto& token = field_query_tokens[i].q_include_tokens[token_index]; + const art_leaf* leaf = (art_leaf *) art_search(search_index.at(field), (const unsigned char*) token.c_str(), + token.length()+1); + + if(!leaf) { + continue; + } + + if(!posting_t::contains(leaf->values, seq_id)) { + continue; + } + + token_bits |= 1UL << token_index; // sets nth bit + //LOG(INFO) << "token_index: " << token_index << ", pop count: " << __builtin_popcount(token_bits); + + words_present += 1; + + /*if(!leaves.empty()) { + LOG(INFO) << "tok: " << leaves[0]->key; + }*/ + } + + if(words_present != 0) { + uint64_t match_score = Match::get_match_score(words_present, 0, 0); + + uint64_t tokens_found = ((match_score >> 24) & 0xFF); + uint64_t field_typos = 255 - ((match_score >> 16) & 0xFF); + total_distances += ((100 - ((match_score >> 8) & 0xFF)) + 1) * priority; + total_typos += (field_typos + 1) * priority; + + if(field_typos == 0 && tokens_found == field_query_tokens[i].q_include_tokens.size()) { + exact_match_fields += weight; + // not possible to calculate verbatim_match_fields accurately here, so we won't + } + + auto weighted_tokens_match = (tokens_found * weight); + + if(weighted_tokens_match > max_weighted_tokens_match) { + max_weighted_tokens_match = weighted_tokens_match; + } + + if(field_typos < min_typos) { + min_typos = field_typos; + } + + total_token_matches += (weight * tokens_found); + //LOG(INFO) << "seq_id: " << seq_id << ", total_typos: " << ((match_score >> 8) & 0xFF); + } + } + + // num tokens present across fields including those containing typos + int64_t uniq_tokens_found = int64_t(__builtin_popcount(token_bits)) - 1; + + // verbtaim match should not consider dropped-token cases + if(uniq_tokens_found != field_query_tokens[0].q_include_tokens.size()) { + // also check for synonyms + bool found_verbatim_syn = false; + for(const auto& synonym: field_query_tokens[0].q_synonyms) { + if(uniq_tokens_found == synonym.size()) { + found_verbatim_syn = true; + break; + } + } + + if(!found_verbatim_syn) { + verbatim_match_fields = 0; + } + } + + // protect most significant byte from overflow, since topster uses int64_t + verbatim_match_fields = std::min(INT8_MAX, verbatim_match_fields); + + exact_match_fields += verbatim_match_fields; + exact_match_fields = std::min(255, exact_match_fields); + max_weighted_tokens_match = std::min(255, max_weighted_tokens_match); + total_typos = std::min(255, total_typos); + total_distances = std::min(100, total_distances); + + uint64_t aggregated_score = ( + (exact_match_fields << 48) | // number of fields that contain *all tokens* in the query + (max_weighted_tokens_match << 40) | // weighted max number of tokens matched in a field + (uniq_tokens_found << 32) | // number of unique tokens found across fields including typos + ((255 - min_typos) << 24) | // minimum typo cost across all fields + (total_token_matches << 16) | // total matches across fields including typos + ((255 - total_typos) << 8) | // total typos across fields (weighted) + ((100 - total_distances) << 0) // total distances across fields (weighted) + ); + + //LOG(INFO) << "seq id: " << seq_id << ", aggregated_score: " << aggregated_score; + + /*LOG(INFO) << "seq id: " << seq_id + << ", verbatim_match_fields: " << verbatim_match_fields + << ", exact_match_fields: " << exact_match_fields + << ", max_weighted_tokens_match: " << max_weighted_tokens_match + << ", uniq_tokens_found: " << uniq_tokens_found + << ", min typo score: " << (255 - min_typos) + << ", total_token_matches: " << total_token_matches + << ", typo score: " << (255 - total_typos) + << ", distance score: " << (100 - total_distances) + << ", aggregated_score: " << aggregated_score << ", token_bits: " << token_bits;*/ + + kvs[0]->scores[kvs[0]->match_score_index] = aggregated_score; + topster->add(kvs[0]); + } +} + void Index::search_fields(const std::vector& filters, const std::map>& included_ids_map, const std::vector& sort_fields_std, const std::vector& num_typos,