From 3a510ccde5067fcd2b9e1baf3122aa3682c1475e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 12 Dec 2023 19:57:09 +0530 Subject: [PATCH 1/5] Parse nested join. --- src/collection_manager.cpp | 4 ++++ test/collection_manager_test.cpp | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 3939162f..9f407d3c 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -951,6 +951,10 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte reference_collection_names.clear(); return; } + + // Need to process the filter expression inside parenthesis in case of nested join. + auto sub_filter_query = filter_query.substr(open_paren_pos + 1, i - open_paren_pos - 2); + _get_reference_collection_names(sub_filter_query, reference_collection_names); } else { while (i + 1 < size && filter_query[++i] != ':'); if (i >= size) { diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index e5f62ff0..1bf90dfd 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1425,6 +1425,15 @@ TEST_F(CollectionManagerTest, GetReferenceCollectionNames) { ASSERT_EQ(1, reference_collection_names.count(item)); } reference_collection_names.clear(); + + filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; + result = {"product_variants", "inventory", "retailers"}; + CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); + ASSERT_EQ(3, reference_collection_names.size()); + for (const auto &item: result) { + ASSERT_EQ(1, reference_collection_names.count(item)); + } + reference_collection_names.clear(); } TEST_F(CollectionManagerTest, ReferencedInBacklog) { From 4bd1619cd630edc5816077c89c5f4fa4e14c8c8f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 20 Dec 2023 13:11:20 +0530 Subject: [PATCH 2/5] Parse nested reference `include_fields`. --- include/collection_manager.h | 16 ++- include/field.h | 5 +- include/string_utils.h | 4 + src/collection_manager.cpp | 145 ++++++++++++++++++++++--- src/string_utils.cpp | 116 +++++++++++++------- test/collection_manager_test.cpp | 179 +++++++++++++++++++++++++------ test/string_utils_test.cpp | 30 +++++- 7 files changed, 406 insertions(+), 89 deletions(-) diff --git a/include/collection_manager.h b/include/collection_manager.h index c8465fee..874b8055 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -113,6 +113,15 @@ public: CollectionManager(CollectionManager const&) = delete; void operator=(CollectionManager const&) = delete; + struct ref_include_collection_names_t { + std::set collection_names; + ref_include_collection_names_t* nested_include = nullptr; + + ~ref_include_collection_names_t() { + delete nested_include; + } + }; + static Collection* init_collection(const nlohmann::json & collection_meta, const uint32_t collection_next_seq_id, Store* store, @@ -210,7 +219,12 @@ public: Option delete_preset(const std::string & preset_name); static void _get_reference_collection_names(const std::string& filter_query, - std::set& reference_collection_names); + ref_include_collection_names_t*& reference_collection_names); + + // Separate out the reference includes into `ref_include_fields_vec`. + static void _initialize_ref_include_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& ref_include_fields_vec); void add_referenced_in_backlog(const std::string& collection_name, reference_pair&& pair); diff --git a/include/field.h b/include/field.h index 808c1398..d563404d 100644 --- a/include/field.h +++ b/include/field.h @@ -518,7 +518,10 @@ struct ref_include_fields { std::string collection_name; std::string fields; std::string alias; - bool nest_ref_doc = false; + bool nest_ref_doc = true; + + // In case we have nested join. + std::vector nested_join_includes; }; struct hnsw_index_t; diff --git a/include/string_utils.h b/include/string_utils.h index 11f16967..0568a29e 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -337,4 +337,8 @@ struct StringUtils { static Option split_include_fields(const std::string& include_fields, std::vector& tokens); static size_t get_occurence_count(const std::string& str, char symbol); + + static Option split_reference_include_fields(const std::string& include_fields, + size_t& index, + std::string& token); }; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 9f407d3c..7f535745 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -910,7 +910,11 @@ Option add_unsigned_int_list_param(const std::string& param_name, const st } void CollectionManager::_get_reference_collection_names(const std::string& filter_query, - std::set& reference_collection_names) { + ref_include_collection_names_t*& ref_include) { + if (ref_include == nullptr) { + ref_include = new ref_include_collection_names_t(); + } + auto size = filter_query.size(); for (uint32_t i = 0; i < size;) { auto c = filter_query[i]; @@ -918,7 +922,7 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte i++; } else if (c == '&' || c == '|') { if (i + 1 >= size || (c == '&' && filter_query[i+1] != '&') || (c == '|' && filter_query[i+1] != '|')) { - reference_collection_names.clear(); + ref_include->collection_names.clear(); return; } i += 2; @@ -927,14 +931,14 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte if (c == '$') { auto open_paren_pos = filter_query.find('(', ++i); if (open_paren_pos == std::string::npos) { - reference_collection_names.clear(); + ref_include->collection_names.clear(); return; } auto reference_collection_name = filter_query.substr(i, open_paren_pos - i); StringUtils::trim(reference_collection_name); if (!reference_collection_name.empty()) { - reference_collection_names.insert(reference_collection_name); + ref_include->collection_names.insert(reference_collection_name); } i = open_paren_pos; @@ -948,17 +952,19 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte } if (parenthesis_count != 0) { - reference_collection_names.clear(); + ref_include->collection_names.clear(); return; } // Need to process the filter expression inside parenthesis in case of nested join. auto sub_filter_query = filter_query.substr(open_paren_pos + 1, i - open_paren_pos - 2); - _get_reference_collection_names(sub_filter_query, reference_collection_names); + if (sub_filter_query.find('$') != std::string::npos) { + _get_reference_collection_names(sub_filter_query, ref_include->nested_include); + } } else { while (i + 1 < size && filter_query[++i] != ':'); if (i >= size) { - reference_collection_names.clear(); + ref_include->collection_names.clear(); return; } @@ -976,11 +982,99 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte } } -// Separate out the reference includes into `ref_include_fields_vec`. -void initialize_ref_include_fields_vec(const std::string& filter_query, std::vector& include_fields_vec, - std::vector& ref_include_fields_vec) { - std::set reference_collection_names; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); +Option parse_nested_include(const std::string& include_field_exp, + CollectionManager::ref_include_collection_names_t* const ref_include_coll_names, + std::vector& ref_include_fields_vec) { + // Format: $ref_collection_name(field_1, field_2, $nested_ref_coll(nested_field_1: nested_include_strategy) as nested_ref_alias: include_strategy) as ref_alias + size_t index = 0; + while (index < include_field_exp.size()) { + auto parenthesis_index = include_field_exp.find('('); + auto ref_collection_name = include_field_exp.substr(index + 1, parenthesis_index - index - 1); + bool nest_ref_doc = true; + std::string ref_fields, ref_alias; + + index = parenthesis_index + 1; + auto nested_include_pos = include_field_exp.find('$', parenthesis_index); + auto closing_parenthesis_pos = include_field_exp.find(')', parenthesis_index); + auto colon_pos = include_field_exp.find(':', index); + size_t comma_pos; + std::vector nested_ref_include_fields_vec; + if (nested_include_pos < closing_parenthesis_pos) { + // Nested reference include. + // "... $product_variants(title, $inventory(qty:merge) as inventory :nest) as variants ..." + do { + ref_fields += include_field_exp.substr(index, nested_include_pos - index); + StringUtils::trim(ref_fields); + index = nested_include_pos; + std::string nested_include_field_exp; + auto split_op = StringUtils::split_reference_include_fields(include_field_exp, index, + nested_include_field_exp); + if (!split_op.ok()) { + return split_op; + } + + auto parse_op = parse_nested_include(nested_include_field_exp, + ref_include_coll_names == nullptr ? nullptr : ref_include_coll_names->nested_include, + nested_ref_include_fields_vec); + if (!parse_op.ok()) { + return parse_op; + } + + nested_include_pos = include_field_exp.find('$', index); + closing_parenthesis_pos = include_field_exp.find(')', index); + colon_pos = include_field_exp.find(':', index); + comma_pos = include_field_exp.find(',', index); + index = std::min(std::min(closing_parenthesis_pos, colon_pos), comma_pos) + 1; + } while(index < include_field_exp.size() && nested_include_pos < closing_parenthesis_pos); + } + + // ... $inventory(qty:merge) as inventory ... + if (colon_pos < closing_parenthesis_pos) { // Merge strategy is specified. + auto include_strategy = include_field_exp.substr(colon_pos + 1, closing_parenthesis_pos - colon_pos - 1); + StringUtils::trim(include_strategy); + nest_ref_doc = include_strategy == ref_include::nest; + + if (index < colon_pos) { + ref_fields += include_field_exp.substr(index, colon_pos - index); + } + } else if (index < closing_parenthesis_pos) { + ref_fields += include_field_exp.substr(index, closing_parenthesis_pos - index); + } + StringUtils::trim(ref_fields); + + auto as_pos = include_field_exp.find(" as ", index); + comma_pos = include_field_exp.find(',', index); + if (as_pos != std::string::npos && as_pos < comma_pos) { + ref_alias = include_field_exp.substr(as_pos + 4, comma_pos - as_pos - 4); + } + + // For an alias `foo`, + // In case of "merge" reference doc, we need append `foo.` to all the top level keys of reference doc. + // In case of "nest" reference doc, `foo` becomes the key with reference doc as value. + ref_alias = !ref_alias.empty() ? (StringUtils::trim(ref_alias) + (nest_ref_doc ? "" : ".")) : ""; + + ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, nest_ref_doc}); + ref_include_fields_vec.back().nested_join_includes = std::move(nested_ref_include_fields_vec); + + // Referenced collection in filter_by is already mentioned in ref_include_fields. + if (ref_include_coll_names != nullptr) { + ref_include_coll_names->collection_names.erase(ref_collection_name); + } + if (comma_pos == std::string::npos) { + break; + } + index = comma_pos + 1; + } + + return Option(true); +} + +void CollectionManager::_initialize_ref_include_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& ref_include_fields_vec) { + ref_include_collection_names_t* ref_include_coll_names = nullptr; + CollectionManager::_get_reference_collection_names(filter_query, ref_include_coll_names); + std::unique_ptr guard(ref_include_coll_names); std::vector result_include_fields_vec; auto wildcard_include_all = true; @@ -995,6 +1089,12 @@ void initialize_ref_include_fields_vec(const std::string& filter_query, std::vec continue; } + // Nested reference include. + if (include_field_exp.find('$', 1) != std::string::npos) { + auto parse_op = parse_nested_include(include_field_exp, ref_include_coll_names, ref_include_fields_vec); + continue; + } + // Format: $ref_collection_name(field_1, field_2: include_strategy) as ref_alias auto as_pos = include_field_exp.find(" as "); auto ref_include = include_field_exp.substr(0, as_pos); @@ -1031,13 +1131,24 @@ void initialize_ref_include_fields_vec(const std::string& filter_query, std::vec continue; } - // Referenced collection in filter_query is already mentioned in ref_include_fields. - reference_collection_names.erase(reference_collection_name); + // Referenced collection in filter_by is already mentioned in ref_include_fields. + if (ref_include_coll_names != nullptr) { + ref_include_coll_names->collection_names.erase(reference_collection_name); + } } // Get all the fields of the referenced collection in the filter but not mentioned in include_fields. - for (const auto &reference_collection_name: reference_collection_names) { - ref_include_fields_vec.emplace_back(ref_include_fields{reference_collection_name, "", "", true}); + auto ref_includes = std::ref(ref_include_fields_vec); + while (ref_include_coll_names != nullptr) { + for (const auto &reference_collection_name: ref_include_coll_names->collection_names) { + ref_includes.get().emplace_back(ref_include_fields{reference_collection_name, "", "", true}); + } + + ref_include_coll_names = ref_include_coll_names->nested_include; + if (ref_includes.get().empty()) { + break; + } + ref_includes = std::ref(ref_includes.get().front().nested_join_includes); } // Since no field of the collection is mentioned in include_fields, get all the fields. @@ -1431,7 +1542,7 @@ Option CollectionManager::do_search(std::map& re per_page = 0; } - initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + _initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); include_fields.insert(include_fields_vec.begin(), include_fields_vec.end()); exclude_fields.insert(exclude_fields_vec.begin(), exclude_fields_vec.end()); diff --git a/src/string_utils.cpp b/src/string_utils.cpp index 109c5d09..cd74e96a 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -492,49 +492,87 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, return Option(true); } +Option StringUtils::split_reference_include_fields(const std::string& include_fields, + size_t& index, + std::string& token) { + auto ref_include_error = Option(400, "Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`."); + auto const& size = include_fields.size(); + size_t start_index = index; + while(++index < size && include_fields[index] != '(') {} + + if (index >= size) { + return ref_include_error; + } + + // In case of nested join, the reference include field could have parenthesis inside it. + int parenthesis_count = 1; + while (++index < size && parenthesis_count > 0) { + if (include_fields[index] == '(') { + parenthesis_count++; + } else if (include_fields[index] == ')') { + parenthesis_count--; + } + } + + if (parenthesis_count != 0) { + return ref_include_error; + } + + // In case of nested reference include, we might end up with one of the following scenarios: + // $ref_include( $nested_ref_include(foo :merge)as nest ) as ref + // ...^ + // $ref_include( $nested_ref_include(foo :merge)as nest, bar ) as ref + // ...^ + // $ref_include( $nested_ref_include(foo :merge)as nest :merge ) as ref + // ...^ + auto closing_parenthesis_pos = include_fields.find(')', index); + auto comma_pos = include_fields.find(',', index); + auto colon_pos = include_fields.find(':', index); + auto alias_start_pos = include_fields.find(" as ", index); + auto alias_end_pos = std::min(std::min(closing_parenthesis_pos, comma_pos), colon_pos); + std::string alias; + if (alias_start_pos != std::string::npos && alias_start_pos < alias_end_pos) { + alias = include_fields.substr(alias_start_pos, alias_end_pos - alias_start_pos); + } + + token = include_fields.substr(start_index, index - start_index) + " " + trim(alias); + trim(token); + + index = alias_end_pos; + return Option(true); +} + Option StringUtils::split_include_fields(const std::string& include_fields, std::vector& tokens) { - size_t start = 0, end = 0, size = include_fields.size(); - std::string include_field; - - while (true) { - auto range_pos = include_fields.find('$', start); - auto comma_pos = include_fields.find(',', start); - - if (range_pos == std::string::npos && comma_pos == std::string::npos) { - if (start < size - 1) { - include_field = include_fields.substr(start, size - start); - include_field = trim(include_field); - if (!include_field.empty()) { - tokens.push_back(include_field); - } + std::string token; + auto const& size = include_fields.size(); + for (size_t i = 0; i < size;) { + auto c = include_fields[i]; + if (c == ' ') { + i++; + continue; + } else if (c == '$') { // Reference include + std::string ref_include_token; + auto split_op = split_reference_include_fields(include_fields, i, ref_include_token); + if (!split_op.ok()) { + return split_op; } + + tokens.push_back(ref_include_token); + continue; + } + + auto comma_pos = include_fields.find(',', i); + token = include_fields.substr(i, (comma_pos == std::string::npos ? size : comma_pos) - i); + trim(token); + if (!token.empty()) { + tokens.push_back(token); + } + + if (comma_pos == std::string::npos) { break; - } else if (range_pos < comma_pos) { - end = include_fields.find(')', range_pos); - if (end == std::string::npos || end < include_fields.find('(', range_pos)) { - return Option(400, "Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`."); - } - - include_field = include_fields.substr(range_pos, (end - range_pos) + 1); - - comma_pos = include_fields.find(',', end); - auto as_pos = include_fields.find(" as ", end); - if (as_pos != std::string::npos && as_pos < comma_pos) { - auto alias = include_fields.substr(as_pos, (comma_pos - as_pos)); - end += alias.size() + 1; - include_field += (" " + trim(alias)); - } - } else { - end = comma_pos; - include_field = include_fields.substr(start, end - start); } - - include_field = trim(include_field); - if (!include_field.empty()) { - tokens.push_back(include_field); - } - - start = end + 1; + i = comma_pos + 1; + token.clear(); } return Option(true); diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 1bf90dfd..e03c4e88 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1368,13 +1368,19 @@ TEST_F(CollectionManagerTest, CloneCollection) { TEST_F(CollectionManagerTest, GetReferenceCollectionNames) { std::string filter_query = ""; - std::set reference_collection_names; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_TRUE(reference_collection_names.empty()); + CollectionManager::ref_include_collection_names_t* ref_includes = nullptr; + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_TRUE(ref_includes->collection_names.empty()); + ASSERT_EQ(nullptr, ref_includes->nested_include); + delete ref_includes; + ref_includes = nullptr; filter_query = "foo"; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_TRUE(reference_collection_names.empty()); + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_TRUE(ref_includes->collection_names.empty()); + ASSERT_EQ(nullptr, ref_includes->nested_include); + delete ref_includes; + ref_includes = nullptr; nlohmann::json schema = R"({ "name": "coll1", @@ -1400,40 +1406,153 @@ TEST_F(CollectionManagerTest, GetReferenceCollectionNames) { ASSERT_EQ(search_op_bool.error(), "Could not parse the filter query."); filter_query = "foo:bar"; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_TRUE(reference_collection_names.empty()); + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_TRUE(ref_includes->collection_names.empty()); + ASSERT_EQ(nullptr, ref_includes->nested_include); + delete ref_includes; + ref_includes = nullptr; filter_query = "$foo(bar:baz) & age: <5"; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_TRUE(reference_collection_names.empty()); + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_TRUE(ref_includes->collection_names.empty()); + ASSERT_EQ(nullptr, ref_includes->nested_include); + delete ref_includes; + ref_includes = nullptr; filter_query = "$foo(bar:baz)"; - std::vector result = {"foo"}; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_EQ(1, reference_collection_names.size()); - for (const auto &item: result) { - ASSERT_EQ(1, reference_collection_names.count(item)); - } - reference_collection_names.clear(); + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_EQ(1, ref_includes->collection_names.size()); + ASSERT_EQ(1, ref_includes->collection_names.count("foo")); + ASSERT_EQ(nullptr, ref_includes->nested_include); + delete ref_includes; + ref_includes = nullptr; filter_query = "((age: <5 || age: >10) && category:= [shoes]) &&" " $Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; - result = {"Customers"}; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_EQ(1, reference_collection_names.size()); - for (const auto &item: result) { - ASSERT_EQ(1, reference_collection_names.count(item)); - } - reference_collection_names.clear(); + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_EQ(1, ref_includes->collection_names.size()); + ASSERT_EQ(1, ref_includes->collection_names.count("Customers")); + ASSERT_EQ(nullptr, ref_includes->nested_include); + delete ref_includes; + ref_includes = nullptr; filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; - result = {"product_variants", "inventory", "retailers"}; - CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); - ASSERT_EQ(3, reference_collection_names.size()); - for (const auto &item: result) { - ASSERT_EQ(1, reference_collection_names.count(item)); - } - reference_collection_names.clear(); + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_EQ(1, ref_includes->collection_names.size()); + ASSERT_EQ(1, ref_includes->collection_names.count("product_variants")); + ASSERT_EQ(1, ref_includes->nested_include->collection_names.size()); + ASSERT_EQ(1, ref_includes->nested_include->collection_names.count("inventory")); + ASSERT_EQ(1, ref_includes->nested_include->nested_include->collection_names.size()); + ASSERT_EQ(1, ref_includes->nested_include->nested_include->collection_names.count("retailers")); + ASSERT_EQ(nullptr, ref_includes->nested_include->nested_include->nested_include); + delete ref_includes; + ref_includes = nullptr; + + filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; + CollectionManager::_get_reference_collection_names(filter_query, ref_includes); + ASSERT_EQ(1, ref_includes->collection_names.size()); + ASSERT_EQ(1, ref_includes->collection_names.count("product_variants")); + ASSERT_EQ(2, ref_includes->nested_include->collection_names.size()); + ASSERT_EQ(1, ref_includes->nested_include->collection_names.count("inventory")); + ASSERT_EQ(1, ref_includes->nested_include->collection_names.count("retailers")); + ASSERT_EQ(nullptr, ref_includes->nested_include->nested_include); + delete ref_includes; + ref_includes = nullptr; +} + +TEST_F(CollectionManagerTest, InitializeRefIncludeFields) { + std::string filter_query = ""; + std::vector include_fields_vec; + std::vector ref_include_fields_vec; + CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + ASSERT_TRUE(ref_include_fields_vec.empty()); + + filter_query = "$foo(bar:baz)"; + CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("foo", ref_include_fields_vec[0].collection_name); + ASSERT_TRUE(ref_include_fields_vec[0].fields.empty()); + ASSERT_TRUE(ref_include_fields_vec[0].alias.empty()); + ASSERT_TRUE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); + ref_include_fields_vec.clear(); + + filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; + include_fields_vec = {"$Customers(product_price: merge) as customers"}; + CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); + ASSERT_EQ("customers.", ref_include_fields_vec[0].alias); + ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); + ref_include_fields_vec.clear(); + + filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; + include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory: nest) as variants"}; + CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("title,", ref_include_fields_vec[0].fields); + ASSERT_EQ("variants", ref_include_fields_vec[0].alias); + ASSERT_TRUE(ref_include_fields_vec[0].nest_ref_doc); + + auto nested_join_includes = ref_include_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_join_includes[0].collection_name); + ASSERT_EQ("qty", nested_join_includes[0].fields); + ASSERT_EQ("inventory.", nested_join_includes[0].alias); + ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + + nested_join_includes = ref_include_fields_vec[0].nested_join_includes[0].nested_join_includes; + ASSERT_EQ("retailers", nested_join_includes[0].collection_name); + ASSERT_TRUE(nested_join_includes[0].fields.empty()); + ASSERT_TRUE(nested_join_includes[0].alias.empty()); + ASSERT_TRUE(nested_join_includes[0].nest_ref_doc); + ref_include_fields_vec.clear(); + + filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; + include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory," + " $retailers(title): merge) as variants"}; + CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("title,", ref_include_fields_vec[0].fields); + ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); + ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + + nested_join_includes = ref_include_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_join_includes[0].collection_name); + ASSERT_EQ("qty", nested_join_includes[0].fields); + ASSERT_EQ("inventory.", nested_join_includes[0].alias); + ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + + ASSERT_EQ("retailers", nested_join_includes[1].collection_name); + ASSERT_EQ("title", nested_join_includes[1].fields); + ASSERT_TRUE(nested_join_includes[1].alias.empty()); + ASSERT_TRUE(nested_join_includes[1].nest_ref_doc); + ref_include_fields_vec.clear(); + + filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; + include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory, description," + " $retailers(title), foo: merge) as variants"}; + CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("title, description, foo", ref_include_fields_vec[0].fields); + ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); + ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + + nested_join_includes = ref_include_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_join_includes[0].collection_name); + ASSERT_EQ("qty", nested_join_includes[0].fields); + ASSERT_EQ("inventory.", nested_join_includes[0].alias); + ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + + ASSERT_EQ("retailers", nested_join_includes[1].collection_name); + ASSERT_EQ("title", nested_join_includes[1].fields); + ASSERT_TRUE(nested_join_includes[1].alias.empty()); + ASSERT_TRUE(nested_join_includes[1].nest_ref_doc); + ref_include_fields_vec.clear(); } TEST_F(CollectionManagerTest, ReferencedInBacklog) { diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index bd33a169..44e39bb9 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -412,10 +412,14 @@ TEST(StringUtilsTest, SplitIncludeFields) { std::string include_fields; std::vector tokens; - include_fields = "id, title, count"; + include_fields = " id, title , count "; tokens = {"id", "title", "count"}; splitIncludeTestHelper(include_fields, tokens); + include_fields = "id, $Collection(title, pref*),count"; + tokens = {"id", "$Collection(title, pref*)", "count"}; + splitIncludeTestHelper(include_fields, tokens); + include_fields = "id, $Collection(title, pref*), count, "; tokens = {"id", "$Collection(title, pref*)", "count"}; splitIncludeTestHelper(include_fields, tokens); @@ -427,4 +431,28 @@ TEST(StringUtilsTest, SplitIncludeFields) { include_fields = "id, $Collection(title, pref*) as coll , count, "; tokens = {"id", "$Collection(title, pref*) as coll", "count"}; splitIncludeTestHelper(include_fields, tokens); + + include_fields = "$Collection(title, pref*: merge) as coll"; + tokens = {"$Collection(title, pref*: merge) as coll"}; + splitIncludeTestHelper(include_fields, tokens); + + include_fields = "$product_variants(id,$inventory(qty,sku,$retailer(id,title: merge) as retailer_info)) as variants"; + tokens = {"$product_variants(id,$inventory(qty,sku,$retailer(id,title: merge) as retailer_info)) as variants"}; + splitIncludeTestHelper(include_fields, tokens); +} + +TEST(StringUtilsTest, SplitReferenceIncludeFields) { + std::string include_fields = "$retailer(id,title: merge) as retailer_info:merge) as variants, foo", token; + size_t index = 0; + auto tokenize_op = StringUtils::split_reference_include_fields(include_fields, index, token); + ASSERT_TRUE(tokenize_op.ok()); + ASSERT_EQ("$retailer(id,title: merge) as retailer_info", token); + ASSERT_EQ(":merge) as variants, foo", include_fields.substr(index)); + + include_fields = "$inventory(qty,sku,$retailer(id,title: merge) as retailer_info) as inventory) as variants, foo"; + index = 0; + tokenize_op = StringUtils::split_reference_include_fields(include_fields, index, token); + ASSERT_TRUE(tokenize_op.ok()); + ASSERT_EQ("$inventory(qty,sku,$retailer(id,title: merge) as retailer_info) as inventory", token); + ASSERT_EQ(") as variants, foo", include_fields.substr(index)); } From feb1041f656d61490bd81556d40f08943b7206ed Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 26 Dec 2023 17:09:37 +0530 Subject: [PATCH 3/5] Add `nest_array` reference include strategy. --- include/collection.h | 2 +- include/collection_manager.h | 6 +-- include/field.h | 24 ++++++++++-- src/collection.cpp | 18 ++++----- src/collection_manager.cpp | 52 ++++++++++++++++++------- test/collection_join_test.cpp | 28 +++++++++++++ test/collection_manager_test.cpp | 67 ++++++++++++++++++++++++-------- 7 files changed, 150 insertions(+), 47 deletions(-) diff --git a/include/collection.h b/include/collection.h index bb0f581a..dfaa4f38 100644 --- a/include/collection.h +++ b/include/collection.h @@ -446,7 +446,7 @@ public: const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, const std::string& error_prefix, const bool& is_reference_array, - const bool& nest_ref_doc); + const ref_include::strategy_enum& strategy); static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", diff --git a/include/collection_manager.h b/include/collection_manager.h index 874b8055..889391f5 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -222,9 +222,9 @@ public: ref_include_collection_names_t*& reference_collection_names); // Separate out the reference includes into `ref_include_fields_vec`. - static void _initialize_ref_include_fields_vec(const std::string& filter_query, - std::vector& include_fields_vec, - std::vector& ref_include_fields_vec); + static Option _initialize_ref_include_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& ref_include_fields_vec); void add_referenced_in_backlog(const std::string& collection_name, reference_pair&& pair); diff --git a/include/field.h b/include/field.h index d563404d..b06e07c3 100644 --- a/include/field.h +++ b/include/field.h @@ -510,18 +510,34 @@ namespace sort_field_const { } namespace ref_include { - static const std::string merge = "merge"; - static const std::string nest = "nest"; + static const std::string merge_string = "merge"; + static const std::string nest_string = "nest"; + static const std::string nest_array_string = "nest_array"; + + enum strategy_enum {merge = 0, nest, nest_array}; + + static Option string_to_enum(const std::string& strategy) { + if (strategy == merge_string) { + return Option(merge); + } else if (strategy == nest_string) { + return Option(nest); + } else if (strategy == nest_array_string) { + return Option(nest_array); + } + + return Option(400, "Unknown include strategy `" + strategy + "`. " + "Valid options are `merge`, `nest`, `nest_array`."); + } } struct ref_include_fields { std::string collection_name; std::string fields; std::string alias; - bool nest_ref_doc = true; + ref_include::strategy_enum strategy = ref_include::nest; // In case we have nested join. - std::vector nested_join_includes; + std::vector nested_join_includes = {}; }; struct hnsw_index_t; diff --git a/src/collection.cpp b/src/collection.cpp index e2564a40..83131a6b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -4870,9 +4870,9 @@ Option Collection::include_references(nlohmann::json& doc, const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, const std::string& error_prefix, const bool& is_reference_array, - const bool& nest_ref_doc) { + const ref_include::strategy_enum& strategy) { // One-to-one relation. - if (!is_reference_array && references.count == 1) { + if (strategy != ref_include::nest_array && !is_reference_array && references.count == 1) { auto ref_doc_seq_id = references.docs[0]; nlohmann::json ref_doc; @@ -4893,7 +4893,7 @@ Option Collection::include_references(nlohmann::json& doc, return Option(true); } - if (nest_ref_doc) { + if (strategy == ref_include::nest) { auto key = alias.empty() ? ref_collection_name : alias; doc[key] = ref_doc; } else { @@ -4931,7 +4931,7 @@ Option Collection::include_references(nlohmann::json& doc, continue; } - if (nest_ref_doc) { + if (strategy == ref_include::nest || strategy == ref_include::nest_array) { auto key = alias.empty() ? ref_collection_name : alias; if (doc.contains(key) && !doc[key].is_array()) { return Option(400, "Could not include the reference document of `" + ref_collection_name + @@ -5116,7 +5116,7 @@ Option Collection::prune_doc(nlohmann::json& doc, reference_filter_results.at(ref_collection_name), ref_include_fields_full, ref_exclude_fields_full, error_prefix, ref_collection->get_schema().at(field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); } else if (doc_has_reference) { auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection->name); if (!get_reference_field_op.ok()) { @@ -5151,7 +5151,7 @@ Option Collection::prune_doc(nlohmann::json& doc, include_references_op = include_references(doc[keys[0]][i], ref_include.collection_name, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, - false, ref_include.nest_ref_doc); + false, ref_include.strategy); if (!include_references_op.ok()) { return include_references_op; } @@ -5167,7 +5167,7 @@ Option Collection::prune_doc(nlohmann::json& doc, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, collection->search_schema.at(field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); result.docs = nullptr; } } else { @@ -5181,7 +5181,7 @@ Option Collection::prune_doc(nlohmann::json& doc, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, collection->search_schema.at(field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); result.docs = nullptr; } } else if (joined_coll_has_reference) { @@ -5217,7 +5217,7 @@ Option Collection::prune_doc(nlohmann::json& doc, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, joined_collection->get_schema().at(reference_field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); result.docs = nullptr; } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 7f535745..8151c095 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1029,10 +1029,17 @@ Option parse_nested_include(const std::string& include_field_exp, } // ... $inventory(qty:merge) as inventory ... + auto include_strategy = ref_include::nest_string; + auto strategy_enum = ref_include::nest; if (colon_pos < closing_parenthesis_pos) { // Merge strategy is specified. - auto include_strategy = include_field_exp.substr(colon_pos + 1, closing_parenthesis_pos - colon_pos - 1); + include_strategy = include_field_exp.substr(colon_pos + 1, closing_parenthesis_pos - colon_pos - 1); StringUtils::trim(include_strategy); - nest_ref_doc = include_strategy == ref_include::nest; + + auto string_to_enum_op = ref_include::string_to_enum(include_strategy); + if (!string_to_enum_op.ok()) { + return Option(400, "Error parsing `" + include_field_exp + "`: " + string_to_enum_op.error()); + } + strategy_enum = string_to_enum_op.get(); if (index < colon_pos) { ref_fields += include_field_exp.substr(index, colon_pos - index); @@ -1051,9 +1058,11 @@ Option parse_nested_include(const std::string& include_field_exp, // For an alias `foo`, // In case of "merge" reference doc, we need append `foo.` to all the top level keys of reference doc. // In case of "nest" reference doc, `foo` becomes the key with reference doc as value. + nest_ref_doc = strategy_enum == ref_include::nest || strategy_enum == ref_include::nest_array; ref_alias = !ref_alias.empty() ? (StringUtils::trim(ref_alias) + (nest_ref_doc ? "" : ".")) : ""; - ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, nest_ref_doc}); + ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, + strategy_enum}); ref_include_fields_vec.back().nested_join_includes = std::move(nested_ref_include_fields_vec); // Referenced collection in filter_by is already mentioned in ref_include_fields. @@ -1069,16 +1078,16 @@ Option parse_nested_include(const std::string& include_field_exp, return Option(true); } -void CollectionManager::_initialize_ref_include_fields_vec(const std::string& filter_query, - std::vector& include_fields_vec, - std::vector& ref_include_fields_vec) { +Option CollectionManager::_initialize_ref_include_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& ref_include_fields_vec) { ref_include_collection_names_t* ref_include_coll_names = nullptr; CollectionManager::_get_reference_collection_names(filter_query, ref_include_coll_names); std::unique_ptr guard(ref_include_coll_names); std::vector result_include_fields_vec; auto wildcard_include_all = true; - for (auto include_field_exp: include_fields_vec) { + for (auto const& include_field_exp: include_fields_vec) { if (include_field_exp[0] != '$') { if (include_field_exp == "*") { continue; @@ -1092,6 +1101,9 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi // Nested reference include. if (include_field_exp.find('$', 1) != std::string::npos) { auto parse_op = parse_nested_include(include_field_exp, ref_include_coll_names, ref_include_fields_vec); + if (!parse_op.ok()) { + return parse_op; + } continue; } @@ -1105,20 +1117,29 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi auto ref_collection_name = ref_include.substr(1, parenthesis_index - 1); auto ref_fields = ref_include.substr(parenthesis_index + 1, ref_include.size() - parenthesis_index - 2); - auto nest_ref_doc = true; + auto include_strategy = ref_include::nest_string; + auto strategy_enum = ref_include::nest; auto colon_pos = ref_fields.find(':'); if (colon_pos != std::string::npos) { - auto include_strategy = ref_fields.substr(colon_pos + 1, ref_fields.size() - colon_pos - 1); + include_strategy = ref_fields.substr(colon_pos + 1, ref_fields.size() - colon_pos - 1); StringUtils::trim(include_strategy); - nest_ref_doc = include_strategy == ref_include::nest; + + auto string_to_enum_op = ref_include::string_to_enum(include_strategy); + if (!string_to_enum_op.ok()) { + return Option(400, "Error parsing `" + include_field_exp + "`: " + string_to_enum_op.error()); + } + strategy_enum = string_to_enum_op.get(); + ref_fields = ref_fields.substr(0, colon_pos); } // For an alias `foo`, // In case of "merge" reference doc, we need append `foo.` to all the top level keys of reference doc. // In case of "nest" reference doc, `foo` becomes the key with reference doc as value. + auto const& nest_ref_doc = strategy_enum == ref_include::nest || strategy_enum == ref_include::nest_array; auto ref_alias = !alias.empty() ? (StringUtils::trim(alias) + (nest_ref_doc ? "" : ".")) : ""; - ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, nest_ref_doc}); + ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, + strategy_enum}); auto open_paren_pos = include_field_exp.find('('); if (open_paren_pos == std::string::npos) { @@ -1141,7 +1162,7 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi auto ref_includes = std::ref(ref_include_fields_vec); while (ref_include_coll_names != nullptr) { for (const auto &reference_collection_name: ref_include_coll_names->collection_names) { - ref_includes.get().emplace_back(ref_include_fields{reference_collection_name, "", "", true}); + ref_includes.get().emplace_back(ref_include_fields{reference_collection_name, "", "", ref_include::nest}); } ref_include_coll_names = ref_include_coll_names->nested_include; @@ -1157,6 +1178,8 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi } include_fields_vec = std::move(result_include_fields_vec); + + return Option(true); } Option CollectionManager::do_search(std::map& req_params, @@ -1542,7 +1565,10 @@ Option CollectionManager::do_search(std::map& re per_page = 0; } - _initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + auto initialize_op = _initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + if (!initialize_op.ok()) { + return initialize_op; + } include_fields.insert(include_fields_vec.begin(), include_fields_vec.end()); exclude_fields.insert(exclude_fields_vec.begin(), exclude_fields_vec.end()); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index a7df7d27..f254a00a 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -1965,6 +1965,34 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"].count("product_id")); ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"].count("product_price")); + req_params = { + {"collection", "Products"}, + {"q", "*"}, + {"query_by", "product_name"}, + {"filter_by", "$Customers(customer_id:=customer_a && product_price:<100)"}, + {"include_fields", "*, $Customers(*:nest_array) as Customers"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + // No fields are mentioned in `include_fields`, should include all fields of Products and Customers by default. + ASSERT_EQ(7, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_description")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("embedding")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("rating")); + // In nest_array strategy we return the referenced docs in an array. + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("customer_id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("customer_name")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("product_id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("product_price")); + req_params = { {"collection", "Products"}, {"q", "*"}, diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index e03c4e88..a4178dc2 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1464,94 +1464,127 @@ TEST_F(CollectionManagerTest, InitializeRefIncludeFields) { std::string filter_query = ""; std::vector include_fields_vec; std::vector ref_include_fields_vec; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + auto initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_TRUE(ref_include_fields_vec.empty()); filter_query = "$foo(bar:baz)"; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("foo", ref_include_fields_vec[0].collection_name); ASSERT_TRUE(ref_include_fields_vec[0].fields.empty()); ASSERT_TRUE(ref_include_fields_vec[0].alias.empty()); - ASSERT_TRUE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); ref_include_fields_vec.clear(); + filter_query = ""; + include_fields_vec = {"$Customers(product_price: foo) as customers"}; + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_FALSE(initialize_op.ok()); + ASSERT_EQ("Error parsing `$Customers(product_price: foo) as customers`: Unknown include strategy `foo`. " + "Valid options are `merge`, `nest`, `nest_array`.", initialize_op.error()); + filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; include_fields_vec = {"$Customers(product_price: merge) as customers"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); ASSERT_EQ("customers.", ref_include_fields_vec[0].alias); - ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); + ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); + ref_include_fields_vec.clear(); + + filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; + include_fields_vec = {"$Customers(product_price: nest_array) as customers"}; + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); + ASSERT_EQ("customers", ref_include_fields_vec[0].alias); + ASSERT_EQ(ref_include::nest_array, ref_include_fields_vec[0].strategy); ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); ref_include_fields_vec.clear(); filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory: nest) as variants"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); ASSERT_EQ("title,", ref_include_fields_vec[0].fields); ASSERT_EQ("variants", ref_include_fields_vec[0].alias); - ASSERT_TRUE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); auto nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); nested_join_includes = ref_include_fields_vec[0].nested_join_includes[0].nested_join_includes; ASSERT_EQ("retailers", nested_join_includes[0].collection_name); ASSERT_TRUE(nested_join_includes[0].fields.empty()); ASSERT_TRUE(nested_join_includes[0].alias.empty()); - ASSERT_TRUE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); ref_include_fields_vec.clear(); filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory," " $retailers(title): merge) as variants"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); ASSERT_EQ("title,", ref_include_fields_vec[0].fields); ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); - ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); ASSERT_EQ("retailers", nested_join_includes[1].collection_name); ASSERT_EQ("title", nested_join_includes[1].fields); ASSERT_TRUE(nested_join_includes[1].alias.empty()); - ASSERT_TRUE(nested_join_includes[1].nest_ref_doc); + ASSERT_EQ(ref_include::nest, nested_join_includes[1].strategy); ref_include_fields_vec.clear(); filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory, description," " $retailers(title), foo: merge) as variants"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); ASSERT_EQ("title, description, foo", ref_include_fields_vec[0].fields); ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); - ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); ASSERT_EQ("retailers", nested_join_includes[1].collection_name); ASSERT_EQ("title", nested_join_includes[1].fields); ASSERT_TRUE(nested_join_includes[1].alias.empty()); - ASSERT_TRUE(nested_join_includes[1].nest_ref_doc); + ASSERT_EQ(ref_include::nest, nested_join_includes[1].strategy); ref_include_fields_vec.clear(); } From a11d0f7a652fa5bf2e932f698bc9ab5ca7b7356f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 28 Dec 2023 15:12:45 +0530 Subject: [PATCH 4/5] Fix `parse_nested_include`. --- src/collection_manager.cpp | 1 + test/collection_manager_test.cpp | 26 +++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 8151c095..1758a69e 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1049,6 +1049,7 @@ Option parse_nested_include(const std::string& include_field_exp, } StringUtils::trim(ref_fields); + index = closing_parenthesis_pos; auto as_pos = include_field_exp.find(" as ", index); comma_pos = include_field_exp.find(',', index); if (as_pos != std::string::npos && as_pos < comma_pos) { diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index a4178dc2..4a7da68a 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1515,6 +1515,30 @@ TEST_F(CollectionManagerTest, InitializeRefIncludeFields) { ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); ref_include_fields_vec.clear(); + filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; + include_fields_vec = {"$product_variants(id,$inventory(qty,sku,$retailers(id,title)))"}; + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("id,", ref_include_fields_vec[0].fields); + ASSERT_TRUE(ref_include_fields_vec[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); + + auto nested_join_includes = ref_include_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_join_includes[0].collection_name); + ASSERT_EQ("qty,sku,", nested_join_includes[0].fields); + ASSERT_TRUE(nested_join_includes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, nested_join_includes[0].strategy); + + nested_join_includes = ref_include_fields_vec[0].nested_join_includes[0].nested_join_includes; + ASSERT_EQ("retailers", nested_join_includes[0].collection_name); + ASSERT_EQ("id,title", nested_join_includes[0].fields); + ASSERT_TRUE(nested_join_includes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); + ref_include_fields_vec.clear(); + filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory: nest) as variants"}; initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, @@ -1526,7 +1550,7 @@ TEST_F(CollectionManagerTest, InitializeRefIncludeFields) { ASSERT_EQ("variants", ref_include_fields_vec[0].alias); ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); - auto nested_join_includes = ref_include_fields_vec[0].nested_join_includes; + nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); From 08c0a087a1d29b03ff737e2b5990c69c66b11dd0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 2 Jan 2024 14:14:21 +0530 Subject: [PATCH 5/5] Support nested join. --- include/collection.h | 23 +- include/collection_manager.h | 9 +- include/field.h | 7 +- include/filter_result_iterator.h | 71 +++- include/index.h | 7 +- include/string_utils.h | 8 +- src/collection.cpp | 476 +++++++++++++------------- src/collection_manager.cpp | 163 ++++++--- src/filter_result_iterator.cpp | 46 +-- src/index.cpp | 473 ++++++++++++++++++++++++-- src/string_utils.cpp | 45 +-- test/collection_join_test.cpp | 559 ++++++++++++++++++++++++++++++- test/collection_manager_test.cpp | 295 ++++++++++------ test/string_utils_test.cpp | 48 ++- 14 files changed, 1732 insertions(+), 498 deletions(-) diff --git a/include/collection.h b/include/collection.h index dfaa4f38..d4253f80 100644 --- a/include/collection.h +++ b/include/collection.h @@ -438,22 +438,23 @@ public: static void remove_reference_helper_fields(nlohmann::json& document); - static Option include_references(nlohmann::json& doc, - const std::string& ref_collection_name, - Collection *const ref_collection, - const std::string& alias, - const reference_filter_result_t& references, - const tsl::htrie_set& ref_include_fields_full, - const tsl::htrie_set& ref_exclude_fields_full, - const std::string& error_prefix, const bool& is_reference_array, - const ref_include::strategy_enum& strategy); + static Option prune_ref_doc(nlohmann::json& doc, + const reference_filter_result_t& references, + const tsl::htrie_set& ref_include_fields_full, + const tsl::htrie_set& ref_exclude_fields_full, + const bool& is_reference_array, + const ref_include_exclude_fields& ref_include_exclude); + + static Option include_references(nlohmann::json& doc, const uint32_t& seq_id, Collection *const collection, + const std::map& reference_filter_results, + const std::vector& ref_include_exclude_fields_vec); static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0, const std::map& reference_filter_results = {}, Collection *const collection = nullptr, const uint32_t& seq_id = 0, - const std::vector& ref_include_fields_vec = {}); + const std::vector& ref_include_exclude_fields_vec = {}); const Index* _get_index() const; @@ -558,7 +559,7 @@ public: const size_t remote_embedding_num_tries = 2, const std::string& stopwords_set="", const std::vector& facet_return_parent = {}, - const std::vector& ref_include_fields_vec = {}, + const std::vector& ref_include_exclude_fields_vec = {}, const std::string& drop_tokens_mode = "right_to_left", const bool prioritize_num_matching_fields = true, const bool group_missing_values = true, diff --git a/include/collection_manager.h b/include/collection_manager.h index 889391f5..44484df1 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -221,10 +221,11 @@ public: static void _get_reference_collection_names(const std::string& filter_query, ref_include_collection_names_t*& reference_collection_names); - // Separate out the reference includes into `ref_include_fields_vec`. - static Option _initialize_ref_include_fields_vec(const std::string& filter_query, - std::vector& include_fields_vec, - std::vector& ref_include_fields_vec); + // Separate out the reference includes and excludes into `ref_include_exclude_fields_vec`. + static Option _initialize_ref_include_exclude_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& exclude_fields_vec, + std::vector& ref_include_exclude_fields_vec); void add_referenced_in_backlog(const std::string& collection_name, reference_pair&& pair); diff --git a/include/field.h b/include/field.h index b06e07c3..20c34427 100644 --- a/include/field.h +++ b/include/field.h @@ -530,14 +530,15 @@ namespace ref_include { } } -struct ref_include_fields { +struct ref_include_exclude_fields { std::string collection_name; - std::string fields; + std::string include_fields; + std::string exclude_fields; std::string alias; ref_include::strategy_enum strategy = ref_include::nest; // In case we have nested join. - std::vector nested_join_includes = {}; + std::vector nested_join_includes = {}; }; struct hnsw_index_t; diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 77da1bd8..1e89b1b7 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -14,8 +14,14 @@ struct filter_node_t; struct reference_filter_result_t { uint32_t count = 0; uint32_t* docs = nullptr; + bool is_reference_array_field = true; - explicit reference_filter_result_t(uint32_t count = 0, uint32_t* docs = nullptr) : count(count), docs(docs) {} + // In case of nested join, references can further have references. + std::map* coll_to_references = nullptr; + + explicit reference_filter_result_t(uint32_t count = 0, uint32_t* docs = nullptr, + bool is_reference_array_field = true) : count(count), docs(docs), + is_reference_array_field(is_reference_array_field) {} reference_filter_result_t(const reference_filter_result_t& obj) { if (&obj == this) { @@ -25,6 +31,9 @@ struct reference_filter_result_t { count = obj.count; docs = new uint32_t[count]; memcpy(docs, obj.docs, count * sizeof(uint32_t)); + is_reference_array_field = obj.is_reference_array_field; + + copy_references(obj, *this); } reference_filter_result_t& operator=(const reference_filter_result_t& obj) noexcept { @@ -35,7 +44,9 @@ struct reference_filter_result_t { count = obj.count; docs = new uint32_t[count]; memcpy(docs, obj.docs, count * sizeof(uint32_t)); + is_reference_array_field = obj.is_reference_array_field; + copy_references(obj, *this); return *this; } @@ -46,26 +57,38 @@ struct reference_filter_result_t { count = obj.count; docs = obj.docs; + coll_to_references = obj.coll_to_references; + is_reference_array_field = obj.is_reference_array_field; + // Set default values in obj. + obj.count = 0; obj.docs = nullptr; + obj.coll_to_references = nullptr; + obj.is_reference_array_field = true; return *this; } ~reference_filter_result_t() { delete[] docs; + delete[] coll_to_references; } + + static void copy_references(const reference_filter_result_t& from, reference_filter_result_t& to); }; struct single_filter_result_t { uint32_t seq_id = 0; // Collection name -> Reference filter result std::map reference_filter_results = {}; + bool is_reference_array_field = true; single_filter_result_t() = default; - single_filter_result_t(uint32_t seq_id, std::map&& reference_filter_results) : - seq_id(seq_id), reference_filter_results(std::move(reference_filter_results)) {} + single_filter_result_t(uint32_t seq_id, std::map&& reference_filter_results, + bool is_reference_array_field = true) : + seq_id(seq_id), reference_filter_results(std::move(reference_filter_results)), + is_reference_array_field(is_reference_array_field) {} single_filter_result_t(const single_filter_result_t& obj) { if (&obj == this) { @@ -73,6 +96,7 @@ struct single_filter_result_t { } seq_id = obj.seq_id; + is_reference_array_field = obj.is_reference_array_field; // Copy every collection's reference. for (const auto &item: obj.reference_filter_results) { @@ -80,6 +104,45 @@ struct single_filter_result_t { reference_filter_results[ref_coll_name] = item.second; } } + + single_filter_result_t(single_filter_result_t&& obj) { + if (&obj == this) { + return; + } + + seq_id = obj.seq_id; + is_reference_array_field = obj.is_reference_array_field; + reference_filter_results = std::move(obj.reference_filter_results); + } + + single_filter_result_t& operator=(const single_filter_result_t& obj) noexcept { + if (&obj == this) { + return *this; + } + + seq_id = obj.seq_id; + is_reference_array_field = obj.is_reference_array_field; + + // Copy every collection's reference. + for (const auto &item: obj.reference_filter_results) { + auto& ref_coll_name = item.first; + reference_filter_results[ref_coll_name] = item.second; + } + + return *this; + } + + single_filter_result_t& operator=(single_filter_result_t&& obj) noexcept { + if (&obj == this) { + return *this; + } + + seq_id = obj.seq_id; + is_reference_array_field = obj.is_reference_array_field; + reference_filter_results = std::move(obj.reference_filter_results); + + return *this; + } }; struct filter_result_t { @@ -127,6 +190,8 @@ struct filter_result_t { docs = obj.docs; coll_to_references = obj.coll_to_references; + // Set default values in obj. + obj.count = 0; obj.docs = nullptr; obj.coll_to_references = nullptr; diff --git a/include/index.h b/include/index.h index db881e87..c767d280 100644 --- a/include/index.h +++ b/include/index.h @@ -786,12 +786,15 @@ public: filter_result_t& filter_result, const std::string& collection_name = "") const; - Option do_reference_filtering_with_lock(filter_node_t* const filter_tree_root, filter_result_t& filter_result, - const std::string& collection_name, + const std::string& ref_collection_name, const std::string& reference_helper_field_name) const; + Option do_filtering_with_reference_ids(const std::string& reference_helper_field_name, + const std::string& ref_collection_name, + filter_result_t&& ref_filter_result) const; + void refresh_schemas(const std::vector& new_fields, const std::vector& del_fields); // the following methods are not synchronized because their parent calls are synchronized or they are const/static diff --git a/include/string_utils.h b/include/string_utils.h index 0568a29e..e3d834e6 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -334,11 +334,11 @@ struct StringUtils { static Option tokenize_filter_query(const std::string& filter_query, std::queue& tokens); - static Option split_include_fields(const std::string& include_fields, std::vector& tokens); + static Option split_include_exclude_fields(const std::string& include_exclude_fields, + std::vector& tokens); static size_t get_occurence_count(const std::string& str, char symbol); - static Option split_reference_include_fields(const std::string& include_fields, - size_t& index, - std::string& token); + static Option split_reference_include_exclude_fields(const std::string& include_fields, + size_t& index, std::string& token); }; diff --git a/src/collection.cpp b/src/collection.cpp index 83131a6b..95563136 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1611,7 +1611,7 @@ Option Collection::search(std::string raw_query, const size_t remote_embedding_num_tries, const std::string& stopwords_set, const std::vector& facet_return_parent, - const std::vector& ref_include_fields_vec, + const std::vector& ref_include_exclude_fields_vec, const std::string& drop_tokens_mode, const bool prioritize_num_matching_fields, const bool group_missing_values, @@ -2547,7 +2547,7 @@ Option Collection::search(std::string raw_query, 0, field_order_kv->reference_filter_results, const_cast(this), get_seq_id_from_key(seq_id_key), - ref_include_fields_vec); + ref_include_exclude_fields_vec); if (!prune_op.ok()) { return Option(prune_op.code(), prune_op.error()); } @@ -4862,15 +4862,23 @@ void Collection::remove_reference_helper_fields(nlohmann::json& document) { } } -Option Collection::include_references(nlohmann::json& doc, - const std::string& ref_collection_name, - Collection *const ref_collection, - const std::string& alias, - const reference_filter_result_t& references, - const tsl::htrie_set& ref_include_fields_full, - const tsl::htrie_set& ref_exclude_fields_full, - const std::string& error_prefix, const bool& is_reference_array, - const ref_include::strategy_enum& strategy) { +Option Collection::prune_ref_doc(nlohmann::json& doc, + const reference_filter_result_t& references, + const tsl::htrie_set& ref_include_fields_full, + const tsl::htrie_set& ref_exclude_fields_full, + const bool& is_reference_array, + const ref_include_exclude_fields& ref_include_exclude) { + auto const& ref_collection_name = ref_include_exclude.collection_name; + auto& cm = CollectionManager::get_instance(); + auto ref_collection = cm.get_collection(ref_collection_name); + if (ref_collection == nullptr) { + return Option(400, "Referenced collection `" + ref_collection_name + "` in `include_fields` not found."); + } + + auto const& alias = ref_include_exclude.alias; + auto const& strategy = ref_include_exclude.strategy; + auto error_prefix = "Referenced collection `" + ref_collection_name + "`: "; + // One-to-one relation. if (strategy != ref_include::nest_array && !is_reference_array && references.count == 1) { auto ref_doc_seq_id = references.docs[0]; @@ -4889,23 +4897,33 @@ Option Collection::include_references(nlohmann::json& doc, return Option(prune_op.code(), error_prefix + prune_op.error()); } - if (ref_doc.empty()) { - return Option(true); + auto const key = alias.empty() ? ref_collection_name : alias; + auto const& nest_ref_doc = (strategy == ref_include::nest); + if (!ref_doc.empty()) { + if (nest_ref_doc) { + doc[key] = ref_doc; + } else { + if (!alias.empty()) { + auto temp_doc = ref_doc; + ref_doc.clear(); + for (const auto &item: temp_doc.items()) { + ref_doc[alias + item.key()] = item.value(); + } + } + doc.update(ref_doc); + } } - if (strategy == ref_include::nest) { - auto key = alias.empty() ? ref_collection_name : alias; - doc[key] = ref_doc; - } else { - if (!alias.empty()) { - auto temp_doc = ref_doc; - ref_doc.clear(); - for (const auto &item: temp_doc.items()) { - ref_doc[alias + item.key()] = item.value(); - } + // Include nested join references. + if (!ref_include_exclude.nested_join_includes.empty() && !references.coll_to_references->empty()) { + auto nested_include_exclude_op = include_references(nest_ref_doc ? doc[key] : doc, ref_doc_seq_id, + ref_collection.get(), references.coll_to_references[0], + ref_include_exclude.nested_join_includes); + if (!nested_include_exclude_op.ok()) { + return nested_include_exclude_op; } - doc.update(ref_doc); } + return Option(true); } @@ -4927,34 +4945,210 @@ Option Collection::include_references(nlohmann::json& doc, return Option(prune_op.code(), error_prefix + prune_op.error()); } - if (ref_doc.empty()) { - continue; - } - - if (strategy == ref_include::nest || strategy == ref_include::nest_array) { - auto key = alias.empty() ? ref_collection_name : alias; - if (doc.contains(key) && !doc[key].is_array()) { - return Option(400, "Could not include the reference document of `" + ref_collection_name + - "` collection. Expected `" + key + "` to be an array. Try " + - (alias.empty() ? "adding an" : "renaming the") + " alias."); - } - - doc[key] += ref_doc; - } else { - for (auto ref_doc_it = ref_doc.begin(); ref_doc_it != ref_doc.end(); ref_doc_it++) { - auto const& ref_doc_key = ref_doc_it.key(); - auto const& doc_key = alias + ref_doc_key; - if (doc.contains(doc_key) && !doc[doc_key].is_array()) { - return Option(400, "Could not include the value of `" + ref_doc_key + - "` key of the reference document of `" + ref_collection_name + - "` collection. Expected `" + doc_key + "` to be an array. Try " + + std::string key; + auto const& nest_ref_doc = (strategy == ref_include::nest || strategy == ref_include::nest_array); + if (!ref_doc.empty()) { + if (nest_ref_doc) { + key = alias.empty() ? ref_collection_name : alias; + if (doc.contains(key) && !doc[key].is_array()) { + return Option(400, "Could not include the reference document of `" + ref_collection_name + + "` collection. Expected `" + key + "` to be an array. Try " + (alias.empty() ? "adding an" : "renaming the") + " alias."); } - // Add the values of ref_doc as JSON array into doc. - doc[doc_key] += ref_doc_it.value(); + doc[key] += ref_doc; + } else { + for (auto ref_doc_it = ref_doc.begin(); ref_doc_it != ref_doc.end(); ref_doc_it++) { + auto const& ref_doc_key = ref_doc_it.key(); + key = alias + ref_doc_key; + if (doc.contains(key) && !doc[key].is_array()) { + return Option(400, "Could not include the value of `" + ref_doc_key + + "` key of the reference document of `" + ref_collection_name + + "` collection. Expected `" + key + "` to be an array. Try " + + (alias.empty() ? "adding an" : "renaming the") + " alias."); + } + + // Add the values of ref_doc as JSON array into doc. + doc[key] += ref_doc_it.value(); + } } } + + // Include nested join references. + if (!ref_include_exclude.nested_join_includes.empty() && + references.coll_to_references != nullptr && !references.coll_to_references->empty()) { + auto nested_include_exclude_op = include_references(nest_ref_doc ? doc[key].at(i) : doc, ref_doc_seq_id, + ref_collection.get(), references.coll_to_references[i], + ref_include_exclude.nested_join_includes); + if (!nested_include_exclude_op.ok()) { + return nested_include_exclude_op; + } + } + } + + return Option(true); +} + +Option Collection::include_references(nlohmann::json& doc, const uint32_t& seq_id, Collection *const collection, + const std::map& reference_filter_results, + const std::vector& ref_include_exclude_fields_vec) { + for (auto const& ref_include_exclude: ref_include_exclude_fields_vec) { + auto const& ref_collection_name = ref_include_exclude.collection_name; + + auto& cm = CollectionManager::get_instance(); + auto ref_collection = cm.get_collection(ref_collection_name); + if (ref_collection == nullptr) { + return Option(400, "Referenced collection `" + ref_collection_name + "` in `include_fields` not found."); + } + + auto const joined_on_ref_collection = reference_filter_results.count(ref_collection_name) > 0, + has_filter_reference = (joined_on_ref_collection && + reference_filter_results.at(ref_collection_name).count > 0); + auto doc_has_reference = false, joined_coll_has_reference = false; + + // Reference include_by without join, check if doc itself contains the reference. + if (!joined_on_ref_collection && collection != nullptr) { + doc_has_reference = ref_collection->is_referenced_in(collection->name); + } + + std::string joined_coll_having_reference; + // Check if the joined collection has a reference. + if (!joined_on_ref_collection && !doc_has_reference) { + for (const auto &reference_filter_result: reference_filter_results) { + joined_coll_has_reference = ref_collection->is_referenced_in(reference_filter_result.first); + if (joined_coll_has_reference) { + joined_coll_having_reference = reference_filter_result.first; + break; + } + } + } + + if (!has_filter_reference && !doc_has_reference && !joined_coll_has_reference) { + continue; + } + + std::vector ref_include_fields_vec, ref_exclude_fields_vec; + StringUtils::split(ref_include_exclude.include_fields, ref_include_fields_vec, ","); + StringUtils::split(ref_include_exclude.exclude_fields, ref_exclude_fields_vec, ","); + + spp::sparse_hash_set ref_include_fields, ref_exclude_fields; + ref_include_fields.insert(ref_include_fields_vec.begin(), ref_include_fields_vec.end()); + ref_exclude_fields.insert(ref_exclude_fields_vec.begin(), ref_exclude_fields_vec.end()); + + tsl::htrie_set ref_include_fields_full, ref_exclude_fields_full; + auto include_exclude_op = ref_collection->populate_include_exclude_fields_lk(ref_include_fields, + ref_exclude_fields, + ref_include_fields_full, + ref_exclude_fields_full); + auto error_prefix = "Referenced collection `" + ref_collection_name + "`: "; + if (!include_exclude_op.ok()) { + return Option(include_exclude_op.code(), error_prefix + include_exclude_op.error()); + } + + Option prune_doc_op = Option(true); + auto const& ref_collection_alias = ref_include_exclude.alias; + if (has_filter_reference) { + auto const& ref_filter_result = reference_filter_results.at(ref_collection_name); + prune_doc_op = prune_ref_doc(doc, ref_filter_result, ref_include_fields_full, ref_exclude_fields_full, + ref_filter_result.is_reference_array_field, ref_include_exclude); + } else if (doc_has_reference) { + auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection->name); + if (!get_reference_field_op.ok()) { + continue; + } + auto const& field_name = get_reference_field_op.get(); + if (collection->search_schema.count(field_name) == 0) { + continue; + } + + if (collection->object_reference_helper_fields.count(field_name) != 0) { + std::vector keys; + StringUtils::split(field_name, keys, "."); + if (!doc.contains(keys[0])) { + return Option(400, "Could not find `" + keys[0] + + "` in the document to include the referenced document."); + } + + if (doc[keys[0]].is_array()) { + for (uint32_t i = 0; i < doc[keys[0]].size(); i++) { + uint32_t ref_doc_id; + auto op = collection->get_object_array_related_id(field_name, seq_id, i, ref_doc_id); + if (!op.ok()) { + if (op.code() == 404) { // field_name is not indexed. + break; + } else { // No reference found for this object. + continue; + } + } + + reference_filter_result_t result(1, new uint32_t[1]{ref_doc_id}); + prune_doc_op = prune_ref_doc(doc[keys[0]][i], result, + ref_include_fields_full, ref_exclude_fields_full, + false, ref_include_exclude); + if (!prune_doc_op.ok()) { + return prune_doc_op; + } + } + } else { + std::vector ids; + auto get_references_op = collection->get_related_ids(field_name, seq_id, ids); + if (!get_references_op.ok()) { + continue; + } + reference_filter_result_t result(ids.size(), &ids[0]); + prune_doc_op = prune_ref_doc(doc[keys[0]], result, ref_include_fields_full, ref_exclude_fields_full, + collection->search_schema.at(field_name).is_array(), ref_include_exclude); + result.docs = nullptr; + } + } else { + std::vector ids; + auto get_references_op = collection->get_related_ids(field_name, seq_id, ids); + if (!get_references_op.ok()) { + continue; + } + reference_filter_result_t result(ids.size(), &ids[0]); + prune_doc_op = prune_ref_doc(doc, result, ref_include_fields_full, ref_exclude_fields_full, + collection->search_schema.at(field_name).is_array(), ref_include_exclude); + result.docs = nullptr; + } + } else if (joined_coll_has_reference) { + auto joined_collection = cm.get_collection(joined_coll_having_reference); + if (joined_collection == nullptr) { + continue; + } + + auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); + if (!reference_field_name_op.ok() || joined_collection->get_schema().count(reference_field_name_op.get()) == 0) { + continue; + } + + auto const& reference_field_name = reference_field_name_op.get(); + auto const& reference_filter_result = reference_filter_results.at(joined_coll_having_reference); + auto const& count = reference_filter_result.count; + std::vector ids; + ids.reserve(count); + for (uint32_t i = 0; i < count; i++) { + joined_collection->get_related_ids_with_lock(reference_field_name, reference_filter_result.docs[i], ids); + } + if (ids.empty()) { + continue; + } + + gfx::timsort(ids.begin(), ids.end()); + ids.erase(unique(ids.begin(), ids.end()), ids.end()); + + reference_filter_result_t result; + result.count = ids.size(); + result.docs = &ids[0]; + prune_doc_op = prune_ref_doc(doc, result, ref_include_fields_full, ref_exclude_fields_full, + joined_collection->get_schema().at(reference_field_name).is_array(), + ref_include_exclude); + result.docs = nullptr; + } + + if (!prune_doc_op.ok()) { + return prune_doc_op; + } } return Option(true); @@ -4966,7 +5160,7 @@ Option Collection::prune_doc(nlohmann::json& doc, const std::string& parent_name, size_t depth, const std::map& reference_filter_results, Collection *const collection, const uint32_t& seq_id, - const std::vector& ref_includes) { + const std::vector& ref_include_exclude_fields_vec) { // doc can only be an object auto it = doc.begin(); while(it != doc.end()) { @@ -5042,191 +5236,7 @@ Option Collection::prune_doc(nlohmann::json& doc, it++; } - for (auto const& ref_include: ref_includes) { - auto const& ref_collection_name = ref_include.collection_name; - - auto& cm = CollectionManager::get_instance(); - auto ref_collection = cm.get_collection(ref_collection_name); - if (ref_collection == nullptr) { - return Option(400, "Referenced collection `" + ref_collection_name + "` in `include_fields` not found."); - } - - auto const joined_on_ref_collection = reference_filter_results.count(ref_collection_name) > 0, - has_filter_reference = (joined_on_ref_collection && - reference_filter_results.at(ref_collection_name).count > 0); - auto doc_has_reference = false, joined_coll_has_reference = false; - - // Reference include_by without join, check if doc itself contains the reference. - if (!joined_on_ref_collection && collection != nullptr) { - doc_has_reference = ref_collection->is_referenced_in(collection->name); - } - - std::string joined_coll_having_reference; - // Check if the joined collection has a reference. - if (!joined_on_ref_collection && !doc_has_reference) { - for (const auto &reference_filter_result: reference_filter_results) { - joined_coll_has_reference = ref_collection->is_referenced_in(reference_filter_result.first); - if (joined_coll_has_reference) { - joined_coll_having_reference = reference_filter_result.first; - break; - } - } - } - - if (!has_filter_reference && !doc_has_reference && !joined_coll_has_reference) { - continue; - } - - std::vector ref_include_fields_vec, ref_exclude_fields_vec; - StringUtils::split(ref_include.fields, ref_include_fields_vec, ","); - auto exclude_reference_it = exclude_names.equal_prefix_range("$" + ref_collection_name); - if (exclude_reference_it.first != exclude_reference_it.second) { - auto ref_exclude = exclude_reference_it.first.key(); - auto parenthesis_index = ref_exclude.find('('); - auto reference_fields = ref_exclude.substr(parenthesis_index + 1, ref_exclude.size() - parenthesis_index - 2); - StringUtils::split(reference_fields, ref_exclude_fields_vec, ","); - } - - spp::sparse_hash_set ref_include_fields, ref_exclude_fields; - ref_include_fields.insert(ref_include_fields_vec.begin(), ref_include_fields_vec.end()); - ref_exclude_fields.insert(ref_exclude_fields_vec.begin(), ref_exclude_fields_vec.end()); - - tsl::htrie_set ref_include_fields_full, ref_exclude_fields_full; - auto include_exclude_op = ref_collection->populate_include_exclude_fields(ref_include_fields, - ref_exclude_fields, - ref_include_fields_full, - ref_exclude_fields_full); - auto error_prefix = "Referenced collection `" + ref_collection_name + "`: "; - if (!include_exclude_op.ok()) { - return Option(include_exclude_op.code(), error_prefix + include_exclude_op.error()); - } - - Option include_references_op = Option(true); - if (has_filter_reference) { - auto get_reference_field_op = collection->get_referenced_in_field(ref_collection_name); - if (!get_reference_field_op.ok()) { - continue; - } - auto const& field_name = get_reference_field_op.get(); - if (ref_collection->search_schema.count(field_name) == 0) { - continue; - } - include_references_op = include_references(doc, ref_include.collection_name, - ref_collection.get(), ref_include.alias, - reference_filter_results.at(ref_collection_name), - ref_include_fields_full, ref_exclude_fields_full, error_prefix, - ref_collection->get_schema().at(field_name).is_array(), - ref_include.strategy); - } else if (doc_has_reference) { - auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection->name); - if (!get_reference_field_op.ok()) { - continue; - } - auto const& field_name = get_reference_field_op.get(); - if (collection->search_schema.count(field_name) == 0) { - continue; - } - - if (collection->object_reference_helper_fields.count(field_name) != 0) { - std::vector keys; - StringUtils::split(field_name, keys, "."); - if (!doc.contains(keys[0])) { - return Option(400, "Could not find `" + keys[0] + - "` in the document to include the referenced document."); - } - - if (doc[keys[0]].is_array()) { - for (uint32_t i = 0; i < doc[keys[0]].size(); i++) { - uint32_t ref_doc_id; - auto op = collection->get_object_array_related_id(field_name, seq_id, i, ref_doc_id); - if (!op.ok()) { - if (op.code() == 404) { // field_name is not indexed. - break; - } else { // No reference found for this object. - continue; - } - } - - reference_filter_result_t result(1, new uint32_t[1]{ref_doc_id}); - include_references_op = include_references(doc[keys[0]][i], ref_include.collection_name, - ref_collection.get(), ref_include.alias, result, - ref_include_fields_full, ref_exclude_fields_full, error_prefix, - false, ref_include.strategy); - if (!include_references_op.ok()) { - return include_references_op; - } - } - } else { - std::vector ids; - auto get_references_op = collection->get_related_ids(field_name, seq_id, ids); - if (!get_references_op.ok()) { - continue; - } - reference_filter_result_t result(ids.size(), &ids[0]); - include_references_op = include_references(doc[keys[0]], ref_include.collection_name, - ref_collection.get(), ref_include.alias, result, - ref_include_fields_full, ref_exclude_fields_full, error_prefix, - collection->search_schema.at(field_name).is_array(), - ref_include.strategy); - result.docs = nullptr; - } - } else { - std::vector ids; - auto get_references_op = collection->get_related_ids(field_name, seq_id, ids); - if (!get_references_op.ok()) { - continue; - } - reference_filter_result_t result(ids.size(), &ids[0]); - include_references_op = include_references(doc, ref_include.collection_name, - ref_collection.get(), ref_include.alias, result, - ref_include_fields_full, ref_exclude_fields_full, error_prefix, - collection->search_schema.at(field_name).is_array(), - ref_include.strategy); - result.docs = nullptr; - } - } else if (joined_coll_has_reference) { - auto joined_collection = cm.get_collection(joined_coll_having_reference); - if (joined_collection == nullptr) { - continue; - } - - auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); - if (!reference_field_name_op.ok() || joined_collection->get_schema().count(reference_field_name_op.get()) == 0) { - continue; - } - - auto const& reference_field_name = reference_field_name_op.get(); - auto const& reference_filter_result = reference_filter_results.at(joined_coll_having_reference); - auto const& count = reference_filter_result.count; - std::vector ids; - ids.reserve(count); - for (uint32_t i = 0; i < count; i++) { - joined_collection->get_related_ids_with_lock(reference_field_name, reference_filter_result.docs[i], ids); - } - if (ids.empty()) { - continue; - } - - gfx::timsort(ids.begin(), ids.end()); - ids.erase(unique(ids.begin(), ids.end()), ids.end()); - - reference_filter_result_t result; - result.count = ids.size(); - result.docs = &ids[0]; - include_references_op = include_references(doc, ref_include.collection_name, - ref_collection.get(), ref_include.alias, result, - ref_include_fields_full, ref_exclude_fields_full, error_prefix, - joined_collection->get_schema().at(reference_field_name).is_array(), - ref_include.strategy); - result.docs = nullptr; - } - - if (!include_references_op.ok()) { - return include_references_op; - } - } - - return Option(true); + return include_references(doc, seq_id, collection, reference_filter_results, ref_include_exclude_fields_vec); } Option Collection::validate_alter_payload(nlohmann::json& schema_changes, diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 1758a69e..aa4184c9 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -982,9 +982,61 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte } } +Option parse_nested_exclude(const std::string& exclude_field_exp, + std::unordered_map& ref_excludes) { + // Format: $ref_collection_name(field_1, field_2, $nested_ref_coll(nested_field_1)) + size_t index = 0; + while (index < exclude_field_exp.size()) { + auto parenthesis_index = exclude_field_exp.find('('); + auto ref_collection_name = exclude_field_exp.substr(index + 1, parenthesis_index - index - 1); + std::string ref_fields; + + index = parenthesis_index + 1; + auto nested_exclude_pos = exclude_field_exp.find('$', parenthesis_index); + auto closing_parenthesis_pos = exclude_field_exp.find(')', parenthesis_index); + size_t comma_pos; + if (nested_exclude_pos < closing_parenthesis_pos) { + // Nested reference exclude. + // "... $product_variants(title, $inventory(qty)) ..." + do { + ref_fields += exclude_field_exp.substr(index, nested_exclude_pos - index); + StringUtils::trim(ref_fields); + index = nested_exclude_pos; + std::string nested_exclude_field_exp; + auto split_op = StringUtils::split_reference_include_exclude_fields(exclude_field_exp, index, + nested_exclude_field_exp); + if (!split_op.ok()) { + return split_op; + } + + auto parse_op = parse_nested_exclude(nested_exclude_field_exp, ref_excludes); + if (!parse_op.ok()) { + return parse_op; + } + + nested_exclude_pos = exclude_field_exp.find('$', index); + closing_parenthesis_pos = exclude_field_exp.find(')', index); + comma_pos = exclude_field_exp.find(',', index); + index = std::min(closing_parenthesis_pos, comma_pos) + 1; + } while (index < exclude_field_exp.size() && nested_exclude_pos < closing_parenthesis_pos); + } + + // ... $inventory(qty) ... + if (index < closing_parenthesis_pos) { + ref_fields += exclude_field_exp.substr(index, closing_parenthesis_pos - index); + } + StringUtils::trim(ref_fields); + + ref_excludes[ref_collection_name] = ref_fields; + index = closing_parenthesis_pos + 1; + } + + return Option(true); +} + Option parse_nested_include(const std::string& include_field_exp, CollectionManager::ref_include_collection_names_t* const ref_include_coll_names, - std::vector& ref_include_fields_vec) { + std::vector& ref_include_exclude_fields_vec) { // Format: $ref_collection_name(field_1, field_2, $nested_ref_coll(nested_field_1: nested_include_strategy) as nested_ref_alias: include_strategy) as ref_alias size_t index = 0; while (index < include_field_exp.size()) { @@ -998,7 +1050,7 @@ Option parse_nested_include(const std::string& include_field_exp, auto closing_parenthesis_pos = include_field_exp.find(')', parenthesis_index); auto colon_pos = include_field_exp.find(':', index); size_t comma_pos; - std::vector nested_ref_include_fields_vec; + std::vector nested_ref_include_exclude_fields_vec; if (nested_include_pos < closing_parenthesis_pos) { // Nested reference include. // "... $product_variants(title, $inventory(qty:merge) as inventory :nest) as variants ..." @@ -1007,15 +1059,15 @@ Option parse_nested_include(const std::string& include_field_exp, StringUtils::trim(ref_fields); index = nested_include_pos; std::string nested_include_field_exp; - auto split_op = StringUtils::split_reference_include_fields(include_field_exp, index, - nested_include_field_exp); + auto split_op = StringUtils::split_reference_include_exclude_fields(include_field_exp, index, + nested_include_field_exp); if (!split_op.ok()) { return split_op; } auto parse_op = parse_nested_include(nested_include_field_exp, ref_include_coll_names == nullptr ? nullptr : ref_include_coll_names->nested_include, - nested_ref_include_fields_vec); + nested_ref_include_exclude_fields_vec); if (!parse_op.ok()) { return parse_op; } @@ -1062,11 +1114,11 @@ Option parse_nested_include(const std::string& include_field_exp, nest_ref_doc = strategy_enum == ref_include::nest || strategy_enum == ref_include::nest_array; ref_alias = !ref_alias.empty() ? (StringUtils::trim(ref_alias) + (nest_ref_doc ? "" : ".")) : ""; - ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, - strategy_enum}); - ref_include_fields_vec.back().nested_join_includes = std::move(nested_ref_include_fields_vec); + ref_include_exclude_fields_vec.emplace_back(ref_include_exclude_fields{ref_collection_name, ref_fields, "", + ref_alias, strategy_enum}); + ref_include_exclude_fields_vec.back().nested_join_includes = std::move(nested_ref_include_exclude_fields_vec); - // Referenced collection in filter_by is already mentioned in ref_include_fields. + // Referenced collection in filter_by is already mentioned in include_fields. if (ref_include_coll_names != nullptr) { ref_include_coll_names->collection_names.erase(ref_collection_name); } @@ -1079,9 +1131,10 @@ Option parse_nested_include(const std::string& include_field_exp, return Option(true); } -Option CollectionManager::_initialize_ref_include_fields_vec(const std::string& filter_query, - std::vector& include_fields_vec, - std::vector& ref_include_fields_vec) { +Option CollectionManager::_initialize_ref_include_exclude_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& exclude_fields_vec, + std::vector& ref_include_exclude_fields_vec) { ref_include_collection_names_t* ref_include_coll_names = nullptr; CollectionManager::_get_reference_collection_names(filter_query, ref_include_coll_names); std::unique_ptr guard(ref_include_coll_names); @@ -1101,7 +1154,7 @@ Option CollectionManager::_initialize_ref_include_fields_vec(const std::st // Nested reference include. if (include_field_exp.find('$', 1) != std::string::npos) { - auto parse_op = parse_nested_include(include_field_exp, ref_include_coll_names, ref_include_fields_vec); + auto parse_op = parse_nested_include(include_field_exp, ref_include_coll_names, ref_include_exclude_fields_vec); if (!parse_op.ok()) { return parse_op; } @@ -1139,46 +1192,77 @@ Option CollectionManager::_initialize_ref_include_fields_vec(const std::st // In case of "nest" reference doc, `foo` becomes the key with reference doc as value. auto const& nest_ref_doc = strategy_enum == ref_include::nest || strategy_enum == ref_include::nest_array; auto ref_alias = !alias.empty() ? (StringUtils::trim(alias) + (nest_ref_doc ? "" : ".")) : ""; - ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, - strategy_enum}); + ref_include_exclude_fields_vec.emplace_back(ref_include_exclude_fields{ref_collection_name, ref_fields, "", + ref_alias, strategy_enum}); - auto open_paren_pos = include_field_exp.find('('); - if (open_paren_pos == std::string::npos) { - continue; - } - - auto reference_collection_name = include_field_exp.substr(1, open_paren_pos - 1); - StringUtils::trim(reference_collection_name); - if (reference_collection_name.empty()) { - continue; - } - - // Referenced collection in filter_by is already mentioned in ref_include_fields. + // Referenced collection in filter_by is already mentioned in include_fields. if (ref_include_coll_names != nullptr) { - ref_include_coll_names->collection_names.erase(reference_collection_name); + ref_include_coll_names->collection_names.erase(ref_collection_name); } } - // Get all the fields of the referenced collection in the filter but not mentioned in include_fields. - auto ref_includes = std::ref(ref_include_fields_vec); + // Get all the fields of the referenced collection mentioned in the filter_by but not in include_fields. + auto references = std::ref(ref_include_exclude_fields_vec); while (ref_include_coll_names != nullptr) { for (const auto &reference_collection_name: ref_include_coll_names->collection_names) { - ref_includes.get().emplace_back(ref_include_fields{reference_collection_name, "", "", ref_include::nest}); + references.get().emplace_back(ref_include_exclude_fields{reference_collection_name, "", "", ""}); } ref_include_coll_names = ref_include_coll_names->nested_include; - if (ref_includes.get().empty()) { + if (references.get().empty()) { break; } - ref_includes = std::ref(ref_includes.get().front().nested_join_includes); + references = std::ref(references.get().front().nested_join_includes); } - // Since no field of the collection is mentioned in include_fields, get all the fields. + std::unordered_map ref_excludes; + std::vector result_exclude_fields_vec; + for (const auto& exclude_field_exp: exclude_fields_vec) { + if (exclude_field_exp[0] != '$') { + result_exclude_fields_vec.emplace_back(exclude_field_exp); + continue; + } + + // Nested reference exclude. + if (exclude_field_exp.find('$', 1) != std::string::npos) { + auto parse_op = parse_nested_exclude(exclude_field_exp, ref_excludes); + if (!parse_op.ok()) { + return parse_op; + } + continue; + } + + // Format: $ref_collection_name(field_1, field_2) + auto parenthesis_index = exclude_field_exp.find('('); + auto ref_collection_name = exclude_field_exp.substr(1, parenthesis_index - 1); + auto ref_fields = exclude_field_exp.substr(parenthesis_index + 1, exclude_field_exp.size() - parenthesis_index - 2); + if (!ref_fields.empty()) { + ref_excludes[ref_collection_name] = ref_fields; + } + } + + if (!ref_excludes.empty()) { + references = std::ref(ref_include_exclude_fields_vec); + while (!references.get().empty()) { + for (auto& ref_include_exclude: references.get()) { + if (ref_excludes.count(ref_include_exclude.collection_name) == 0) { + continue; + } + + ref_include_exclude.exclude_fields = ref_excludes[ref_include_exclude.collection_name]; + } + + references = std::ref(references.get().front().nested_join_includes); + } + } + + // Since no field of the collection being searched is mentioned in include_fields, include all the fields. if (wildcard_include_all) { result_include_fields_vec.clear(); } include_fields_vec = std::move(result_include_fields_vec); + exclude_fields_vec = std::move(result_exclude_fields_vec); return Option(true); } @@ -1356,7 +1440,7 @@ Option CollectionManager::do_search(std::map& re std::vector include_fields_vec; std::vector exclude_fields_vec; - std::vector ref_include_fields_vec; + std::vector ref_include_exclude_fields_vec; spp::sparse_hash_set include_fields; spp::sparse_hash_set exclude_fields; @@ -1540,8 +1624,8 @@ Option CollectionManager::do_search(std::map& re if(key == FACET_BY){ StringUtils::split_facet(val, *find_str_list_it->second); } - else if(key == INCLUDE_FIELDS){ - auto op = StringUtils::split_include_fields(val, *find_str_list_it->second); + else if(key == INCLUDE_FIELDS || key == EXCLUDE_FIELDS){ + auto op = StringUtils::split_include_exclude_fields(val, *find_str_list_it->second); if (!op.ok()) { return op; } @@ -1566,7 +1650,8 @@ Option CollectionManager::do_search(std::map& re per_page = 0; } - auto initialize_op = _initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + auto initialize_op = _initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, exclude_fields_vec, + ref_include_exclude_fields_vec); if (!initialize_op.ok()) { return initialize_op; } @@ -1662,7 +1747,7 @@ Option CollectionManager::do_search(std::map& re remote_embedding_num_tries, stopwords_set, facet_return_parent, - ref_include_fields_vec, + ref_include_exclude_fields_vec, drop_tokens_mode_str, prioritize_num_matching_fields, group_missing_values, diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index b3254aaf..44867fec 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -14,23 +14,31 @@ #include "posting.h" #include "collection_manager.h" -void filter_result_t::copy_references(const filter_result_t& from, filter_result_t& to) { - if (from.coll_to_references == nullptr) { +void copy_references_helper(const std::map* from, + std::map*& to, const uint32_t& count) { + if (from == nullptr) { return; } - auto const& count = from.count; - to.coll_to_references = new std::map[count] {}; + to = new std::map[count] {}; for (uint32_t i = 0; i < count; i++) { - if (from.coll_to_references[i].empty()) { + if (from[i].empty()) { continue; } - auto& ref = to.coll_to_references[i]; - ref.insert(from.coll_to_references[i].begin(), from.coll_to_references[i].end()); + auto& ref = to[i]; + ref.insert(from[i].begin(), from[i].end()); } } +void reference_filter_result_t::copy_references(const reference_filter_result_t& from, reference_filter_result_t& to) { + return copy_references_helper(from.coll_to_references, to.coll_to_references, from.count); +} + +void filter_result_t::copy_references(const filter_result_t& from, filter_result_t& to) { + return copy_references_helper(from.coll_to_references, to.coll_to_references, from.count); +} + void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { auto lenA = a.count, lenB = b.count; if (lenA == 0 || lenB == 0) { @@ -660,26 +668,16 @@ void filter_result_iterator_t::init() { return; } - std::vector values; - for (uint32_t i = 0; i < result.count; i++) { - values.push_back(std::to_string(result.docs[i])); - } - - filter filter_exp = {get_reference_field_op.get(), std::move(values), - std::vector(result.count, EQUALS)}; - - auto filter_tree_root = new filter_node_t(filter_exp); - std::unique_ptr filter_tree_root_guard(filter_tree_root); - - auto fit = filter_result_iterator_t(collection_name, index, filter_tree_root); - auto filter_init_op = fit.init_status(); - if (!filter_init_op.ok()) { - status = Option(filter_init_op.code(), filter_init_op.error()); + auto const& reference_helper_field_name = get_reference_field_op.get(); + auto op = index->do_filtering_with_reference_ids(reference_helper_field_name, ref_collection_name, + std::move(result)); + if (!op.ok()) { + status = Option(op.code(), op.error()); validity = invalid; return; } - filter_result = std::move(fit.filter_result); + filter_result = op.get(); } if (filter_result.count == 0) { @@ -1304,6 +1302,7 @@ int filter_result_iterator_t::is_valid(uint32_t id) { } seq_id = id; + and_filter_iterators(); return 1; } else { validity = (left_it->validity == valid || right_it->validity == valid) ? valid : invalid; @@ -1327,6 +1326,7 @@ int filter_result_iterator_t::is_valid(uint32_t id) { } seq_id = id; + or_filter_iterators(); return 1; } } diff --git a/src/index.cpp b/src/index.cpp index 0287f443..d2eb01b6 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1801,37 +1801,160 @@ Option Index::do_filtering_with_lock(filter_node_t* const filter_tree_root return filter_init_op; } - filter_result.count = filter_result_iterator.to_filter_id_array(filter_result.docs); + if (filter_result_iterator.reference.empty()) { + filter_result.count = filter_result_iterator.to_filter_id_array(filter_result.docs); + return Option(true); + } + filter_result_iterator.compute_result(); + if (filter_result_iterator.approx_filter_ids_length == 0) { + return Option(true); + } + + uint32_t count = filter_result_iterator.approx_filter_ids_length, dummy; + auto ref_filter_result = new filter_result_t(); + std::unique_ptr ref_filter_result_guard(ref_filter_result); + filter_result_iterator.get_n_ids(count, dummy, nullptr, 0, ref_filter_result); + + if (filter_result_iterator.validity == filter_result_iterator_t::timed_out) { + return Option(true); + } + + filter_result = std::move(*ref_filter_result); return Option(true); } +void aggregate_nested_references(single_filter_result_t *const reference_result, + reference_filter_result_t& ref_filter_result) { + // Add reference doc id in result. + auto temp_docs = new uint32_t[ref_filter_result.count + 1]; + std::copy(ref_filter_result.docs, ref_filter_result.docs + ref_filter_result.count, temp_docs); + temp_docs[ref_filter_result.count] = reference_result->seq_id; + + delete[] ref_filter_result.docs; + ref_filter_result.docs = temp_docs; + ref_filter_result.count++; + ref_filter_result.is_reference_array_field = false; + + // Add references of the reference doc id in result. + auto& references = ref_filter_result.coll_to_references; + auto temp_references = new std::map[ref_filter_result.count] {}; + for (uint32_t i = 0; i < ref_filter_result.count - 1; i++) { + temp_references[i] = std::move(references[i]); + } + temp_references[ref_filter_result.count - 1] = std::move(reference_result->reference_filter_results); + + delete[] references; + references = temp_references; +} + Option Index::do_reference_filtering_with_lock(filter_node_t* const filter_tree_root, filter_result_t& filter_result, - const std::string& collection_name, + const std::string& ref_collection_name, const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); - auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, - search_begin_us, search_stop_us); - auto filter_init_op = filter_result_iterator.init_status(); + auto ref_filter_result_iterator = filter_result_iterator_t(ref_collection_name, this, filter_tree_root, + search_begin_us, search_stop_us); + auto filter_init_op = ref_filter_result_iterator.init_status(); if (!filter_init_op.ok()) { return filter_init_op; } - uint32_t* reference_docs = nullptr; - uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs); - std::unique_ptr docs_guard(reference_docs); - - if (count == 0) { + ref_filter_result_iterator.compute_result(); + if (ref_filter_result_iterator.approx_filter_ids_length == 0) { return Option(true); } - if (search_schema.at(reference_helper_field_name).is_singular()) { + uint32_t count = ref_filter_result_iterator.approx_filter_ids_length, dummy; + auto ref_filter_result = new filter_result_t(); + std::unique_ptr ref_filter_result_guard(ref_filter_result); + ref_filter_result_iterator.get_n_ids(count, dummy, nullptr, 0, ref_filter_result); + + if (ref_filter_result_iterator.validity == filter_result_iterator_t::timed_out) { + return Option(true); + } + + uint32_t* reference_docs = ref_filter_result->docs; + ref_filter_result->docs = nullptr; + std::unique_ptr docs_guard(reference_docs); + + auto const is_nested_join = !ref_filter_result_iterator.reference.empty(); + if (search_schema.at(reference_helper_field_name).is_singular()) { // Only one reference per doc. + if (sort_index.count(reference_helper_field_name) == 0) { + return Option(400, "`" + reference_helper_field_name + "` is not present in sort index."); + } + auto const& ref_index = *sort_index.at(reference_helper_field_name); + + if (is_nested_join) { + // In case of nested join, we need to collect all the doc ids from the reference ids along with their references. + std::vector> id_pairs; + std::unordered_set unique_doc_ids; + + for (uint32_t i = 0; i < count; i++) { + auto& reference_doc_id = reference_docs[i]; + auto reference_doc_references = std::move(ref_filter_result->coll_to_references[i]); + if (ref_index.count(reference_doc_id) == 0) { // Reference field might be optional. + continue; + } + auto doc_id = ref_index.at(reference_doc_id); + + id_pairs.emplace_back(std::make_pair(doc_id, new single_filter_result_t(reference_doc_id, + std::move(reference_doc_references), + false))); + unique_doc_ids.insert(doc_id); + } + + if (id_pairs.empty()) { + return Option(true); + } + + std::sort(id_pairs.begin(), id_pairs.end(), [](auto const& left, auto const& right) { + return left.first < right.first; + }); + + filter_result.count = unique_doc_ids.size(); + filter_result.docs = new uint32_t[unique_doc_ids.size()]; + filter_result.coll_to_references = new std::map[unique_doc_ids.size()] {}; + + reference_filter_result_t previous_doc_references; + for (uint32_t i = 0, previous_doc = id_pairs[0].first + 1, result_index = 0; i < id_pairs.size(); i++) { + auto const& current_doc = id_pairs[i].first; + auto& reference_result = id_pairs[i].second; + + if (current_doc != previous_doc) { + filter_result.docs[result_index] = current_doc; + if (result_index > 0) { + std::map references; + references[ref_collection_name] = std::move(previous_doc_references); + filter_result.coll_to_references[result_index - 1] = std::move(references); + } + + result_index++; + previous_doc = current_doc; + aggregate_nested_references(reference_result, previous_doc_references); + } else { + aggregate_nested_references(reference_result, previous_doc_references); + } + } + + if (previous_doc_references.count != 0) { + std::map references; + references[ref_collection_name] = std::move(previous_doc_references); + filter_result.coll_to_references[filter_result.count - 1] = std::move(references); + } + + for (auto &item: id_pairs) { + delete item.second; + } + + return Option(true); + } + // Collect all the doc ids from the reference ids. std::vector> id_pairs; std::unordered_set unique_doc_ids; - auto const& ref_index = *sort_index.at(reference_helper_field_name); + for (uint32_t i = 0; i < count; i++) { auto& reference_doc_id = reference_docs[i]; if (ref_index.count(reference_doc_id) == 0) { // Reference field might be optional. @@ -1839,10 +1962,14 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter } auto doc_id = ref_index.at(reference_doc_id); - id_pairs.emplace_back(std::pair(doc_id, reference_doc_id)); + id_pairs.emplace_back(std::make_pair(doc_id, reference_doc_id)); unique_doc_ids.insert(doc_id); } + if (id_pairs.empty()) { + return Option(true); + } + std::sort(id_pairs.begin(), id_pairs.end(), [](auto const& left, auto const& right) { return left.first < right.first; }); @@ -1861,9 +1988,11 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter if (result_index > 0) { auto& reference_result = filter_result.coll_to_references[result_index - 1]; - auto r = reference_filter_result_t(previous_doc_references.size(), new uint32_t[previous_doc_references.size()]); + auto r = reference_filter_result_t(previous_doc_references.size(), + new uint32_t[previous_doc_references.size()], + false); std::copy(previous_doc_references.begin(), previous_doc_references.end(), r.docs); - reference_result[collection_name] = std::move(r); + reference_result[ref_collection_name] = std::move(r); previous_doc_references.clear(); } @@ -1879,41 +2008,317 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter if (!previous_doc_references.empty()) { auto& reference_result = filter_result.coll_to_references[filter_result.count - 1]; - auto r = reference_filter_result_t(previous_doc_references.size(), new uint32_t[previous_doc_references.size()]); + auto r = reference_filter_result_t(previous_doc_references.size(), + new uint32_t[previous_doc_references.size()], + false); std::copy(previous_doc_references.begin(), previous_doc_references.end(), r.docs); - reference_result[collection_name] = std::move(r); + reference_result[ref_collection_name] = std::move(r); } return Option(true); } - size_t ids_len = 0; - uint32_t *ids = nullptr; + // Multiple references per doc. + if (reference_index.count(reference_helper_field_name) == 0) { + return Option(400, "`" + reference_helper_field_name + "` is not present in reference index."); + } auto& ref_index = *reference_index.at(reference_helper_field_name); + + if (is_nested_join) { + // In case of nested join, we need to collect all the doc ids from the reference ids along with their references. + std::vector> id_pairs; + std::unordered_set unique_doc_ids; + + for (uint32_t i = 0; i < count; i++) { + auto& reference_doc_id = reference_docs[i]; + auto reference_doc_references = std::move(ref_filter_result->coll_to_references[i]); + size_t doc_ids_len = 0; + uint32_t* doc_ids = nullptr; + + ref_index.search(EQUALS, reference_doc_id, &doc_ids, doc_ids_len); + + for (size_t j = 0; j < doc_ids_len; j++) { + auto doc_id = doc_ids[j]; + auto reference_doc_references_copy = reference_doc_references; + id_pairs.emplace_back(std::make_pair(doc_id, new single_filter_result_t(reference_doc_id, + std::move(reference_doc_references_copy), + false))); + unique_doc_ids.insert(doc_id); + } + delete[] doc_ids; + } + + if (id_pairs.empty()) { + return Option(true); + } + + std::sort(id_pairs.begin(), id_pairs.end(), [](auto const& left, auto const& right) { + return left.first < right.first; + }); + + filter_result.count = unique_doc_ids.size(); + filter_result.docs = new uint32_t[unique_doc_ids.size()]; + filter_result.coll_to_references = new std::map[unique_doc_ids.size()] {}; + + reference_filter_result_t previous_doc_references; + for (uint32_t i = 0, previous_doc = id_pairs[0].first + 1, result_index = 0; i < id_pairs.size(); i++) { + auto const& current_doc = id_pairs[i].first; + auto& reference_result = id_pairs[i].second; + + if (current_doc != previous_doc) { + filter_result.docs[result_index] = current_doc; + if (result_index > 0) { + std::map references; + references[ref_collection_name] = std::move(previous_doc_references); + filter_result.coll_to_references[result_index - 1] = std::move(references); + } + + result_index++; + previous_doc = current_doc; + aggregate_nested_references(reference_result, previous_doc_references); + } else { + aggregate_nested_references(reference_result, previous_doc_references); + } + } + + if (previous_doc_references.count != 0) { + std::map references; + references[ref_collection_name] = std::move(previous_doc_references); + filter_result.coll_to_references[filter_result.count - 1] = std::move(references); + } + + for (auto &item: id_pairs) { + delete item.second; + } + + return Option(true); + } + + std::vector> id_pairs; + std::unordered_set unique_doc_ids; + for (uint32_t i = 0; i < count; i++) { auto& reference_doc_id = reference_docs[i]; - ref_index.search(EQUALS, reference_doc_id, &ids, ids_len); + size_t doc_ids_len = 0; + uint32_t* doc_ids = nullptr; + + ref_index.search(EQUALS, reference_doc_id, &doc_ids, doc_ids_len); + + for (size_t j = 0; j < doc_ids_len; j++) { + auto doc_id = doc_ids[j]; + id_pairs.emplace_back(std::make_pair(doc_id, reference_doc_id)); + unique_doc_ids.insert(doc_id); + } + delete[] doc_ids; } - filter_result.count = ids_len; - filter_result.docs = new uint32_t[ids_len]; - filter_result.coll_to_references = new std::map[ids_len] {}; - - auto& num_index = *numerical_index.at(reference_helper_field_name); - for (size_t i = 0; i < ids_len; i++) { - filter_result.docs[i] = ids[i]; - - reference_filter_result_t reference_result; - size_t len = 0; - num_index.search(EQUALS, ids[i], &reference_result.docs, len); - reference_result.count = len; - filter_result.coll_to_references[i][collection_name] = std::move(reference_result); + if (id_pairs.empty()) { + return Option(true); + } + + std::sort(id_pairs.begin(), id_pairs.end(), [](auto const& left, auto const& right) { + return left.first < right.first; + }); + + filter_result.count = unique_doc_ids.size(); + filter_result.docs = new uint32_t[unique_doc_ids.size()]; + filter_result.coll_to_references = new std::map[unique_doc_ids.size()] {}; + + std::vector previous_doc_references; + for (uint32_t i = 0, previous_doc = id_pairs[0].first + 1, result_index = 0; i < id_pairs.size(); i++) { + auto const& current_doc = id_pairs[i].first; + auto const& reference_doc_id = id_pairs[i].second; + + if (current_doc != previous_doc) { + filter_result.docs[result_index] = current_doc; + if (result_index > 0) { + auto& reference_result = filter_result.coll_to_references[result_index - 1]; + + auto r = reference_filter_result_t(previous_doc_references.size(), new uint32_t[previous_doc_references.size()]); + std::copy(previous_doc_references.begin(), previous_doc_references.end(), r.docs); + reference_result[ref_collection_name] = std::move(r); + + previous_doc_references.clear(); + } + + result_index++; + previous_doc = current_doc; + previous_doc_references.push_back(reference_doc_id); + } else { + previous_doc_references.push_back(reference_doc_id); + } + } + + if (!previous_doc_references.empty()) { + auto& reference_result = filter_result.coll_to_references[filter_result.count - 1]; + + auto r = reference_filter_result_t(previous_doc_references.size(), new uint32_t[previous_doc_references.size()]); + std::copy(previous_doc_references.begin(), previous_doc_references.end(), r.docs); + reference_result[ref_collection_name] = std::move(r); } - delete [] ids; return Option(true); } +Option Index::do_filtering_with_reference_ids(const std::string& reference_helper_field_name, + const std::string& ref_collection_name, + filter_result_t&& ref_filter_result) const { + filter_result_t filter_result; + auto const& count = ref_filter_result.count; + auto const& reference_docs = ref_filter_result.docs; + auto const is_nested_join = ref_filter_result.coll_to_references != nullptr; + + if (count == 0) { + return Option(filter_result); + } + + if (numerical_index.count(reference_helper_field_name) == 0) { + return Option(400, "`" + reference_helper_field_name + "` is not present in index."); + } + auto num_tree = numerical_index.at(reference_helper_field_name); + + if (is_nested_join) { + // In case of nested join, we need to collect all the doc ids from the reference ids along with their references. + std::vector> id_pairs; + std::unordered_set unique_doc_ids; + + for (uint32_t i = 0; i < count; i++) { + auto& reference_doc_id = reference_docs[i]; + auto reference_doc_references = std::move(ref_filter_result.coll_to_references[i]); + size_t doc_ids_len = 0; + uint32_t* doc_ids = nullptr; + + num_tree->search(NUM_COMPARATOR::EQUALS, reference_doc_id, &doc_ids, doc_ids_len); + + for (size_t j = 0; j < doc_ids_len; j++) { + auto doc_id = doc_ids[j]; + auto reference_doc_references_copy = reference_doc_references; + id_pairs.emplace_back(std::make_pair(doc_id, new single_filter_result_t(reference_doc_id, + std::move(reference_doc_references_copy), + false))); + unique_doc_ids.insert(doc_id); + } + + delete[] doc_ids; + } + + if (id_pairs.empty()) { + return Option(filter_result); + } + + std::sort(id_pairs.begin(), id_pairs.end(), [](auto const& left, auto const& right) { + return left.first < right.first; + }); + + filter_result.count = unique_doc_ids.size(); + filter_result.docs = new uint32_t[unique_doc_ids.size()]; + filter_result.coll_to_references = new std::map[unique_doc_ids.size()] {}; + + reference_filter_result_t previous_doc_references; + for (uint32_t i = 0, previous_doc = id_pairs[0].first + 1, result_index = 0; i < id_pairs.size(); i++) { + auto const& current_doc = id_pairs[i].first; + auto& reference_result = id_pairs[i].second; + + if (current_doc != previous_doc) { + filter_result.docs[result_index] = current_doc; + if (result_index > 0) { + std::map references; + references[ref_collection_name] = std::move(previous_doc_references); + filter_result.coll_to_references[result_index - 1] = std::move(references); + } + + result_index++; + previous_doc = current_doc; + aggregate_nested_references(reference_result, previous_doc_references); + } else { + aggregate_nested_references(reference_result, previous_doc_references); + } + } + + if (previous_doc_references.count != 0) { + std::map references; + references[ref_collection_name] = std::move(previous_doc_references); + filter_result.coll_to_references[filter_result.count - 1] = std::move(references); + } + + for (auto &item: id_pairs) { + delete item.second; + } + + return Option(filter_result); + } + + // Collect all the doc ids from the reference ids. + std::vector> id_pairs; + std::unordered_set unique_doc_ids; + + for (uint32_t i = 0; i < count; i++) { + auto& reference_doc_id = reference_docs[i]; + size_t doc_ids_len = 0; + uint32_t* doc_ids = nullptr; + + num_tree->search(NUM_COMPARATOR::EQUALS, reference_doc_id, &doc_ids, doc_ids_len); + + for (size_t j = 0; j < doc_ids_len; j++) { + auto doc_id = doc_ids[j]; + id_pairs.emplace_back(std::make_pair(doc_id, reference_doc_id)); + unique_doc_ids.insert(doc_id); + } + delete[] doc_ids; + } + + if (id_pairs.empty()) { + return Option(filter_result); + } + + std::sort(id_pairs.begin(), id_pairs.end(), [](auto const& left, auto const& right) { + return left.first < right.first; + }); + + filter_result.count = unique_doc_ids.size(); + filter_result.docs = new uint32_t[unique_doc_ids.size()]; + filter_result.coll_to_references = new std::map[unique_doc_ids.size()] {}; + + std::vector previous_doc_references; + for (uint32_t i = 0, previous_doc = id_pairs[0].first + 1, result_index = 0; i < id_pairs.size(); i++) { + auto const& current_doc = id_pairs[i].first; + auto const& reference_doc_id = id_pairs[i].second; + + if (current_doc != previous_doc) { + filter_result.docs[result_index] = current_doc; + if (result_index > 0) { + auto& reference_result = filter_result.coll_to_references[result_index - 1]; + + auto r = reference_filter_result_t(previous_doc_references.size(), + new uint32_t[previous_doc_references.size()], + false); + std::copy(previous_doc_references.begin(), previous_doc_references.end(), r.docs); + reference_result[ref_collection_name] = std::move(r); + + previous_doc_references.clear(); + } + + result_index++; + previous_doc = current_doc; + previous_doc_references.push_back(reference_doc_id); + } else { + previous_doc_references.push_back(reference_doc_id); + } + } + + if (!previous_doc_references.empty()) { + auto& reference_result = filter_result.coll_to_references[filter_result.count - 1]; + + auto r = reference_filter_result_t(previous_doc_references.size(), + new uint32_t[previous_doc_references.size()], + false); + std::copy(previous_doc_references.begin(), previous_doc_references.end(), r.docs); + reference_result[ref_collection_name] = std::move(r); + } + + return Option(filter_result); +} + Option Index::run_search(search_args* search_params, const std::string& collection_name, facet_index_type_t facet_index_type) { return search(search_params->field_query_tokens, diff --git a/src/string_utils.cpp b/src/string_utils.cpp index cd74e96a..781929ab 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -492,24 +492,24 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, return Option(true); } -Option StringUtils::split_reference_include_fields(const std::string& include_fields, - size_t& index, - std::string& token) { - auto ref_include_error = Option(400, "Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`."); - auto const& size = include_fields.size(); +Option StringUtils::split_reference_include_exclude_fields(const std::string& include_exclude_fields, + size_t& index, std::string& token) { + auto ref_include_error = Option(400, "Invalid reference `" + include_exclude_fields + "` in include_fields/" + "exclude_fields, expected `$CollectionName(fieldA, ...)`."); + auto const& size = include_exclude_fields.size(); size_t start_index = index; - while(++index < size && include_fields[index] != '(') {} + while(++index < size && include_exclude_fields[index] != '(') {} if (index >= size) { return ref_include_error; } - // In case of nested join, the reference include field could have parenthesis inside it. + // In case of nested join, the reference include/exclude field could have parenthesis inside it. int parenthesis_count = 1; while (++index < size && parenthesis_count > 0) { - if (include_fields[index] == '(') { + if (include_exclude_fields[index] == '(') { parenthesis_count++; - } else if (include_fields[index] == ')') { + } else if (include_exclude_fields[index] == ')') { parenthesis_count--; } } @@ -525,34 +525,35 @@ Option StringUtils::split_reference_include_fields(const std::string& incl // ...^ // $ref_include( $nested_ref_include(foo :merge)as nest :merge ) as ref // ...^ - auto closing_parenthesis_pos = include_fields.find(')', index); - auto comma_pos = include_fields.find(',', index); - auto colon_pos = include_fields.find(':', index); - auto alias_start_pos = include_fields.find(" as ", index); + auto closing_parenthesis_pos = include_exclude_fields.find(')', index); + auto comma_pos = include_exclude_fields.find(',', index); + auto colon_pos = include_exclude_fields.find(':', index); + auto alias_start_pos = include_exclude_fields.find(" as ", index); auto alias_end_pos = std::min(std::min(closing_parenthesis_pos, comma_pos), colon_pos); std::string alias; if (alias_start_pos != std::string::npos && alias_start_pos < alias_end_pos) { - alias = include_fields.substr(alias_start_pos, alias_end_pos - alias_start_pos); + alias = include_exclude_fields.substr(alias_start_pos, alias_end_pos - alias_start_pos); } - token = include_fields.substr(start_index, index - start_index) + " " + trim(alias); + token = include_exclude_fields.substr(start_index, index - start_index) + " " + trim(alias); trim(token); index = alias_end_pos; return Option(true); } -Option StringUtils::split_include_fields(const std::string& include_fields, std::vector& tokens) { +Option StringUtils::split_include_exclude_fields(const std::string& include_exclude_fields, + std::vector& tokens) { std::string token; - auto const& size = include_fields.size(); + auto const& size = include_exclude_fields.size(); for (size_t i = 0; i < size;) { - auto c = include_fields[i]; + auto c = include_exclude_fields[i]; if (c == ' ') { i++; continue; - } else if (c == '$') { // Reference include + } else if (c == '$') { // Reference include/exclude std::string ref_include_token; - auto split_op = split_reference_include_fields(include_fields, i, ref_include_token); + auto split_op = split_reference_include_exclude_fields(include_exclude_fields, i, ref_include_token); if (!split_op.ok()) { return split_op; } @@ -561,8 +562,8 @@ Option StringUtils::split_include_fields(const std::string& include_fields continue; } - auto comma_pos = include_fields.find(',', i); - token = include_fields.substr(i, (comma_pos == std::string::npos ? size : comma_pos) - i); + auto comma_pos = include_exclude_fields.find(',', i); + token = include_exclude_fields.substr(i, (comma_pos == std::string::npos ? size : comma_pos) - i); trim(token); if (!token.empty()) { tokens.push_back(token); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index f254a00a..2fb3c103 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -1827,6 +1827,559 @@ TEST_F(CollectionJoinTest, FilterByNReferences) { collectionManager.drop_collection("Links"); } +TEST_F(CollectionJoinTest, FilterByNestedReferences) { + auto schema_json = + R"({ + "name": "Coll_A", + "fields": [ + {"name": "title", "type": "string"} + ] + })"_json; + std::vector documents = { + R"({ + "title": "coll_a_0" + })"_json, + R"({ + "title": "coll_a_1" + })"_json + }; + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + schema_json = + R"({ + "name": "Coll_B", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "ref_coll_a", "type": "string", "reference": "Coll_A.id"} + ] + })"_json; + documents = { + R"({ + "title": "coll_b_0", + "ref_coll_a": "1" + })"_json, + R"({ + "title": "coll_b_1", + "ref_coll_a": "0" + })"_json, + R"({ + "title": "coll_b_2", + "ref_coll_a": "0" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + schema_json = + R"({ + "name": "Coll_C", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "ref_coll_b", "type": "string[]", "reference": "Coll_B.id"} + ] + })"_json; + documents = { + R"({ + "title": "coll_c_0", + "ref_coll_b": ["0"] + })"_json, + R"({ + "title": "coll_c_1", + "ref_coll_b": ["1"] + })"_json, + R"({ + "title": "coll_c_2", + "ref_coll_b": ["0", "1"] + })"_json, + R"({ + "title": "coll_c_3", + "ref_coll_b": ["2"] + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + std::map req_params = { + {"collection", "Coll_A"}, + {"q", "*"}, + {"filter_by", "$Coll_B($Coll_C(id: [1, 3]))"}, + {"include_fields", "title, $Coll_B(title, $Coll_C(title))"} + }; + nlohmann::json embedded_params; + std::string json_res; + auto now_ts = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count(); + + auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + nlohmann::json res_obj = nlohmann::json::parse(json_res); + // coll_b_1 <- coll_c_1 + // coll_a_0 < + // coll_b_2 <- coll_c_3 + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("coll_a_0", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["Coll_B"].size()); + ASSERT_EQ("coll_b_1", res_obj["hits"][0]["document"]["Coll_B"][0]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_B"][0]["Coll_C"].size()); + ASSERT_EQ("coll_c_1", res_obj["hits"][0]["document"]["Coll_B"][0]["Coll_C"][0]["title"]); + ASSERT_EQ("coll_b_2", res_obj["hits"][0]["document"]["Coll_B"][1]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_B"][1]["Coll_C"].size()); + ASSERT_EQ("coll_c_3", res_obj["hits"][0]["document"]["Coll_B"][1]["Coll_C"][0]["title"]); + + req_params = { + {"collection", "Coll_A"}, + {"q", "*"}, + {"filter_by", "$Coll_B($Coll_C(id: != 0))"}, + {"include_fields", "title, $Coll_B(title, $Coll_C(title):nest_array)"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + // coll_a_1 <- coll_b_0 <- coll_c_2 + // + // coll_b_1 <- coll_c_1, coll_c_2 + // coll_a_0 < + // coll_b_2 <- coll_c_3 + ASSERT_EQ(2, res_obj["found"].get()); + ASSERT_EQ(2, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("coll_a_1", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_B"].size()); + ASSERT_EQ("coll_b_0", res_obj["hits"][0]["document"]["Coll_B"][0]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_B"][0]["Coll_C"].size()); + ASSERT_EQ("coll_c_2", res_obj["hits"][0]["document"]["Coll_B"][0]["Coll_C"][0]["title"]); + + ASSERT_EQ("coll_a_0", res_obj["hits"][1]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["Coll_B"].size()); + ASSERT_EQ("coll_b_1", res_obj["hits"][1]["document"]["Coll_B"][0]["title"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["Coll_B"][0]["Coll_C"].size()); + ASSERT_EQ("coll_c_1", res_obj["hits"][1]["document"]["Coll_B"][0]["Coll_C"][0]["title"]); + ASSERT_EQ("coll_c_2", res_obj["hits"][1]["document"]["Coll_B"][0]["Coll_C"][1]["title"]); + ASSERT_EQ("coll_b_2", res_obj["hits"][1]["document"]["Coll_B"][1]["title"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Coll_B"][1]["Coll_C"].size()); + ASSERT_EQ("coll_c_3", res_obj["hits"][1]["document"]["Coll_B"][1]["Coll_C"][0]["title"]); + + req_params = { + {"collection", "Coll_C"}, + {"q", "*"}, + {"filter_by", "$Coll_B($Coll_A(id: 0))"}, + {"include_fields", "title, $Coll_B(title, $Coll_A(title))"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + // coll_c_3 -> coll_b_2 -> coll_a_0 + // + // coll_c_2 -> coll_b_1 -> coll_a_0 + // + // coll_c_1 -> coll_b_1 -> coll_a_0 + ASSERT_EQ(3, res_obj["found"].get()); + ASSERT_EQ(3, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("coll_c_3", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["Coll_B"].size()); + ASSERT_EQ("coll_b_2", res_obj["hits"][0]["document"]["Coll_B"]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_B"]["Coll_A"].size()); + ASSERT_EQ("coll_a_0", res_obj["hits"][0]["document"]["Coll_B"]["Coll_A"]["title"]); + + ASSERT_EQ(2, res_obj["hits"][1]["document"].size()); + ASSERT_EQ("coll_c_2", res_obj["hits"][1]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["Coll_B"].size()); + ASSERT_EQ("coll_b_1", res_obj["hits"][1]["document"]["Coll_B"]["title"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Coll_B"]["Coll_A"].size()); + ASSERT_EQ("coll_a_0", res_obj["hits"][1]["document"]["Coll_B"]["Coll_A"]["title"]); + + ASSERT_EQ(2, res_obj["hits"][2]["document"].size()); + ASSERT_EQ("coll_c_1", res_obj["hits"][2]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][2]["document"]["Coll_B"].size()); + ASSERT_EQ("coll_b_1", res_obj["hits"][2]["document"]["Coll_B"]["title"]); + ASSERT_EQ(1, res_obj["hits"][2]["document"]["Coll_B"]["Coll_A"].size()); + ASSERT_EQ("coll_a_0", res_obj["hits"][2]["document"]["Coll_B"]["Coll_A"]["title"]); + + schema_json = + R"({ + "name": "Coll_D", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "ref_coll_c", "type": "string[]", "reference": "Coll_C.id"} + ] + })"_json; + documents = { + R"({ + "title": "coll_d_0", + "ref_coll_c": [] + })"_json, + R"({ + "title": "coll_d_1", + "ref_coll_c": ["1", "3"] + })"_json, + R"({ + "title": "coll_d_2", + "ref_coll_c": ["2", "3"] + })"_json, + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + req_params = { + {"collection", "Coll_B"}, + {"q", "*"}, + {"filter_by", "$Coll_C($Coll_D(id: *))"}, + {"include_fields", "title, $Coll_C(title, $Coll_D(title:nest_array):nest_array)"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + // coll_b_2 <- coll_c_3 <- coll_d_1, coll_d_2 + // + // coll_c_1 <- coll_d_1 + // coll_b_1 < + // coll_c_2 <- coll_d_2 + // + // coll_b_0 <- coll_c_2 <- coll_d_2 + ASSERT_EQ(3, res_obj["found"].get()); + ASSERT_EQ(3, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("coll_b_2", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_C"].size()); + ASSERT_EQ("coll_c_3", res_obj["hits"][0]["document"]["Coll_C"][0]["title"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["Coll_C"][0]["Coll_D"].size()); + ASSERT_EQ("coll_d_1", res_obj["hits"][0]["document"]["Coll_C"][0]["Coll_D"][0]["title"]); + ASSERT_EQ("coll_d_2", res_obj["hits"][0]["document"]["Coll_C"][0]["Coll_D"][1]["title"]); + + ASSERT_EQ(2, res_obj["hits"][1]["document"].size()); + ASSERT_EQ("coll_b_1", res_obj["hits"][1]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["Coll_C"].size()); + ASSERT_EQ("coll_c_1", res_obj["hits"][1]["document"]["Coll_C"][0]["title"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Coll_C"][0]["Coll_D"].size()); + ASSERT_EQ("coll_d_1", res_obj["hits"][1]["document"]["Coll_C"][0]["Coll_D"][0]["title"]); + ASSERT_EQ("coll_c_2", res_obj["hits"][1]["document"]["Coll_C"][1]["title"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Coll_C"][1]["Coll_D"].size()); + ASSERT_EQ("coll_d_2", res_obj["hits"][1]["document"]["Coll_C"][1]["Coll_D"][0]["title"]); + + ASSERT_EQ(2, res_obj["hits"][2]["document"].size()); + ASSERT_EQ("coll_b_0", res_obj["hits"][2]["document"]["title"]); + ASSERT_EQ(1, res_obj["hits"][2]["document"]["Coll_C"].size()); + ASSERT_EQ("coll_c_2", res_obj["hits"][2]["document"]["Coll_C"][0]["title"]); + ASSERT_EQ(1, res_obj["hits"][2]["document"]["Coll_C"][0]["Coll_D"].size()); + ASSERT_EQ("coll_d_2", res_obj["hits"][2]["document"]["Coll_C"][0]["Coll_D"][0]["title"]); + + req_params = { + {"collection", "Coll_D"}, + {"q", "*"}, + {"filter_by", "$Coll_C($Coll_B(id: [0, 1]))"}, + {"include_fields", "title, $Coll_C(title, $Coll_B(title:nest_array):nest_array)"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + LOG(INFO) << res_obj.dump(); + // coll_d_2 -> coll_c_2 -> coll_b_0, coll_b_1 + // + // coll_d_1 -> coll_c_1 -> coll_b_1 + ASSERT_EQ(2, res_obj["found"].get()); + ASSERT_EQ(2, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("coll_d_2", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Coll_C"].size()); + ASSERT_EQ("coll_c_2", res_obj["hits"][0]["document"]["Coll_C"][0]["title"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["Coll_C"][0]["Coll_B"].size()); + ASSERT_EQ("coll_b_0", res_obj["hits"][0]["document"]["Coll_C"][0]["Coll_B"][0]["title"]); + ASSERT_EQ("coll_b_1", res_obj["hits"][0]["document"]["Coll_C"][0]["Coll_B"][1]["title"]); + + ASSERT_EQ(2, res_obj["hits"][1]["document"].size()); + ASSERT_EQ("coll_d_1", res_obj["hits"][1]["document"]["title"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Coll_C"].size()); + ASSERT_EQ("coll_c_1", res_obj["hits"][1]["document"]["Coll_C"][0]["title"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Coll_C"][0]["Coll_B"].size()); + ASSERT_EQ("coll_b_1", res_obj["hits"][1]["document"]["Coll_C"][0]["Coll_B"][0]["title"]); + + schema_json = + R"({ + "name": "products", + "fields": [ + {"name": "title", "type": "string"} + ] + })"_json; + documents = { + R"({ + "title": "shampoo" + })"_json, + R"({ + "title": "soap" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "product_variants", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "product_id", "type": "string", "reference": "products.id"} + ] + })"_json; + documents = { + R"({ + "title": "panteen", + "product_id": "0" + })"_json, + R"({ + "title": "loreal", + "product_id": "0" + })"_json, + R"({ + "title": "pears", + "product_id": "1" + })"_json, + R"({ + "title": "lifebuoy", + "product_id": "1" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "retailers", + "fields": [ + {"name": "title", "type": "string"}, + {"name": "location", "type": "geopoint"} + ] + })"_json; + documents = { + R"({ + "title": "retailer 1", + "location": [48.872576479306765, 2.332291112241466] + })"_json, + R"({ + "title": "retailer 2", + "location": [48.888286721920934, 2.342340862419206] + })"_json, + R"({ + "title": "retailer 3", + "location": [48.87538726829884, 2.296113163780903] + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "inventory", + "fields": [ + {"name": "qty", "type": "int32"}, + {"name": "retailer_id", "type": "string", "reference": "retailers.id"}, + {"name": "product_variant_id", "type": "string", "reference": "product_variants.id"} + ] + })"_json; + documents = { + R"({ + "qty": "1", + "retailer_id": "0", + "product_variant_id": "0" + })"_json, + R"({ + "qty": "2", + "retailer_id": "0", + "product_variant_id": "1" + })"_json, + R"({ + "qty": "3", + "retailer_id": "0", + "product_variant_id": "2" + })"_json, + R"({ + "qty": "4", + "retailer_id": "0", + "product_variant_id": "3" + })"_json, + R"({ + "qty": "5", + "retailer_id": "1", + "product_variant_id": "0" + })"_json, + R"({ + "qty": "6", + "retailer_id": "1", + "product_variant_id": "1" + })"_json, + R"({ + "qty": "7", + "retailer_id": "1", + "product_variant_id": "2" + })"_json, + R"({ + "qty": "8", + "retailer_id": "1", + "product_variant_id": "3" + })"_json, + R"({ + "qty": "9", + "retailer_id": "2", + "product_variant_id": "0" + })"_json, + R"({ + "qty": "10", + "retailer_id": "2", + "product_variant_id": "1" + })"_json, + R"({ + "qty": "11", + "retailer_id": "2", + "product_variant_id": "2" + })"_json, + R"({ + "qty": "12", + "retailer_id": "2", + "product_variant_id": "3" + })"_json, + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + req_params = { + {"collection", "products"}, + {"q", "*"}, + {"filter_by", "$product_variants($inventory($retailers(location:(48.87538726829884, 2.296113163780903,1 km))))"}, + {"include_fields", "$product_variants(id,$inventory(qty,sku,$retailers(id,title)))"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(2, res_obj["found"].get()); + ASSERT_EQ(2, res_obj["hits"].size()); + ASSERT_EQ("1", res_obj["hits"][0]["document"]["id"]); + ASSERT_EQ("soap", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["product_variants"].size()); + + ASSERT_EQ("2", res_obj["hits"][0]["document"]["product_variants"][0]["id"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["product_variants"][0]["inventory"].size()); + ASSERT_EQ(11, res_obj["hits"][0]["document"]["product_variants"][0]["inventory"]["qty"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["product_variants"][0]["inventory"]["retailers"].size()); + ASSERT_EQ("2", res_obj["hits"][0]["document"]["product_variants"][0]["inventory"]["retailers"]["id"]); + ASSERT_EQ("retailer 3", res_obj["hits"][0]["document"]["product_variants"][0]["inventory"]["retailers"]["title"]); + + ASSERT_EQ("3", res_obj["hits"][0]["document"]["product_variants"][1]["id"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["product_variants"][1]["inventory"].size()); + ASSERT_EQ(12, res_obj["hits"][0]["document"]["product_variants"][1]["inventory"]["qty"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["product_variants"][1]["inventory"]["retailers"].size()); + ASSERT_EQ("2", res_obj["hits"][0]["document"]["product_variants"][1]["inventory"]["retailers"]["id"]); + ASSERT_EQ("retailer 3", res_obj["hits"][0]["document"]["product_variants"][1]["inventory"]["retailers"]["title"]); + + ASSERT_EQ("0", res_obj["hits"][1]["document"]["id"]); + ASSERT_EQ("shampoo", res_obj["hits"][1]["document"]["title"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["product_variants"].size()); + + ASSERT_EQ("0", res_obj["hits"][1]["document"]["product_variants"][0]["id"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["product_variants"][0]["inventory"].size()); + ASSERT_EQ(9, res_obj["hits"][1]["document"]["product_variants"][0]["inventory"]["qty"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["product_variants"][0]["inventory"]["retailers"].size()); + ASSERT_EQ("2", res_obj["hits"][1]["document"]["product_variants"][0]["inventory"]["retailers"]["id"]); + ASSERT_EQ("retailer 3", res_obj["hits"][1]["document"]["product_variants"][0]["inventory"]["retailers"]["title"]); + + ASSERT_EQ("1", res_obj["hits"][1]["document"]["product_variants"][1]["id"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["product_variants"][1]["inventory"].size()); + ASSERT_EQ(10, res_obj["hits"][1]["document"]["product_variants"][1]["inventory"]["qty"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["product_variants"][1]["inventory"]["retailers"].size()); + ASSERT_EQ("2", res_obj["hits"][1]["document"]["product_variants"][1]["inventory"]["retailers"]["id"]); + ASSERT_EQ("retailer 3", res_obj["hits"][1]["document"]["product_variants"][1]["inventory"]["retailers"]["title"]); + + req_params = { + {"collection", "products"}, + {"q", "*"}, + {"filter_by", "$product_variants($inventory($retailers(id: [0, 1]) && qty: [4..5]))"}, + {"include_fields", "$product_variants(id,$inventory(qty,sku,$retailers(id,title)))"}, + {"exclude_fields", "$product_variants($inventory($retailers(id)))"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(2, res_obj["found"].get()); + ASSERT_EQ(2, res_obj["hits"].size()); + ASSERT_EQ("1", res_obj["hits"][0]["document"]["id"]); + ASSERT_EQ("soap", res_obj["hits"][0]["document"]["title"]); + ASSERT_EQ("3", res_obj["hits"][0]["document"]["product_variants"]["id"]); + ASSERT_EQ(2, res_obj["hits"][0]["document"]["product_variants"]["inventory"].size()); + ASSERT_EQ(4, res_obj["hits"][0]["document"]["product_variants"]["inventory"]["qty"]); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["product_variants"]["inventory"]["retailers"].size()); + ASSERT_EQ("retailer 1", res_obj["hits"][0]["document"]["product_variants"]["inventory"]["retailers"]["title"]); + + ASSERT_EQ("0", res_obj["hits"][1]["document"]["id"]); + ASSERT_EQ("shampoo", res_obj["hits"][1]["document"]["title"]); + ASSERT_EQ("0", res_obj["hits"][1]["document"]["product_variants"]["id"]); + ASSERT_EQ(2, res_obj["hits"][1]["document"]["product_variants"]["inventory"].size()); + ASSERT_EQ(5, res_obj["hits"][1]["document"]["product_variants"]["inventory"]["qty"]); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["product_variants"]["inventory"]["retailers"].size()); + ASSERT_EQ("retailer 2", res_obj["hits"][1]["document"]["product_variants"]["inventory"]["retailers"]["title"]); +} + TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { auto schema_json = R"({ @@ -1926,12 +2479,14 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); ASSERT_FALSE(search_op.ok()); - ASSERT_EQ("Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`.", search_op.error()); + ASSERT_EQ("Invalid reference `$foo.bar` in include_fields/exclude_fields, expected `$CollectionName(fieldA, ...)`.", + search_op.error()); req_params["include_fields"] = "$foo(bar"; search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); ASSERT_FALSE(search_op.ok()); - ASSERT_EQ("Invalid reference in include_fields, expected `$CollectionName(fieldA, ...)`.", search_op.error()); + ASSERT_EQ("Invalid reference `$foo(bar` in include_fields/exclude_fields, expected `$CollectionName(fieldA, ...)`.", + search_op.error()); req_params["include_fields"] = "$foo(bar)"; search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 4a7da68a..29373f2a 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1460,156 +1460,241 @@ TEST_F(CollectionManagerTest, GetReferenceCollectionNames) { ref_includes = nullptr; } -TEST_F(CollectionManagerTest, InitializeRefIncludeFields) { +TEST_F(CollectionManagerTest, InitializeRefIncludeExcludeFields) { std::string filter_query = ""; - std::vector include_fields_vec; - std::vector ref_include_fields_vec; - auto initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + std::vector include_fields_vec, exclude_fields_vec; + std::vector ref_include_exclude_fields_vec; + auto initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_TRUE(ref_include_fields_vec.empty()); + ASSERT_TRUE(ref_include_exclude_fields_vec.empty()); filter_query = "$foo(bar:baz)"; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + exclude_fields_vec = {"$foo(bar)"}; + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("foo", ref_include_fields_vec[0].collection_name); - ASSERT_TRUE(ref_include_fields_vec[0].fields.empty()); - ASSERT_TRUE(ref_include_fields_vec[0].alias.empty()); - ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); - ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); - ref_include_fields_vec.clear(); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("foo", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].include_fields.empty()); + ASSERT_EQ("bar", ref_include_exclude_fields_vec[0].exclude_fields); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].nested_join_includes.empty()); + ref_include_exclude_fields_vec.clear(); + exclude_fields_vec.clear(); filter_query = ""; include_fields_vec = {"$Customers(product_price: foo) as customers"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_FALSE(initialize_op.ok()); ASSERT_EQ("Error parsing `$Customers(product_price: foo) as customers`: Unknown include strategy `foo`. " "Valid options are `merge`, `nest`, `nest_array`.", initialize_op.error()); filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; include_fields_vec = {"$Customers(product_price: merge) as customers"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); - ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); - ASSERT_EQ("customers.", ref_include_fields_vec[0].alias); - ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); - ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); - ref_include_fields_vec.clear(); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("Customers", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("product_price", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_EQ("customers.", ref_include_exclude_fields_vec[0].alias); + ASSERT_EQ(ref_include::merge, ref_include_exclude_fields_vec[0].strategy); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].nested_join_includes.empty()); + ref_include_exclude_fields_vec.clear(); filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; include_fields_vec = {"$Customers(product_price: nest_array) as customers"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); - ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); - ASSERT_EQ("customers", ref_include_fields_vec[0].alias); - ASSERT_EQ(ref_include::nest_array, ref_include_fields_vec[0].strategy); - ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); - ref_include_fields_vec.clear(); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("Customers", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("product_price", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_EQ("customers", ref_include_exclude_fields_vec[0].alias); + ASSERT_EQ(ref_include::nest_array, ref_include_exclude_fields_vec[0].strategy); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].nested_join_includes.empty()); + ref_include_exclude_fields_vec.clear(); filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; include_fields_vec = {"$product_variants(id,$inventory(qty,sku,$retailers(id,title)))"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); - ASSERT_EQ("id,", ref_include_fields_vec[0].fields); - ASSERT_TRUE(ref_include_fields_vec[0].alias.empty()); - ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("id,", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); - auto nested_join_includes = ref_include_fields_vec[0].nested_join_includes; - ASSERT_EQ("inventory", nested_join_includes[0].collection_name); - ASSERT_EQ("qty,sku,", nested_join_includes[0].fields); - ASSERT_TRUE(nested_join_includes[0].alias.empty()); - ASSERT_EQ(ref_include::nest, nested_join_includes[0].strategy); + auto nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_include_excludes[0].collection_name); + ASSERT_EQ("qty,sku,", nested_include_excludes[0].include_fields); + ASSERT_TRUE(nested_include_excludes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, nested_include_excludes[0].strategy); - nested_join_includes = ref_include_fields_vec[0].nested_join_includes[0].nested_join_includes; - ASSERT_EQ("retailers", nested_join_includes[0].collection_name); - ASSERT_EQ("id,title", nested_join_includes[0].fields); - ASSERT_TRUE(nested_join_includes[0].alias.empty()); - ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); - ref_include_fields_vec.clear(); + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes[0].nested_join_includes; + ASSERT_EQ("retailers", nested_include_excludes[0].collection_name); + ASSERT_EQ("id,title", nested_include_excludes[0].include_fields); + ASSERT_TRUE(nested_include_excludes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + ref_include_exclude_fields_vec.clear(); filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory: nest) as variants"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); - ASSERT_EQ("title,", ref_include_fields_vec[0].fields); - ASSERT_EQ("variants", ref_include_fields_vec[0].alias); - ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("title,", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_EQ("variants", ref_include_exclude_fields_vec[0].alias); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); - nested_join_includes = ref_include_fields_vec[0].nested_join_includes; - ASSERT_EQ("inventory", nested_join_includes[0].collection_name); - ASSERT_EQ("qty", nested_join_includes[0].fields); - ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_include_excludes[0].collection_name); + ASSERT_EQ("qty", nested_include_excludes[0].include_fields); + ASSERT_EQ("inventory.", nested_include_excludes[0].alias); + ASSERT_EQ(ref_include::merge, nested_include_excludes[0].strategy); - nested_join_includes = ref_include_fields_vec[0].nested_join_includes[0].nested_join_includes; - ASSERT_EQ("retailers", nested_join_includes[0].collection_name); - ASSERT_TRUE(nested_join_includes[0].fields.empty()); - ASSERT_TRUE(nested_join_includes[0].alias.empty()); - ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); - ref_include_fields_vec.clear(); + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes[0].nested_join_includes; + ASSERT_EQ("retailers", nested_include_excludes[0].collection_name); + ASSERT_TRUE(nested_include_excludes[0].include_fields.empty()); + ASSERT_TRUE(nested_include_excludes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + ref_include_exclude_fields_vec.clear(); filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory," " $retailers(title): merge) as variants"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); - ASSERT_EQ("title,", ref_include_fields_vec[0].fields); - ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); - ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("title,", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_EQ("variants.", ref_include_exclude_fields_vec[0].alias); + ASSERT_EQ(ref_include::merge, ref_include_exclude_fields_vec[0].strategy); - nested_join_includes = ref_include_fields_vec[0].nested_join_includes; - ASSERT_EQ("inventory", nested_join_includes[0].collection_name); - ASSERT_EQ("qty", nested_join_includes[0].fields); - ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_include_excludes[0].collection_name); + ASSERT_EQ("qty", nested_include_excludes[0].include_fields); + ASSERT_EQ("inventory.", nested_include_excludes[0].alias); + ASSERT_EQ(ref_include::merge, nested_include_excludes[0].strategy); - ASSERT_EQ("retailers", nested_join_includes[1].collection_name); - ASSERT_EQ("title", nested_join_includes[1].fields); - ASSERT_TRUE(nested_join_includes[1].alias.empty()); - ASSERT_EQ(ref_include::nest, nested_join_includes[1].strategy); - ref_include_fields_vec.clear(); + ASSERT_EQ("retailers", nested_include_excludes[1].collection_name); + ASSERT_EQ("title", nested_include_excludes[1].include_fields); + ASSERT_TRUE(nested_include_excludes[1].alias.empty()); + ASSERT_EQ(ref_include::nest, nested_include_excludes[1].strategy); + ref_include_exclude_fields_vec.clear(); filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory, description," " $retailers(title), foo: merge) as variants"}; - initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, - ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); ASSERT_TRUE(initialize_op.ok()); - ASSERT_EQ(1, ref_include_fields_vec.size()); - ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); - ASSERT_EQ("title, description, foo", ref_include_fields_vec[0].fields); - ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); - ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("title, description, foo", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_EQ("variants.", ref_include_exclude_fields_vec[0].alias); + ASSERT_EQ(ref_include::merge, ref_include_exclude_fields_vec[0].strategy); - nested_join_includes = ref_include_fields_vec[0].nested_join_includes; - ASSERT_EQ("inventory", nested_join_includes[0].collection_name); - ASSERT_EQ("qty", nested_join_includes[0].fields); - ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_include_excludes[0].collection_name); + ASSERT_EQ("qty", nested_include_excludes[0].include_fields); + ASSERT_EQ("inventory.", nested_include_excludes[0].alias); + ASSERT_EQ(ref_include::merge, nested_include_excludes[0].strategy); - ASSERT_EQ("retailers", nested_join_includes[1].collection_name); - ASSERT_EQ("title", nested_join_includes[1].fields); - ASSERT_TRUE(nested_join_includes[1].alias.empty()); - ASSERT_EQ(ref_include::nest, nested_join_includes[1].strategy); - ref_include_fields_vec.clear(); + ASSERT_EQ("retailers", nested_include_excludes[1].collection_name); + ASSERT_EQ("title", nested_include_excludes[1].include_fields); + ASSERT_TRUE(nested_include_excludes[1].alias.empty()); + ASSERT_EQ(ref_include::nest, nested_include_excludes[1].strategy); + ref_include_exclude_fields_vec.clear(); + + filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; + include_fields_vec.clear(); + exclude_fields_vec = {"$Customers(product_price)"}; + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); + ASSERT_TRUE(initialize_op.ok()); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("Customers", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].include_fields.empty()); + ASSERT_EQ("product_price", ref_include_exclude_fields_vec[0].exclude_fields); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].nested_join_includes.empty()); + ref_include_exclude_fields_vec.clear(); + + filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; + include_fields_vec.clear(); + exclude_fields_vec = {"$product_variants(title, $inventory(qty), description, $retailers(title), foo)"}; + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); + ASSERT_TRUE(initialize_op.ok()); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].include_fields.empty()); + ASSERT_EQ("title, description, foo", ref_include_exclude_fields_vec[0].exclude_fields); + ASSERT_TRUE(ref_include_exclude_fields_vec[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_include_excludes[0].collection_name); + ASSERT_TRUE(nested_include_excludes[0].include_fields.empty()); + ASSERT_EQ("qty", nested_include_excludes[0].exclude_fields); + ASSERT_TRUE(nested_include_excludes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, nested_include_excludes[0].strategy); + + ASSERT_EQ("retailers", nested_include_excludes[1].collection_name); + ASSERT_TRUE(nested_include_excludes[1].include_fields.empty()); + ASSERT_EQ("title", nested_include_excludes[1].exclude_fields); + ASSERT_TRUE(nested_include_excludes[1].alias.empty()); + ASSERT_EQ(ref_include::nest, nested_include_excludes[1].strategy); + ref_include_exclude_fields_vec.clear(); + + filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; + include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory: nest) as variants"}; + exclude_fields_vec = {"$product_variants(title, $inventory(qty, $retailers(title)))"}; + initialize_op = CollectionManager::_initialize_ref_include_exclude_fields_vec(filter_query, include_fields_vec, + exclude_fields_vec, + ref_include_exclude_fields_vec); + ASSERT_TRUE(initialize_op.ok()); + ASSERT_EQ(1, ref_include_exclude_fields_vec.size()); + ASSERT_EQ("product_variants", ref_include_exclude_fields_vec[0].collection_name); + ASSERT_EQ("title,", ref_include_exclude_fields_vec[0].include_fields); + ASSERT_EQ("title,", ref_include_exclude_fields_vec[0].exclude_fields); + ASSERT_EQ("variants", ref_include_exclude_fields_vec[0].alias); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes; + ASSERT_EQ("inventory", nested_include_excludes[0].collection_name); + ASSERT_EQ("qty", nested_include_excludes[0].include_fields); + ASSERT_EQ("qty,", nested_include_excludes[0].exclude_fields); + ASSERT_EQ("inventory.", nested_include_excludes[0].alias); + ASSERT_EQ(ref_include::merge, nested_include_excludes[0].strategy); + + nested_include_excludes = ref_include_exclude_fields_vec[0].nested_join_includes[0].nested_join_includes; + ASSERT_EQ("retailers", nested_include_excludes[0].collection_name); + ASSERT_TRUE(nested_include_excludes[0].include_fields.empty()); + ASSERT_EQ("title", nested_include_excludes[0].exclude_fields); + ASSERT_TRUE(nested_include_excludes[0].alias.empty()); + ASSERT_EQ(ref_include::nest, ref_include_exclude_fields_vec[0].strategy); + ref_include_exclude_fields_vec.clear(); } TEST_F(CollectionManagerTest, ReferencedInBacklog) { diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index 44e39bb9..1093343b 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -398,9 +398,9 @@ TEST(StringUtilsTest, TokenizeFilterQuery) { tokenizeTestHelper(filter_query, tokenList); } -void splitIncludeTestHelper(const std::string& include_fields, const std::vector& expected) { +void splitIncludeExcludeTestHelper(const std::string& include_exclude_fields, const std::vector& expected) { std::vector output; - auto tokenize_op = StringUtils::split_include_fields(include_fields, output); + auto tokenize_op = StringUtils::split_include_exclude_fields(include_exclude_fields, output); ASSERT_TRUE(tokenize_op.ok()); ASSERT_EQ(expected.size(), output.size()); for (auto i = 0; i < output.size(); i++) { @@ -408,51 +408,73 @@ void splitIncludeTestHelper(const std::string& include_fields, const std::vector } } -TEST(StringUtilsTest, SplitIncludeFields) { +TEST(StringUtilsTest, SplitIncludeExcludeFields) { std::string include_fields; std::vector tokens; include_fields = " id, title , count "; tokens = {"id", "title", "count"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); include_fields = "id, $Collection(title, pref*),count"; tokens = {"id", "$Collection(title, pref*)", "count"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); include_fields = "id, $Collection(title, pref*), count, "; tokens = {"id", "$Collection(title, pref*)", "count"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); include_fields = "$Collection(title, pref*) as coll"; tokens = {"$Collection(title, pref*) as coll"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); include_fields = "id, $Collection(title, pref*) as coll , count, "; tokens = {"id", "$Collection(title, pref*) as coll", "count"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); include_fields = "$Collection(title, pref*: merge) as coll"; tokens = {"$Collection(title, pref*: merge) as coll"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); include_fields = "$product_variants(id,$inventory(qty,sku,$retailer(id,title: merge) as retailer_info)) as variants"; tokens = {"$product_variants(id,$inventory(qty,sku,$retailer(id,title: merge) as retailer_info)) as variants"}; - splitIncludeTestHelper(include_fields, tokens); + splitIncludeExcludeTestHelper(include_fields, tokens); + + std::string exclude_fields = " id, title, $Collection(title), count,"; + tokens = {"id", "title", "$Collection(title)", "count"}; + splitIncludeExcludeTestHelper(exclude_fields, tokens); + + exclude_fields = " id, title , count, $Collection(title), $product_variants(id,$inventory(qty,sku,$retailer(id,title)))"; + tokens = {"id", "title", "count", "$Collection(title)", "$product_variants(id,$inventory(qty,sku,$retailer(id,title)))"}; + splitIncludeExcludeTestHelper(exclude_fields, tokens); } -TEST(StringUtilsTest, SplitReferenceIncludeFields) { +TEST(StringUtilsTest, SplitReferenceIncludeExcludeFields) { std::string include_fields = "$retailer(id,title: merge) as retailer_info:merge) as variants, foo", token; size_t index = 0; - auto tokenize_op = StringUtils::split_reference_include_fields(include_fields, index, token); + auto tokenize_op = StringUtils::split_reference_include_exclude_fields(include_fields, index, token); ASSERT_TRUE(tokenize_op.ok()); ASSERT_EQ("$retailer(id,title: merge) as retailer_info", token); ASSERT_EQ(":merge) as variants, foo", include_fields.substr(index)); include_fields = "$inventory(qty,sku,$retailer(id,title: merge) as retailer_info) as inventory) as variants, foo"; index = 0; - tokenize_op = StringUtils::split_reference_include_fields(include_fields, index, token); + tokenize_op = StringUtils::split_reference_include_exclude_fields(include_fields, index, token); ASSERT_TRUE(tokenize_op.ok()); ASSERT_EQ("$inventory(qty,sku,$retailer(id,title: merge) as retailer_info) as inventory", token); ASSERT_EQ(") as variants, foo", include_fields.substr(index)); + + std::string exclude_fields = "$Collection(title), $product_variants(id,$inventory(qty,sku,$retailer(id,title)))"; + index = 0; + tokenize_op = StringUtils::split_reference_include_exclude_fields(exclude_fields, index, token); + ASSERT_TRUE(tokenize_op.ok()); + ASSERT_EQ("$Collection(title)", token); + ASSERT_EQ(", $product_variants(id,$inventory(qty,sku,$retailer(id,title)))", exclude_fields.substr(index)); + + exclude_fields = "$inventory(qty,sku,$retailer(id,title)), foo)"; + index = 0; + tokenize_op = StringUtils::split_reference_include_exclude_fields(exclude_fields, index, token); + ASSERT_TRUE(tokenize_op.ok()); + ASSERT_EQ("$inventory(qty,sku,$retailer(id,title))", token); + ASSERT_EQ(", foo)", exclude_fields.substr(index)); }