diff --git a/include/collection.h b/include/collection.h index bb0f581a..dfaa4f38 100644 --- a/include/collection.h +++ b/include/collection.h @@ -446,7 +446,7 @@ public: const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, const std::string& error_prefix, const bool& is_reference_array, - const bool& nest_ref_doc); + const ref_include::strategy_enum& strategy); static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", diff --git a/include/collection_manager.h b/include/collection_manager.h index 874b8055..889391f5 100644 --- a/include/collection_manager.h +++ b/include/collection_manager.h @@ -222,9 +222,9 @@ public: ref_include_collection_names_t*& reference_collection_names); // Separate out the reference includes into `ref_include_fields_vec`. - static void _initialize_ref_include_fields_vec(const std::string& filter_query, - std::vector& include_fields_vec, - std::vector& ref_include_fields_vec); + static Option _initialize_ref_include_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& ref_include_fields_vec); void add_referenced_in_backlog(const std::string& collection_name, reference_pair&& pair); diff --git a/include/field.h b/include/field.h index d563404d..b06e07c3 100644 --- a/include/field.h +++ b/include/field.h @@ -510,18 +510,34 @@ namespace sort_field_const { } namespace ref_include { - static const std::string merge = "merge"; - static const std::string nest = "nest"; + static const std::string merge_string = "merge"; + static const std::string nest_string = "nest"; + static const std::string nest_array_string = "nest_array"; + + enum strategy_enum {merge = 0, nest, nest_array}; + + static Option string_to_enum(const std::string& strategy) { + if (strategy == merge_string) { + return Option(merge); + } else if (strategy == nest_string) { + return Option(nest); + } else if (strategy == nest_array_string) { + return Option(nest_array); + } + + return Option(400, "Unknown include strategy `" + strategy + "`. " + "Valid options are `merge`, `nest`, `nest_array`."); + } } struct ref_include_fields { std::string collection_name; std::string fields; std::string alias; - bool nest_ref_doc = true; + ref_include::strategy_enum strategy = ref_include::nest; // In case we have nested join. - std::vector nested_join_includes; + std::vector nested_join_includes = {}; }; struct hnsw_index_t; diff --git a/src/collection.cpp b/src/collection.cpp index e2564a40..83131a6b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -4870,9 +4870,9 @@ Option Collection::include_references(nlohmann::json& doc, const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, const std::string& error_prefix, const bool& is_reference_array, - const bool& nest_ref_doc) { + const ref_include::strategy_enum& strategy) { // One-to-one relation. - if (!is_reference_array && references.count == 1) { + if (strategy != ref_include::nest_array && !is_reference_array && references.count == 1) { auto ref_doc_seq_id = references.docs[0]; nlohmann::json ref_doc; @@ -4893,7 +4893,7 @@ Option Collection::include_references(nlohmann::json& doc, return Option(true); } - if (nest_ref_doc) { + if (strategy == ref_include::nest) { auto key = alias.empty() ? ref_collection_name : alias; doc[key] = ref_doc; } else { @@ -4931,7 +4931,7 @@ Option Collection::include_references(nlohmann::json& doc, continue; } - if (nest_ref_doc) { + if (strategy == ref_include::nest || strategy == ref_include::nest_array) { auto key = alias.empty() ? ref_collection_name : alias; if (doc.contains(key) && !doc[key].is_array()) { return Option(400, "Could not include the reference document of `" + ref_collection_name + @@ -5116,7 +5116,7 @@ Option Collection::prune_doc(nlohmann::json& doc, reference_filter_results.at(ref_collection_name), ref_include_fields_full, ref_exclude_fields_full, error_prefix, ref_collection->get_schema().at(field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); } else if (doc_has_reference) { auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection->name); if (!get_reference_field_op.ok()) { @@ -5151,7 +5151,7 @@ Option Collection::prune_doc(nlohmann::json& doc, include_references_op = include_references(doc[keys[0]][i], ref_include.collection_name, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, - false, ref_include.nest_ref_doc); + false, ref_include.strategy); if (!include_references_op.ok()) { return include_references_op; } @@ -5167,7 +5167,7 @@ Option Collection::prune_doc(nlohmann::json& doc, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, collection->search_schema.at(field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); result.docs = nullptr; } } else { @@ -5181,7 +5181,7 @@ Option Collection::prune_doc(nlohmann::json& doc, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, collection->search_schema.at(field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); result.docs = nullptr; } } else if (joined_coll_has_reference) { @@ -5217,7 +5217,7 @@ Option Collection::prune_doc(nlohmann::json& doc, ref_collection.get(), ref_include.alias, result, ref_include_fields_full, ref_exclude_fields_full, error_prefix, joined_collection->get_schema().at(reference_field_name).is_array(), - ref_include.nest_ref_doc); + ref_include.strategy); result.docs = nullptr; } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 7f535745..8151c095 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1029,10 +1029,17 @@ Option parse_nested_include(const std::string& include_field_exp, } // ... $inventory(qty:merge) as inventory ... + auto include_strategy = ref_include::nest_string; + auto strategy_enum = ref_include::nest; if (colon_pos < closing_parenthesis_pos) { // Merge strategy is specified. - auto include_strategy = include_field_exp.substr(colon_pos + 1, closing_parenthesis_pos - colon_pos - 1); + include_strategy = include_field_exp.substr(colon_pos + 1, closing_parenthesis_pos - colon_pos - 1); StringUtils::trim(include_strategy); - nest_ref_doc = include_strategy == ref_include::nest; + + auto string_to_enum_op = ref_include::string_to_enum(include_strategy); + if (!string_to_enum_op.ok()) { + return Option(400, "Error parsing `" + include_field_exp + "`: " + string_to_enum_op.error()); + } + strategy_enum = string_to_enum_op.get(); if (index < colon_pos) { ref_fields += include_field_exp.substr(index, colon_pos - index); @@ -1051,9 +1058,11 @@ Option parse_nested_include(const std::string& include_field_exp, // For an alias `foo`, // In case of "merge" reference doc, we need append `foo.` to all the top level keys of reference doc. // In case of "nest" reference doc, `foo` becomes the key with reference doc as value. + nest_ref_doc = strategy_enum == ref_include::nest || strategy_enum == ref_include::nest_array; ref_alias = !ref_alias.empty() ? (StringUtils::trim(ref_alias) + (nest_ref_doc ? "" : ".")) : ""; - ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, nest_ref_doc}); + ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, + strategy_enum}); ref_include_fields_vec.back().nested_join_includes = std::move(nested_ref_include_fields_vec); // Referenced collection in filter_by is already mentioned in ref_include_fields. @@ -1069,16 +1078,16 @@ Option parse_nested_include(const std::string& include_field_exp, return Option(true); } -void CollectionManager::_initialize_ref_include_fields_vec(const std::string& filter_query, - std::vector& include_fields_vec, - std::vector& ref_include_fields_vec) { +Option CollectionManager::_initialize_ref_include_fields_vec(const std::string& filter_query, + std::vector& include_fields_vec, + std::vector& ref_include_fields_vec) { ref_include_collection_names_t* ref_include_coll_names = nullptr; CollectionManager::_get_reference_collection_names(filter_query, ref_include_coll_names); std::unique_ptr guard(ref_include_coll_names); std::vector result_include_fields_vec; auto wildcard_include_all = true; - for (auto include_field_exp: include_fields_vec) { + for (auto const& include_field_exp: include_fields_vec) { if (include_field_exp[0] != '$') { if (include_field_exp == "*") { continue; @@ -1092,6 +1101,9 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi // Nested reference include. if (include_field_exp.find('$', 1) != std::string::npos) { auto parse_op = parse_nested_include(include_field_exp, ref_include_coll_names, ref_include_fields_vec); + if (!parse_op.ok()) { + return parse_op; + } continue; } @@ -1105,20 +1117,29 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi auto ref_collection_name = ref_include.substr(1, parenthesis_index - 1); auto ref_fields = ref_include.substr(parenthesis_index + 1, ref_include.size() - parenthesis_index - 2); - auto nest_ref_doc = true; + auto include_strategy = ref_include::nest_string; + auto strategy_enum = ref_include::nest; auto colon_pos = ref_fields.find(':'); if (colon_pos != std::string::npos) { - auto include_strategy = ref_fields.substr(colon_pos + 1, ref_fields.size() - colon_pos - 1); + include_strategy = ref_fields.substr(colon_pos + 1, ref_fields.size() - colon_pos - 1); StringUtils::trim(include_strategy); - nest_ref_doc = include_strategy == ref_include::nest; + + auto string_to_enum_op = ref_include::string_to_enum(include_strategy); + if (!string_to_enum_op.ok()) { + return Option(400, "Error parsing `" + include_field_exp + "`: " + string_to_enum_op.error()); + } + strategy_enum = string_to_enum_op.get(); + ref_fields = ref_fields.substr(0, colon_pos); } // For an alias `foo`, // In case of "merge" reference doc, we need append `foo.` to all the top level keys of reference doc. // In case of "nest" reference doc, `foo` becomes the key with reference doc as value. + auto const& nest_ref_doc = strategy_enum == ref_include::nest || strategy_enum == ref_include::nest_array; auto ref_alias = !alias.empty() ? (StringUtils::trim(alias) + (nest_ref_doc ? "" : ".")) : ""; - ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, nest_ref_doc}); + ref_include_fields_vec.emplace_back(ref_include_fields{ref_collection_name, ref_fields, ref_alias, + strategy_enum}); auto open_paren_pos = include_field_exp.find('('); if (open_paren_pos == std::string::npos) { @@ -1141,7 +1162,7 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi auto ref_includes = std::ref(ref_include_fields_vec); while (ref_include_coll_names != nullptr) { for (const auto &reference_collection_name: ref_include_coll_names->collection_names) { - ref_includes.get().emplace_back(ref_include_fields{reference_collection_name, "", "", true}); + ref_includes.get().emplace_back(ref_include_fields{reference_collection_name, "", "", ref_include::nest}); } ref_include_coll_names = ref_include_coll_names->nested_include; @@ -1157,6 +1178,8 @@ void CollectionManager::_initialize_ref_include_fields_vec(const std::string& fi } include_fields_vec = std::move(result_include_fields_vec); + + return Option(true); } Option CollectionManager::do_search(std::map& req_params, @@ -1542,7 +1565,10 @@ Option CollectionManager::do_search(std::map& re per_page = 0; } - _initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + auto initialize_op = _initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + if (!initialize_op.ok()) { + return initialize_op; + } include_fields.insert(include_fields_vec.begin(), include_fields_vec.end()); exclude_fields.insert(exclude_fields_vec.begin(), exclude_fields_vec.end()); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index a7df7d27..f254a00a 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -1965,6 +1965,34 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"].count("product_id")); ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"].count("product_price")); + req_params = { + {"collection", "Products"}, + {"q", "*"}, + {"query_by", "product_name"}, + {"filter_by", "$Customers(customer_id:=customer_a && product_price:<100)"}, + {"include_fields", "*, $Customers(*:nest_array) as Customers"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + // No fields are mentioned in `include_fields`, should include all fields of Products and Customers by default. + ASSERT_EQ(7, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_description")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("embedding")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("rating")); + // In nest_array strategy we return the referenced docs in an array. + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("customer_id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("customer_name")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("product_id")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Customers"][0].count("product_price")); + req_params = { {"collection", "Products"}, {"q", "*"}, diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index e03c4e88..a4178dc2 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1464,94 +1464,127 @@ TEST_F(CollectionManagerTest, InitializeRefIncludeFields) { std::string filter_query = ""; std::vector include_fields_vec; std::vector ref_include_fields_vec; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + auto initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_TRUE(ref_include_fields_vec.empty()); filter_query = "$foo(bar:baz)"; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("foo", ref_include_fields_vec[0].collection_name); ASSERT_TRUE(ref_include_fields_vec[0].fields.empty()); ASSERT_TRUE(ref_include_fields_vec[0].alias.empty()); - ASSERT_TRUE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); ref_include_fields_vec.clear(); + filter_query = ""; + include_fields_vec = {"$Customers(product_price: foo) as customers"}; + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_FALSE(initialize_op.ok()); + ASSERT_EQ("Error parsing `$Customers(product_price: foo) as customers`: Unknown include strategy `foo`. " + "Valid options are `merge`, `nest`, `nest_array`.", initialize_op.error()); + filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; include_fields_vec = {"$Customers(product_price: merge) as customers"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); ASSERT_EQ("customers.", ref_include_fields_vec[0].alias); - ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); + ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); + ref_include_fields_vec.clear(); + + filter_query = "$Customers(customer_id:=customer_a && (product_price:>100 && product_price:<200))"; + include_fields_vec = {"$Customers(product_price: nest_array) as customers"}; + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); + ASSERT_EQ(1, ref_include_fields_vec.size()); + ASSERT_EQ("Customers", ref_include_fields_vec[0].collection_name); + ASSERT_EQ("product_price", ref_include_fields_vec[0].fields); + ASSERT_EQ("customers", ref_include_fields_vec[0].alias); + ASSERT_EQ(ref_include::nest_array, ref_include_fields_vec[0].strategy); ASSERT_TRUE(ref_include_fields_vec[0].nested_join_includes.empty()); ref_include_fields_vec.clear(); filter_query = "$product_variants( $inventory($retailers(location:(33.865,-118.375,100 km))))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory: nest) as variants"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); ASSERT_EQ("title,", ref_include_fields_vec[0].fields); ASSERT_EQ("variants", ref_include_fields_vec[0].alias); - ASSERT_TRUE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); auto nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); nested_join_includes = ref_include_fields_vec[0].nested_join_includes[0].nested_join_includes; ASSERT_EQ("retailers", nested_join_includes[0].collection_name); ASSERT_TRUE(nested_join_includes[0].fields.empty()); ASSERT_TRUE(nested_join_includes[0].alias.empty()); - ASSERT_TRUE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::nest, ref_include_fields_vec[0].strategy); ref_include_fields_vec.clear(); filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory," " $retailers(title): merge) as variants"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); ASSERT_EQ("title,", ref_include_fields_vec[0].fields); ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); - ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); ASSERT_EQ("retailers", nested_join_includes[1].collection_name); ASSERT_EQ("title", nested_join_includes[1].fields); ASSERT_TRUE(nested_join_includes[1].alias.empty()); - ASSERT_TRUE(nested_join_includes[1].nest_ref_doc); + ASSERT_EQ(ref_include::nest, nested_join_includes[1].strategy); ref_include_fields_vec.clear(); filter_query = "$product_variants( $inventory(id:*) && $retailers(location:(33.865,-118.375,100 km)))"; include_fields_vec = {"$product_variants(title, $inventory(qty:merge) as inventory, description," " $retailers(title), foo: merge) as variants"}; - CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec); + initialize_op = CollectionManager::_initialize_ref_include_fields_vec(filter_query, include_fields_vec, + ref_include_fields_vec); + ASSERT_TRUE(initialize_op.ok()); ASSERT_EQ(1, ref_include_fields_vec.size()); ASSERT_EQ("product_variants", ref_include_fields_vec[0].collection_name); ASSERT_EQ("title, description, foo", ref_include_fields_vec[0].fields); ASSERT_EQ("variants.", ref_include_fields_vec[0].alias); - ASSERT_FALSE(ref_include_fields_vec[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, ref_include_fields_vec[0].strategy); nested_join_includes = ref_include_fields_vec[0].nested_join_includes; ASSERT_EQ("inventory", nested_join_includes[0].collection_name); ASSERT_EQ("qty", nested_join_includes[0].fields); ASSERT_EQ("inventory.", nested_join_includes[0].alias); - ASSERT_FALSE(nested_join_includes[0].nest_ref_doc); + ASSERT_EQ(ref_include::merge, nested_join_includes[0].strategy); ASSERT_EQ("retailers", nested_join_includes[1].collection_name); ASSERT_EQ("title", nested_join_includes[1].fields); ASSERT_TRUE(nested_join_includes[1].alias.empty()); - ASSERT_TRUE(nested_join_includes[1].nest_ref_doc); + ASSERT_EQ(ref_include::nest, nested_join_includes[1].strategy); ref_include_fields_vec.clear(); }