Fix "include_fields": "*, " only including indexed fields in the response.

This commit is contained in:
Harpreet Sangar 2023-09-13 13:09:59 +05:30
parent 270561e06b
commit 81af3643f7
5 changed files with 80 additions and 37 deletions

View File

@ -395,7 +395,8 @@ public:
const tsl::htrie_set<char>& exclude_names, const std::string& parent_name = "",
size_t depth = 0,
const std::map<std::string, reference_filter_result_t>& reference_filter_results = {},
Collection *const collection = nullptr, const uint32_t& seq_id = 0);
Collection *const collection = nullptr, const uint32_t& seq_id = 0,
const std::vector<std::string>& ref_include_fields_vec = {});
const Index* _get_index() const;
@ -494,7 +495,8 @@ public:
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_tries = 2,
const std::string& stopwords_set="",
const std::vector<std::string>& facet_return_parent = {}) const;
const std::vector<std::string>& facet_return_parent = {},
const std::vector<std::string>& ref_include_fields_vec = {}) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -1388,7 +1388,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries,
const std::string& stopwords_set,
const std::vector<std::string>& facet_return_parent) const {
const std::vector<std::string>& facet_return_parent,
const std::vector<std::string>& ref_include_fields_vec) const {
std::shared_lock lock(mutex);
@ -2214,7 +2215,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
"",
0,
field_order_kv->reference_filter_results,
const_cast<Collection *>(this), get_seq_id_from_key(seq_id_key));
const_cast<Collection *>(this), get_seq_id_from_key(seq_id_key),
ref_include_fields_vec);
if (!prune_op.ok()) {
return Option<nlohmann::json>(prune_op.code(), prune_op.error());
}
@ -4364,7 +4366,8 @@ Option<bool> Collection::prune_doc(nlohmann::json& doc,
const tsl::htrie_set<char>& exclude_names,
const std::string& parent_name, size_t depth,
const std::map<std::string, reference_filter_result_t>& reference_filter_results,
Collection *const collection, const uint32_t& seq_id) {
Collection *const collection, const uint32_t& seq_id,
const std::vector<std::string>& ref_includes) {
// doc can only be an object
auto it = doc.begin();
while(it != doc.end()) {
@ -4440,9 +4443,7 @@ Option<bool> Collection::prune_doc(nlohmann::json& doc,
it++;
}
auto include_reference_it_pair = include_names.equal_prefix_range("$");
for (auto reference = include_reference_it_pair.first; reference != include_reference_it_pair.second; reference++) {
auto ref = reference.key();
for (auto const& ref: ref_includes) {
size_t parenthesis_index = ref.find('(');
auto ref_collection_name = ref.substr(1, parenthesis_index - 1);
@ -4484,9 +4485,9 @@ Option<bool> Collection::prune_doc(nlohmann::json& doc,
StringUtils::split(reference_fields, ref_include_fields_vec, ",");
auto exclude_reference_it = exclude_names.equal_prefix_range("$" + ref_collection_name);
if (exclude_reference_it.first != exclude_reference_it.second) {
ref = exclude_reference_it.first.key();
parenthesis_index = ref.find('(');
reference_fields = ref.substr(parenthesis_index + 1, ref.size() - parenthesis_index - 2);
auto ref_exclude = exclude_reference_it.first.key();
parenthesis_index = ref_exclude.find('(');
reference_fields = ref_exclude.substr(parenthesis_index + 1, ref_exclude.size() - parenthesis_index - 2);
StringUtils::split(reference_fields, ref_exclude_fields_vec, ",");
}

View File

@ -836,21 +836,22 @@ void CollectionManager::_get_reference_collection_names(const std::string& filte
}
}
void initialize_include_fields_vec(const std::string& filter_query, std::vector<std::string>& include_fields_vec) {
if (filter_query.empty()) {
return;
}
// Separate out the reference includes into `ref_include_fields_vec`.
void initialize_ref_include_fields_vec(const std::string& filter_query, std::vector<std::string>& include_fields_vec,
std::vector<std::string>& ref_include_fields_vec) {
std::set<std::string> reference_collection_names;
CollectionManager::_get_reference_collection_names(filter_query, reference_collection_names);
if (reference_collection_names.empty()) {
return;
}
bool non_reference_include_found = false;
std::vector<std::string> result_include_fields_vec;
auto wildcard_include_all = true;
for (auto include_field: include_fields_vec) {
if (include_field[0] != '$') {
non_reference_include_found = true;
if (include_field == "*") {
continue;
}
wildcard_include_all = false;
result_include_fields_vec.emplace_back(include_field);
continue;
}
@ -865,19 +866,23 @@ void initialize_include_fields_vec(const std::string& filter_query, std::vector<
continue;
}
// Referenced collection in filter_query is already mentioned in include_fields.
ref_include_fields_vec.emplace_back(include_field);
// Referenced collection in filter_query is already mentioned in ref_include_fields.
reference_collection_names.erase(reference_collection_name);
}
// Get all the fields of the referenced collection in the filter but not mentioned in include_fields.
for (const auto &reference_collection_name: reference_collection_names) {
include_fields_vec.emplace_back("$" + reference_collection_name + "(*)");
ref_include_fields_vec.emplace_back("$" + reference_collection_name + "(*)");
}
// Since no field of the collection is mentioned in include_fields, get all the fields.
if (!include_fields_vec.empty() && !non_reference_include_found) {
include_fields_vec.emplace_back("*");
if (wildcard_include_all) {
result_include_fields_vec.clear();
}
include_fields_vec = std::move(result_include_fields_vec);
}
Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& req_params,
@ -1043,6 +1048,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
std::vector<std::string> include_fields_vec;
std::vector<std::string> exclude_fields_vec;
std::vector<std::string> ref_include_fields_vec;
spp::sparse_hash_set<std::string> include_fields;
spp::sparse_hash_set<std::string> exclude_fields;
@ -1235,7 +1241,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
per_page = 0;
}
initialize_include_fields_vec(filter_query, include_fields_vec);
initialize_ref_include_fields_vec(filter_query, include_fields_vec, ref_include_fields_vec);
include_fields.insert(include_fields_vec.begin(), include_fields_vec.end());
exclude_fields.insert(exclude_fields_vec.begin(), exclude_fields_vec.end());
@ -1326,7 +1332,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
remote_embedding_timeout_ms,
remote_embedding_num_tries,
stopwords_set,
facet_return_parent);
facet_return_parent,
ref_include_fields_vec);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin).count();

View File

@ -1167,7 +1167,8 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
// No fields are mentioned in `include_fields`, should include all fields of Products and Customers by default.
ASSERT_EQ(8, res_obj["hits"][0]["document"].size());
ASSERT_EQ(9, res_obj["hits"][0]["document"].size());
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("id"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_description"));
@ -1191,7 +1192,8 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
// No fields of Products collection are mentioned in `include_fields`, should include all of its fields by default.
ASSERT_EQ(4, res_obj["hits"][0]["document"].size());
ASSERT_EQ(5, res_obj["hits"][0]["document"].size());
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("id"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_description"));
@ -1210,7 +1212,7 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
ASSERT_EQ(5, res_obj["hits"][0]["document"].size());
ASSERT_EQ(6, res_obj["hits"][0]["document"].size());
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price"));
ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price"));
@ -1227,7 +1229,7 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
ASSERT_EQ(6, res_obj["hits"][0]["document"].size());
ASSERT_EQ(7, res_obj["hits"][0]["document"].size());
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price"));
ASSERT_EQ(73.5, res_obj["hits"][0]["document"].at("product_price"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("customer_id"));
@ -1246,8 +1248,8 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
// 4 fields in Products document and 2 fields from Customers document
ASSERT_EQ(6, res_obj["hits"][0]["document"].size());
// 5 fields in Products document and 2 fields from Customers document
ASSERT_EQ(7, res_obj["hits"][0]["document"].size());
req_params = {
{"collection", "Products"},
@ -1262,8 +1264,9 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
// 4 fields in Products document and 2 fields from Customers document
ASSERT_EQ(6, res_obj["hits"][0]["document"].size());
// 5 fields in Products document and 2 fields from Customers document
ASSERT_EQ(7, res_obj["hits"][0]["document"].size());
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_price"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id_sequence_id"));
req_params = {
@ -1280,8 +1283,8 @@ TEST_F(CollectionJoinTest, IncludeExcludeFieldsByReference) {
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
// 4 fields in Products document and 1 fields from Customers document
ASSERT_EQ(5, res_obj["hits"][0]["document"].size());
// 5 fields in Products document and 1 fields from Customers document
ASSERT_EQ(6, res_obj["hits"][0]["document"].size());
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_id"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_name"));
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("product_description"));

View File

@ -2879,6 +2879,36 @@ TEST_F(CollectionSpecificTest, NonIndexField) {
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ(1, coll1->_get_index()->_get_search_index().size());
std::map<std::string, std::string> req_params = {
{"collection", "coll1"},
{"q", "*"},
{"include_fields", "*, "}
};
nlohmann::json embedded_params;
std::string json_res;
auto now_ts = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count();
collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
results = nlohmann::json::parse(json_res);
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ(3, results["hits"][0].at("document").size());
ASSERT_EQ(1, results["hits"][0].at("document").count("description"));
req_params = {
{"collection", "coll1"},
{"q", "*"},
{"include_fields", "*, title"} // Adding a field name overrides include all wildcard
};
collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
results = nlohmann::json::parse(json_res);
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ(1, results["hits"][0].at("document").size());
ASSERT_EQ(1, results["hits"][0].at("document").count("title"));
collectionManager.drop_collection("coll1");
}