From 5c5f43195c759f9b4255ea7e4f6f7b461fd1d01e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 27 Jan 2023 12:57:13 +0530 Subject: [PATCH] Add `Index::rearranging_recursive_filter`. --- include/field.h | 2 + include/index.h | 20 ++-- src/index.cpp | 156 ++++++++++++++++++++++------- test/collection_filtering_test.cpp | 56 +++++++++++ 4 files changed, 188 insertions(+), 46 deletions(-) diff --git a/include/field.h b/include/field.h index ee69bff5..34d0c6e7 100644 --- a/include/field.h +++ b/include/field.h @@ -536,6 +536,7 @@ struct filter_node_t { bool isOperator; filter_node_t* left; filter_node_t* right; + std::pair match_index_ids; filter_node_t(filter filter_exp) : filter_exp(std::move(filter_exp)), @@ -552,6 +553,7 @@ struct filter_node_t { right(right) {} ~filter_node_t() { + delete[] match_index_ids.second; delete left; delete right; } diff --git a/include/index.h b/include/index.h index a6b161cd..f8690e99 100644 --- a/include/index.h +++ b/include/index.h @@ -99,7 +99,7 @@ struct search_args { std::vector field_query_tokens; std::vector search_fields; const text_match_type_t match_type; - const filter_node_t* filter_tree_root; + filter_node_t* filter_tree_root; std::vector& facets; std::vector>& included_ids; std::vector excluded_ids; @@ -484,14 +484,16 @@ private: uint32_t*& ids, size_t& ids_len) const; - void do_filtering(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t const* const root) const; + void do_filtering(filter_node_t* const root) const; + + void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const; void recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t const* const root, - const bool enable_short_circuit) const; + filter_node_t* const root, + const bool enable_short_circuit = false) const; + + void get_filter_matches(filter_node_t* const root, std::vector>& vec) const; void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; @@ -653,7 +655,7 @@ public: void search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t const* const& filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, @@ -713,10 +715,10 @@ public: void do_filtering_with_lock( uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t const* const& filter_tree_root) const; + filter_node_t* filter_tree_root) const; void do_reference_filtering_with_lock(std::pair& reference_index_ids, - filter_node_t const* const& filter_tree_root, + filter_node_t* filter_tree_root, const std::string& reference_field_name) const; void refresh_schemas(const std::vector& new_fields, const std::vector& del_fields); diff --git a/src/index.cpp b/src/index.cpp index 64ad5a66..5e722fdd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1661,11 +1661,9 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, ids = out; } -void Index::do_filtering(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - filter_node_t const* const root) const { +void Index::do_filtering(filter_node_t* const root) const { // auto begin = std::chrono::high_resolution_clock::now(); -/**/ const filter a_filter = root->filter_exp; + const filter a_filter = root->filter_exp; bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); if (is_referenced_filter) { @@ -1673,16 +1671,12 @@ void Index::do_filtering(uint32_t*& filter_ids, auto& cm = CollectionManager::get_instance(); auto collection = cm.get_collection(a_filter.referenced_collection_name); - std::pair documents; auto op = collection->get_reference_filter_ids(a_filter.field_name, cm.get_collection_with_id(collection_id)->get_name(), - documents); + root->match_index_ids); if (!op.ok()) { return; } - - filter_ids_length = documents.first; - filter_ids = documents.second; return; } @@ -1695,17 +1689,9 @@ void Index::do_filtering(uint32_t*& filter_ids, std::sort(result_ids.begin(), result_ids.end()); - if (filter_ids_length == 0) { - filter_ids = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), filter_ids); - filter_ids_length = result_ids.size(); - } else { - uint32_t* filtered_results = nullptr; - filter_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, &result_ids[0], - result_ids.size(), &filtered_results); - delete[] filter_ids; - filter_ids = filtered_results; - } + root->match_index_ids.second = new uint32[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), root->match_index_ids.second); + root->match_index_ids.first = result_ids.size(); return; } @@ -2005,8 +1991,8 @@ void Index::do_filtering(uint32_t*& filter_ids, result_ids_len = to_include_ids_len; } - filter_ids = result_ids; - filter_ids_length = result_ids_len; + root->match_index_ids.first = result_ids_len; + root->match_index_ids.second = result_ids; /*long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() @@ -2015,38 +2001,131 @@ void Index::do_filtering(uint32_t*& filter_ids, LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ } -void Index::recursive_filter(uint32_t*& filter_ids, - uint32_t& filter_ids_length, - const filter_node_t* root, - const bool enable_short_circuit) const { +void Index::get_filter_matches(filter_node_t* const root, std::vector>& vec) const { if (root == nullptr) { return; } + if (root->isOperator && root->filter_operator == OR) { + uint32_t* l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + recursive_filter(l_filter_ids, l_filter_ids_length, root->left); + } + + uint32_t* r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + recursive_filter(r_filter_ids, r_filter_ids_length, root->right); + } + + root->match_index_ids.first = ArrayUtils::or_scalar( + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &(root->match_index_ids.second)); + + delete[] l_filter_ids; + delete[] r_filter_ids; + + vec.emplace_back(root->match_index_ids.first, root); + } else if (root->left == nullptr && root->right == nullptr) { + do_filtering(root); + vec.emplace_back(root->match_index_ids.first, root); + } else { + get_filter_matches(root->left, vec); + get_filter_matches(root->right, vec); + } +} + +void evaluate_filter_tree(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + bool is_rearranged, + std::vector>& vec, + size_t& index) { + if (root == nullptr) { + return; + } + + if (root->isOperator) { + if (root->filter_operator == AND) { + + uint32_t* l_filter_ids = nullptr; + uint32_t l_filter_ids_length = 0; + if (root->left != nullptr) { + evaluate_filter_tree(l_filter_ids, l_filter_ids_length, root->left, is_rearranged, vec, index); + } + + uint32_t* r_filter_ids = nullptr; + uint32_t r_filter_ids_length = 0; + if (root->right != nullptr) { + evaluate_filter_tree(r_filter_ids, r_filter_ids_length, root->right, is_rearranged, vec, index); + } + + root->match_index_ids.first = ArrayUtils::and_scalar( + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &(root->match_index_ids.second)); + + filter_ids_length = root->match_index_ids.first; + filter_ids = root->match_index_ids.second; + } else { + filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first; + filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second; + index++; + } + } else if (root->left == nullptr && root->right == nullptr) { + filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first; + filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second; + index++; + } else { + // malformed + } +} + +void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const { + std::vector> vec; + get_filter_matches(root, vec); + + bool should_rearrange = vec.size() > 2; + if (should_rearrange) { + std::sort(vec.begin(), vec.end(), + [](const std::pair& lhs, const std::pair& rhs) { + return lhs.first < rhs.first; + }); + } + + size_t index = 0; + evaluate_filter_tree(filter_ids, filter_ids_length, root, should_rearrange, vec, index); +} + +void Index::recursive_filter(uint32_t*& filter_ids, + uint32_t& filter_ids_length, + filter_node_t* const root, + const bool enable_short_circuit) const { + if (root == nullptr) { + return; + } uint32_t* l_filter_ids = nullptr; uint32_t l_filter_ids_length = 0; if (root->left != nullptr) { recursive_filter(l_filter_ids, l_filter_ids_length, root->left, enable_short_circuit); } - uint32_t* r_filter_ids = nullptr; uint32_t r_filter_ids_length = 0; if (root->right != nullptr) { recursive_filter(r_filter_ids, r_filter_ids_length, root->right, enable_short_circuit); } - if (root->isOperator) { uint32_t* filtered_results = nullptr; if (root->filter_operator == AND) { filter_ids_length = ArrayUtils::and_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &filtered_results); + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &filtered_results); } else { filter_ids_length = ArrayUtils::or_scalar( - l_filter_ids, l_filter_ids_length, r_filter_ids, - r_filter_ids_length, &filtered_results); + l_filter_ids, l_filter_ids_length, r_filter_ids, + r_filter_ids_length, &filtered_results); } delete[] l_filter_ids; @@ -2054,7 +2133,10 @@ void Index::recursive_filter(uint32_t*& filter_ids, filter_ids = filtered_results; } else if (root->left == nullptr && root->right == nullptr) { - do_filtering(filter_ids, filter_ids_length, root); + do_filtering(root); + filter_ids_length = root->match_index_ids.first; + filter_ids = root->match_index_ids.second; + root->match_index_ids.second = nullptr; } else { // malformed } @@ -2062,13 +2144,13 @@ void Index::recursive_filter(uint32_t*& filter_ids, void Index::do_filtering_with_lock(uint32_t*& filter_ids, uint32_t& filter_ids_length, - filter_node_t const* const& filter_tree_root) const { + filter_node_t* filter_tree_root) const { std::shared_lock lock(mutex); recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false); } void Index::do_reference_filtering_with_lock(std::pair& reference_index_ids, - filter_node_t const* const& filter_tree_root, + filter_node_t* filter_tree_root, const std::string& reference_field_name) const { std::shared_lock lock(mutex); recursive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false); @@ -2077,7 +2159,7 @@ void Index::do_reference_filtering_with_lock(std::pair& ref vector.reserve(reference_index_ids.first); for (uint32_t i = 0; i < reference_index_ids.first; i++) { - auto filtered_doc_id = *(reference_index_ids.second + i); + auto filtered_doc_id = reference_index_ids.second[i]; // Extract the sequence_id from the reference field. vector.push_back(sort_index.at(reference_field_name)->at(filtered_doc_id)); @@ -2550,7 +2632,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name void Index::search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t const* const& filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 39194688..d253bf4e 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -2536,3 +2536,59 @@ TEST_F(CollectionFilteringTest, FilteringAfterUpsertOnArrayWithSymbolsToIndex) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionFilteringTest, ComplexFilterQuery) { + nlohmann::json schema_json = + R"({ + "name": "ComplexFilterQueryCollection", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int32"}, + {"name": "years", "type": "int32[]"}, + {"name": "rating", "type": "float"} + ] + })"_json; + + auto op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(op.ok()); + auto coll = op.get(); + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::string json_line; + while (std::getline(infile, json_line)) { + auto add_op = coll->add(json_line); + ASSERT_TRUE(add_op.ok()); + } + infile.close(); + + std::vector sort_fields_desc = {sort_by("rating", "DESC")}; + nlohmann::json results = coll->search("Jeremy", {"name"}, "(rating:>=0 && years:>2000) && age:>50", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["hits"].size()); + + results = coll->search("Jeremy", {"name"}, "(age:>50 || rating:>5) && years:<2000", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(2, results["hits"].size()); + + std::vector ids = {"4", "3"}; + for (size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results = coll->search("Jeremy", {"name"}, "(age:<50 && rating:10) || (years:>2000 && rating:<5)", + {}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(1, results["hits"].size()); + + ids = {"0"}; + for (size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + collectionManager.drop_collection("ComplexFilterQueryCollection"); +} \ No newline at end of file