From 16d6a5cbf05e4c0eb6abc1e37add76ff0f4153eb Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 3 Feb 2023 14:30:17 +0530 Subject: [PATCH] Fix double locking of collection mutex. --- include/index.h | 24 ++++-- src/collection.cpp | 4 +- src/index.cpp | 61 +++++++------- test/collection_join_test.cpp | 147 +++++++++++++++++++++++++++++++++- 4 files changed, 198 insertions(+), 38 deletions(-) diff --git a/include/index.h b/include/index.h index e142e6ae..e0875935 100644 --- a/include/index.h +++ b/include/index.h @@ -484,21 +484,27 @@ private: uint32_t*& ids, size_t& ids_len) const; - void do_filtering(filter_node_t* const root) const; + void do_filtering(filter_node_t* const root, const std::string& collection_name) const; - void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const; + void rearranging_recursive_filter (uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const; void recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root, - const bool enable_short_circuit = false) const; + const std::string& collection_name) const; void adaptive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const filter_tree_root, - const bool enable_short_circuit = false) const; + const std::string& collection_name = "") const; - void get_filter_matches(filter_node_t* const root, std::vector>& vec) const; + void get_filter_matches(filter_node_t* const root, + std::vector>& vec, + const std::string& collection_name) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; @@ -656,7 +662,7 @@ public: // Public operations - void run_search(search_args* search_params); + void run_search(search_args* search_params, const std::string& collection_name); void search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, @@ -679,7 +685,8 @@ public: size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, enable_t split_join_tokens, - const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold) const; + const vector_query_t& vector_query, size_t facet_sample_percent, size_t facet_sample_threshold, + const std::string& collection_name) const; void remove_field(uint32_t seq_id, const nlohmann::json& document, const std::string& field_name); @@ -720,7 +727,8 @@ public: void do_filtering_with_lock( uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t* filter_tree_root) const; + filter_node_t* filter_tree_root, + const std::string& collection_name) const; void do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, diff --git a/src/collection.cpp b/src/collection.cpp index 2e2b8b36..cba0d545 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1449,7 +1449,7 @@ Option Collection::search(const std::string & raw_query, filter_curated_hits, split_join_tokens, vector_query, facet_sample_percent, facet_sample_threshold); - index->run_search(search_params); + index->run_search(search_params, name); // for grouping we have to re-aggregate @@ -2405,7 +2405,7 @@ Option Collection::get_filter_ids(const std::string & filter_query, uint32_t* filter_ids = nullptr; uint32_t filter_ids_len = 0; - index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root); + index->do_filtering_with_lock(filter_ids, filter_ids_len, filter_tree_root, name); index_ids.emplace_back(filter_ids_len, filter_ids); delete filter_tree_root; diff --git a/src/index.cpp b/src/index.cpp index bf4987fd..a14b5274 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1617,7 +1617,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, ids = out; } -void Index::do_filtering(filter_node_t* const root) const { +void Index::do_filtering(filter_node_t* const root, const std::string& collection_name) const { // auto begin = std::chrono::high_resolution_clock::now(); const filter a_filter = root->filter_exp; @@ -1628,7 +1628,7 @@ void Index::do_filtering(filter_node_t* const root) const { auto collection = cm.get_collection(a_filter.referenced_collection_name); auto op = collection->get_reference_filter_ids(a_filter.field_name, - cm.get_collection_with_id(collection_id)->get_name(), + collection_name, root->match_index_ids); if (!op.ok()) { return; @@ -1957,26 +1957,29 @@ void Index::do_filtering(filter_node_t* const root) const { LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ } -void Index::get_filter_matches(filter_node_t* const root, std::vector>& vec) const { +void Index::get_filter_matches(filter_node_t* const root, + std::vector>& vec, + const std::string& collection_name) const { if (root == nullptr) { return; } if (root->isOperator) { if (root->filter_operator == AND) { - get_filter_matches(root->left, vec); - get_filter_matches(root->right, vec); + get_filter_matches(root->left, vec, collection_name); + get_filter_matches(root->right, vec, collection_name); } else { uint32_t *l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { - rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + rearranging_recursive_filter(l_filter_ids, l_filter_ids_length, root->left, collection_name); } uint32_t *r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right); + rearranging_recursive_filter(r_filter_ids, r_filter_ids_length, root->right, collection_name); } root->match_index_ids.first = ArrayUtils::or_scalar( @@ -1992,7 +1995,7 @@ void Index::get_filter_matches(filter_node_t* const root, std::vectormatch_index_ids.first, root); } @@ -2031,9 +2034,12 @@ void evaluate_rearranged_filter_tree(uint32_t*& filter_ids, } } -void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const { +void Index::rearranging_recursive_filter(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const std::string& collection_name) const { std::vector> vec; - get_filter_matches(root, vec); + get_filter_matches(root, vec, collection_name); std::sort(vec.begin(), vec.end(), [](const std::pair& lhs, const std::pair& rhs) { @@ -2050,7 +2056,7 @@ void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter void Index::recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root, - const bool enable_short_circuit) const { + const std::string& collection_name) const { if (root == nullptr) { return; } @@ -2059,15 +2065,13 @@ void Index::recursive_filter(uint32_t*& filter_ids, uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { - recursive_filter(l_filter_ids, l_filter_ids_length, root->left, - enable_short_circuit); + recursive_filter(l_filter_ids, l_filter_ids_length, root->left,collection_name); } uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { - recursive_filter(r_filter_ids, r_filter_ids_length, root->right, - enable_short_circuit); + recursive_filter(r_filter_ids, r_filter_ids_length, root->right,collection_name); } uint32_t* filtered_results = nullptr; @@ -2088,7 +2092,7 @@ void Index::recursive_filter(uint32_t*& filter_ids, return; } - do_filtering(root); + do_filtering(root, collection_name); filter_ids_length = root->match_index_ids.first; filter_ids = root->match_index_ids.second; @@ -2099,7 +2103,7 @@ void Index::recursive_filter(uint32_t*& filter_ids, void Index::adaptive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const filter_tree_root, - const bool enable_short_circuit) const { + const std::string& collection_name) const { if (filter_tree_root == nullptr) { return; } @@ -2109,24 +2113,25 @@ void Index::adaptive_filter(uint32_t*& filter_ids, (*filter_tree_root->metrics).and_operator_count > 0 && // If there are more || in the filter tree than &&, we'll not gain much by rearranging the filter tree. ((float) (*filter_tree_root->metrics).or_operator_count / (float) (*filter_tree_root->metrics).and_operator_count < 0.5)) { - rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root); + rearranging_recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); } else { - recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); + recursive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); } } void Index::do_filtering_with_lock(uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t* filter_tree_root) const { + filter_node_t* filter_tree_root, + const std::string& collection_name) const { std::shared_lock lock(mutex); - adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, false); + adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); } void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, filter_node_t* filter_tree_root, const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); - adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); + adaptive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root); std::vector vector; vector.reserve(reference_index_ids.first); @@ -2142,7 +2147,7 @@ void Index::do_reference_filtering_with_lock(std::pair& ref std::copy(vector.begin(), vector.end(), reference_index_ids.second); } -void Index::run_search(search_args* search_params) { +void Index::run_search(search_args* search_params, const std::string& collection_name) { search(search_params->field_query_tokens, search_params->search_fields, search_params->match_type, @@ -2175,7 +2180,8 @@ void Index::run_search(search_args* search_params) { search_params->split_join_tokens, search_params->vector_query, search_params->facet_sample_percent, - search_params->facet_sample_threshold); + search_params->facet_sample_threshold, + collection_name); } void Index::collate_included_ids(const std::vector& q_included_tokens, @@ -2625,7 +2631,8 @@ void Index::search(std::vector& field_query_tokens, const std::v const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, const enable_t split_join_tokens, const vector_query_t& vector_query, - size_t facet_sample_percent, size_t facet_sample_threshold) const { + size_t facet_sample_percent, size_t facet_sample_threshold, + const std::string& collection_name) const { // process the filters @@ -2634,7 +2641,7 @@ void Index::search(std::vector& field_query_tokens, const std::v std::shared_lock lock(mutex); - adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, true); + adaptive_filter(filter_ids, filter_ids_length, filter_tree_root, collection_name); if (filter_tree_root != nullptr && filter_ids_length == 0) { delete [] filter_ids; @@ -4730,7 +4737,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root, true); + adaptive_filter(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filter_tree_root); } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 7d45523a..ab1936d2 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -299,7 +299,7 @@ TEST_F(CollectionJoinTest, IndexDocumentHavingReferenceField) { collectionManager.drop_collection("Products"); } -TEST_F(CollectionJoinTest, FilterByReferenceField) { +TEST_F(CollectionJoinTest, FilterByReferenceField_SingleMatch) { auto schema_json = R"({ "name": "Products", @@ -404,4 +404,149 @@ TEST_F(CollectionJoinTest, FilterByReferenceField) { ASSERT_EQ(1, result["found"].get()); ASSERT_EQ(1, result["hits"].size()); ASSERT_EQ("soap", result["hits"][0]["document"]["product_name"].get()); + +// collectionManager.drop_collection("Customers"); +// collectionManager.drop_collection("Products"); } + +TEST_F(CollectionJoinTest, FilterByReferenceField_MultipleMatch) { + auto schema_json = + R"({ + "name": "Users", + "fields": [ + {"name": "user_id", "type": "string"}, + {"name": "user_name", "type": "string"} + ] + })"_json; + std::vector documents = { + R"({ + "user_id": "user_a", + "user_name": "Roshan" + })"_json, + R"({ + "user_id": "user_b", + "user_name": "Ruby" + })"_json, + R"({ + "user_id": "user_c", + "user_name": "Joe" + })"_json, + R"({ + "user_id": "user_d", + "user_name": "Aby" + })"_json + }; + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Repos", + "fields": [ + {"name": "repo_id", "type": "string"}, + {"name": "repo_content", "type": "string"} + ] + })"_json; + documents = { + R"({ + "repo_id": "repo_a", + "repo_content": "body1" + })"_json, + R"({ + "repo_id": "repo_b", + "repo_content": "body2" + })"_json, + R"({ + "repo_id": "repo_c", + "repo_content": "body3" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Links", + "fields": [ + {"name": "repo_id", "type": "string", "reference": "Repos.repo_id"}, + {"name": "user_id", "type": "string", "reference": "Users.user_id"} + ] + })"_json; + documents = { + R"({ + "repo_id": "repo_a", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_a", + "user_id": "user_c" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_a" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_d" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_a" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_c" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_d" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + auto coll = collectionManager.get_collection("Users"); + + // Search for users linked to repo_b + auto result = coll->search("R", {"user_name"}, "$Links(repo_id:=repo_b)", {}, {}, {0}, + 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD).get(); + + ASSERT_EQ(2, result["found"].get()); + ASSERT_EQ(2, result["hits"].size()); + ASSERT_EQ("user_b", result["hits"][0]["document"]["user_id"].get()); + ASSERT_EQ("user_a", result["hits"][1]["document"]["user_id"].get()); + +// collectionManager.drop_collection("Users"); +// collectionManager.drop_collection("Repos"); +// collectionManager.drop_collection("Links"); +} \ No newline at end of file