From 0e2adb4242f16ea1c8755f18eac973f4ab42f8c4 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 17 Aug 2021 18:37:42 +0530 Subject: [PATCH] Copy-free intersect + score. --- include/collection.h | 2 - include/index.h | 11 +- include/posting.h | 23 +- include/posting_list.h | 110 +++++++-- src/collection.cpp | 44 ++-- src/index.cpp | 452 +++++++++++++++++-------------------- src/posting.cpp | 29 +-- src/posting_list.cpp | 264 ++++++++++------------ test/posting_list_test.cpp | 83 +++---- 9 files changed, 510 insertions(+), 508 deletions(-) diff --git a/include/collection.h b/include/collection.h index 4d499e84..4d3d5c74 100644 --- a/include/collection.h +++ b/include/collection.h @@ -371,8 +371,6 @@ private: return std::tie(a_count, a_value_size) > std::tie(b_count, b_value_size); } - void free_leaf_indices(std::vector& leaf_to_indices) const; - Option parse_filter_query(const std::string& simple_filter_query, std::vector& filters) const; static Option parse_geopoint_filter_value(std::string& raw_value, diff --git a/include/index.h b/include/index.h index 0bf56522..fc9367e5 100644 --- a/include/index.h +++ b/include/index.h @@ -191,6 +191,11 @@ private: StringUtils string_utils; + // used as sentinels + + static spp::sparse_hash_map text_match_sentinel_value; + static spp::sparse_hash_map seq_id_sentinel_value; + // Internal utility functions static inline uint32_t next_suggestion(const std::vector &token_candidates_vec, @@ -321,8 +326,10 @@ public: void score_results(const std::vector &sort_fields, const uint16_t &query_index, const uint8_t &field_id, const uint32_t total_cost, Topster *topster, const std::vector &query_suggestion, spp::sparse_hash_set &groups_processed, - const std::vector>>& array_token_positions_vec, - const uint32_t* result_ids, size_t result_ids_size, + const std::unordered_map>& array_token_positions, + const uint32_t seq_id, const int sort_order[3], + std::array*, 3> field_values, + const std::vector& geopoint_indices, const size_t group_limit, const std::vector &group_by_fields, uint32_t token_bits, const std::vector &query_tokens, diff --git a/include/posting.h b/include/posting.h index 767fab1f..cd9d48aa 100644 --- a/include/posting.h +++ b/include/posting.h @@ -48,25 +48,21 @@ private: public: struct block_intersector_t { - size_t batch_size; std::vector its; std::vector plists; std::vector expanded_plist_indices; - posting_list_t::result_iter_state_t& iter_state; block_intersector_t(const std::vector& raw_posting_lists, - size_t batch_size, posting_list_t::result_iter_state_t& iter_state): - batch_size(batch_size), iter_state(iter_state) { + iter_state(iter_state) { + to_expanded_plists(raw_posting_lists, plists, expanded_plist_indices); its.reserve(plists.size()); for(const auto& posting_list: plists) { its.push_back(posting_list->new_iterator()); } - - iter_state.num_lists = plists.size(); } ~block_intersector_t() { @@ -75,9 +71,8 @@ public: } } - bool intersect() { - return posting_list_t::block_intersect(plists, batch_size, its, iter_state);; - } + template + bool intersect(T func); }; static void upsert(void*& obj, uint32_t id, const std::vector& offsets); @@ -101,6 +96,12 @@ public: static void get_array_token_positions( uint32_t id, const std::vector& posting_lists, - std::vector>>& array_token_positions_vec + std::unordered_map>& array_token_positions ); -}; \ No newline at end of file +}; + +template +bool posting_t::block_intersector_t::intersect(T func) { + posting_list_t::block_intersect(plists, its, iter_state, func); + return true; +} diff --git a/include/posting_list.h b/include/posting_list.h index 1ca4ba9d..39c20354 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -47,27 +47,41 @@ public: block_t* curr_block; uint32_t curr_index; - // uncompressed data structures for performance - block_t* uncompressed_block; - uint32_t* ids; public: + // uncompressed data structures for performance + uint32_t* ids = nullptr; + uint32_t* offset_index = nullptr; + uint32_t* offsets = nullptr; + explicit iterator_t(block_t* root); iterator_t(iterator_t&& rhs) noexcept; ~iterator_t(); - [[nodiscard]] inline bool valid() const; - void inline next(); + [[nodiscard]] bool valid() const; + void next(); void skip_to(uint32_t id); - [[nodiscard]] inline uint32_t id(); + [[nodiscard]] uint32_t id() const; [[nodiscard]] inline uint32_t index() const; [[nodiscard]] inline block_t* block() const; }; struct result_iter_state_t { - size_t num_lists; - std::vector blocks; - std::vector indices; - std::vector ids; + uint32_t* excluded_result_ids = nullptr; + size_t excluded_result_ids_size = 0; + uint32_t* filter_ids = nullptr; + size_t filter_ids_length = 0; + + size_t excluded_result_ids_index = 0; + size_t filter_ids_index = 0; + + std::vector>> array_token_positions_vec; + + result_iter_state_t() = default; + + result_iter_state_t(uint32_t* excluded_result_ids, size_t excluded_result_ids_size, uint32_t* filter_ids, + size_t filter_ids_length) : excluded_result_ids(excluded_result_ids), + excluded_result_ids_size(excluded_result_ids_size), + filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} }; private: @@ -134,15 +148,83 @@ public: static void intersect(const std::vector& posting_lists, std::vector& result_ids); + template static bool block_intersect( const std::vector& posting_lists, - size_t batch_size, std::vector& its, - result_iter_state_t& iter_state + result_iter_state_t& istate, + T func ); + static bool take_id(result_iter_state_t& istate, uint32_t id); + static bool get_offsets( - result_iter_state_t& iter_state, - std::vector>>& array_token_positions + std::vector& its, + std::unordered_map>& array_token_pos ); }; + +template +bool posting_list_t::block_intersect(const std::vector& posting_lists, + std::vector& its, + result_iter_state_t& istate, + T func) { + + if(posting_lists.empty()) { + return false; + } + + if(its.empty()) { + its.reserve(posting_lists.size()); + + for(const auto& posting_list: posting_lists) { + its.push_back(posting_list->new_iterator()); + } + + } else { + // already in the middle of iteration: prepare for next batch + + } + + size_t num_lists = its.size(); + + switch (num_lists) { + case 1: + while(its[0].valid()) { + if(posting_list_t::take_id(istate, its[0].id())) { + func(its[0].id(), its); + } + + its[0].next(); + } + break; + case 2: + while(!at_end2(its)) { + if(equals2(its)) { + if(posting_list_t::take_id(istate, its[0].id())) { + func(its[0].id(), its); + } + + advance_all2(its); + } else { + advance_non_largest2(its); + } + } + break; + default: + while(!at_end(its)) { + if(equals(its)) { + //LOG(INFO) << its[0].id(); + if(posting_list_t::take_id(istate, its[0].id())) { + func(its[0].id(), its); + } + + advance_all(its); + } else { + advance_non_largest(its); + } + } + } + + return false; +} diff --git a/src/collection.cpp b/src/collection.cpp index 9f206af2..683ad911 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1150,9 +1150,14 @@ Option Collection::search(const std::string & query, const std:: bool highlighted_fully = (fields_highlighted_fully.find(field_name) != fields_highlighted_fully.end()); highlight_t highlight; + //LOG(INFO) << "Highlighting: " << document; + /*if(document["title"] == "Quantum Quest: A Cassini Space Odyssey") { + LOG(INFO) << "here!"; + }*/ highlight_result(search_field, searched_queries, q_tokens, field_order_kv, document, string_utils, snippet_threshold, highlight_affix_num_tokens, highlighted_fully, highlight_start_tag, highlight_end_tag, highlight); + //LOG(INFO) << "End"; if(!highlight.snippets.empty()) { highlights.push_back(highlight); @@ -1547,7 +1552,7 @@ void Collection::highlight_result(const field &search_field, Index* index = indices[field_order_kv->key % num_memory_shards]; art_leaf* actual_leaf = index->get_token_leaf(search_field.name, &token_leaf->key[0], token_leaf->key_len); - if(actual_leaf != nullptr) { + if(actual_leaf != nullptr && posting_t::contains(actual_leaf->values, field_order_kv->key)) { query_suggestion.push_back(actual_leaf); query_suggestion_tokens.insert(token); //LOG(INFO) << "field: " << search_field.name << ", key: " << token; @@ -1557,7 +1562,9 @@ void Collection::highlight_result(const field &search_field, qindex++; } while(field_order_kv->query_indices != nullptr && qindex < field_order_kv->query_indices[0]); - for(const std::string& q_token: q_tokens) { + for(size_t i = 0; i < q_tokens.size(); i++) { + const std::string& q_token = q_tokens[i]; + if(query_suggestion_tokens.count(q_token) != 0) { continue; } @@ -1566,32 +1573,29 @@ void Collection::highlight_result(const field &search_field, art_leaf *actual_leaf = index->get_token_leaf(search_field.name, reinterpret_cast(q_token.c_str()), q_token.size() + 1); - if(actual_leaf != nullptr) { + + if(actual_leaf != nullptr && posting_t::contains(actual_leaf->values, field_order_kv->key)) { query_suggestion.push_back(actual_leaf); query_suggestion_tokens.insert(q_token); + } else if(i == q_tokens.size()-1) { + // we will copy the last token for highlighting prefix matches + query_suggestion_tokens.insert(q_token); } } - if(query_suggestion.empty()) { + if(query_suggestion_tokens.empty()) { // none of the tokens from the query were found on this field return ; } - //LOG(INFO) << "Document ID: " << document["id"]; - std::vector posting_lists; for(art_leaf* leaf: query_suggestion) { posting_lists.push_back(leaf->values); } - std::vector>> array_token_positions_vec; - posting_t::get_array_token_positions(field_order_kv->key, posting_lists, array_token_positions_vec); + std::unordered_map> array_token_positions; + posting_t::get_array_token_positions(field_order_kv->key, posting_lists, array_token_positions); - if(array_token_positions_vec.empty()) { - return; - } - - std::unordered_map>& array_token_positions = array_token_positions_vec[0]; std::vector match_indices; for(const auto& kv: array_token_positions) { @@ -1639,6 +1643,14 @@ void Collection::highlight_result(const field &search_field, } } + if(!document.contains(search_field.name)) { + // could be an optional field + continue; + } + + /*LOG(INFO) << "field: " << document[search_field.name] << ", id: " << field_order_kv->key + << ", index: " << match_index.index;*/ + std::string text = (search_field.type == field_types::STRING) ? document[search_field.name] : document[search_field.name][match_index.index]; Tokenizer tokenizer(text, true, false, search_field.locale); @@ -1793,12 +1805,6 @@ void Collection::highlight_result(const field &search_field, highlight.match_score = match_indices[0].match_score; } -void Collection::free_leaf_indices(std::vector& leaf_to_indices) const { - for(uint32_t* leaf_indices: leaf_to_indices) { - delete [] leaf_indices; - } -} - Option Collection::get(const std::string & id) const { std::string seq_id_str; StoreStatus seq_id_status = store->get(get_doc_id_key(id), seq_id_str); diff --git a/src/index.cpp b/src/index.cpp index e426c62f..3ff5aad3 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -20,6 +20,9 @@ #include #include "logger.h" +spp::sparse_hash_map Index::text_match_sentinel_value; +spp::sparse_hash_map Index::seq_id_sentinel_value; + Index::Index(const std::string name, const std::unordered_map & search_schema, std::map facet_schema, std::unordered_map sort_schema): name(name), search_schema(search_schema), facet_schema(facet_schema), sort_schema(sort_schema) { @@ -615,6 +618,8 @@ void Index::index_string_array_field(const std::vector & strings, c last_token = token; } + //LOG(INFO) << "Str: " << str << ", last_token: " << last_token; + if(token_set.empty()) { continue; } @@ -896,6 +901,29 @@ void Index::search_candidates(const uint8_t & field_id, auto product = []( long long a, token_candidates & b ) { return a*b.candidates.size(); }; long long int N = std::accumulate(token_candidates_vec.begin(), token_candidates_vec.end(), 1LL, product); + int sort_order[3]; // 1 or -1 based on DESC or ASC respectively + std::array*, 3> field_values; + std::vector geopoint_indices; + + for (size_t i = 0; i < sort_fields.size(); i++) { + sort_order[i] = 1; + if (sort_fields[i].order == sort_field_const::asc) { + sort_order[i] = -1; + } + + if (sort_fields[i].name == sort_field_const::text_match) { + field_values[i] = &text_match_sentinel_value; + } else if (sort_fields[i].name == sort_field_const::seq_id) { + field_values[i] = &seq_id_sentinel_value; + } else if (sort_index.count(sort_fields[i].name) != 0) { + field_values[i] = sort_index.at(sort_fields[i].name); + + if (sort_schema.at(sort_fields[i].name).is_geopoint()) { + geopoint_indices.push_back(i); + } + } + } + for(long long n=0; n query_suggestion(token_candidates_vec.size()); @@ -917,12 +945,6 @@ void Index::search_candidates(const uint8_t & field_id, } LOG(INFO) << fullq.str();*/ - // initialize results with the starting element (for further intersection) - size_t result_size = posting_t::num_ids(query_suggestion[0]->values); - if(result_size == 0) { - continue; - } - // Prepare excluded document IDs that we can later remove from the result set uint32_t* excluded_result_ids = nullptr; size_t excluded_result_ids_size = ArrayUtils::or_scalar(exclude_token_ids, exclude_token_ids_size, @@ -934,95 +956,42 @@ void Index::search_candidates(const uint8_t & field_id, posting_lists.push_back(query_leaf->values); } - std::vector its; - posting_list_t::result_iter_state_t iter_state; + posting_list_t::result_iter_state_t iter_state( + excluded_result_ids, excluded_result_ids_size, filter_ids, filter_ids_length + ); + + // We will have to be judicious about computing full match score: only when token does not match exact query + bool use_single_token_score = (query_suggestion.size() == 1) && + (query_suggestion.size() == query_tokens.size()) && + ((std::string((const char*)query_suggestion[0]->key, query_suggestion[0]->key_len-1) != query_tokens[0].value)); + std::vector result_id_vec; + posting_t::block_intersector_t intersector(posting_lists, iter_state); - size_t excluded_result_ids_index = 0; - size_t filter_ids_index = 0; - - posting_t::block_intersector_t intersector(posting_lists, 1000, iter_state); - bool has_more = true; - - while(has_more) { - has_more = intersector.intersect(); - posting_list_t::result_iter_state_t updated_iter_state; - size_t id_block_index = 0; - - for(size_t i = 0; i < iter_state.ids.size(); i++) { - uint32_t id = iter_state.ids[i]; - - // decide if this result id should be excluded - if(excluded_result_ids_size != 0) { - while(excluded_result_ids_index < excluded_result_ids_size && - excluded_result_ids[excluded_result_ids_index] < id) { - excluded_result_ids_index++; - } - - if(excluded_result_ids_index < excluded_result_ids_size && - id == excluded_result_ids[excluded_result_ids_index]) { - excluded_result_ids_index++; - continue; - } - } - - bool id_found_in_filter = true; - - // decide if this result be matched with filter results - if(filter_ids_length != 0) { - id_found_in_filter = false; - - // e.g. [1, 3] vs [2, 3] - - while(filter_ids_index < filter_ids_length && filter_ids[filter_ids_index] < id) { - filter_ids_index++; - } - - if(filter_ids_index < filter_ids_length && filter_ids[filter_ids_index] == id) { - filter_ids_index++; - id_found_in_filter = true; - } - } - - if(id_found_in_filter) { - result_id_vec.push_back(id); - - updated_iter_state.num_lists = iter_state.num_lists; - updated_iter_state.ids.push_back(id); - - for(size_t k = 0; k < iter_state.num_lists; k++) { - updated_iter_state.blocks.push_back(iter_state.blocks[id_block_index]); - updated_iter_state.indices.push_back(iter_state.indices[id_block_index++]); - } - } - } - - // We will have to be judicious about computing full match score: only when token does not match exact query - bool use_single_token_score = (query_suggestion.size() == 1) && - (query_suggestion.size() == query_tokens.size()) && - ((std::string((const char*)query_suggestion[0]->key, query_suggestion[0]->key_len-1) != query_tokens[0].value)); - - std::vector>> array_token_positions_vec; + intersector.intersect([&](uint32_t seq_id, std::vector& its) { + std::unordered_map> array_token_positions; if(!use_single_token_score) { - posting_list_t::get_offsets(updated_iter_state, array_token_positions_vec); + posting_list_t::get_offsets(its, array_token_positions); } score_results(sort_fields, (uint16_t) searched_queries.size(), field_id, total_cost, topster, - query_suggestion, groups_processed, array_token_positions_vec, - &updated_iter_state.ids[0], updated_iter_state.ids.size(), + query_suggestion, groups_processed, array_token_positions, + seq_id, sort_order, field_values, geopoint_indices, group_limit, group_by_fields, token_bits, query_tokens, use_single_token_score, prioritize_exact_match); - } + + result_id_vec.push_back(seq_id); + }); if(result_id_vec.empty()) { continue; } uint32_t* result_ids = &result_id_vec[0]; - result_size = result_id_vec.size(); + size_t result_size = result_id_vec.size(); - field_num_results += result_id_vec.size(); + field_num_results += result_size; uint32_t* new_all_result_ids = nullptr; all_result_ids_len = ArrayUtils::or_scalar(*all_result_ids, all_result_ids_len, result_ids, @@ -1571,10 +1540,41 @@ void Index::search(const std::vector& field_query_tokens, filter_ids = excluded_result_ids; } + + // FIXME: duplicated + int sort_order[3]; // 1 or -1 based on DESC or ASC respectively + std::array*, 3> field_values; + std::vector geopoint_indices; + + for (size_t i = 0; i < sort_fields_std.size(); i++) { + sort_order[i] = 1; + if (sort_fields_std[i].order == sort_field_const::asc) { + sort_order[i] = -1; + } + + if (sort_fields_std[i].name == sort_field_const::text_match) { + field_values[i] = &text_match_sentinel_value; + } else if (sort_fields_std[i].name == sort_field_const::seq_id) { + field_values[i] = &seq_id_sentinel_value; + } else if (sort_index.count(sort_fields_std[i].name) != 0) { + field_values[i] = sort_index.at(sort_fields_std[i].name); + + if (sort_schema.at(sort_fields_std[i].name).is_geopoint()) { + geopoint_indices.push_back(i); + } + } + } + uint32_t token_bits = 255; - score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {}, - groups_processed, {}, filter_ids, filter_ids_length, group_limit, group_by_fields, token_bits, {}, - true, prioritize_exact_match); + + for(size_t i = 0; i < filter_ids_length; i++) { + const uint32_t seq_id = filter_ids[i]; + score_results(sort_fields_std, (uint16_t) searched_queries.size(), field_id, 0, topster, {}, + groups_processed, {}, seq_id, sort_order, field_values, geopoint_indices, + group_limit, group_by_fields, token_bits, {}, + true, prioritize_exact_match); + } + collate_included_ids(field_query_tokens[0].q_include_tokens, field, field_id, included_ids_map, curated_topster, searched_queries); all_result_ids_len = filter_ids_length; @@ -2063,191 +2063,163 @@ void Index::score_results(const std::vector & sort_fields, const uint16 const uint8_t & field_id, const uint32_t total_cost, Topster* topster, const std::vector &query_suggestion, spp::sparse_hash_set& groups_processed, - const std::vector>>& array_token_positions_vec, - const uint32_t* result_ids, size_t result_ids_size, + const std::unordered_map>& array_token_positions, + const uint32_t seq_id, const int sort_order[3], + std::array*, 3> field_values, + const std::vector& geopoint_indices, const size_t group_limit, const std::vector& group_by_fields, uint32_t token_bits, const std::vector& query_tokens, bool use_single_token_score, bool prioritize_exact_match) const { - int sort_order[3]; // 1 or -1 based on DESC or ASC respectively - spp::sparse_hash_map* field_values[3]; - + spp::sparse_hash_map* TEXT_MATCH_SENTINEL = &text_match_sentinel_value; + spp::sparse_hash_map* SEQ_ID_SENTINEL = &seq_id_sentinel_value; spp::sparse_hash_map geopoint_distances[3]; - spp::sparse_hash_map text_match_sentinel_value, seq_id_sentinel_value; - spp::sparse_hash_map *TEXT_MATCH_SENTINEL = &text_match_sentinel_value; - spp::sparse_hash_map *SEQ_ID_SENTINEL = &seq_id_sentinel_value; + for(auto& i: geopoint_indices) { + spp::sparse_hash_map* geopoints = field_values[i]; - for (size_t i = 0; i < sort_fields.size(); i++) { - sort_order[i] = 1; - if (sort_fields[i].order == sort_field_const::asc) { - sort_order[i] = -1; + S2LatLng reference_lat_lng; + GeoPoint::unpack_lat_lng(sort_fields[i].geopoint, reference_lat_lng); + + auto it = geopoints->find(seq_id); + int64_t dist = INT32_MAX; + + if(it != geopoints->end()) { + int64_t packed_latlng = it->second; + S2LatLng s2_lat_lng; + GeoPoint::unpack_lat_lng(packed_latlng, s2_lat_lng); + dist = GeoPoint::distance(s2_lat_lng, reference_lat_lng); } - if (sort_fields[i].name == sort_field_const::text_match) { - field_values[i] = TEXT_MATCH_SENTINEL; - } else if (sort_fields[i].name == sort_field_const::seq_id) { - field_values[i] = SEQ_ID_SENTINEL; - } else if (sort_schema.at(sort_fields[i].name).is_geopoint()) { - // we have to populate distances that will be used for match scoring - spp::sparse_hash_map *geopoints = sort_index.at(sort_fields[i].name); - - S2LatLng reference_lat_lng; - GeoPoint::unpack_lat_lng(sort_fields[i].geopoint, reference_lat_lng); - - for (size_t rindex = 0; rindex < result_ids_size; rindex++) { - const uint32_t seq_id = result_ids[rindex]; - auto it = geopoints->find(seq_id); - int64_t dist = INT32_MAX; - - if(it != geopoints->end()) { - int64_t packed_latlng = it->second; - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(packed_latlng, s2_lat_lng); - dist = GeoPoint::distance(s2_lat_lng, reference_lat_lng); - } - - if(dist < sort_fields[i].exclude_radius) { - dist = 0; - } - - if(sort_fields[i].geo_precision > 0) { - dist = dist + sort_fields[i].geo_precision - 1 - - (dist + sort_fields[i].geo_precision - 1) % sort_fields[i].geo_precision; - } - - geopoint_distances[i].emplace(seq_id, dist); - } - - field_values[i] = &geopoint_distances[i]; - } else { - field_values[i] = sort_index.at(sort_fields[i].name); + if(dist < sort_fields[i].exclude_radius) { + dist = 0; } + + if(sort_fields[i].geo_precision > 0) { + dist = dist + sort_fields[i].geo_precision - 1 - + (dist + sort_fields[i].geo_precision - 1) % sort_fields[i].geo_precision; + } + + geopoint_distances[i].emplace(seq_id, dist); + + // Swap (id -> latlong) index to (id -> distance) index + field_values[i] = &geopoint_distances[i]; } - Match single_token_match = Match(1, 0); - const uint64_t single_token_match_score = single_token_match.get_match_score(total_cost); - //auto begin = std::chrono::high_resolution_clock::now(); //const std::string first_token((const char*)query_suggestion[0]->key, query_suggestion[0]->key_len-1); - for (size_t i = 0; i < result_ids_size; i++) { - const uint32_t seq_id = result_ids[i]; + uint64_t match_score = 0; - uint64_t match_score = 0; + if (use_single_token_score || array_token_positions.empty()) { + Match single_token_match = Match(1, 0); + const uint64_t single_token_match_score = single_token_match.get_match_score(total_cost); + match_score = single_token_match_score; + } else { + uint64_t total_tokens_found = 0, total_num_typos = 0, total_distance = 0, total_verbatim = 0; - if (use_single_token_score || array_token_positions_vec.empty()) { - match_score = single_token_match_score; - } else { - const std::unordered_map>& array_token_positions = - array_token_positions_vec[i]; - - uint64_t total_tokens_found = 0, total_num_typos = 0, total_distance = 0, total_verbatim = 0; - - for (const auto& kv: array_token_positions) { - const std::vector& token_positions = kv.second; - if (token_positions.empty()) { - continue; - } - const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match); - uint64_t this_match_score = match.get_match_score(total_cost); - - total_tokens_found += ((this_match_score >> 24) & 0xFF); - total_num_typos += 255 - ((this_match_score >> 16) & 0xFF); - total_distance += 100 - ((this_match_score >> 8) & 0xFF); - total_verbatim += (this_match_score & 0xFF); - - /*std::ostringstream os; - os << name << ", total_cost: " << (255 - total_cost) - << ", words_present: " << match.words_present - << ", match_score: " << match_score - << ", match.distance: " << match.distance - << ", seq_id: " << seq_id << std::endl; - LOG(INFO) << os.str();*/ + for (const auto& kv: array_token_positions) { + const std::vector& token_positions = kv.second; + if (token_positions.empty()) { + continue; } + const Match &match = Match(seq_id, token_positions, false, prioritize_exact_match); + uint64_t this_match_score = match.get_match_score(total_cost); - match_score = ( - (uint64_t(total_tokens_found) << 24) | - (uint64_t(255 - total_num_typos) << 16) | - (uint64_t(100 - total_distance) << 8) | - (uint64_t(total_verbatim) << 1) - ); + total_tokens_found += ((this_match_score >> 24) & 0xFF); + total_num_typos += 255 - ((this_match_score >> 16) & 0xFF); + total_distance += 100 - ((this_match_score >> 8) & 0xFF); + total_verbatim += (this_match_score & 0xFF); - /*LOG(INFO) << "Match score: " << match_score << ", for seq_id: " << seq_id - << " - total_tokens_found: " << total_tokens_found - << " - total_num_typos: " << total_num_typos - << " - total_distance: " << total_distance - << " - total_verbatim: " << total_verbatim - << " - total_cost: " << total_cost;*/ + /*std::ostringstream os; + os << name << ", total_cost: " << (255 - total_cost) + << ", words_present: " << match.words_present + << ", match_score: " << match_score + << ", match.distance: " << match.distance + << ", seq_id: " << seq_id << std::endl; + LOG(INFO) << os.str();*/ } - const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional) - int64_t scores[3] = {0}; - size_t match_score_index = 0; + match_score = ( + (uint64_t(total_tokens_found) << 24) | + (uint64_t(255 - total_num_typos) << 16) | + (uint64_t(100 - total_distance) << 8) | + (uint64_t(total_verbatim) << 1) + ); - // avoiding loop - if (sort_fields.size() > 0) { - if (field_values[0] == TEXT_MATCH_SENTINEL) { - scores[0] = int64_t(match_score); - match_score_index = 0; - } else if (field_values[0] == SEQ_ID_SENTINEL) { - scores[0] = seq_id; - } else { - auto it = field_values[0]->find(seq_id); - scores[0] = (it == field_values[0]->end()) ? default_score : it->second; - } - if (sort_order[0] == -1) { - scores[0] = -scores[0]; - } - } - - - if(sort_fields.size() > 1) { - if (field_values[1] == TEXT_MATCH_SENTINEL) { - scores[1] = int64_t(match_score); - match_score_index = 1; - } else if (field_values[1] == SEQ_ID_SENTINEL) { - scores[1] = seq_id; - } else { - auto it = field_values[1]->find(seq_id); - scores[1] = (it == field_values[1]->end()) ? default_score : it->second; - } - - if (sort_order[1] == -1) { - scores[1] = -scores[1]; - } - } - - if(sort_fields.size() > 2) { - if(field_values[2] == TEXT_MATCH_SENTINEL) { - scores[2] = int64_t(match_score); - match_score_index = 2; - } else if (field_values[2] == SEQ_ID_SENTINEL) { - scores[2] = seq_id; - } else { - auto it = field_values[2]->find(seq_id); - scores[2] = (it == field_values[2]->end()) ? default_score : it->second; - } - - if(sort_order[2] == -1) { - scores[2] = -scores[2]; - } - } - - uint64_t distinct_id = seq_id; - - if(group_limit != 0) { - distinct_id = get_distinct_id(group_by_fields, seq_id); - groups_processed.emplace(distinct_id); - } - - //LOG(INFO) << "Seq id: " << seq_id << ", match_score: " << match_score; - KV kv(field_id, query_index, token_bits, seq_id, distinct_id, match_score_index, scores); - topster->add(&kv); + /*LOG(INFO) << "Match score: " << match_score << ", for seq_id: " << seq_id + << " - total_tokens_found: " << total_tokens_found + << " - total_num_typos: " << total_num_typos + << " - total_distance: " << total_distance + << " - total_verbatim: " << total_verbatim + << " - total_cost: " << total_cost;*/ } + const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional) + int64_t scores[3] = {0}; + size_t match_score_index = 0; + + // avoiding loop + if (sort_fields.size() > 0) { + if (field_values[0] == TEXT_MATCH_SENTINEL) { + scores[0] = int64_t(match_score); + match_score_index = 0; + } else if (field_values[0] == SEQ_ID_SENTINEL) { + scores[0] = seq_id; + } else { + auto it = field_values[0]->find(seq_id); + scores[0] = (it == field_values[0]->end()) ? default_score : it->second; + } + if (sort_order[0] == -1) { + scores[0] = -scores[0]; + } + } + + if(sort_fields.size() > 1) { + if (field_values[1] == TEXT_MATCH_SENTINEL) { + scores[1] = int64_t(match_score); + match_score_index = 1; + } else if (field_values[1] == SEQ_ID_SENTINEL) { + scores[1] = seq_id; + } else { + auto it = field_values[1]->find(seq_id); + scores[1] = (it == field_values[1]->end()) ? default_score : it->second; + } + + if (sort_order[1] == -1) { + scores[1] = -scores[1]; + } + } + + if(sort_fields.size() > 2) { + if(field_values[2] == TEXT_MATCH_SENTINEL) { + scores[2] = int64_t(match_score); + match_score_index = 2; + } else if (field_values[2] == SEQ_ID_SENTINEL) { + scores[2] = seq_id; + } else { + auto it = field_values[2]->find(seq_id); + scores[2] = (it == field_values[2]->end()) ? default_score : it->second; + } + + if(sort_order[2] == -1) { + scores[2] = -scores[2]; + } + } + + uint64_t distinct_id = seq_id; + + if(group_limit != 0) { + distinct_id = get_distinct_id(group_by_fields, seq_id); + groups_processed.emplace(distinct_id); + } + + //LOG(INFO) << "Seq id: " << seq_id << ", match_score: " << match_score; + KV kv(field_id, query_index, token_bits, seq_id, distinct_id, match_score_index, scores); + topster->add(&kv); + //long long int timeNanos = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - begin).count(); //LOG(INFO) << "Time taken for results iteration: " << timeNanos << "ms"; } diff --git a/src/posting.cpp b/src/posting.cpp index 2a3b4745..cc5c0e03 100644 --- a/src/posting.cpp +++ b/src/posting.cpp @@ -418,39 +418,20 @@ void posting_t::destroy_list(void*& obj) { } void posting_t::get_array_token_positions(uint32_t id, const std::vector& raw_posting_lists, - std::vector>>& array_token_positions_vec) { + std::unordered_map>& array_token_positions) { std::vector plists; std::vector expanded_plist_indices; to_expanded_plists(raw_posting_lists, plists, expanded_plist_indices); - posting_list_t::result_iter_state_t iter_state; - iter_state.ids.push_back(id); - iter_state.num_lists = plists.size(); - - std::vector& block_vec = iter_state.blocks; - std::vector& index_vec = iter_state.indices; + std::vector its; for(posting_list_t* pl: plists) { - posting_list_t::block_t* block = pl->block_of(id); - block_vec.push_back(block); - - bool found_index = false; - - if(block != nullptr) { - uint32_t index = block->ids.indexOf(id); - if(index != block->ids.getLength()) { - index_vec.push_back(index); - found_index = true; - } - } - - if(!found_index) { - index_vec.push_back(UINT32_MAX); - } + its.push_back(pl->new_iterator()); + its.back().skip_to(id); } - posting_list_t::get_offsets(iter_state, array_token_positions_vec); + posting_list_t::get_offsets(its, array_token_positions); for(uint32_t expanded_plist_index: expanded_plist_indices) { delete plists[expanded_plist_index]; diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 3cbd0189..f4501f67 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -661,96 +661,44 @@ void posting_list_t::intersect(const std::vector& posting_lists } } -bool posting_list_t::block_intersect(const std::vector& posting_lists, const size_t batch_size, - std::vector& its, - result_iter_state_t& iter_state) { - - if(posting_lists.empty()) { - return false; - } - - if(its.empty()) { - its.reserve(posting_lists.size()); - - for(const auto& posting_list: posting_lists) { - its.push_back(posting_list->new_iterator()); +bool posting_list_t::take_id(result_iter_state_t& istate, uint32_t id) { + // decide if this result id should be excluded + if(istate.excluded_result_ids_size != 0) { + while(istate.excluded_result_ids_index < istate.excluded_result_ids_size && + istate.excluded_result_ids[istate.excluded_result_ids_index] < id) { + istate.excluded_result_ids_index++; } - iter_state.num_lists = posting_lists.size(); - iter_state.blocks.reserve(100); - iter_state.indices.reserve(100); - } else { - // already in the middle of iteration: prepare for next batch - iter_state.ids.clear(); - iter_state.indices.clear(); - iter_state.blocks.clear(); + if(istate.excluded_result_ids_index < istate.excluded_result_ids_size && + id == istate.excluded_result_ids[istate.excluded_result_ids_index]) { + istate.excluded_result_ids_index++; + return false; + } } - size_t num_lists = its.size(); + bool id_found_in_filter = true; - switch (num_lists) { - case 1: - while(its[0].valid()) { - iter_state.ids.push_back(its[0].id()); - iter_state.blocks.push_back(its[0].block()); - iter_state.indices.push_back(its[0].index()); + // decide if this result be matched with filter results + if(istate.filter_ids_length != 0) { + id_found_in_filter = false; - its[0].next(); + // e.g. [1, 3] vs [2, 3] - if(iter_state.ids.size() == batch_size) { - return its[0].valid(); - } - } - break; - case 2: - while(!at_end2(its)) { - if(equals2(its)) { - // still need to ensure that the ID exists in inclusion list but NOT in exclusion list - iter_state.ids.push_back(its[0].id()); + while(istate.filter_ids_index < istate.filter_ids_length && istate.filter_ids[istate.filter_ids_index] < id) { + istate.filter_ids_index++; + } - iter_state.blocks.push_back(its[0].block()); - iter_state.blocks.push_back(its[1].block()); - - iter_state.indices.push_back(its[0].index()); - iter_state.indices.push_back(its[1].index()); - - advance_all2(its); - } else { - advance_non_largest2(its); - } - - if(iter_state.ids.size() == batch_size) { - return !at_end2(its); - } - } - break; - default: - while(!at_end(its)) { - if(equals(its)) { - //LOG(INFO) << its[0].id(); - iter_state.ids.push_back(its[0].id()); - - for(size_t i = 0; i < its.size(); i++) { - iter_state.blocks.push_back(its[i].block()); - iter_state.indices.push_back(its[i].index()); - } - - advance_all(its); - } else { - advance_non_largest(its); - } - - if(iter_state.ids.size() == batch_size) { - return !at_end(its); - } - } + if(istate.filter_ids_index < istate.filter_ids_length && istate.filter_ids[istate.filter_ids_index] == id) { + istate.filter_ids_index++; + id_found_in_filter = true; + } } - return false; + return id_found_in_filter; } -bool posting_list_t::get_offsets(posting_list_t::result_iter_state_t& iter_state, - std::vector>>& array_token_positions_vec) { +bool posting_list_t::get_offsets(std::vector& its, + std::unordered_map>& array_token_pos) { // Plain string format: // offset1, offset2, ... , 0 (if token is the last offset for the document) @@ -763,78 +711,84 @@ bool posting_list_t::get_offsets(posting_list_t::result_iter_state_t& iter_state size_t id_block_index = 0; - for(size_t i = 0; i < iter_state.ids.size(); i++) { - uint32_t id = iter_state.ids[i]; - array_token_positions_vec.emplace_back(); - std::unordered_map>& array_tok_pos = array_token_positions_vec.back(); + for(size_t j = 0; j < its.size(); j++) { + block_t* curr_block = its[j].block(); + uint32_t curr_index = its[j].index(); - for(size_t j = 0; j < iter_state.num_lists; j++) { - block_t* curr_block = iter_state.blocks[id_block_index]; - uint32_t curr_index = iter_state.indices[id_block_index++]; + if(curr_block == nullptr || curr_index == UINT32_MAX) { + continue; + } - if(curr_block == nullptr || curr_index == UINT32_MAX) { + /*uint32_t* offsets = curr_block->offsets.uncompress(); + + uint32_t start_offset = curr_block->offset_index.at(curr_index); + uint32_t end_offset = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + curr_block->offset_index.at(curr_index + 1);*/ + + uint32_t* offsets = its[j].offsets; + + uint32_t start_offset = its[j].offset_index[curr_index]; + uint32_t end_offset = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + its[j].offset_index[curr_index + 1]; + + std::vector positions; + int prev_pos = -1; + bool is_last_token = false; + + /*LOG(INFO) << "id: " << its[j].id() << ", start_offset: " << start_offset << ", end_offset: " << end_offset; + for(size_t x = 0; x < end_offset; x++) { + LOG(INFO) << "x: " << x << ", pos: " << offsets[x]; + }*/ + + while(start_offset < end_offset) { + int pos = offsets[start_offset]; + start_offset++; + + if(pos == 0) { + // indicates that token is the last token on the doc + is_last_token = true; + start_offset++; continue; } - uint32_t* offsets = curr_block->offsets.uncompress(); + if(pos == prev_pos) { // indicates end of array index + if(!positions.empty()) { + size_t array_index = (size_t) offsets[start_offset]; + is_last_token = false; - uint32_t start_offset = curr_block->offset_index.at(curr_index); - uint32_t end_offset = (curr_index == curr_block->size() - 1) ? - curr_block->offsets.getLength() : - curr_block->offset_index.at(curr_index + 1); - - std::vector positions; - int prev_pos = -1; - bool is_last_token = false; - - while(start_offset < end_offset) { - int pos = offsets[start_offset]; - start_offset++; - - if(pos == 0) { - // indicates that token is the last token on the doc - is_last_token = true; - start_offset++; - continue; - } - - if(pos == prev_pos) { // indicates end of array index - if(!positions.empty()) { - size_t array_index = (size_t) offsets[start_offset]; - is_last_token = false; - - if(start_offset+1 < end_offset) { - size_t next_offset = (size_t) offsets[start_offset + 1]; - if(next_offset == 0) { - // indicates that token is the last token on the doc - is_last_token = true; - start_offset++; - } + if(start_offset+1 < end_offset) { + size_t next_offset = (size_t) offsets[start_offset + 1]; + if(next_offset == 0) { + // indicates that token is the last token on the doc + is_last_token = true; + start_offset++; } - - array_tok_pos[array_index].push_back(token_positions_t{is_last_token, positions}); - positions.clear(); } - start_offset++; // skip current value which is the array index or flag for last index - prev_pos = -1; - continue; + array_token_pos[array_index].push_back(token_positions_t{is_last_token, positions}); + positions.clear(); } - prev_pos = pos; - positions.push_back((uint16_t)pos - 1); + start_offset++; // skip current value which is the array index or flag for last index + prev_pos = -1; + continue; } - if(!positions.empty()) { - // for plain string fields - array_tok_pos[0].push_back(token_positions_t{is_last_token, positions}); - } - - delete [] offsets; + prev_pos = pos; + positions.push_back((uint16_t)pos - 1); } + + if(!positions.empty()) { + // for plain string fields + array_token_pos[0].push_back(token_positions_t{is_last_token, positions}); + } + + //delete [] offsets; } - return false; + return true; } bool posting_list_t::at_end(const std::vector& its) { @@ -982,8 +936,13 @@ bool posting_list_t::contains_atleast_one(const uint32_t* target_ids, size_t tar /* iterator_t operations */ posting_list_t::iterator_t::iterator_t(posting_list_t::block_t* root): - curr_block(root), curr_index(0), uncompressed_block(nullptr), ids(nullptr) { + curr_block(root), curr_index(0) { + if(curr_block != nullptr) { + ids = curr_block->ids.uncompress(); + offset_index = curr_block->offset_index.uncompress(); + offsets = curr_block->offsets.uncompress(); + } } bool posting_list_t::iterator_t::valid() const { @@ -995,21 +954,22 @@ void posting_list_t::iterator_t::next() { if(curr_index == curr_block->size()) { curr_index = 0; curr_block = curr_block->next; - } -} - -uint32_t posting_list_t::iterator_t::id() { - if(uncompressed_block != curr_block) { - uncompressed_block = curr_block; delete [] ids; - ids = nullptr; + delete [] offset_index; + delete [] offsets; + + ids = offset_index = offsets = nullptr; if(curr_block != nullptr) { ids = curr_block->ids.uncompress(); + offset_index = curr_block->offset_index.uncompress(); + offsets = curr_block->offsets.uncompress(); } } +} +uint32_t posting_list_t::iterator_t::id() const { return ids[curr_index]; } @@ -1025,6 +985,20 @@ void posting_list_t::iterator_t::skip_to(uint32_t id) { bool skipped_block = false; while(curr_block != nullptr && curr_block->ids.last() < id) { curr_block = curr_block->next; + + // FIXME: remove duplication + delete [] ids; + delete [] offset_index; + delete [] offsets; + + ids = offset_index = offsets = nullptr; + + if(curr_block != nullptr) { + ids = curr_block->ids.uncompress(); + offset_index = curr_block->offset_index.uncompress(); + offsets = curr_block->offsets.uncompress(); + } + skipped_block = true; } @@ -1045,10 +1019,12 @@ posting_list_t::iterator_t::~iterator_t() { posting_list_t::iterator_t::iterator_t(iterator_t&& rhs) noexcept { curr_block = rhs.curr_block; curr_index = rhs.curr_index; - uncompressed_block = rhs.uncompressed_block; ids = rhs.ids; + offset_index = rhs.offset_index; + offsets = rhs.offsets; rhs.curr_block = nullptr; - rhs.uncompressed_block = nullptr; rhs.ids = nullptr; + rhs.offset_index = nullptr; + rhs.offsets = nullptr; } diff --git a/test/posting_list_test.cpp b/test/posting_list_test.cpp index 6e8d67e1..53eb2414 100644 --- a/test/posting_list_test.cpp +++ b/test/posting_list_test.cpp @@ -608,28 +608,13 @@ TEST(PostingListTest, IntersectionBasics) { std::vector its; posting_list_t::result_iter_state_t iter_state; - bool has_more = posting_list_t::block_intersect(lists, 2, its, iter_state); - ASSERT_FALSE(has_more); - ASSERT_EQ(2, iter_state.ids.size()); - ASSERT_EQ(3, iter_state.ids[0]); - ASSERT_EQ(20, iter_state.ids[1]); - - ASSERT_EQ(6, iter_state.blocks.size()); - ASSERT_EQ(6, iter_state.indices.size()); - - // try with smaller batch size - - std::vector its2; - posting_list_t::result_iter_state_t iter_state2; - has_more = posting_list_t::block_intersect(lists, 1, its2, iter_state2); - ASSERT_TRUE(has_more); - ASSERT_EQ(1, iter_state2.ids.size()); - ASSERT_EQ(3, iter_state2.ids[0]); - - has_more = posting_list_t::block_intersect(lists, 1, its2, iter_state2); - ASSERT_FALSE(has_more); - ASSERT_EQ(1, iter_state2.ids.size()); - ASSERT_EQ(20, iter_state2.ids[0]); + result_ids.clear(); + posting_list_t::block_intersect(lists, its, iter_state, [&](auto id, auto& its){ + result_ids.push_back(id); + }); + ASSERT_EQ(2, result_ids.size()); + ASSERT_EQ(3, result_ids[0]); + ASSERT_EQ(20, result_ids[1]); // single item itersection std::vector single_item_list = {&p1}; @@ -644,23 +629,15 @@ TEST(PostingListTest, IntersectionBasics) { std::vector its3; posting_list_t::result_iter_state_t iter_state3; - has_more = posting_list_t::block_intersect(single_item_list, 2, its3, iter_state3); - ASSERT_TRUE(has_more); - ASSERT_EQ(2, iter_state3.ids.size()); - ASSERT_EQ(0, iter_state3.ids[0]); - ASSERT_EQ(2, iter_state3.ids[1]); - ASSERT_EQ(2, iter_state3.blocks.size()); - ASSERT_EQ(2, iter_state3.indices.size()); - ASSERT_EQ(0, iter_state3.indices[0]); - ASSERT_EQ(1, iter_state3.indices[1]); - - has_more = posting_list_t::block_intersect(single_item_list, 2, its3, iter_state3); - ASSERT_FALSE(has_more); - ASSERT_EQ(2, iter_state3.ids.size()); - ASSERT_EQ(2, iter_state3.blocks.size()); - ASSERT_EQ(2, iter_state3.indices.size()); - ASSERT_EQ(0, iter_state3.indices[0]); - ASSERT_EQ(1, iter_state3.indices[1]); + result_ids.clear(); + posting_list_t::block_intersect(single_item_list, its3, iter_state3, [&](auto id, auto& its){ + result_ids.push_back(id); + }); + ASSERT_EQ(4, result_ids.size()); + ASSERT_EQ(0, result_ids[0]); + ASSERT_EQ(2, result_ids[1]); + ASSERT_EQ(3, result_ids[2]); + ASSERT_EQ(20, result_ids[3]); // empty intersection list std::vector empty_list; @@ -670,11 +647,11 @@ TEST(PostingListTest, IntersectionBasics) { std::vector its4; posting_list_t::result_iter_state_t iter_state4; - has_more = posting_list_t::block_intersect(empty_list, 1, its4, iter_state4); - ASSERT_FALSE(has_more); - ASSERT_EQ(0, iter_state4.ids.size()); - ASSERT_EQ(0, iter_state4.blocks.size()); - ASSERT_EQ(0, iter_state4.indices.size()); + result_ids.clear(); + posting_list_t::block_intersect(empty_list, its4, iter_state4, [&](auto id, auto& its){ + result_ids.push_back(id); + }); + ASSERT_EQ(0, result_ids.size()); } TEST(PostingListTest, ResultsAndOffsetsBasics) { @@ -729,6 +706,7 @@ TEST(PostingListTest, ResultsAndOffsetsBasics) { lists.push_back(&p2); lists.push_back(&p3); + /* std::vector its; posting_list_t::result_iter_state_t iter_state; bool has_more = posting_list_t::block_intersect(lists, 2, its, iter_state); @@ -745,6 +723,7 @@ TEST(PostingListTest, ResultsAndOffsetsBasics) { ASSERT_EQ(actual_offsets_20[0].positions, array_token_positions_vec[1].at(0)[0].positions); ASSERT_EQ(actual_offsets_20[1].positions, array_token_positions_vec[1].at(0)[1].positions); ASSERT_EQ(actual_offsets_20[2].positions, array_token_positions_vec[1].at(0)[2].positions); + */ } TEST(PostingListTest, IntersectionSkipBlocks) { @@ -1269,15 +1248,15 @@ TEST(PostingListTest, BlockIntersectionOnMixedLists) { std::vector raw_posting_lists = {SET_COMPACT_POSTING(list1), &p1}; posting_list_t::result_iter_state_t iter_state; - posting_t::block_intersector_t intersector(raw_posting_lists, 1, iter_state); + posting_t::block_intersector_t intersector(raw_posting_lists, iter_state); - ASSERT_TRUE(intersector.intersect()); - ASSERT_EQ(1, iter_state.ids.size()); - ASSERT_EQ(5, iter_state.ids[0]); - - ASSERT_FALSE(intersector.intersect()); - ASSERT_EQ(1, iter_state.ids.size()); - ASSERT_EQ(8, iter_state.ids[0]); + std::vector result_ids; + intersector.intersect([&](auto seq_id, auto& its) { + result_ids.push_back(seq_id); + }); + ASSERT_EQ(2, result_ids.size()); + ASSERT_EQ(5, result_ids[0]); + ASSERT_EQ(8, result_ids[1]); free(list1); }