diff --git a/include/collection.h b/include/collection.h index 8d1cea01..c6592b77 100644 --- a/include/collection.h +++ b/include/collection.h @@ -215,8 +215,6 @@ private: void populate_text_match_info(nlohmann::json& info, uint64_t match_score, const text_match_type_t match_type) const; - static void remove_flat_fields(nlohmann::json& document); - bool handle_highlight_text(std::string& text, bool normalise, const field &search_field, const std::vector& symbols_to_index, const std::vector& token_separators, highlight_t& highlight, StringUtils & string_utils, bool use_word_tokenizer, @@ -251,6 +249,11 @@ private: static uint64_t extract_bits(uint64_t value, unsigned lsb_offset, unsigned n); + Option populate_include_exclude_fields(const spp::sparse_hash_set& include_fields, + const spp::sparse_hash_set& exclude_fields, + tsl::htrie_set& include_fields_full, + tsl::htrie_set& exclude_fields_full) const; + public: enum {MAX_ARRAY_MATCHES = 5}; @@ -337,6 +340,8 @@ public: Option index_in_memory(nlohmann::json & document, uint32_t seq_id, const index_operation_t op, const DIRTY_VALUES& dirty_values); + static void remove_flat_fields(nlohmann::json& document); + static void prune_doc(nlohmann::json& doc, const tsl::htrie_set& include_names, const tsl::htrie_set& exclude_names, const std::string& parent_name = "", size_t depth = 0); @@ -377,6 +382,11 @@ public: std::string& req_dirty_values, const int batch_size = 1000); + Option populate_include_exclude_fields_lk(const spp::sparse_hash_set& include_fields, + const spp::sparse_hash_set& exclude_fields, + tsl::htrie_set& include_fields_full, + tsl::htrie_set& exclude_fields_full) const; + Option search(const std::string & query, const std::vector & search_fields, const std::string & filter_query, const std::vector & facet_fields, const std::vector & sort_fields, const std::vector& num_typos, diff --git a/include/core_api_utils.h b/include/core_api_utils.h index 4e057035..cc266576 100644 --- a/include/core_api_utils.h +++ b/include/core_api_utils.h @@ -22,8 +22,8 @@ struct export_state_t: public req_state_t { Collection* collection; std::vector> index_ids; std::vector offsets; - std::set include_fields; - std::set exclude_fields; + tsl::htrie_set include_fields; + tsl::htrie_set exclude_fields; size_t export_batch_size = 100; std::string* res_body; diff --git a/src/collection.cpp b/src/collection.cpp index df7cfd30..9f9540ee 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1113,49 +1113,16 @@ Option Collection::search(const std::string & raw_query, } } - std::vector include_fields_vec; - std::vector exclude_fields_vec; tsl::htrie_set include_fields_full; tsl::htrie_set exclude_fields_full; - for(auto& f_name: include_fields) { - auto field_op = extract_field_name(f_name, search_schema, include_fields_vec, false, enable_nested_fields); - if(!field_op.ok()) { - if(field_op.code() == 404) { - // field need not be part of schema to be included (could be a stored value in the doc) - include_fields_vec.push_back(f_name); - continue; - } - return Option(field_op.code(), field_op.error()); - } + auto include_exclude_op = populate_include_exclude_fields(include_fields, exclude_fields, + include_fields_full, exclude_fields_full); + + if(!include_exclude_op.ok()) { + return Option(include_exclude_op.code(), include_exclude_op.error()); } - for(auto& f_name: exclude_fields) { - if(f_name == "out_of") { - // `out_of` is strictly a meta-field, but we handle it since it's useful - continue; - } - - auto field_op = extract_field_name(f_name, search_schema, exclude_fields_vec, false, enable_nested_fields); - if(!field_op.ok()) { - if(field_op.code() == 404) { - // field need not be part of schema to be excluded (could be a stored value in the doc) - exclude_fields_vec.push_back(f_name); - continue; - } - return Option(field_op.code(), field_op.error()); - } - } - - for(auto& f_name: include_fields_vec) { - include_fields_full.insert(f_name); - } - - for(auto& f_name: exclude_fields_vec) { - exclude_fields_full.insert(f_name); - } - - // process weights for search fields std::vector reordered_search_fields; std::vector weighted_search_fields; @@ -4327,4 +4294,60 @@ Option Collection::parse_facet(const std::string& facet_field, std::vector } return Option(true); +} + +Option Collection::populate_include_exclude_fields(const spp::sparse_hash_set& include_fields, + const spp::sparse_hash_set& exclude_fields, + tsl::htrie_set& include_fields_full, + tsl::htrie_set& exclude_fields_full) const { + + std::vector include_fields_vec; + std::vector exclude_fields_vec; + + for(auto& f_name: include_fields) { + auto field_op = extract_field_name(f_name, search_schema, include_fields_vec, false, enable_nested_fields); + if(!field_op.ok()) { + if(field_op.code() == 404) { + // field need not be part of schema to be included (could be a stored value in the doc) + include_fields_vec.push_back(f_name); + continue; + } + return Option(field_op.code(), field_op.error()); + } + } + + for(auto& f_name: exclude_fields) { + if(f_name == "out_of") { + // `out_of` is strictly a meta-field, but we handle it since it's useful + continue; + } + + auto field_op = extract_field_name(f_name, search_schema, exclude_fields_vec, false, enable_nested_fields); + if(!field_op.ok()) { + if(field_op.code() == 404) { + // field need not be part of schema to be excluded (could be a stored value in the doc) + exclude_fields_vec.push_back(f_name); + continue; + } + return Option(field_op.code(), field_op.error()); + } + } + + for(auto& f_name: include_fields_vec) { + include_fields_full.insert(f_name); + } + + for(auto& f_name: exclude_fields_vec) { + exclude_fields_full.insert(f_name); + } + + return Option(true); +} + +Option Collection::populate_include_exclude_fields_lk(const spp::sparse_hash_set& include_fields, + const spp::sparse_hash_set& exclude_fields, + tsl::htrie_set& include_fields_full, + tsl::htrie_set& exclude_fields_full) const { + std::shared_lock lock(mutex); + return populate_include_exclude_fields(include_fields, exclude_fields, include_fields_full, exclude_fields_full); } \ No newline at end of file diff --git a/src/core_api.cpp b/src/core_api.cpp index 9ea14369..48839d3d 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -603,6 +603,8 @@ bool get_export_documents(const std::shared_ptr& req, const std::share req->data = export_state; std::string simple_filter_query; + spp::sparse_hash_set exclude_fields; + spp::sparse_hash_set include_fields; if(req->params.count(FILTER_BY) != 0) { simple_filter_query = req->params[FILTER_BY]; @@ -611,15 +613,18 @@ bool get_export_documents(const std::shared_ptr& req, const std::share if(req->params.count(INCLUDE_FIELDS) != 0) { std::vector include_fields_vec; StringUtils::split(req->params[INCLUDE_FIELDS], include_fields_vec, ","); - export_state->include_fields = std::set(include_fields_vec.begin(), include_fields_vec.end()); + include_fields = spp::sparse_hash_set(include_fields_vec.begin(), include_fields_vec.end()); } if(req->params.count(EXCLUDE_FIELDS) != 0) { std::vector exclude_fields_vec; StringUtils::split(req->params[EXCLUDE_FIELDS], exclude_fields_vec, ","); - export_state->exclude_fields = std::set(exclude_fields_vec.begin(), exclude_fields_vec.end()); + exclude_fields = spp::sparse_hash_set(exclude_fields_vec.begin(), exclude_fields_vec.end()); } + collection->populate_include_exclude_fields_lk(include_fields, exclude_fields, + export_state->include_fields, export_state->exclude_fields); + if(req->params.count(BATCH_SIZE) != 0 && StringUtils::is_uint32_t(req->params[BATCH_SIZE])) { export_state->export_batch_size = std::stoul(req->params[BATCH_SIZE]); } @@ -659,20 +664,8 @@ bool get_export_documents(const std::shared_ptr& req, const std::share res->body += it->value().ToString(); } else { nlohmann::json doc = nlohmann::json::parse(it->value().ToString()); - nlohmann::json filtered_doc; - for(const auto& kv: doc.items()) { - bool must_include = export_state->include_fields.empty() || - (export_state->include_fields.count(kv.key()) != 0); - - bool must_exclude = !export_state->exclude_fields.empty() && - (export_state->exclude_fields.count(kv.key()) != 0); - - if(must_include && !must_exclude) { - filtered_doc[kv.key()] = kv.value(); - } - } - - res->body += filtered_doc.dump(); + Collection::prune_doc(doc, export_state->include_fields, export_state->exclude_fields); + res->body += doc.dump(); } it->Next(); diff --git a/src/core_api_utils.cpp b/src/core_api_utils.cpp index 7c88adec..fda1777b 100644 --- a/src/core_api_utils.cpp +++ b/src/core_api_utils.cpp @@ -66,20 +66,9 @@ Option stateful_export_docs(export_state_t* export_state, size_t batch_siz if(export_state->include_fields.empty() && export_state->exclude_fields.empty()) { export_state->res_body->append(doc.dump()); } else { - nlohmann::json filtered_doc; - for(const auto& kv: doc.items()) { - bool must_include = export_state->include_fields.empty() || - (export_state->include_fields.count(kv.key()) != 0); - - bool must_exclude = !export_state->exclude_fields.empty() && - (export_state->exclude_fields.count(kv.key()) != 0); - - if(must_include && !must_exclude) { - filtered_doc[kv.key()] = kv.value(); - } - } - - export_state->res_body->append(filtered_doc.dump()); + Collection::remove_flat_fields(doc); + Collection::prune_doc(doc, export_state->include_fields, export_state->exclude_fields); + export_state->res_body->append(doc.dump()); } export_state->res_body->append("\n"); diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index d8d8787b..2545fe13 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -559,3 +559,124 @@ TEST_F(CoreAPIUtilsTest, ExportWithFilter) { ASSERT_TRUE(done); ASSERT_EQ('}', export_state.res_body->back()); } + +TEST_F(CoreAPIUtilsTest, ExportIncludeExcludeFields) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name": "name", "type": "object" }, + {"name": "points", "type": "int32" } + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll1 = op.get(); + + auto doc1 = R"({ + "name": {"first": "John", "last": "Smith"}, + "points": 100 + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + std::shared_ptr req = std::make_shared(); + std::shared_ptr res = std::make_shared(nullptr); + req->params["collection"] = "coll1"; + + // include fields + + req->params["include_fields"] = "name.last"; + + get_export_documents(req, res); + + std::vector res_strs; + StringUtils::split(res->body, res_strs, "\n"); + nlohmann::json doc = nlohmann::json::parse(res_strs[0]); + ASSERT_EQ(1, doc.size()); + ASSERT_EQ(1, doc.count("name")); + ASSERT_EQ(1, doc["name"].count("last")); + + // exclude fields + + delete dynamic_cast(req->data); + req->data = nullptr; + res->body.clear(); + req->params.erase("include_fields"); + req->params["exclude_fields"] = "name.last"; + get_export_documents(req, res); + + res_strs.clear(); + StringUtils::split(res->body, res_strs, "\n"); + doc = nlohmann::json::parse(res_strs[0]); + ASSERT_EQ(3, doc.size()); + ASSERT_EQ(1, doc.count("id")); + ASSERT_EQ(1, doc.count("points")); + ASSERT_EQ(1, doc.count("name")); + ASSERT_EQ(1, doc["name"].count("first")); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CoreAPIUtilsTest, ExportIncludeExcludeFieldsWithFilter) { + nlohmann::json schema = R"({ + "name": "coll1", + "enable_nested_fields": true, + "fields": [ + {"name": "name", "type": "object" }, + {"name": "points", "type": "int32" } + ] + })"_json; + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll1 = op.get(); + + auto doc1 = R"({ + "name": {"first": "John", "last": "Smith"}, + "points": 100 + })"_json; + + auto add_op = coll1->add(doc1.dump(), CREATE); + ASSERT_TRUE(add_op.ok()); + + std::shared_ptr req = std::make_shared(); + std::shared_ptr res = std::make_shared(nullptr); + req->params["collection"] = "coll1"; + + // include fields + + req->params["include_fields"] = "name.last"; + req->params["filter_by"] = "points:>=0"; + + get_export_documents(req, res); + + std::vector res_strs; + StringUtils::split(res->body, res_strs, "\n"); + nlohmann::json doc = nlohmann::json::parse(res_strs[0]); + ASSERT_EQ(1, doc.size()); + ASSERT_EQ(1, doc.count("name")); + ASSERT_EQ(1, doc["name"].count("last")); + + // exclude fields + + delete dynamic_cast(req->data); + req->data = nullptr; + res->body.clear(); + req->params.erase("include_fields"); + req->params["exclude_fields"] = "name.last"; + get_export_documents(req, res); + + res_strs.clear(); + StringUtils::split(res->body, res_strs, "\n"); + doc = nlohmann::json::parse(res_strs[0]); + ASSERT_EQ(3, doc.size()); + ASSERT_EQ(1, doc.count("id")); + ASSERT_EQ(1, doc.count("points")); + ASSERT_EQ(1, doc.count("name")); + ASSERT_EQ(1, doc["name"].count("first")); + + collectionManager.drop_collection("coll1"); +} \ No newline at end of file