diff --git a/include/collection.h b/include/collection.h index 9d2202bc..4f8cd435 100644 --- a/include/collection.h +++ b/include/collection.h @@ -386,6 +386,7 @@ public: static Option add_reference_fields(nlohmann::json& doc, Collection *const ref_collection, + const std::string& alias, const reference_filter_result_t& references, const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, @@ -396,7 +397,7 @@ public: size_t depth = 0, const std::map& reference_filter_results = {}, Collection *const collection = nullptr, const uint32_t& seq_id = 0, - const std::vector& ref_include_fields_vec = {}); + const std::vector& ref_include_fields_vec = {}); const Index* _get_index() const; @@ -496,7 +497,7 @@ public: const size_t remote_embedding_num_tries = 2, const std::string& stopwords_set="", const std::vector& facet_return_parent = {}, - const std::vector& ref_include_fields_vec = {}) const; + const std::vector& ref_include_fields_vec = {}) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/field.h b/include/field.h index aa54842f..e8fcd1e7 100644 --- a/include/field.h +++ b/include/field.h @@ -487,6 +487,11 @@ namespace sort_field_const { static const std::string vector_distance = "_vector_distance"; } +struct ref_include_fields { + std::string expression; + std::string alias; +}; + struct sort_by { enum missing_values_t { first, diff --git a/src/collection.cpp b/src/collection.cpp index 192fd036..7496d680 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1389,7 +1389,7 @@ Option Collection::search(std::string raw_query, const size_t remote_embedding_num_tries, const std::string& stopwords_set, const std::vector& facet_return_parent, - const std::vector& ref_include_fields_vec) const { + const std::vector& ref_include_fields_vec) const { std::shared_lock lock(mutex); @@ -4314,6 +4314,7 @@ void Collection::remove_flat_fields(nlohmann::json& document) { Option Collection::add_reference_fields(nlohmann::json& doc, Collection *const ref_collection, + const std::string& alias, const reference_filter_result_t& references, const tsl::htrie_set& ref_include_fields_full, const tsl::htrie_set& ref_exclude_fields_full, @@ -4328,11 +4329,21 @@ Option Collection::add_reference_fields(nlohmann::json& doc, return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); } + remove_flat_fields(ref_doc); + auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); if (!prune_op.ok()) { return Option(prune_op.code(), error_prefix + prune_op.error()); } + if (!alias.empty()) { + auto temp_doc = ref_doc; + ref_doc.clear(); + for (const auto &item: temp_doc.items()) { + ref_doc[alias + item.key()] = item.value(); + } + } + doc.update(ref_doc); return Option(true); } @@ -4347,11 +4358,21 @@ Option Collection::add_reference_fields(nlohmann::json& doc, return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); } + remove_flat_fields(ref_doc); + auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); if (!prune_op.ok()) { return Option(prune_op.code(), error_prefix + prune_op.error()); } + if (!alias.empty()) { + auto temp_doc = ref_doc; + ref_doc.clear(); + for (const auto &item: temp_doc.items()) { + ref_doc[alias + item.key()] = item.value(); + } + } + for (auto ref_doc_it = ref_doc.begin(); ref_doc_it != ref_doc.end(); ref_doc_it++) { // Add the values of ref_doc as JSON array into doc. doc[ref_doc_it.key()] += ref_doc_it.value(); @@ -4367,7 +4388,7 @@ Option Collection::prune_doc(nlohmann::json& doc, const std::string& parent_name, size_t depth, const std::map& reference_filter_results, Collection *const collection, const uint32_t& seq_id, - const std::vector& ref_includes) { + const std::vector& ref_includes) { // doc can only be an object auto it = doc.begin(); while(it != doc.end()) { @@ -4443,7 +4464,8 @@ Option Collection::prune_doc(nlohmann::json& doc, it++; } - for (auto const& ref: ref_includes) { + for (auto const& ref_include: ref_includes) { + auto const& ref = ref_include.expression; size_t parenthesis_index = ref.find('('); auto ref_collection_name = ref.substr(1, parenthesis_index - 1); @@ -4507,7 +4529,7 @@ Option Collection::prune_doc(nlohmann::json& doc, Option add_reference_fields_op = Option(true); if (has_filter_reference) { - add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), + add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), ref_include.alias, reference_filter_results.at(ref_collection_name), ref_include_fields_full, ref_exclude_fields_full, error_prefix); @@ -4523,7 +4545,7 @@ Option Collection::prune_doc(nlohmann::json& doc, } reference_filter_result_t r{1, new uint32[1]{get_reference_doc_id_op.get()}}; - add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), r, + add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), ref_include.alias, r, ref_include_fields_full, ref_exclude_fields_full, error_prefix); } else if (joined_coll_has_reference) { @@ -4552,7 +4574,7 @@ Option Collection::prune_doc(nlohmann::json& doc, r.docs[i] = op.get(); } - add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), r, + add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), ref_include.alias, r, ref_include_fields_full, ref_exclude_fields_full, error_prefix); } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 0c1793ea..f7727688 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -9,6 +9,7 @@ #include "logger.h" #include "magic_enum.hpp" #include "stopwords_manager.h" +#include "field.h" constexpr const size_t CollectionManager::DEFAULT_NUM_MEMORY_SHARDS; @@ -838,43 +839,50 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte // Separate out the reference includes into `ref_include_fields_vec`. void initialize_ref_include_fields_vec(const std::string& filter_query, std::vector& include_fields_vec, - std::vector& ref_include_fields_vec) { + std::vector& ref_include_fields_vec) { std::set reference_collection_names; CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names); std::vector result_include_fields_vec; auto wildcard_include_all = true; - for (auto include_field: include_fields_vec) { - if (include_field[0] != '$') { - if (include_field == "*") { + for (auto include_field_exp: include_fields_vec) { + if (include_field_exp[0] != '$') { + if (include_field_exp == "*") { continue; } wildcard_include_all = false; - result_include_fields_vec.emplace_back(include_field); + result_include_fields_vec.emplace_back(include_field_exp); continue; } - auto open_paren_pos = include_field.find('('); + auto as_pos = include_field_exp.find(" as "); + auto ref_include = include_field_exp.substr(0, as_pos), + alias = (as_pos == std::string::npos) ? "" : + include_field_exp.substr(as_pos + 4, include_field_exp.size() - (as_pos + 4)); + + // For an alias `foo`, we need append `foo.` to all the top level keys of reference doc. + ref_include_fields_vec.emplace_back(ref_include_fields{ref_include, alias.empty() ? alias : + StringUtils::trim(alias) + "."}); + + auto open_paren_pos = include_field_exp.find('('); if (open_paren_pos == std::string::npos) { continue; } - auto reference_collection_name = include_field.substr(1, open_paren_pos - 1); + auto reference_collection_name = include_field_exp.substr(1, open_paren_pos - 1); StringUtils::trim(reference_collection_name); if (reference_collection_name.empty()) { continue; } - ref_include_fields_vec.emplace_back(include_field); - // Referenced collection in filter_query is already mentioned in ref_include_fields. reference_collection_names.erase(reference_collection_name); } // Get all the fields of the referenced collection in the filter but not mentioned in include_fields. for (const auto &reference_collection_name: reference_collection_names) { - ref_include_fields_vec.emplace_back("$" + reference_collection_name + "(*)"); + ref_include_fields_vec.emplace_back(ref_include_fields{"$" + reference_collection_name + "(*)", ""}); } // Since no field of the collection is mentioned in include_fields, get all the fields. @@ -1048,7 +1056,7 @@ Option CollectionManager::do_search(std::map& re std::vector include_fields_vec; std::vector exclude_fields_vec; - std::vector ref_include_fields_vec; + std::vector ref_include_fields_vec; spp::sparse_hash_set include_fields; spp::sparse_hash_set exclude_fields; diff --git a/src/string_utils.cpp b/src/string_utils.cpp index 6b320fc2..a9296cd5 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -516,6 +516,14 @@ Option StringUtils::split_include_fields(const std::string& include_fields } include_field = include_fields.substr(range_pos, (end - range_pos) + 1); + + comma_pos = include_fields.find(',', end); + auto as_pos = include_fields.find(" as ", end); + if (as_pos != std::string::npos && as_pos < comma_pos) { + auto alias = include_fields.substr(as_pos, (comma_pos - as_pos)); + end += alias.size() + 1; + include_field += (" " + trim(alias)); + } } else { end = comma_pos; include_field = include_fields.substr(start, end - start); diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index e1a6e2e4..bc8dea8c 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -1511,6 +1511,26 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); + // Add alias using `as` + req_params = { + {"collection", "Customers"}, + {"q", "Joe"}, + {"query_by", "customer_name"}, + {"filter_by", "product_price:<100"}, + {"include_fields", "$Products(product_name) as p, product_price"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(1, res_obj["found"].get()); + ASSERT_EQ(1, res_obj["hits"].size()); + ASSERT_EQ(2, res_obj["hits"][0]["document"].size()); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("p.product_name")); + ASSERT_EQ("soap", res_obj["hits"][0]["document"].at("p.product_name")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); + ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); + schema_json = R"({ "name": "Users", @@ -1649,13 +1669,19 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { "name": "Organizations", "fields": [ {"name": "org_id", "type": "string"}, - {"name": "org_name", "type": "string"} - ] + {"name": "name", "type": "object"}, + {"name": "name.first", "type": "string"}, + {"name": "name.last", "type": "string"} + ], + "enable_nested_fields": true })"_json; documents = { R"({ "org_id": "org_a", - "org_name": "Typesense" + "name": { + "first": "type", + "last": "sense" + } })"_json }; collection_create_op = collectionManager.create_collection(schema_json); @@ -1708,7 +1734,7 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { {"q", "R"}, {"query_by", "user_name"}, {"filter_by", "$Participants(org_id:=org_a) && $Links(repo_id:=repo_b)"}, - {"include_fields", "user_id, user_name, $Repos(repo_content), $Organizations(org_name)"}, + {"include_fields", "user_id, user_name, $Repos(repo_content), $Organizations(name) as org"}, {"exclude_fields", "$Participants(*), $Links(*), "} }; search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); @@ -1718,14 +1744,18 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { ASSERT_EQ(2, res_obj["found"].get()); ASSERT_EQ(2, res_obj["hits"].size()); ASSERT_EQ(4, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("user_b", res_obj["hits"][0]["document"].at("user_id")); ASSERT_EQ("Ruby", res_obj["hits"][0]["document"].at("user_name")); ASSERT_EQ("body2", res_obj["hits"][0]["document"].at("repo_content")); - ASSERT_EQ("Typesense", res_obj["hits"][0]["document"].at("org_name")); + ASSERT_EQ("type", res_obj["hits"][0]["document"]["org.name"].at("first")); + ASSERT_EQ("sense", res_obj["hits"][0]["document"]["org.name"].at("last")); + ASSERT_EQ("user_a", res_obj["hits"][1]["document"].at("user_id")); ASSERT_EQ("Roshan", res_obj["hits"][1]["document"].at("user_name")); ASSERT_EQ("body2", res_obj["hits"][1]["document"].at("repo_content")); - ASSERT_EQ("Typesense", res_obj["hits"][1]["document"].at("org_name")); + ASSERT_EQ("type", res_obj["hits"][0]["document"]["org.name"].at("first")); + ASSERT_EQ("sense", res_obj["hits"][0]["document"]["org.name"].at("last")); } TEST_F(CollectionJoinTest, CascadeDeletion) { diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index 1ee9e6ac..bd33a169 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -419,4 +419,12 @@ TEST(StringUtilsTest, SplitIncludeFields) { include_fields = "id, $Collection(title, pref*), count, "; tokens = {"id", "$Collection(title, pref*)", "count"}; splitIncludeTestHelper(include_fields, tokens); + + include_fields = "$Collection(title, pref*) as coll"; + tokens = {"$Collection(title, pref*) as coll"}; + splitIncludeTestHelper(include_fields, tokens); + + include_fields = "id, $Collection(title, pref*) as coll , count, "; + tokens = {"id", "$Collection(title, pref*) as coll", "count"}; + splitIncludeTestHelper(include_fields, tokens); }