diff --git a/include/collection.h b/include/collection.h index 205737d7..bfc6d0cc 100644 --- a/include/collection.h +++ b/include/collection.h @@ -283,8 +283,6 @@ private: Option get_reference_doc_id(const std::string& ref_collection_name, const uint32_t& seq_id) const; - Option get_reference_field(const std::string& ref_collection_name) const; - static void hide_credential(nlohmann::json& json, const std::string& credential_name); public: @@ -380,6 +378,13 @@ public: static void remove_flat_fields(nlohmann::json& document); + static Option add_reference_fields(nlohmann::json& doc, + Collection *const ref_collection, + const reference_filter_result_t& references, + const tsl::htrie_set& ref_include_fields_full, + const tsl::htrie_set& ref_exclude_fields_full, + const std::string& error_prefix); + static Option prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0, @@ -489,7 +494,7 @@ public: Option get_reference_filter_ids(const std::string& filter_query, filter_result_t& filter_result, - const std::string& collection_name) const; + const std::string& reference_field_name) const; Option get(const std::string & id) const; @@ -580,6 +585,14 @@ public: std::array*, 3>& field_values) const; int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const; + + bool is_referenced_in(const std::string& collection_name) const; + + Option get_reference_field(const std::string& collection_name) const; + + Option get_sort_indexed_field_value(const std::string& field_name, const uint32_t& seq_id) const; + + friend class filter_result_iterator_t; }; template diff --git a/include/index.h b/include/index.h index b4b6da0c..0e45dffc 100644 --- a/include/index.h +++ b/include/index.h @@ -728,6 +728,8 @@ public: int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const; + Option get_sort_indexed_field_value(const std::string& field_name, const uint32_t& seq_id) const; + static void remove_matched_tokens(std::vector& tokens, const std::set& rule_token_set) ; void compute_facet_infos(const std::vector& facets, facet_query_t& facet_query, diff --git a/src/collection.cpp b/src/collection.cpp index 792ee826..0e17d4db 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1322,7 +1322,8 @@ Option Collection::extract_field_name(const std::string& field_name, std::string error = "No string or string array field found matching the pattern `" + field_name + "` in the schema."; return Option(404, error); } else if (!field_found) { - std::string error = is_wildcard ? "No field found matching the pattern `" : "Could not find a field named `" + field_name + "` in the schema."; + std::string error = is_wildcard ? "No field found matching the pattern `" : "Could not find a field named `" + + field_name + "` in the schema."; return Option(404, error); } @@ -2958,44 +2959,15 @@ Option Collection::get_filter_ids(const std::string& filter_query, filter_ return index->do_filtering_with_lock(filter_tree_root, filter_result, name); } -Option Collection::get_reference_doc_id(const std::string& ref_collection_name, const uint32_t& seq_id) const { - auto get_reference_field_op = get_reference_field(ref_collection_name); - if (!get_reference_field_op.ok()) { - return Option(get_reference_field_op.code(), get_reference_field_op.error()); - } - - auto field_name = get_reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX; - return index->get_reference_doc_id_with_lock(field_name, seq_id); -} - -Option Collection::get_reference_field(const std::string& ref_collection_name) const { - std::string reference_field_name; - for (auto const& pair: reference_fields) { - auto reference_pair = pair.second; - if (reference_pair.collection == ref_collection_name) { - reference_field_name = pair.first; - break; - } - } - - if (reference_field_name.empty()) { - return Option(400, "Could not find any field in `" + name + "` referencing the collection `" - + ref_collection_name + "`."); - } - - return Option(reference_field_name); +Option Collection::get_reference_doc_id(const std::string& ref_field_name, const uint32_t& seq_id) const { + return index->get_reference_doc_id_with_lock(ref_field_name, seq_id); } Option Collection::get_reference_filter_ids(const std::string & filter_query, filter_result_t& filter_result, - const std::string & collection_name) const { + const std::string& reference_field_name) const { std::shared_lock lock(mutex); - auto reference_field_op = get_reference_field(collection_name); - if (!reference_field_op.ok()) { - return Option(reference_field_op.code(), reference_field_op.error()); - } - const std::string doc_id_prefix = std::to_string(collection_id) + "_" + DOC_ID_PREFIX + "_"; filter_node_t* filter_tree_root = nullptr; Option parse_op = filter::parse_filter_query(filter_query, search_schema, @@ -3006,9 +2978,7 @@ Option Collection::get_reference_filter_ids(const std::string & filter_que return parse_op; } - // Reference helper field has the sequence id of other collection's documents. - auto field_name = reference_field_op.get() + REFERENCE_HELPER_FIELD_SUFFIX; - return index->do_reference_filtering_with_lock(filter_tree_root, filter_result, name, field_name); + return index->do_reference_filtering_with_lock(filter_tree_root, filter_result, name, reference_field_name); } bool Collection::facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, @@ -4311,6 +4281,55 @@ void Collection::remove_flat_fields(nlohmann::json& document) { } } +Option Collection::add_reference_fields(nlohmann::json& doc, + Collection *const ref_collection, + const reference_filter_result_t& references, + const tsl::htrie_set& ref_include_fields_full, + const tsl::htrie_set& ref_exclude_fields_full, + const std::string& error_prefix) { + // One-to-one relation. + if (references.count == 1) { + auto ref_doc_seq_id = references.docs[0]; + + nlohmann::json ref_doc; + auto get_doc_op = ref_collection->get_document_from_store(ref_doc_seq_id, ref_doc); + if (!get_doc_op.ok()) { + return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); + } + + auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); + if (!prune_op.ok()) { + return Option(prune_op.code(), error_prefix + prune_op.error()); + } + + doc.update(ref_doc); + return Option(true); + } + + // One-to-many relation. + for (uint32_t i = 0; i < references.count; i++) { + auto ref_doc_seq_id = references.docs[i]; + + nlohmann::json ref_doc; + auto get_doc_op = ref_collection->get_document_from_store(ref_doc_seq_id, ref_doc); + if (!get_doc_op.ok()) { + return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); + } + + auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); + if (!prune_op.ok()) { + return Option(prune_op.code(), error_prefix + prune_op.error()); + } + + for (auto ref_doc_it = ref_doc.begin(); ref_doc_it != ref_doc.end(); ref_doc_it++) { + // Add the values of ref_doc as JSON array into doc. + doc[ref_doc_it.key()] += ref_doc_it.value(); + } + } + + return Option(true); +} + Option Collection::prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, @@ -4406,6 +4425,32 @@ Option Collection::prune_doc(nlohmann::json& doc, return Option(400, "Referenced collection `" + ref_collection_name + "` in include_fields not found."); } + auto const joined_on_ref_collection = reference_filter_results.count(ref_collection_name) > 0, + has_filter_reference = (joined_on_ref_collection && + reference_filter_results.at(ref_collection_name).count > 0); + auto doc_has_reference = false, joined_coll_has_reference = false; + + // Reference include_by without join, check if doc itself contains the reference. + if (!joined_on_ref_collection && collection != nullptr) { + doc_has_reference = ref_collection->is_referenced_in(collection->name); + } + + std::string joined_coll_having_reference; + // Check if the joined collection has a reference. + if (!joined_on_ref_collection && !doc_has_reference) { + for (const auto &reference_filter_result: reference_filter_results) { + joined_coll_has_reference = ref_collection->is_referenced_in(reference_filter_result.first); + if (joined_coll_has_reference) { + joined_coll_having_reference = reference_filter_result.first; + break; + } + } + } + + if (!has_filter_reference && !doc_has_reference && !joined_coll_has_reference) { + continue; + } + std::vector ref_include_fields_vec, ref_exclude_fields_vec; StringUtils::split(reference_fields, ref_include_fields_vec, ","); auto exclude_reference_it = exclude_names.equal_prefix_range("$" + ref_collection_name); @@ -4430,78 +4475,60 @@ Option Collection::prune_doc(nlohmann::json& doc, return Option(include_exclude_op.code(), error_prefix + include_exclude_op.error()); } - bool has_filter_reference = reference_filter_results.count(ref_collection_name) > 0; - if (!has_filter_reference) { - if (collection == nullptr) { + Option add_reference_fields_op = Option(true); + if (has_filter_reference) { + add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), + reference_filter_results.at(ref_collection_name), + ref_include_fields_full, ref_exclude_fields_full, + error_prefix); + } else if (doc_has_reference) { + auto get_reference_field_op = ref_collection->get_reference_field(collection->name); + if (!get_reference_field_op.ok()) { continue; } - - // Reference include_by without join, check if doc itself contains the reference. - auto get_reference_doc_id_op = collection->get_reference_doc_id(ref_collection_name, seq_id); + auto const& field_name = get_reference_field_op.get(); + auto get_reference_doc_id_op = collection->get_reference_doc_id(field_name, seq_id); if (!get_reference_doc_id_op.ok()) { continue; } - auto ref_doc_seq_id = get_reference_doc_id_op.get(); - - nlohmann::json ref_doc; - auto get_doc_op = ref_collection->get_document_from_store(ref_doc_seq_id, ref_doc); - if (!get_doc_op.ok()) { - return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); + reference_filter_result_t r{1, new uint32[1]{get_reference_doc_id_op.get()}}; + add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), r, + ref_include_fields_full, ref_exclude_fields_full, + error_prefix); + } else if (joined_coll_has_reference) { + auto joined_collection = cm.get_collection(joined_coll_having_reference); + if (joined_collection == nullptr) { + continue; } - auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); - if (!prune_op.ok()) { - return Option(prune_op.code(), error_prefix + prune_op.error()); + auto reference_field_name_op = ref_collection->get_reference_field(joined_coll_having_reference); + if (!reference_field_name_op.ok()) { + continue; } - doc.update(ref_doc); - continue; + auto const& reference_field_name = reference_field_name_op.get(); + auto const& reference_filter_result = reference_filter_results.at(joined_coll_having_reference); + auto const& count = reference_filter_result.count; + reference_filter_result_t r{count, new uint32[count]}; + + for (uint32_t i = 0; i < count; i++) { + auto op = joined_collection->get_sort_indexed_field_value(reference_field_name, + reference_filter_result.docs[i]); + if (!op.ok()) { + return Option(op.code(), error_prefix + op.error()); + } + + r.docs[i] = op.get(); + } + + add_reference_fields_op = add_reference_fields(doc, ref_collection.get(), r, + ref_include_fields_full, ref_exclude_fields_full, + error_prefix); } - if (has_filter_reference && reference_filter_results.at(ref_collection_name).count == 0) { - continue; - } - - auto const& reference_filter_result = reference_filter_results.at(ref_collection_name); - // One-to-one relation. - if (reference_filter_result.count == 1) { - auto ref_doc_seq_id = reference_filter_result.docs[0]; - - nlohmann::json ref_doc; - auto get_doc_op = ref_collection->get_document_from_store(ref_doc_seq_id, ref_doc); - if (!get_doc_op.ok()) { - return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); - } - - auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); - if (!prune_op.ok()) { - return Option(prune_op.code(), error_prefix + prune_op.error()); - } - - doc.update(ref_doc); - continue; - } - - // One-to-many relation. - for (uint32_t i = 0; i < reference_filter_result.count; i++) { - auto ref_doc_seq_id = reference_filter_result.docs[i]; - - nlohmann::json ref_doc; - auto get_doc_op = ref_collection->get_document_from_store(ref_doc_seq_id, ref_doc); - if (!get_doc_op.ok()) { - return Option(get_doc_op.code(), error_prefix + get_doc_op.error()); - } - - auto prune_op = prune_doc(ref_doc, ref_include_fields_full, ref_exclude_fields_full); - if (!prune_op.ok()) { - return Option(prune_op.code(), error_prefix + prune_op.error()); - } - - for (auto ref_doc_it = ref_doc.begin(); ref_doc_it != ref_doc.end(); ref_doc_it++) { - // Add the values of ref_doc as JSON array into doc. - doc[ref_doc_it.key()] += ref_doc_it.value(); - } + if (!add_reference_fields_op.ok()) { + return add_reference_fields_op; } } @@ -5385,3 +5412,24 @@ int64_t Collection::reference_string_sort_score(const string &field_name, const std::shared_lock lock(mutex); return index->reference_string_sort_score(field_name, seq_id); } + +bool Collection::is_referenced_in(const std::string& collection_name) const { + std::shared_lock lock(mutex); + return referenced_in.count(collection_name) > 0; +} + +Option Collection::get_reference_field(const std::string& collection_name) const { + std::shared_lock lock(mutex); + + if (referenced_in.count(collection_name) == 0) { + return Option(400, "Could not find any field in `" + name + "` referencing the collection `" + + collection_name + "`."); + } + + return Option(referenced_in.at(collection_name)); +} + +Option Collection::get_sort_indexed_field_value(const std::string& field_name, const uint32_t& seq_id) const { + std::shared_lock lock(mutex); + return index->get_sort_indexed_field_value(field_name, seq_id); +} diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 229d0797..5c892c7f 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -592,16 +592,26 @@ void filter_result_iterator_t::init() { if (is_referenced_filter) { // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(a_filter.referenced_collection_name); - if (collection == nullptr) { - status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + auto const& ref_collection_name = a_filter.referenced_collection_name; + auto ref_collection = cm.get_collection(ref_collection_name); + if (ref_collection == nullptr) { + status = Option(400, "Referenced collection `" + ref_collection_name + "` not found."); is_valid = false; return; } - auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, - filter_result, - collection_name); + auto coll = cm.get_collection(collection_name); + if (coll->referenced_in.count(ref_collection_name) == 0 || coll->referenced_in.at(ref_collection_name).empty()) { + status = Option(400, "Could not find a reference to `" + collection_name + "` in `" + + ref_collection_name + "` collection."); + is_valid = false; + return; + } + + auto const& field_name = coll->referenced_in.at(ref_collection_name); + auto reference_filter_op = ref_collection->get_reference_filter_ids(a_filter.field_name, + filter_result, + field_name); if (!reference_filter_op.ok()) { status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name + "` collection: " + reference_filter_op.error()); diff --git a/src/index.cpp b/src/index.cpp index 02ca40f5..af6c62fd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6420,6 +6420,16 @@ int64_t Index::reference_string_sort_score(const string &field_name, const uint3 return str_sort_index.at(field_name)->rank(seq_id); } +Option Index::get_sort_indexed_field_value(const string& field_name, const uint32_t& seq_id) const { + std::shared_lock lock(mutex); + if (sort_index.count(field_name) == 0 || sort_index.at(field_name)->count(seq_id) == 0) { + return Option(400, "Could not find `" + field_name + "` value for doc `" + std::to_string(seq_id) + + "`."); + } + + return Option(sort_index.at(field_name)->at(seq_id)); +} + /* // https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon // NOTE: polygon and point should have been transformed with `transform_for_180th_meridian` diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 40c607ca..cffa7345 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -1044,7 +1044,7 @@ TEST_F(CollectionJoinTest, FilterByNReferences) { collectionManager.drop_collection("Links"); } -TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference_SingleMatch) { +TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) { auto schema_json = R"({ "name": "Products", @@ -1504,6 +1504,222 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference_SingleMatch) { ASSERT_EQ("soap", res_obj["hits"][0]["document"].at("product_name")); ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price")); ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price")); + + schema_json = + R"({ + "name": "Users", + "fields": [ + {"name": "user_id", "type": "string"}, + {"name": "user_name", "type": "string"} + ] + })"_json; + documents = { + R"({ + "user_id": "user_a", + "user_name": "Roshan" + })"_json, + R"({ + "user_id": "user_b", + "user_name": "Ruby" + })"_json, + R"({ + "user_id": "user_c", + "user_name": "Joe" + })"_json, + R"({ + "user_id": "user_d", + "user_name": "Aby" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Repos", + "fields": [ + {"name": "repo_id", "type": "string"}, + {"name": "repo_content", "type": "string"}, + {"name": "repo_stars", "type": "int32"}, + {"name": "repo_is_private", "type": "bool"} + ] + })"_json; + documents = { + R"({ + "repo_id": "repo_a", + "repo_content": "body1", + "repo_stars": 431, + "repo_is_private": true + })"_json, + R"({ + "repo_id": "repo_b", + "repo_content": "body2", + "repo_stars": 4562, + "repo_is_private": false + })"_json, + R"({ + "repo_id": "repo_c", + "repo_content": "body3", + "repo_stars": 945, + "repo_is_private": false + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Links", + "fields": [ + {"name": "repo_id", "type": "string", "reference": "Repos.repo_id"}, + {"name": "user_id", "type": "string", "reference": "Users.user_id"} + ] + })"_json; + documents = { + R"({ + "repo_id": "repo_a", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_a", + "user_id": "user_c" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_a" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_b", + "user_id": "user_d" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_a" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_b" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_c" + })"_json, + R"({ + "repo_id": "repo_c", + "user_id": "user_d" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Organizations", + "fields": [ + {"name": "org_id", "type": "string"}, + {"name": "org_name", "type": "string"} + ] + })"_json; + documents = { + R"({ + "org_id": "org_a", + "org_name": "Typesense" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + schema_json = + R"({ + "name": "Participants", + "fields": [ + {"name": "user_id", "type": "string", "reference": "Users.user_id"}, + {"name": "org_id", "type": "string", "reference": "Organizations.org_id"} + ] + })"_json; + documents = { + R"({ + "user_id": "user_a", + "org_id": "org_a" + })"_json, + R"({ + "user_id": "user_b", + "org_id": "org_a" + })"_json, + R"({ + "user_id": "user_d", + "org_id": "org_a" + })"_json, + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + + // Search for users within an organization with access to a particular repo. + req_params = { + {"collection", "Users"}, + {"q", "R"}, + {"query_by", "user_name"}, + {"filter_by", "$Participants(org_id:=org_a) && $Links(repo_id:=repo_b)"}, + {"include_fields", "user_id, user_name, $Repos(repo_content), $Organizations(org_name)"}, + {"exclude_fields", "$Participants(*), $Links(*), "} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(2, res_obj["found"].get()); + ASSERT_EQ(2, res_obj["hits"].size()); + ASSERT_EQ(4, res_obj["hits"][0]["document"].size()); + ASSERT_EQ("user_b", res_obj["hits"][0]["document"].at("user_id")); + ASSERT_EQ("Ruby", res_obj["hits"][0]["document"].at("user_name")); + ASSERT_EQ("body2", res_obj["hits"][0]["document"].at("repo_content")); + ASSERT_EQ("Typesense", res_obj["hits"][0]["document"].at("org_name")); + ASSERT_EQ("user_a", res_obj["hits"][1]["document"].at("user_id")); + ASSERT_EQ("Roshan", res_obj["hits"][1]["document"].at("user_name")); + ASSERT_EQ("body2", res_obj["hits"][1]["document"].at("repo_content")); + ASSERT_EQ("Typesense", res_obj["hits"][1]["document"].at("org_name")); } TEST_F(CollectionJoinTest, CascadeDeletion) {