Implement search cutoff during intersection.

This commit is contained in:
Kishore Nallan 2022-11-30 15:37:12 +05:30
parent 71260ea2f0
commit b911766379
6 changed files with 102 additions and 70 deletions

View File

@ -56,6 +56,7 @@ template<class T>
bool or_iterator_t::intersect(std::vector<or_iterator_t>& its, result_iter_state_t& istate, T func) {
size_t it_size = its.size();
bool is_excluded;
size_t num_processed = 0;
switch (its.size()) {
case 0:
@ -66,6 +67,14 @@ bool or_iterator_t::intersect(std::vector<or_iterator_t>& its, result_iter_state
}
while(its.size() == it_size && its[0].valid()) {
num_processed++;
if (num_processed % 65536 == 0 &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) {
search_cutoff = true;
break;
}
auto id = its[0].id();
if(take_id(istate, id, is_excluded)) {
func(id, its);
@ -90,6 +99,14 @@ bool or_iterator_t::intersect(std::vector<or_iterator_t>& its, result_iter_state
}
while(its.size() == it_size && !at_end2(its)) {
num_processed++;
if (num_processed % 65536 == 0 &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) {
search_cutoff = true;
break;
}
if(equals2(its)) {
auto id = its[0].id();
if(take_id(istate, id, is_excluded)) {
@ -120,6 +137,14 @@ bool or_iterator_t::intersect(std::vector<or_iterator_t>& its, result_iter_state
}
while(its.size() == it_size && !at_end(its)) {
num_processed++;
if (num_processed % 65536 == 0 &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) {
search_cutoff = true;
break;
}
if(equals(its)) {
auto id = its[0].id();
if(take_id(istate, id, is_excluded)) {

View File

@ -48,13 +48,10 @@ public:
std::vector<posting_list_t*> plists;
std::vector<posting_list_t*> expanded_plists;
result_iter_state_t& iter_state;
ThreadPool* thread_pool;
block_intersector_t(const std::vector<void*>& raw_posting_lists,
result_iter_state_t& iter_state,
ThreadPool* thread_pool,
size_t parallelize_min_ids = 1):
iter_state(iter_state), thread_pool(thread_pool) {
result_iter_state_t& iter_state):
iter_state(iter_state) {
to_expanded_plists(raw_posting_lists, plists, expanded_plists);
@ -72,7 +69,7 @@ public:
}
template<class T>
bool intersect(T func, size_t concurrency=4);
bool intersect(T func);
};
static void to_expanded_plists(const std::vector<void*>& raw_posting_lists, std::vector<posting_list_t*>& plists,
@ -115,7 +112,7 @@ public:
};
template<class T>
bool posting_t::block_intersector_t::intersect(T func, size_t concurrency) {
bool posting_t::block_intersector_t::intersect(T func) {
if(plists.empty()) {
return true;
}

View File

@ -5,6 +5,7 @@
#include "sorted_array.h"
#include "array.h"
#include "match_score.h"
#include "thread_local_vars.h"
typedef uint32_t last_id_t;
@ -17,7 +18,6 @@ struct result_iter_state_t {
size_t excluded_result_ids_index = 0;
size_t filter_ids_index = 0;
size_t index = 0;
result_iter_state_t() = default;
@ -203,13 +203,23 @@ template<class T>
bool posting_list_t::block_intersect(std::vector<posting_list_t::iterator_t>& its, result_iter_state_t& istate,
T func) {
size_t num_processed = 0;
switch (its.size()) {
case 0:
break;
case 1:
while(its[0].valid()) {
num_processed++;
if (num_processed % 65536 == 0 &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) {
search_cutoff = true;
break;
}
if(posting_list_t::take_id(istate, its[0].id())) {
func(its[0].id(), its, istate.index);
func(its[0].id(), its);
}
its[0].next();
@ -217,9 +227,17 @@ bool posting_list_t::block_intersect(std::vector<posting_list_t::iterator_t>& it
break;
case 2:
while(!at_end2(its)) {
num_processed++;
if (num_processed % 65536 == 0 &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) {
search_cutoff = true;
break;
}
if(equals2(its)) {
if(posting_list_t::take_id(istate, its[0].id())) {
func(its[0].id(), its, istate.index);
func(its[0].id(), its);
}
advance_all2(its);
@ -230,10 +248,18 @@ bool posting_list_t::block_intersect(std::vector<posting_list_t::iterator_t>& it
break;
default:
while(!at_end(its)) {
num_processed++;
if (num_processed % 65536 == 0 &&
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - search_begin).count() > search_stop_ms) {
search_cutoff = true;
break;
}
if(equals(its)) {
//LOG(INFO) << its[0].id();
if(posting_list_t::take_id(istate, its[0].id())) {
func(its[0].id(), its, istate.index);
func(its[0].id(), its);
}
advance_all(its);

View File

@ -1470,66 +1470,26 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
single_exact_query_token = true;
}
std::vector<uint32_t> result_id_vecs[concurrency];
Topster* topsters[concurrency];
std::vector<spp::sparse_hash_set<uint64_t>> groups_processed_vec(concurrency);
if(topster == nullptr) {
posting_t::block_intersector_t(
posting_lists, iter_state, thread_pool, 100
)
.intersect([&](uint32_t seq_id, std::vector<posting_list_t::iterator_t>& its, size_t index) {
result_id_vecs[index].push_back(seq_id);
}, concurrency);
posting_t::block_intersector_t(posting_lists, iter_state)
.intersect([&](uint32_t seq_id, std::vector<posting_list_t::iterator_t>& its) {
id_buff.push_back(seq_id);
});
} else {
for(size_t i = 0; i < concurrency; i++) {
topsters[i] = new Topster(topster->MAX_SIZE, topster->distinct);
}
posting_t::block_intersector_t(posting_lists, iter_state)
.intersect([&](uint32_t seq_id, std::vector<posting_list_t::iterator_t>& its) {
score_results(sort_fields, searched_queries.size(), field_id, field_is_array,
total_cost, topster, query_suggestion, groups_processed,
seq_id, sort_order, field_values, geopoint_indices,
group_limit, group_by_fields, token_bits,
prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, its);
posting_t::block_intersector_t(
posting_lists, iter_state, thread_pool, 100
)
.intersect([&](uint32_t seq_id, std::vector<posting_list_t::iterator_t>& its, size_t index) {
score_results(sort_fields, searched_queries.size(), field_id, field_is_array,
total_cost, topsters[index], query_suggestion, groups_processed_vec[index],
seq_id, sort_order, field_values, geopoint_indices,
group_limit, group_by_fields, token_bits,
prioritize_exact_match, single_exact_query_token, syn_orig_num_tokens, its);
result_id_vecs[index].push_back(seq_id);
}, concurrency);
id_buff.push_back(seq_id);
});
}
delete [] excluded_result_ids;
size_t num_result_ids = 0;
for(size_t i = 0; i < concurrency; i++) {
// empty vec can happen if not all threads produce results
if (!result_id_vecs[i].empty()) {
if(exhaustive_search) {
id_buff.insert(id_buff.end(), result_id_vecs[i].begin(), result_id_vecs[i].end());
} else {
uint32_t* new_all_result_ids = nullptr;
all_result_ids_len = ArrayUtils::or_scalar(*all_result_ids, all_result_ids_len, &result_id_vecs[i][0],
result_id_vecs[i].size(), &new_all_result_ids);
delete[] *all_result_ids;
*all_result_ids = new_all_result_ids;
}
num_result_ids += result_id_vecs[i].size();
if (topster != nullptr) {
// topster is null when used by overrides which requires only IDs but not actual processing
aggregate_topster(topster, topsters[i]);
groups_processed.insert(groups_processed_vec[i].begin(), groups_processed_vec[i].end());
}
}
if(topster != nullptr) {
delete topsters[i];
}
}
const size_t num_result_ids = id_buff.size();
if(id_buff.size() > 100000) {
// prevents too many ORs during exhaustive searching

View File

@ -1712,3 +1712,27 @@ TEST_F(CollectionSpecificMoreTest, HighlightOnFieldNameWithDot) {
highlight = R"({"org.title":{"matched_tokens":["Infinity"],"snippet":"<mark>Infinity</mark> Inc."}})"_json;
ASSERT_EQ(highlight.dump(), res["hits"][0]["highlight"].dump());
}
TEST_F(CollectionSpecificMoreTest, SearchCutoffTest) {
nlohmann::json schema = R"({
"name": "coll1",
"fields": [
{"name": "title", "type": "string"}
]
})"_json;
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 70000; i++) {
nlohmann::json doc;
doc["title"] = "1 2";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto res = coll1->search("1 2", {"title"}, "", {}, {}, {0}, 3, 1, FREQUENCY, {false}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 1).get();
ASSERT_TRUE(res["search_cutoff"].get<bool>());
}

View File

@ -642,7 +642,7 @@ TEST_F(PostingListTest, IntersectionBasics) {
result_iter_state_t iter_state;
result_ids.clear();
posting_t::block_intersector_t(raw_lists, iter_state, pool).intersect([&](auto id, auto& its, size_t index){
posting_t::block_intersector_t(raw_lists, iter_state).intersect([&](auto id, auto& its){
std::unique_lock lk(vecm);
result_ids.push_back(id);
});
@ -668,7 +668,7 @@ TEST_F(PostingListTest, IntersectionBasics) {
result_ids.clear();
raw_lists = {&p1};
posting_t::block_intersector_t(raw_lists, iter_state2, pool).intersect([&](auto id, auto& its, size_t index){
posting_t::block_intersector_t(raw_lists, iter_state2).intersect([&](auto id, auto& its){
std::unique_lock lk(vecm);
result_ids.push_back(id);
});
@ -691,7 +691,7 @@ TEST_F(PostingListTest, IntersectionBasics) {
result_ids.clear();
raw_lists.clear();
posting_t::block_intersector_t(raw_lists, iter_state3, pool).intersect([&](auto id, auto& its, size_t index){
posting_t::block_intersector_t(raw_lists, iter_state3).intersect([&](auto id, auto& its){
std::unique_lock lk(vecm);
result_ids.push_back(id);
});
@ -1305,8 +1305,8 @@ TEST_F(PostingListTest, BlockIntersectionOnMixedLists) {
std::vector<uint32_t> result_ids;
std::mutex vecm;
posting_t::block_intersector_t(raw_posting_lists, iter_state, pool)
.intersect([&](auto seq_id, auto& its, size_t index) {
posting_t::block_intersector_t(raw_posting_lists, iter_state)
.intersect([&](auto seq_id, auto& its) {
std::unique_lock lock(vecm);
result_ids.push_back(seq_id);
});