diff --git a/include/field.h b/include/field.h index 8ebd55ba..09b77576 100644 --- a/include/field.h +++ b/include/field.h @@ -601,6 +601,7 @@ struct sort_by { eval_t eval; std::string reference_collection_name; + std::vector nested_join_collection_names; sort_vector_query_t vector_query; sort_by(const std::string & name, const std::string & order): @@ -636,10 +637,15 @@ struct sort_by { missing_values = other.missing_values; eval = other.eval; reference_collection_name = other.reference_collection_name; + nested_join_collection_names = other.nested_join_collection_names; vector_query = other.vector_query; } sort_by& operator=(const sort_by& other) { + if (&other == this) { + return *this; + } + name = other.name; eval_expressions = other.eval_expressions; order = other.order; @@ -651,8 +657,13 @@ struct sort_by { missing_values = other.missing_values; eval = other.eval; reference_collection_name = other.reference_collection_name; + nested_join_collection_names = other.nested_join_collection_names; return *this; } + + [[nodiscard]] inline bool is_nested_join_sort_by() const { + return !nested_join_collection_names.empty(); + } }; class GeoPoint { diff --git a/include/index.h b/include/index.h index 51144adc..ee415de2 100644 --- a/include/index.h +++ b/include/index.h @@ -1002,6 +1002,10 @@ public: bool enable_typos_for_numerical_tokens, bool enable_typos_for_alpha_numerical_tokens) const; + Option ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id, + bool& reference_found, const std::map, reference_filter_result_t>& references, + const std::string& collection_name) const; + Option compute_sort_scores(const std::vector& sort_fields, const int* sort_order, std::array*, 3> field_values, const std::vector& geopoint_indices, uint32_t seq_id, diff --git a/src/collection.cpp b/src/collection.cpp index 9af72dcc..282eca0b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1234,8 +1234,24 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries); + + std::vector nested_join_coll_names; + for (auto const& coll_name: _sort_field.nested_join_collection_names) { + auto coll = cm.get_collection(coll_name); + if (coll == nullptr) { + return Option(400, "Referenced collection `" + coll_name + "` in `sort_by` not found."); + } + // `CollectionManager::get_collection` accounts for collection alias being used and provides pointer to the + // original collection. + nested_join_coll_names.emplace_back(coll->name); + } + for (auto& ref_sort_field_std: ref_sort_fields_std) { ref_sort_field_std.reference_collection_name = ref_collection_name; + ref_sort_field_std.nested_join_collection_names.insert(ref_sort_field_std.nested_join_collection_names.begin(), + nested_join_coll_names.begin(), + nested_join_coll_names.end()); + sort_fields_std.emplace_back(ref_sort_field_std); } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 6d3fb453..5e189d46 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -894,9 +894,82 @@ bool parse_eval(const std::string& sort_by_str, uint32_t& index, std::vector& sort_fields) { + if (sort_by_str[index] != '$') { + return false; + } + + std::string sort_field_expr; + char prev_non_space_char = '`'; + + auto open_paren_pos = sort_by_str.find('(', index); + if (open_paren_pos == std::string::npos) { + return false; + } + + auto const& collection_name = sort_by_str.substr(index + 1, open_paren_pos - index - 1); + index = open_paren_pos; + int paren_count = 1; + while (++index < sort_by_str.size() && paren_count > 0) { + if (sort_by_str[index] == '(') { + paren_count++; + } else if (sort_by_str[index] == ')' && --paren_count == 0) { + break; + } + + if (sort_by_str[index] == '$' && (prev_non_space_char == '`' || prev_non_space_char == ',')) { + // Nested join sort_by + + // Process the sort fields provided up until now. + if (!sort_field_expr.empty()) { + sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", ""); + auto& collection_names = sort_fields.back().nested_join_collection_names; + collection_names.insert(collection_names.begin(), parent_coll_name); + collection_names.emplace_back(collection_name); + + sort_field_expr.clear(); + } + + auto prev_size = sort_fields.size(); + if (!parse_nested_join_sort_by_str(sort_by_str, index, collection_name, sort_fields)) { + return false; + } + + for (; prev_size < sort_fields.size(); prev_size++) { + auto& collection_names = sort_fields[prev_size].nested_join_collection_names; + collection_names.insert(collection_names.begin(), parent_coll_name); + } + + continue; + } + sort_field_expr += sort_by_str[index]; + if (sort_by_str[index] != ' ') { + prev_non_space_char = sort_by_str[index]; + } + } + if (paren_count != 0) { + return false; + } + + if (!sort_field_expr.empty()) { + sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", ""); + auto& collection_names = sort_fields.back().nested_join_collection_names; + collection_names.insert(collection_names.begin(), parent_coll_name); + collection_names.emplace_back(collection_name); + } + + // Skip the space in between the sort_by expressions. + while (index + 1 < sort_by_str.size() && (sort_by_str[index + 1] == ' ' || sort_by_str[index + 1] == ',')) { + index++; + } + + return true; +} + bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector& sort_fields) { std::string sort_field_expr; - char prev_non_space_char = 'a'; + char prev_non_space_char = '`'; for(uint32_t i=0; i < sort_by_str.size(); i++) { if (sort_field_expr.empty()) { @@ -906,27 +979,50 @@ bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector 0) { if (sort_by_str[i] == '(') { paren_count++; - } else if (sort_by_str[i] == ')') { - paren_count--; + } else if (sort_by_str[i] == ')' && --paren_count == 0) { + break; + } + + if (sort_by_str[i] == '$' && (prev_non_space_char == '`' || prev_non_space_char == ',')) { + // Nested join sort_by + + // Process the sort fields provided up until now. Doing this step to maintain the order of sort_by + // as specified. Eg, `$Customers(product_price:DESC, $foo(bar:asc))` should result into + // {`$Customers(product_price:DESC)`, `$Customers($foo(bar:asc))`} and not the other way around. + if (!sort_field_expr.empty()) { + sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", ""); + sort_field_expr.clear(); + } + + if (!parse_nested_join_sort_by_str(sort_by_str, i, collection_name, sort_fields)) { + return false; + } + + continue; } sort_field_expr += sort_by_str[i]; + if (sort_by_str[i] != ' ') { + prev_non_space_char = sort_by_str[i]; + } } if (paren_count != 0) { return false; } - sort_fields.emplace_back(sort_field_expr, ""); - sort_field_expr = ""; + if (!sort_field_expr.empty()) { + sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", ""); + sort_field_expr.clear(); + } // Skip the space in between the sort_by expressions. - while (i + 1 < sort_by_str.size() && sort_by_str[i + 1] == ' ') { + while (i + 1 < sort_by_str.size() && (sort_by_str[i + 1] == ' ' || sort_by_str[i + 1] == ',')) { i++; } continue; diff --git a/src/index.cpp b/src/index.cpp index 31cd4056..eb15d422 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4675,6 +4675,192 @@ Option Index::search_across_fields(const std::vector& query_token return Option(true); } +Option Index::ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id, + bool& reference_found, const std::map, reference_filter_result_t>& references, + const std::string& collection_name) const { + auto const& ref_collection_name = sort_field.reference_collection_name; + auto const& multiple_references_error_message = "Multiple references found to sort by on `" + + ref_collection_name + "." + sort_field.name + "`."; + auto const& no_references_error_message = "No references found to sort by on `" + + ref_collection_name + "." + sort_field.name + "`."; + + if (sort_field.is_nested_join_sort_by()) { + // Get the reference doc_id by following through all the nested join collections. + ref_seq_id = seq_id; + std::string prev_coll_name = collection_name; + for (const auto &coll_name: sort_field.nested_join_collection_names) { + // Joined on ref collection + if (references.count(coll_name) > 0) { + auto const& count = references.at(coll_name).count; + if (count == 0) { + reference_found = false; + break; + } else if (count == 1) { + ref_seq_id = references.at(coll_name).docs[0]; + } else { + return Option(400, multiple_references_error_message); + } + } else { + auto& cm = CollectionManager::get_instance(); + auto ref_collection = cm.get_collection(coll_name); + if (ref_collection == nullptr) { + return Option(400, "Referenced collection `" + coll_name + + "` in `sort_by` not found."); + } + + // Current collection has a reference. + if (ref_collection->is_referenced_in(prev_coll_name)) { + auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(prev_coll_name); + if (!get_reference_field_op.ok()) { + return Option(get_reference_field_op.code(), get_reference_field_op.error()); + } + auto const& field_name = get_reference_field_op.get(); + + auto prev_coll = cm.get_collection(prev_coll_name); + if (prev_coll == nullptr) { + return Option(400, "Referenced collection `" + prev_coll_name + + "` in `sort_by` not found."); + } + + auto sort_index_op = prev_coll->get_sort_index_value_with_lock(field_name, ref_seq_id); + if (!sort_index_op.ok()) { + if (sort_index_op.code() == 400) { + return Option(400, sort_index_op.error()); + } + reference_found = false; + break; + } else { + ref_seq_id = sort_index_op.get(); + } + } + // Joined collection has a reference + else { + std::string joined_coll_having_reference; + for (const auto &reference: references) { + if (ref_collection->is_referenced_in(reference.first)) { + joined_coll_having_reference = reference.first; + break; + } + } + + if (joined_coll_having_reference.empty()) { + return Option(400, no_references_error_message); + } + + auto joined_collection = cm.get_collection(joined_coll_having_reference); + if (joined_collection == nullptr) { + return Option(400, "Referenced collection `" + joined_coll_having_reference + + "` in `sort_by` not found."); + } + + auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); + if (!reference_field_name_op.ok()) { + return Option(reference_field_name_op.code(), reference_field_name_op.error()); + } + + auto const& reference_field_name = reference_field_name_op.get(); + auto const& reference = references.at(joined_coll_having_reference); + auto const& count = reference.count; + + if (count == 0) { + reference_found = false; + break; + } else if (count == 1) { + auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name, + reference.docs[0]); + if (!op.ok()) { + return Option(op.code(), op.error()); + } + + ref_seq_id = op.get(); + } else { + return Option(400, multiple_references_error_message); + } + } + } + + prev_coll_name = coll_name; + } + } else if (references.count(ref_collection_name) > 0) { // Joined on ref collection + auto const& count = references.at(ref_collection_name).count; + if (count == 0) { + reference_found = false; + } else if (count == 1) { + ref_seq_id = references.at(ref_collection_name).docs[0]; + } else { + return Option(400, multiple_references_error_message); + } + } else { + auto& cm = CollectionManager::get_instance(); + auto ref_collection = cm.get_collection(ref_collection_name); + if (ref_collection == nullptr) { + return Option(400, "Referenced collection `" + ref_collection_name + + "` in `sort_by` not found."); + } + + // Current collection has a reference. + if (ref_collection->is_referenced_in(collection_name)) { + auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name); + if (!get_reference_field_op.ok()) { + return Option(get_reference_field_op.code(), get_reference_field_op.error()); + } + auto const& field_name = get_reference_field_op.get(); + if (sort_index.count(field_name) == 0) { + return Option(400, "Could not find `" + field_name + "` in sort_index."); + } else if (sort_index.at(field_name)->count(seq_id) == 0) { + reference_found = false; + } else { + ref_seq_id = sort_index.at(field_name)->at(seq_id); + } + } + // Joined collection has a reference + else { + std::string joined_coll_having_reference; + for (const auto &reference: references) { + if (ref_collection->is_referenced_in(reference.first)) { + joined_coll_having_reference = reference.first; + break; + } + } + + if (joined_coll_having_reference.empty()) { + return Option(400, no_references_error_message); + } + + auto joined_collection = cm.get_collection(joined_coll_having_reference); + if (joined_collection == nullptr) { + return Option(400, "Referenced collection `" + joined_coll_having_reference + + "` in `sort_by` not found."); + } + + auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); + if (!reference_field_name_op.ok()) { + return Option(reference_field_name_op.code(), reference_field_name_op.error()); + } + + auto const& reference_field_name = reference_field_name_op.get(); + auto const& reference = references.at(joined_coll_having_reference); + auto const& count = reference.count; + + if (count == 0) { + reference_found = false; + } else if (count == 1) { + auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name, + reference.docs[0]); + if (!op.ok()) { + return Option(op.code(), op.error()); + } + + ref_seq_id = op.get(); + } else { + return Option(400, multiple_references_error_message); + } + } + } + + return Option(true); +} + Option Index::compute_sort_scores(const std::vector& sort_fields, const int* sort_order, std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -4744,89 +4930,10 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, // In case of reference sort_by, we need to get the sort score of the reference doc id. if (!sort_fields[0].reference_collection_name.empty()) { - auto& sort_field = sort_fields[0]; - auto const& ref_collection_name = sort_field.reference_collection_name; - auto const& multiple_references_error_message = "Multiple references found to sort by on `" + - ref_collection_name + "." + sort_field.name + "`."; - auto const& no_references_error_message = "No references found to sort by on `" + - ref_collection_name + "." + sort_field.name + "`."; - - // Joined on ref collection - if (references.count(ref_collection_name) > 0) { - auto const& count = references.at(ref_collection_name).count; - if (count == 0) { - reference_found = false; - } else if (count == 1) { - ref_seq_id = references.at(ref_collection_name).docs[0]; - } else { - return Option(400, multiple_references_error_message); - } - } else { - auto& cm = CollectionManager::get_instance(); - auto ref_collection = cm.get_collection(ref_collection_name); - if (ref_collection == nullptr) { - return Option(400, "Referenced collection `" + ref_collection_name + - "` in `sort_by` not found."); - } - - // Current collection has a reference. - if (ref_collection->is_referenced_in(collection_name)) { - auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name); - if (!get_reference_field_op.ok()) { - return Option(get_reference_field_op.code(), get_reference_field_op.error()); - } - auto const& field_name = get_reference_field_op.get(); - if (sort_index.count(field_name) == 0) { - return Option(400, "Could not find `" + field_name + "` in sort_index."); - } else if (sort_index.at(field_name)->count(seq_id) == 0) { - reference_found = false; - } else { - ref_seq_id = sort_index.at(field_name)->at(seq_id); - } - } - // Joined collection has a reference - else { - std::string joined_coll_having_reference; - for (const auto &reference: references) { - if (ref_collection->is_referenced_in(reference.first)) { - joined_coll_having_reference = reference.first; - break; - } - } - - if (joined_coll_having_reference.empty()) { - return Option(400, no_references_error_message); - } - - auto joined_collection = cm.get_collection(joined_coll_having_reference); - if (joined_collection == nullptr) { - return Option(400, "Referenced collection `" + joined_coll_having_reference + - "` in `sort_by` not found."); - } - - auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); - if (!reference_field_name_op.ok()) { - return Option(reference_field_name_op.code(), reference_field_name_op.error()); - } - - auto const& reference_field_name = reference_field_name_op.get(); - auto const& reference = references.at(joined_coll_having_reference); - auto const& count = reference.count; - - if (count == 0) { - reference_found = false; - } else if (count == 1) { - auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name, - reference.docs[0]); - if (!op.ok()) { - return Option(op.code(), op.error()); - } - - ref_seq_id = op.get(); - } else { - return Option(400, multiple_references_error_message); - } - } + auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[0], seq_id, ref_seq_id, reference_found, + references, collection_name); + if (!ref_compute_op.ok()) { + return ref_compute_op; } } @@ -4932,89 +5039,10 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, // In case of reference sort_by, we need to get the sort score of the reference doc id. if (!sort_fields[1].reference_collection_name.empty()) { - auto& sort_field = sort_fields[1]; - auto const& ref_collection_name = sort_field.reference_collection_name; - auto const& multiple_references_error_message = "Multiple references found to sort by on `" + - ref_collection_name + "." + sort_field.name + "`."; - auto const& no_references_error_message = "No references found to sort by on `" + - ref_collection_name + "." + sort_field.name + "`."; - - // Joined on ref collection - if (references.count(ref_collection_name) > 0) { - auto const& count = references.at(ref_collection_name).count; - if (count == 0) { - reference_found = false; - } else if (count == 1) { - ref_seq_id = references.at(ref_collection_name).docs[0]; - } else { - return Option(400, multiple_references_error_message); - } - } else { - auto& cm = CollectionManager::get_instance(); - auto ref_collection = cm.get_collection(ref_collection_name); - if (ref_collection == nullptr) { - return Option(400, "Referenced collection `" + ref_collection_name + - "` in `sort_by` not found."); - } - - // Current collection has a reference. - if (ref_collection->is_referenced_in(collection_name)) { - auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name); - if (!get_reference_field_op.ok()) { - return Option(get_reference_field_op.code(), get_reference_field_op.error()); - } - auto const& field_name = get_reference_field_op.get(); - if (sort_index.count(field_name) == 0) { - return Option(400, "Could not find `" + field_name + "` in sort_index."); - } else if (sort_index.at(field_name)->count(seq_id) == 0) { - reference_found = false; - } else { - ref_seq_id = sort_index.at(field_name)->at(seq_id); - } - } - // Joined collection has a reference - else { - std::string joined_coll_having_reference; - for (const auto &reference: references) { - if (ref_collection->is_referenced_in(reference.first)) { - joined_coll_having_reference = reference.first; - break; - } - } - - if (joined_coll_having_reference.empty()) { - return Option(400, no_references_error_message); - } - - auto joined_collection = cm.get_collection(joined_coll_having_reference); - if (joined_collection == nullptr) { - return Option(400, "Referenced collection `" + joined_coll_having_reference + - "` in `sort_by` not found."); - } - - auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); - if (!reference_field_name_op.ok()) { - return Option(reference_field_name_op.code(), reference_field_name_op.error()); - } - - auto const& reference_field_name = reference_field_name_op.get(); - auto const& reference = references.at(joined_coll_having_reference); - auto const& count = reference.count; - - if (count == 0) { - reference_found = false; - } else if (count == 1) { - auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name, - reference.docs[0]); - if (!op.ok()) { - return Option(op.code(), op.error()); - } - - ref_seq_id = op.get(); - } else { - return Option(400, multiple_references_error_message); - } - } + auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[1], seq_id, ref_seq_id, reference_found, + references, collection_name); + if (!ref_compute_op.ok()) { + return ref_compute_op; } } @@ -5118,89 +5146,10 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, // In case of reference sort_by, we need to get the sort score of the reference doc id. if (!sort_fields[2].reference_collection_name.empty()) { - auto& sort_field = sort_fields[2]; - auto const& ref_collection_name = sort_field.reference_collection_name; - auto const& multiple_references_error_message = "Multiple references found to sort by on `" + - ref_collection_name + "." + sort_field.name + "`."; - auto const& no_references_error_message = "No references found to sort by on `" + - ref_collection_name + "." + sort_field.name + "`."; - - // Joined on ref collection - if (references.count(ref_collection_name) > 0) { - auto const& count = references.at(ref_collection_name).count; - if (count == 0) { - reference_found = false; - } else if (count == 1) { - ref_seq_id = references.at(ref_collection_name).docs[0]; - } else { - return Option(400, multiple_references_error_message); - } - } else { - auto& cm = CollectionManager::get_instance(); - auto ref_collection = cm.get_collection(ref_collection_name); - if (ref_collection == nullptr) { - return Option(400, "Referenced collection `" + ref_collection_name + - "` in `sort_by` not found."); - } - - // Current collection has a reference. - if (ref_collection->is_referenced_in(collection_name)) { - auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name); - if (!get_reference_field_op.ok()) { - return Option(get_reference_field_op.code(), get_reference_field_op.error()); - } - auto const& field_name = get_reference_field_op.get(); - if (sort_index.count(field_name) == 0) { - return Option(400, "Could not find `" + field_name + "` in sort_index."); - } else if (sort_index.at(field_name)->count(seq_id) == 0) { - reference_found = false; - } else { - ref_seq_id = sort_index.at(field_name)->at(seq_id); - } - } - // Joined collection has a reference - else { - std::string joined_coll_having_reference; - for (const auto &reference: references) { - if (ref_collection->is_referenced_in(reference.first)) { - joined_coll_having_reference = reference.first; - break; - } - } - - if (joined_coll_having_reference.empty()) { - return Option(400, no_references_error_message); - } - - auto joined_collection = cm.get_collection(joined_coll_having_reference); - if (joined_collection == nullptr) { - return Option(400, "Referenced collection `" + joined_coll_having_reference + - "` in `sort_by` not found."); - } - - auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference); - if (!reference_field_name_op.ok()) { - return Option(reference_field_name_op.code(), reference_field_name_op.error()); - } - - auto const& reference_field_name = reference_field_name_op.get(); - auto const& reference = references.at(joined_coll_having_reference); - auto const& count = reference.count; - - if (count == 0) { - reference_found = false; - } else if (count == 1) { - auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name, - reference.docs[0]); - if (!op.ok()) { - return Option(op.code(), op.error()); - } - - ref_seq_id = op.get(); - } else { - return Option(400, multiple_references_error_message); - } - } + auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[2], seq_id, ref_seq_id, reference_found, + references, collection_name); + if (!ref_compute_op.ok()) { + return ref_compute_op; } } @@ -7657,7 +7606,7 @@ Option Index::get_sort_index_value_with_lock(const std::string& collec return Option(400, "Cannot sort on `" + field_name + "` in the collection, `" + collection_name + "` is `" + search_schema.at(field_name).type + "`."); } else if (sort_index.count(field_name) == 0 || sort_index.at(field_name)->count(seq_id) == 0) { - return Option(400, "Could not find `" + field_name + "` value for doc `" + + return Option(404, "Could not find `" + field_name + "` value for doc `" + std::to_string(seq_id) + "`.");; } diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 01560769..5e7e6b9c 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -5544,29 +5544,6 @@ TEST_F(CollectionJoinTest, SortByReference) { ASSERT_FALSE(search_op.ok()); ASSERT_EQ("Multiple references found to sort by on `Customers.product_price`.", search_op.error()); - schema_json = - R"({ - "name": "Ads", - "fields": [ - {"name": "id", "type": "string"} - ] - })"_json; - documents = { - R"({ - "id": "ad_a" - })"_json - }; - collection_create_op = collectionManager.create_collection(schema_json); - ASSERT_TRUE(collection_create_op.ok()); - - for (auto const &json: documents) { - auto add_op = collection_create_op.get()->add(json.dump()); - if (!add_op.ok()) { - LOG(INFO) << add_op.error(); - } - ASSERT_TRUE(add_op.ok()); - } - schema_json = R"({ "name": "Structures", @@ -5596,6 +5573,35 @@ TEST_F(CollectionJoinTest, SortByReference) { ASSERT_TRUE(add_op.ok()); } + schema_json = + R"({ + "name": "Ads", + "fields": [ + {"name": "id", "type": "string"}, + {"name": "structure", "type": "string", "reference": "Structures.id"} + ] + })"_json; + documents = { + R"({ + "id": "ad_a", + "structure": "struct_b" + })"_json, + R"({ + "id": "ad_b", + "structure": "struct_a" + })"_json + }; + collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + + for (auto const &json: documents) { + auto add_op = collection_create_op.get()->add(json.dump()); + if (!add_op.ok()) { + LOG(INFO) << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); + } + schema_json = R"({ "name": "Candidates", @@ -5606,13 +5612,16 @@ TEST_F(CollectionJoinTest, SortByReference) { })"_json; documents = { R"({ - "structure": "struct_a" + "structure": "struct_b" })"_json, R"({ "ad": "ad_a" })"_json, R"({ - "structure": "struct_b" + "structure": "struct_a" + })"_json, + R"({ + "ad": "ad_b" })"_json }; collection_create_op = collectionManager.create_collection(schema_json); @@ -5630,25 +5639,67 @@ TEST_F(CollectionJoinTest, SortByReference) { {"collection", "Candidates"}, {"q", "*"}, {"filter_by", "$Ads(id:*) || $Structures(id:*)"}, - {"sort_by", "$Structures(name: asc)"} + {"sort_by", "$Structures(name: asc)"}, + {"include_fields", "$Ads($Structures(*))"} }; search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); ASSERT_TRUE(search_op.ok()); res_obj = nlohmann::json::parse(json_res); - ASSERT_EQ(3, res_obj["found"].get()); - ASSERT_EQ(3, res_obj["hits"].size()); - ASSERT_EQ("2", res_obj["hits"][0]["document"].at("id")); + ASSERT_EQ(4, res_obj["found"].get()); + ASSERT_EQ(4, res_obj["hits"].size()); + ASSERT_EQ("0", res_obj["hits"][0]["document"].at("id")); ASSERT_EQ("bar", res_obj["hits"][0]["document"]["Structures"].at("name")); ASSERT_EQ(0, res_obj["hits"][0]["document"].count("Ads")); - ASSERT_EQ("0", res_obj["hits"][1]["document"].at("id")); + ASSERT_EQ("2", res_obj["hits"][1]["document"].at("id")); ASSERT_EQ("foo", res_obj["hits"][1]["document"]["Structures"].at("name")); ASSERT_EQ(0, res_obj["hits"][1]["document"].count("Ads")); - ASSERT_EQ("1", res_obj["hits"][2]["document"].at("id")); + ASSERT_EQ("3", res_obj["hits"][2]["document"].at("id")); ASSERT_EQ(0, res_obj["hits"][2]["document"].count("Structures")); ASSERT_EQ(1, res_obj["hits"][2]["document"].count("Ads")); + ASSERT_EQ(1, res_obj["hits"][2]["document"]["Ads"].count("Structures")); + ASSERT_EQ("foo", res_obj["hits"][2]["document"]["Ads"]["Structures"]["name"]); + + ASSERT_EQ("1", res_obj["hits"][3]["document"].at("id")); + ASSERT_EQ(0, res_obj["hits"][3]["document"].count("Structures")); + ASSERT_EQ(1, res_obj["hits"][3]["document"].count("Ads")); + ASSERT_EQ(1, res_obj["hits"][3]["document"]["Ads"].count("Structures")); + ASSERT_EQ("bar", res_obj["hits"][3]["document"]["Ads"]["Structures"]["name"]); + + req_params = { + {"collection", "Candidates"}, + {"q", "*"}, + {"filter_by", "$Ads(id:*) || $Structures(id:*)"}, + {"sort_by", "$Ads($Structures(name: asc))"}, + {"include_fields", "$Ads($Structures(*))"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_TRUE(search_op.ok()); + + res_obj = nlohmann::json::parse(json_res); + ASSERT_EQ(4, res_obj["found"].get()); + ASSERT_EQ(4, res_obj["hits"].size()); + ASSERT_EQ("1", res_obj["hits"][0]["document"].at("id")); + ASSERT_EQ(0, res_obj["hits"][0]["document"].count("Structures")); + ASSERT_EQ(1, res_obj["hits"][0]["document"].count("Ads")); + ASSERT_EQ(1, res_obj["hits"][0]["document"]["Ads"].count("Structures")); + ASSERT_EQ("bar", res_obj["hits"][0]["document"]["Ads"]["Structures"]["name"]); + + ASSERT_EQ("3", res_obj["hits"][1]["document"].at("id")); + ASSERT_EQ(0, res_obj["hits"][1]["document"].count("Structures")); + ASSERT_EQ(1, res_obj["hits"][1]["document"].count("Ads")); + ASSERT_EQ(1, res_obj["hits"][1]["document"]["Ads"].count("Structures")); + ASSERT_EQ("foo", res_obj["hits"][1]["document"]["Ads"]["Structures"]["name"]); + + ASSERT_EQ("2", res_obj["hits"][2]["document"].at("id")); + ASSERT_EQ("foo", res_obj["hits"][2]["document"]["Structures"].at("name")); + ASSERT_EQ(0, res_obj["hits"][2]["document"].count("Ads")); + + ASSERT_EQ("0", res_obj["hits"][3]["document"].at("id")); + ASSERT_EQ("bar", res_obj["hits"][3]["document"]["Structures"].at("name")); + ASSERT_EQ(0, res_obj["hits"][3]["document"].count("Ads")); } TEST_F(CollectionJoinTest, FilterByReferenceAlias) { diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index ad985251..64861318 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -1309,6 +1309,45 @@ TEST_F(CollectionManagerTest, ParseSortByClause) { ASSERT_TRUE(sort_by_parsed); ASSERT_EQ("$foo( _eval(brand:nike && foo:bar):DESC,points:desc)", sort_fields[0].name); + sort_fields.clear(); + sort_by_parsed = CollectionManager::parse_sort_by_str("$Customers(product_price:DESC, $foo(bar:asc))", sort_fields); + ASSERT_TRUE(sort_by_parsed); + ASSERT_EQ(2, sort_fields.size()); + ASSERT_EQ("$Customers(product_price:DESC, )", sort_fields[0].name); + ASSERT_EQ("$foo(bar:asc)", sort_fields[1].name); + ASSERT_TRUE(sort_fields[1].is_nested_join_sort_by()); + ASSERT_EQ(2, sort_fields[1].nested_join_collection_names.size()); + ASSERT_EQ("Customers", sort_fields[1].nested_join_collection_names[0]); + ASSERT_EQ("foo", sort_fields[1].nested_join_collection_names[1]); + + sort_fields.clear(); + sort_by_parsed = CollectionManager::parse_sort_by_str("$foo($bar($baz(field:asc)))", sort_fields); + ASSERT_TRUE(sort_by_parsed); + ASSERT_EQ(1, sort_fields.size()); + ASSERT_EQ("$baz(field:asc)", sort_fields[0].name); + ASSERT_TRUE(sort_fields[0].is_nested_join_sort_by()); + ASSERT_EQ(3, sort_fields[0].nested_join_collection_names.size()); + ASSERT_EQ("foo", sort_fields[0].nested_join_collection_names[0]); + ASSERT_EQ("bar", sort_fields[0].nested_join_collection_names[1]); + ASSERT_EQ("baz", sort_fields[0].nested_join_collection_names[2]); + + sort_fields.clear(); + sort_by_parsed = CollectionManager::parse_sort_by_str("$Customers(product_price:DESC, $foo($bar( _eval(brand:nike && foo:bar):DESC), baz:asc))", sort_fields); + ASSERT_TRUE(sort_by_parsed); + ASSERT_EQ(3, sort_fields.size()); + ASSERT_EQ("$Customers(product_price:DESC, )", sort_fields[0].name); + ASSERT_EQ("$bar( _eval(brand:nike && foo:bar):DESC)", sort_fields[1].name); + ASSERT_TRUE(sort_fields[1].is_nested_join_sort_by()); + ASSERT_EQ(3, sort_fields[1].nested_join_collection_names.size()); + ASSERT_EQ("Customers", sort_fields[1].nested_join_collection_names[0]); + ASSERT_EQ("foo", sort_fields[1].nested_join_collection_names[1]); + ASSERT_EQ("bar", sort_fields[1].nested_join_collection_names[2]); + ASSERT_EQ("$foo(baz:asc)", sort_fields[2].name); + ASSERT_TRUE(sort_fields[2].is_nested_join_sort_by()); + ASSERT_EQ(2, sort_fields[2].nested_join_collection_names.size()); + ASSERT_EQ("Customers", sort_fields[2].nested_join_collection_names[0]); + ASSERT_EQ("foo", sort_fields[2].nested_join_collection_names[1]); + sort_fields.clear(); sort_by_parsed = CollectionManager::parse_sort_by_str("", sort_fields); ASSERT_TRUE(sort_by_parsed);