Fix double locking of collection mutex.

This commit is contained in:
Harpreet Sangar 2023-02-03 14:30:17 +05:30
parent 34f039e584
commit 16d6a5cbf0
4 changed files with 198 additions and 38 deletions

View File

@ -484,21 +484,27 @@ private:
uint32_t*& ids,
size_t& ids_len) const;
void do_filtering(filter_node_t* const root) const;
void do_filtering(filter_node_t* const root, const std::string& collection_name) const;
void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const;
void rearranging_recursive_filter (uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const root,
const std::string& collection_name) const;
void recursive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const root,
const bool enable_short_circuit = false) const;
const std::string& collection_name) const;
void adaptive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const filter_tree_root,
const bool enable_short_circuit = false) const;
const std::string& collection_name = "") const;
void get_filter_matches(filter_node_t* const root, std::vector<std::pair<uint32_t, filter_node_t*>>& vec) const;
void get_filter_matches(filter_node_t* const root,
std::vector<std::pair<uint32_t,
filter_node_t*>>& vec,
const std::string& collection_name) const;
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;
@ -656,7 +662,7 @@ public:
// Public operations
void run_search(search_args* search_params);
void run_search(search_args* search_params, const std::string& collection_name);
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
@ -679,7 +685,8 @@ public:
size_t max_candidates, const std::vector<enable_t>& infixes, const size_t max_extra_prefix,
const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits, enable_t split_join_tokens,
const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold) const;
const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold,
const std::string& collection_name) const;
void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name);
@ -720,7 +727,8 @@ public:
void do_filtering_with_lock(
uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* filter_tree_root) const;
filter_node_t* filter_tree_root,
const std::string& collection_name) const;
void do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& reference_index_ids,
filter_node_t* filter_tree_root,

View File

@ -1449,7 +1449,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
filter_curated_hits, split_join_tokens, vector_query,
facet_sample_percent, facet_sample_threshold);
index->run_search(search_params);
index->run_search(search_params, name);
// for grouping we have to re-aggregate
@ -2405,7 +2405,7 @@ Option<bool> Collection::get_filter_ids(const std::string & filter_query,
uint32_t* filter_ids = nullptr;
uint32_t filter_ids_len = 0;
index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root);
index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root, name);
index_ids.emplace_back(filter_ids_len, filter_ids);
delete filter_tree_root;

View File

@ -1617,7 +1617,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
ids = out;
}
void Index::do_filtering(filter_node_t* const root) const {
void Index::do_filtering(filter_node_t* const root, const std::string& collection_name) const {
// auto begin = std::chrono::high_resolution_clock::now();
const filter a_filter = root->filter_exp;
@ -1628,7 +1628,7 @@ void Index::do_filtering(filter_node_t* const root) const {
auto collection = cm.get_collection(a_filter.referenced_collection_name);
auto op = collection->get_reference_filter_ids(a_filter.field_name,
cm.get_collection_with_id(collection_id)->get_name(),
collection_name,
root->match_index_ids);
if (!op.ok()) {
return;
@ -1957,26 +1957,29 @@ void Index::do_filtering(filter_node_t* const root) const {
LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/
}
void Index::get_filter_matches(filter_node_t* const root, std::vector<std::pair<uint32_t, filter_node_t*>>& vec) const {
void Index::get_filter_matches(filter_node_t* const root,
std::vector<std::pair<uint32_t,
filter_node_t*>>& vec,
const std::string& collection_name) const {
if (root == nullptr) {
return;
}
if (root->isOperator) {
if (root->filter_operator == AND) {
get_filter_matches(root->left, vec);
get_filter_matches(root->right, vec);
get_filter_matches(root->left, vec, collection_name);
get_filter_matches(root->right, vec, collection_name);
} else {
uint32_t *l_filter_ids = nullptr;
uint32_t l_filter_ids_length = 0;
if (root->left != nullptr) {
rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left);
rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left, collection_name);
}
uint32_t *r_filter_ids = nullptr;
uint32_t r_filter_ids_length = 0;
if (root->right != nullptr) {
rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right);
rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right, collection_name);
}
root->match_index_ids.first = ArrayUtils::or_scalar(
@ -1992,7 +1995,7 @@ void Index::get_filter_matches(filter_node_t* const root, std::vector<std::pair<
return;
}
do_filtering(root);
do_filtering(root, collection_name);
vec.emplace_back(root->match_index_ids.first, root);
}
@ -2031,9 +2034,12 @@ void evaluate_rearranged_filter_tree(uint32_t*& filter_ids,
}
}
void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const {
void Index::rearranging_recursive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const root,
const std::string& collection_name) const {
std::vector<std::pair<uint32_t, filter_node_t*>> vec;
get_filter_matches(root, vec);
get_filter_matches(root, vec, collection_name);
std::sort(vec.begin(), vec.end(),
[](const std::pair<uint32_t, filter_node_t*>& lhs, const std::pair<uint32_t, filter_node_t*>& rhs) {
@ -2050,7 +2056,7 @@ void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter
void Index::recursive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const root,
const bool enable_short_circuit) const {
const std::string& collection_name) const {
if (root == nullptr) {
return;
}
@ -2059,15 +2065,13 @@ void Index::recursive_filter(uint32_t*& filter_ids,
uint32_t* l_filter_ids = nullptr;
uint32_t l_filter_ids_length = 0;
if (root->left != nullptr) {
recursive_filter(l_filter_ids, l_filter_ids_length, root->left,
enable_short_circuit);
recursive_filter(l_filter_ids, l_filter_ids_length, root->left,collection_name);
}
uint32_t* r_filter_ids = nullptr;
uint32_t r_filter_ids_length = 0;
if (root->right != nullptr) {
recursive_filter(r_filter_ids, r_filter_ids_length, root->right,
enable_short_circuit);
recursive_filter(r_filter_ids, r_filter_ids_length, root->right,collection_name);
}
uint32_t* filtered_results = nullptr;
@ -2088,7 +2092,7 @@ void Index::recursive_filter(uint32_t*& filter_ids,
return;
}
do_filtering(root);
do_filtering(root, collection_name);
filter_ids_length = root->match_index_ids.first;
filter_ids = root->match_index_ids.second;
@ -2099,7 +2103,7 @@ void Index::recursive_filter(uint32_t*& filter_ids,
void Index::adaptive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const filter_tree_root,
const bool enable_short_circuit) const {
const std::string& collection_name) const {
if (filter_tree_root == nullptr) {
return;
}
@ -2109,24 +2113,25 @@ void Index::adaptive_filter(uint32_t*& filter_ids,
(*filter_tree_root->metrics).and_operator_count > 0 &&
// If there are more || in the filter tree than &&, we'll not gain much by rearranging the filter tree.
((float) (*filter_tree_root->metrics).or_operator_count / (float) (*filter_tree_root->metrics).and_operator_count < 0.5)) {
rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root);
rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name);
} else {
recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false);
recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name);
}
}
void Index::do_filtering_with_lock(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* filter_tree_root) const {
filter_node_t* filter_tree_root,
const std::string& collection_name) const {
std::shared_lock lock(mutex);
adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, false);
adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name);
}
void Index::do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& reference_index_ids,
filter_node_t* filter_tree_root,
const std::string& reference_helper_field_name) const {
std::shared_lock lock(mutex);
adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false);
adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root);
std::vector<uint32> vector;
vector.reserve(reference_index_ids.first);
@ -2142,7 +2147,7 @@ void Index::do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& ref
std::copy(vector.begin(), vector.end(), reference_index_ids.second);
}
void Index::run_search(search_args* search_params) {
void Index::run_search(search_args* search_params, const std::string& collection_name) {
search(search_params->field_query_tokens,
search_params->search_fields,
search_params->match_type,
@ -2175,7 +2180,8 @@ void Index::run_search(search_args* search_params) {
search_params->split_join_tokens,
search_params->vector_query,
search_params->facet_sample_percent,
search_params->facet_sample_threshold);
search_params->facet_sample_threshold,
collection_name);
}
void Index::collate_included_ids(const std::vector<token_t>& q_included_tokens,
@ -2625,7 +2631,8 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits, const enable_t split_join_tokens,
const vector_query_t& vector_query,
size_t facet_sample_percent, size_t facet_sample_threshold) const {
size_t facet_sample_percent, size_t facet_sample_threshold,
const std::string& collection_name) const {
// process the filters
@ -2634,7 +2641,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
std::shared_lock lock(mutex);
adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, true);
adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name);
if (filter_tree_root != nullptr && filter_ids_length == 0) {
delete [] filter_ids;
@ -4730,7 +4737,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
field_values[i] = &seq_id_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root, true);
adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root);
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
geopoint_indices.push_back(i);

View File

@ -299,7 +299,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) {
collectionManager.drop_collection("Products");
}
TEST_F(CollectionJoinTest, FilterByReferenceField) {
TEST_F(CollectionJoinTest, FilterByReferenceField_SingleMatch) {
auto schema_json =
R"({
"name": "Products",
@ -404,4 +404,149 @@ TEST_F(CollectionJoinTest, FilterByReferenceField) {
ASSERT_EQ(1, result["found"].get<size_t>());
ASSERT_EQ(1, result["hits"].size());
ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get<std::string>());
// collectionManager.drop_collection("Customers");
// collectionManager.drop_collection("Products");
}
TEST_F(CollectionJoinTest, FilterByReferenceField_MultipleMatch) {
auto schema_json =
R"({
"name": "Users",
"fields": [
{"name": "user_id", "type": "string"},
{"name": "user_name", "type": "string"}
]
})"_json;
std::vector<nlohmann::json> documents = {
R"({
"user_id": "user_a",
"user_name": "Roshan"
})"_json,
R"({
"user_id": "user_b",
"user_name": "Ruby"
})"_json,
R"({
"user_id": "user_c",
"user_name": "Joe"
})"_json,
R"({
"user_id": "user_d",
"user_name": "Aby"
})"_json
};
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Repos",
"fields": [
{"name": "repo_id", "type": "string"},
{"name": "repo_content", "type": "string"}
]
})"_json;
documents = {
R"({
"repo_id": "repo_a",
"repo_content": "body1"
})"_json,
R"({
"repo_id": "repo_b",
"repo_content": "body2"
})"_json,
R"({
"repo_id": "repo_c",
"repo_content": "body3"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Links",
"fields": [
{"name": "repo_id", "type": "string", "reference": "Repos.repo_id"},
{"name": "user_id", "type": "string", "reference": "Users.user_id"}
]
})"_json;
documents = {
R"({
"repo_id": "repo_a",
"user_id": "user_b"
})"_json,
R"({
"repo_id": "repo_a",
"user_id": "user_c"
})"_json,
R"({
"repo_id": "repo_b",
"user_id": "user_a"
})"_json,
R"({
"repo_id": "repo_b",
"user_id": "user_b"
})"_json,
R"({
"repo_id": "repo_b",
"user_id": "user_d"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_a"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_b"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_c"
})"_json,
R"({
"repo_id": "repo_c",
"user_id": "user_d"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
auto coll = collectionManager.get_collection("Users");
// Search for users linked to repo_b
auto result = coll->search("R", {"user_name"}, "$Links(repo_id:=repo_b)", {}, {}, {0},
10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD).get();
ASSERT_EQ(2, result["found"].get<size_t>());
ASSERT_EQ(2, result["hits"].size());
ASSERT_EQ("user_b", result["hits"][0]["document"]["user_id"].get<std::string>());
ASSERT_EQ("user_a", result["hits"][1]["document"]["user_id"].get<std::string>());
// collectionManager.drop_collection("Users");
// collectionManager.drop_collection("Repos");
// collectionManager.drop_collection("Links");
}