diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 784aa385..2e30a2cb 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -172,6 +172,7 @@ public: /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The /// iterator may become invalid during this operation. void get_n_ids(const uint32_t &n, + size_t& excluded_result_index, uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, std::vector &results); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 4ea630cd..9977c6be 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1348,16 +1348,23 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { - if (excluded_result_ids == nullptr || excluded_result_ids_size == 0) { + if (excluded_result_ids == nullptr || excluded_result_ids_size == 0 || + excluded_result_index >= excluded_result_ids_size) { return get_n_ids(n, results); } if (is_filter_result_initialized) { for (uint32_t count = 0; count < n && result_index < filter_result.count;) { auto id = filter_result.docs[result_index++]; - if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, id)) { + + while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < id) { + excluded_result_index++; + } + + if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != id) { results.push_back(id); count++; } @@ -1368,7 +1375,11 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } for (uint32_t count = 0; count < n && is_valid;) { - if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, seq_id)) { + while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < seq_id) { + excluded_result_index++; + } + + if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != seq_id) { results.push_back(seq_id); count++; } diff --git a/src/index.cpp b/src/index.cpp index eef5cc67..ba405b53 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4966,12 +4966,14 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_begin = search_begin_us; const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; + size_t excluded_result_index = 0; for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, exclude_token_ids, exclude_token_ids_size, batch_result_ids); + filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, + batch_result_ids); num_queued++;