diff --git a/include/or_iterator.h b/include/or_iterator.h index 958d10b0..d98b94a9 100644 --- a/include/or_iterator.h +++ b/include/or_iterator.h @@ -56,6 +56,7 @@ template bool or_iterator_t::intersect(std::vector& its, result_iter_state_t& istate, T func) { size_t it_size = its.size(); bool is_excluded; + size_t num_processed = 0; switch (its.size()) { case 0: @@ -66,6 +67,14 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } while(its.size() == it_size && its[0].valid()) { + num_processed++; + if (num_processed % 65536 == 0 && + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) { + search_cutoff = true; + break; + } + auto id = its[0].id(); if(take_id(istate, id, is_excluded)) { func(id, its); @@ -90,6 +99,14 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } while(its.size() == it_size && !at_end2(its)) { + num_processed++; + if (num_processed % 65536 == 0 && + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) { + search_cutoff = true; + break; + } + if(equals2(its)) { auto id = its[0].id(); if(take_id(istate, id, is_excluded)) { @@ -120,6 +137,14 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } while(its.size() == it_size && !at_end(its)) { + num_processed++; + if (num_processed % 65536 == 0 && + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) { + search_cutoff = true; + break; + } + if(equals(its)) { auto id = its[0].id(); if(take_id(istate, id, is_excluded)) { diff --git a/include/posting.h b/include/posting.h index 9bb68c9e..29ab8cc4 100644 --- a/include/posting.h +++ b/include/posting.h @@ -48,13 +48,10 @@ public: std::vector plists; std::vector expanded_plists; result_iter_state_t& iter_state; - ThreadPool* thread_pool; block_intersector_t(const std::vector& raw_posting_lists, - result_iter_state_t& iter_state, - ThreadPool* thread_pool, - size_t parallelize_min_ids = 1): - iter_state(iter_state), thread_pool(thread_pool) { + result_iter_state_t& iter_state): + iter_state(iter_state) { to_expanded_plists(raw_posting_lists, plists, expanded_plists); @@ -72,7 +69,7 @@ public: } template - bool intersect(T func, size_t concurrency=4); + bool intersect(T func); }; static void to_expanded_plists(const std::vector& raw_posting_lists, std::vector& plists, @@ -115,7 +112,7 @@ public: }; template -bool posting_t::block_intersector_t::intersect(T func, size_t concurrency) { +bool posting_t::block_intersector_t::intersect(T func) { if(plists.empty()) { return true; } diff --git a/include/posting_list.h b/include/posting_list.h index 76dc52c8..95a57cbc 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -5,6 +5,7 @@ #include "sorted_array.h" #include "array.h" #include "match_score.h" +#include "thread_local_vars.h" typedef uint32_t last_id_t; @@ -17,7 +18,6 @@ struct result_iter_state_t { size_t excluded_result_ids_index = 0; size_t filter_ids_index = 0; - size_t index = 0; result_iter_state_t() = default; @@ -203,13 +203,23 @@ template bool posting_list_t::block_intersect(std::vector& its, result_iter_state_t& istate, T func) { + size_t num_processed = 0; + switch (its.size()) { case 0: break; case 1: while(its[0].valid()) { + num_processed++; + if (num_processed % 65536 == 0 && + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) { + search_cutoff = true; + break; + } + if(posting_list_t::take_id(istate, its[0].id())) { - func(its[0].id(), its, istate.index); + func(its[0].id(), its); } its[0].next(); @@ -217,9 +227,17 @@ bool posting_list_t::block_intersect(std::vector& it break; case 2: while(!at_end2(its)) { + num_processed++; + if (num_processed % 65536 == 0 && + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) { + search_cutoff = true; + break; + } + if(equals2(its)) { if(posting_list_t::take_id(istate, its[0].id())) { - func(its[0].id(), its, istate.index); + func(its[0].id(), its); } advance_all2(its); @@ -230,10 +248,18 @@ bool posting_list_t::block_intersect(std::vector& it break; default: while(!at_end(its)) { + num_processed++; + if (num_processed % 65536 == 0 && + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) { + search_cutoff = true; + break; + } + if(equals(its)) { //LOG(INFO) << its[0].id(); if(posting_list_t::take_id(istate, its[0].id())) { - func(its[0].id(), its, istate.index); + func(its[0].id(), its); } advance_all(its); diff --git a/src/index.cpp b/src/index.cpp index e742003c..fb80f4f5 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1470,66 +1470,26 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, single_exact_query_token = true; } - std::vector result_id_vecs[concurrency]; - Topster* topsters[concurrency]; - std::vector> groups_processed_vec(concurrency); - if(topster == nullptr) { - posting_t::block_intersector_t( - posting_lists, iter_state, thread_pool, 100 - ) - .intersect([&](uint32_t seq_id, std::vector& its, size_t index) { - result_id_vecs[index].push_back(seq_id); - }, concurrency); + posting_t::block_intersector_t(posting_lists, iter_state) + .intersect([&](uint32_t seq_id, std::vector& its) { + id_buff.push_back(seq_id); + }); } else { - for(size_t i = 0; i < concurrency; i++) { - topsters[i] = new Topster(topster->MAX_SIZE, topster->distinct); - } + posting_t::block_intersector_t(posting_lists, iter_state) + .intersect([&](uint32_t seq_id, std::vector& its) { + score_results(sort_fields, searched_queries.size(), field_id, field_is_array, + total_cost, topster, query_suggestion, groups_processed, + seq_id, sort_order, field_values, geopoint_indices, + group_limit, group_by_fields, token_bits, + prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, its); - posting_t::block_intersector_t( - posting_lists, iter_state, thread_pool, 100 - ) - .intersect([&](uint32_t seq_id, std::vector& its, size_t index) { - score_results(sort_fields, searched_queries.size(), field_id, field_is_array, - total_cost, topsters[index], query_suggestion, groups_processed_vec[index], - seq_id, sort_order, field_values, geopoint_indices, - group_limit, group_by_fields, token_bits, - prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, its); - - result_id_vecs[index].push_back(seq_id); - }, concurrency); + id_buff.push_back(seq_id); + }); } delete [] excluded_result_ids; - - size_t num_result_ids = 0; - - for(size_t i = 0; i < concurrency; i++) { - // empty vec can happen if not all threads produce results - if (!result_id_vecs[i].empty()) { - if(exhaustive_search) { - id_buff.insert(id_buff.end(), result_id_vecs[i].begin(), result_id_vecs[i].end()); - } else { - uint32_t* new_all_result_ids = nullptr; - all_result_ids_len = ArrayUtils::or_scalar(*all_result_ids, all_result_ids_len, &result_id_vecs[i][0], - result_id_vecs[i].size(), &new_all_result_ids); - delete[] *all_result_ids; - *all_result_ids = new_all_result_ids; - } - - num_result_ids += result_id_vecs[i].size(); - - if (topster != nullptr) { - // topster is null when used by overrides which requires only IDs but not actual processing - aggregate_topster(topster, topsters[i]); - groups_processed.insert(groups_processed_vec[i].begin(), groups_processed_vec[i].end()); - } - } - - if(topster != nullptr) { - delete topsters[i]; - } - } + const size_t num_result_ids = id_buff.size(); if(id_buff.size() > 100000) { // prevents too many ORs during exhaustive searching diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 0e5d2338..5c941eb9 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -1712,3 +1712,27 @@ TEST_F(CollectionSpecificMoreTest, HighlightOnFieldNameWithDot) { highlight = R"({"org.title":{"matched_tokens":["Infinity"],"snippet":"Infinity Inc."}})"_json; ASSERT_EQ(highlight.dump(), res["hits"][0]["highlight"].dump()); } + +TEST_F(CollectionSpecificMoreTest, SearchCutoffTest) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string"} + ] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + for(size_t i = 0; i < 70000; i++) { + nlohmann::json doc; + doc["title"] = "1 2"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + auto res = coll1->search("1 2", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, 5, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 1).get(); + + ASSERT_TRUE(res["search_cutoff"].get()); +} diff --git a/test/posting_list_test.cpp b/test/posting_list_test.cpp index 55495152..16e510aa 100644 --- a/test/posting_list_test.cpp +++ b/test/posting_list_test.cpp @@ -642,7 +642,7 @@ TEST_F(PostingListTest, IntersectionBasics) { result_iter_state_t iter_state; result_ids.clear(); - posting_t::block_intersector_t(raw_lists, iter_state, pool).intersect([&](auto id, auto& its, size_t index){ + posting_t::block_intersector_t(raw_lists, iter_state).intersect([&](auto id, auto& its){ std::unique_lock lk(vecm); result_ids.push_back(id); }); @@ -668,7 +668,7 @@ TEST_F(PostingListTest, IntersectionBasics) { result_ids.clear(); raw_lists = {&p1}; - posting_t::block_intersector_t(raw_lists, iter_state2, pool).intersect([&](auto id, auto& its, size_t index){ + posting_t::block_intersector_t(raw_lists, iter_state2).intersect([&](auto id, auto& its){ std::unique_lock lk(vecm); result_ids.push_back(id); }); @@ -691,7 +691,7 @@ TEST_F(PostingListTest, IntersectionBasics) { result_ids.clear(); raw_lists.clear(); - posting_t::block_intersector_t(raw_lists, iter_state3, pool).intersect([&](auto id, auto& its, size_t index){ + posting_t::block_intersector_t(raw_lists, iter_state3).intersect([&](auto id, auto& its){ std::unique_lock lk(vecm); result_ids.push_back(id); }); @@ -1305,8 +1305,8 @@ TEST_F(PostingListTest, BlockIntersectionOnMixedLists) { std::vector result_ids; std::mutex vecm; - posting_t::block_intersector_t(raw_posting_lists, iter_state, pool) - .intersect([&](auto seq_id, auto& its, size_t index) { + posting_t::block_intersector_t(raw_posting_lists, iter_state) + .intersect([&](auto seq_id, auto& its) { std::unique_lock lock(vecm); result_ids.push_back(seq_id); });