From 33be7e6c6881545f9ce2c284b42788fc8c1ee79c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 09:39:51 +0530 Subject: [PATCH] Add `ArrayUtils::skip_index_to_id`. --- include/array_utils.h | 6 ++++++ include/filter_result_iterator.h | 2 +- src/array_utils.cpp | 25 +++++++++++++++++++++++++ src/filter_result_iterator.cpp | 16 ++++------------ src/index.cpp | 2 +- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/include/array_utils.h b/include/array_utils.h index 81a0576c..008a8a36 100644 --- a/include/array_utils.h +++ b/include/array_utils.h @@ -16,4 +16,10 @@ public: static size_t exclude_scalar(const uint32_t *src, const size_t lenSrc, const uint32_t *filter, const size_t lenFilter, uint32_t **out); + + /// Performs binary search to find the index of id. If id is not found, curr_index is set to the index of next bigger + /// number than id in the array. + /// \return Whether or not id was found in array. + static bool skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, + const uint32_t& id); }; \ No newline at end of file diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 2e30a2cb..259b93f0 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -172,7 +172,7 @@ public: /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The /// iterator may become invalid during this operation. void get_n_ids(const uint32_t &n, - size_t& excluded_result_index, + uint32_t& excluded_result_index, uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, std::vector &results); diff --git a/src/array_utils.cpp b/src/array_utils.cpp index 9f0c7f4a..ad22a85f 100644 --- a/src/array_utils.cpp +++ b/src/array_utils.cpp @@ -149,4 +149,29 @@ size_t ArrayUtils::exclude_scalar(const uint32_t *A, const size_t lenA, delete[] results; return res_index; +} + +bool ArrayUtils::skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, + const uint32_t& id) { + if (id <= array[curr_index]) { + return id == array[curr_index]; + } + + long start = curr_index, mid, end = array_len; + + while (start <= end) { + mid = start + (end - start) / 2; + + if (array[mid] == id) { + curr_index = mid; + return true; + } else if (array[mid] < id) { + start = mid + 1; + } else { + end = mid - 1; + } + } + + curr_index = start; + return false; } \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 9977c6be..63a76f73 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -959,7 +959,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } if (is_filter_result_initialized) { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id); if (result_index >= filter_result.count) { is_valid = false; @@ -1348,7 +1348,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { if (excluded_result_ids == nullptr || excluded_result_ids_size == 0 || @@ -1360,11 +1360,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, for (uint32_t count = 0; count < n && result_index < filter_result.count;) { auto id = filter_result.docs[result_index++]; - while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < id) { - excluded_result_index++; - } - - if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != id) { + if (!ArrayUtils::skip_index_to_id(excluded_result_index, excluded_result_ids, excluded_result_ids_size, id)) { results.push_back(id); count++; } @@ -1375,11 +1371,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } for (uint32_t count = 0; count < n && is_valid;) { - while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < seq_id) { - excluded_result_index++; - } - - if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != seq_id) { + if (!ArrayUtils::skip_index_to_id(excluded_result_index, excluded_result_ids, excluded_result_ids_size, seq_id)) { results.push_back(seq_id); count++; } diff --git a/src/index.cpp b/src/index.cpp index ba405b53..225a7b7b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4966,7 +4966,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_begin = search_begin_us; const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - size_t excluded_result_index = 0; + uint32_t excluded_result_index = 0; for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { std::vector batch_result_ids;