diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 1184b74a..784aa385 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -169,6 +169,12 @@ public: /// Collects n doc ids while advancing the iterator. The iterator may become invalid during this operation. void get_n_ids(const uint32_t& n, std::vector& results); + /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The + /// iterator may become invalid during this operation. + void get_n_ids(const uint32_t &n, + uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, + std::vector &results); + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d0037512..4ea630cd 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1275,6 +1275,10 @@ filter_result_iterator_t::filter_result_iterator_t(const std::string collection_ } init(); + + if (!is_valid) { + this->approx_filter_ids_length = 0; + } } filter_result_iterator_t::~filter_result_iterator_t() { @@ -1342,3 +1346,32 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { + if (excluded_result_ids == nullptr || excluded_result_ids_size == 0) { + return get_n_ids(n, results); + } + + if (is_filter_result_initialized) { + for (uint32_t count = 0; count < n && result_index < filter_result.count;) { + auto id = filter_result.docs[result_index++]; + if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, id)) { + results.push_back(id); + count++; + } + } + + is_valid = result_index < filter_result.count; + return; + } + + for (uint32_t count = 0; count < n && is_valid;) { + if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, seq_id)) { + results.push_back(seq_id); + count++; + } + next(); + } +} diff --git a/src/index.cpp b/src/index.cpp index 9003dfac..eef5cc67 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1970,8 +1970,14 @@ void Index::aproximate_numerical_match(num_tree_t* const num_tree, uint32_t to_exclude_ids_len = 0; num_tree->approx_search_count(EQUALS, value, to_exclude_ids_len); - auto all_ids_size = seq_ids->num_ids(); - filter_ids_length += (all_ids_size - to_exclude_ids_len); + if (to_exclude_ids_len == 0) { + filter_ids_length += seq_ids->num_ids(); + } else if (to_exclude_ids_len >= seq_ids->num_ids()) { + filter_ids_length += 0; + } else { + filter_ids_length += (seq_ids->num_ids() - to_exclude_ids_len); + } + return; } @@ -4965,7 +4971,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, batch_result_ids); + filter_result_iterator.get_n_ids(window_size, exclude_token_ids, exclude_token_ids_size, batch_result_ids); num_queued++;