diff --git a/include/art.h b/include/art.h index 11f57a68..92e043e3 100644 --- a/include/art.h +++ b/include/art.h @@ -279,7 +279,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, const size_t max_words, const token_ordering token_order, const bool prefix, bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::vector &results, std::set& exclude_leaves); void encode_int32(int32_t n, unsigned char *chars); diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index bc8c4c23..b3b12555 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -109,6 +109,8 @@ private: std::vector> posting_list_iterators; std::vector expanded_plists; + bool delete_filter_node = false; + /// Initializes the state of iterator node after it's creation. void init(); @@ -127,6 +129,8 @@ private: /// Finds the next match for a filter on string field. void get_string_filter_next_match(const bool& field_is_array); + explicit filter_result_iterator_t(uint32_t approx_filter_ids_length); + public: uint32_t seq_id = 0; /// Collection name -> references @@ -143,6 +147,8 @@ public: /// iterator reaching it's end. (is_valid would be false in both these cases) uint32_t approx_filter_ids_length; + explicit filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count); + explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node, uint32_t approx_filter_ids_length = UINT32_MAX); @@ -193,4 +199,7 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); + + static void add_phrase_ids(filter_result_iterator_t*& filter_result_iterator, + uint32_t* phrase_result_ids, const uint32_t& phrase_result_count); }; diff --git a/include/index.h b/include/index.h index e80d5696..5d825555 100644 --- a/include/index.h +++ b/include/index.h @@ -408,7 +408,7 @@ private: void search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -725,7 +725,7 @@ public: const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, @@ -766,7 +766,7 @@ public: std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -797,7 +797,7 @@ public: spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -814,6 +814,7 @@ public: std::array*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + filter_result_iterator_t*& filter_result_iterator, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed, const std::set& curated_ids, @@ -821,8 +822,7 @@ public: const std::unordered_set& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, - bool is_wildcard_query, - uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const; + bool is_wildcard_query) const; void fuzzy_search_fields(const std::vector& the_fields, const std::vector& query_tokens, @@ -830,7 +830,7 @@ public: const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -859,7 +859,7 @@ public: const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, @@ -881,7 +881,7 @@ public: const std::vector& group_by_fields, bool prioritize_exact_match, const bool search_all_candidates, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, @@ -933,7 +933,7 @@ public: void process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, const bool filter_curated_hits, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, diff --git a/src/art.cpp b/src/art.cpp index ee0fb53a..65d3ca63 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1003,7 +1003,7 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str bool validate_and_add_leaf(art_leaf* leaf, const std::string& prev_token, const art_leaf* prev_leaf, const art_leaf* exact_leaf, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& exclude_leaves, std::vector& results) { if(leaf == exact_leaf) { @@ -1016,10 +1016,10 @@ bool validate_and_add_leaf(art_leaf* leaf, } if(prev_token.empty() || !prev_leaf) { - if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) { + if (filter_result_iterator->is_valid && !filter_result_iterator->contains_atleast_one(leaf->values)) { return false; } - } else if (!filter_result_iterator.is_valid) { + } else if (!filter_result_iterator->is_valid) { std::vector prev_leaf_ids; posting_t::merge({prev_leaf->values}, prev_leaf_ids); @@ -1031,8 +1031,8 @@ bool validate_and_add_leaf(art_leaf* leaf, posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); bool found = false; - for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator.is_valid && !found; i++) { - found = (filter_result_iterator.valid(leaf_ids[i]) == 1); + for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator->is_valid && !found; i++) { + found = (filter_result_iterator->valid(leaf_ids[i]) == 1); } if (!found) { @@ -1145,7 +1145,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, const art_leaf* exact_leaf, const bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const art_tree* t, std::set& exclude_leaves, std::vector& results) { printf("INSIDE art_topk_iter: root->type: %d\n", root->type); @@ -1177,7 +1177,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator, exclude_leaves, results); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if (++num_processed % 1024 == 0 && (microseconds( std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { @@ -1767,7 +1767,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, const size_t max_words, const token_ordering token_order, const bool prefix, bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::vector &results, std::set& exclude_leaves) { std::vector nodes; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f155d955..c4dcd0b5 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1252,6 +1252,10 @@ filter_result_iterator_t::~filter_result_iterator_t() { delete expanded_plist; } + if (delete_filter_node) { + delete filter_node; + } + delete left_it; delete right_it; } @@ -1343,3 +1347,44 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, next(); } } + +filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_length) : + approx_filter_ids_length(approx_filter_ids_length) { + filter_node = new filter_node_t(AND, nullptr, nullptr); + delete_filter_node = true; +} + +filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count) { + filter_result.count = approx_filter_ids_length = ids_count; + filter_result.docs = ids; + is_valid = ids_count > 0; + + if (is_valid) { + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + filter_node = new filter_node_t({"dummy", {}, {}}); + delete_filter_node = true; + } +} + +void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_result_iterator, + uint32_t* phrase_result_ids, const uint32_t& phrase_result_count) { + auto root_iterator = new filter_result_iterator_t(std::min(phrase_result_count, filter_result_iterator->approx_filter_ids_length)); + root_iterator->left_it = new filter_result_iterator_t(phrase_result_ids, phrase_result_count); + root_iterator->right_it = filter_result_iterator; + + auto& left_it = root_iterator->left_it; + auto& right_it = root_iterator->right_it; + + while (left_it->is_valid && right_it->is_valid && left_it->seq_id != right_it->seq_id) { + if (left_it->seq_id < right_it->seq_id) { + left_it->skip_to(right_it->seq_id); + } else { + right_it->skip_to(left_it->seq_id); + } + } + + root_iterator->is_valid = left_it->is_valid && right_it->is_valid; + root_iterator->seq_id = left_it->seq_id; + filter_result_iterator = root_iterator; +} diff --git a/src/index.cpp b/src/index.cpp index 552b3149..9090230d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1263,7 +1263,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { void Index::search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -2247,14 +2247,16 @@ Option Index::search(std::vector& field_query_tokens, cons return rearrange_op; } - auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, + auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root, approx_filter_ids_length); - auto filter_init_op = filter_result_iterator.init_status(); + std::unique_ptr filter_iterator_guard(filter_result_iterator); + + auto filter_init_op = filter_result_iterator->init_status(); if (!filter_init_op.ok()) { return filter_init_op; } - if (filter_tree_root != nullptr && !filter_result_iterator.is_valid) { + if (filter_tree_root != nullptr && !filter_result_iterator->is_valid) { return Option(true); } @@ -2268,7 +2270,7 @@ Option Index::search(std::vector& field_query_tokens, cons process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits, filter_result_iterator, curated_ids, included_ids_map, included_ids_vec, excluded_group_ids); - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2299,24 +2301,19 @@ Option Index::search(std::vector& field_query_tokens, cons field_query_tokens[0].q_include_tokens[0].value == "*"; - // TODO: Do AND with phrase ids at last // handle phrase searches - uint32_t* phrase_result_ids = nullptr; - uint32_t phrase_result_count = 0; - std::unique_ptr phrase_result_ids_guard; - if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, topster, sort_order, field_values, geopoint_indices, curated_ids_sorted, - all_result_ids, all_result_ids_len, groups_processed, curated_ids, + filter_result_iterator, all_result_ids, all_result_ids_len, groups_processed, curated_ids, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, - included_ids_map, is_wildcard_query, - phrase_result_ids, phrase_result_count); + included_ids_map, is_wildcard_query); - phrase_result_ids_guard.reset(phrase_result_ids); + filter_iterator_guard.release(); + filter_iterator_guard.reset(filter_result_iterator); - if (phrase_result_count == 0) { + if (filter_result_iterator->approx_filter_ids_length == 0) { goto process_search_results; } } @@ -2324,7 +2321,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.is_valid); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator->is_valid); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2372,8 +2369,10 @@ Option Index::search(std::vector& field_query_tokens, cons Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, store, doc_id_prefix, filter_tree_root); - filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); - approx_filter_ids_length = filter_result_iterator.approx_filter_ids_length; + filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); + filter_iterator_guard.reset(filter_result_iterator); + + approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -2391,9 +2390,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint32_t filter_id_count = 0; while (!no_filters_provided && - filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { - auto seq_id = filter_result_iterator.seq_id; - filter_result_iterator.next(); + filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator->is_valid) { + auto seq_id = filter_result_iterator->seq_id; + filter_result_iterator->next(); std::vector values; try { @@ -2417,12 +2416,13 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels.emplace_back(dist, seq_id); filter_id_count++; } + filter_result_iterator->reset(); if(no_filters_provided || - (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.is_valid)) { + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->is_valid)) { dist_labels.clear(); - VectorFilterFunctor filterFunctor(&filter_result_iterator); + VectorFilterFunctor filterFunctor(filter_result_iterator); if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); @@ -2432,8 +2432,7 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } } - - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector nearest_ids; @@ -2488,7 +2487,7 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } // filter tree was initialized to have all sequence ids in this flow. @@ -2549,7 +2548,7 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -2586,7 +2585,7 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } } @@ -2602,7 +2601,7 @@ Option Index::search(std::vector& field_query_tokens, cons filter_result_iterator, query_hashes, sort_order, field_values, geopoint_indices, qtoken_set); - filter_result_iterator.reset(); + filter_result_iterator->reset(); // gather up both original query and synonym queries and do drop tokens @@ -2659,7 +2658,7 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } else { break; @@ -2676,7 +2675,7 @@ Option Index::search(std::vector& field_query_tokens, cons sort_order, field_values, geopoint_indices, curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if(!vector_query.field_name.empty()) { // check at least one of sort fields is text match @@ -2693,7 +2692,7 @@ Option Index::search(std::vector& field_query_tokens, cons constexpr float TEXT_MATCH_WEIGHT = 0.7; constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; - VectorFilterFunctor filterFunctor(&filter_result_iterator); + VectorFilterFunctor filterFunctor(filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; auto k = std::max(vector_query.k, fetch_size); @@ -2705,7 +2704,7 @@ Option Index::search(std::vector& field_query_tokens, cons } else { dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector> vec_results; for (const auto& dist_label : dist_labels) { @@ -2915,7 +2914,7 @@ Option Index::search(std::vector& field_query_tokens, cons void Index::process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, - const bool filter_curated_hits, filter_result_iterator_t& filter_result_iterator, + const bool filter_curated_hits, filter_result_iterator_t* const filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, @@ -2938,9 +2937,9 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_result_iterator.is_valid && filter_curated_hits) { + if(filter_result_iterator->is_valid && filter_curated_hits) { for (const auto &included_id: included_ids_vec) { - auto result = filter_result_iterator.valid(included_id); + auto result = filter_result_iterator->valid(included_id); if (result == -1) { break; @@ -3007,7 +3006,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector & sort_fields, @@ -3153,7 +3152,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, last_token, prev_token, filter_result_iterator, field_leaves, unique_tokens); - filter_result_iterator.reset(); + filter_result_iterator->reset(); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3184,7 +3183,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, token_candidates_vec.back().candidates[0], the_fields, num_search_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, prev_token_doc_ids, popular_field_ids); - filter_result_iterator.reset(); + filter_result_iterator->reset(); for(size_t field_id: query_field_ids) { auto& the_field = the_fields[field_id]; @@ -3207,7 +3206,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, false, "", filter_result_iterator, field_leaves, unique_tokens); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if(field_leaves.empty()) { // look at the next field @@ -3271,7 +3270,6 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, exhaustive_search, max_candidates, syn_orig_num_tokens, sort_order, field_values, geopoint_indices, query_hashes, id_buff); - filter_result_iterator.reset(); if(id_buff.size() > 1) { gfx::timsort(id_buff.begin(), id_buff.end()); @@ -3332,7 +3330,7 @@ void Index::find_across_fields(const token_t& previous_token, const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, std::vector& top_prefix_field_ids) const { @@ -3343,7 +3341,7 @@ void Index::find_across_fields(const token_t& previous_token, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator); const bool prefix_search = previous_token.is_prefix_searched; const uint32_t token_num_typos = previous_token.num_typos; @@ -3424,7 +3422,7 @@ void Index::search_across_fields(const std::vector& query_tokens, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool prioritize_token_position, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, @@ -3485,7 +3483,7 @@ void Index::search_across_fields(const std::vector& query_tokens, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator); // for each token, find the posting lists across all query_by fields for(size_t ti = 0; ti < query_tokens.size(); ti++) { @@ -3947,6 +3945,7 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + filter_result_iterator_t*& filter_result_iterator, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed, const std::set& curated_ids, @@ -3954,9 +3953,10 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, - bool is_wildcard_query, - uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const { + bool is_wildcard_query) const { + uint32_t* phrase_result_ids = nullptr; + uint32_t phrase_result_count = 0; std::map phrase_match_id_scores; for(size_t i = 0; i < num_search_fields; i++) { @@ -4045,12 +4045,19 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vectoris_valid) { + filter_result_iterator_t::add_phrase_ids(filter_result_iterator, phrase_result_ids, phrase_result_count); + } else { + delete filter_result_iterator; + filter_result_iterator = new filter_result_iterator_t(phrase_result_ids, phrase_result_count); + } + size_t filter_index = 0; if(is_wildcard_query) { - all_result_ids = new uint32_t[phrase_result_count]; - std::copy(phrase_result_ids, phrase_result_ids + phrase_result_count, all_result_ids); - all_result_ids_len = phrase_result_count; + all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids); + filter_result_iterator->reset(); } else { // this means that the there are non-phrase tokens in the query // so we cannot directly copy to the all_result_ids array @@ -4058,8 +4065,8 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector(10000, phrase_result_count); i++) { - auto seq_id = phrase_result_ids[i]; + for(size_t i = 0; i < std::min(10000, all_result_ids_len); i++) { + auto seq_id = all_result_ids[i]; int64_t match_score = phrase_match_id_scores[seq_id]; int64_t scores[3] = {0}; @@ -4112,7 +4119,7 @@ void Index::do_synonym_search(const std::vector& the_fields, spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -4140,7 +4147,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -4174,10 +4181,10 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vectoris_valid) { uint32_t *filtered_raw_infix_ids = nullptr; - raw_infix_ids_length = filter_result_iterator.and_scalar(raw_infix_ids, raw_infix_ids_length, + raw_infix_ids_length = filter_result_iterator->and_scalar(raw_infix_ids, raw_infix_ids_length, filtered_raw_infix_ids); if(raw_infix_ids != &infix_ids[0]) { delete [] raw_infix_ids; @@ -4472,7 +4479,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, @@ -4502,11 +4509,11 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, auto parent_search_cutoff = search_cutoff; uint32_t excluded_result_index = 0; - for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { + for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator->is_valid; thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, + filter_result_iterator->get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, batch_result_ids); num_queued++; @@ -4588,8 +4595,8 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - filter_result_iterator.reset(); - all_result_ids_len = filter_result_iterator.to_filter_id_array(all_result_ids); + filter_result_iterator->reset(); + all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids); } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, diff --git a/test/filter_test.cpp b/test/filter_test.cpp index d3aa3f31..ac6efdb4 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -482,5 +482,19 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_EQ(6, iter_skip_test4.seq_id); ASSERT_TRUE(iter_skip_test4.is_valid); + auto iter_add_phrase_ids_test = new filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + std::unique_ptr filter_iter_guard(iter_add_phrase_ids_test); + ASSERT_TRUE(iter_add_phrase_ids_test->init_status().ok()); + + auto phrase_ids = new uint32_t[4]; + for (uint32_t i = 0; i < 4; i++) { + phrase_ids[i] = i * 2; + } + filter_result_iterator_t::add_phrase_ids(iter_add_phrase_ids_test, phrase_ids, 4); + filter_iter_guard.reset(iter_add_phrase_ids_test); + + ASSERT_TRUE(iter_add_phrase_ids_test->is_valid); + ASSERT_EQ(6, iter_add_phrase_ids_test->seq_id); + delete filter_tree_root; } \ No newline at end of file