mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 13:42:26 +08:00
sort_by
nested join collection field. (#1793)
* sort_by nested join collection field. * Refactor reference sort logic.
This commit is contained in:
parent
b92bba5f3f
commit
6cb6a07891
@ -601,6 +601,7 @@ struct sort_by {
|
||||
eval_t eval;
|
||||
|
||||
std::string reference_collection_name;
|
||||
std::vector<std::string> nested_join_collection_names;
|
||||
sort_vector_query_t vector_query;
|
||||
|
||||
sort_by(const std::string & name, const std::string & order):
|
||||
@ -636,10 +637,15 @@ struct sort_by {
|
||||
missing_values = other.missing_values;
|
||||
eval = other.eval;
|
||||
reference_collection_name = other.reference_collection_name;
|
||||
nested_join_collection_names = other.nested_join_collection_names;
|
||||
vector_query = other.vector_query;
|
||||
}
|
||||
|
||||
sort_by& operator=(const sort_by& other) {
|
||||
if (&other == this) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
name = other.name;
|
||||
eval_expressions = other.eval_expressions;
|
||||
order = other.order;
|
||||
@ -651,8 +657,13 @@ struct sort_by {
|
||||
missing_values = other.missing_values;
|
||||
eval = other.eval;
|
||||
reference_collection_name = other.reference_collection_name;
|
||||
nested_join_collection_names = other.nested_join_collection_names;
|
||||
return *this;
|
||||
}
|
||||
|
||||
[[nodiscard]] inline bool is_nested_join_sort_by() const {
|
||||
return !nested_join_collection_names.empty();
|
||||
}
|
||||
};
|
||||
|
||||
class GeoPoint {
|
||||
|
@ -1002,6 +1002,10 @@ public:
|
||||
bool enable_typos_for_numerical_tokens,
|
||||
bool enable_typos_for_alpha_numerical_tokens) const;
|
||||
|
||||
Option<bool> ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id,
|
||||
bool& reference_found, const std::map<basic_string<char>, reference_filter_result_t>& references,
|
||||
const std::string& collection_name) const;
|
||||
|
||||
Option<bool> compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
|
||||
|
@ -1234,8 +1234,24 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
is_group_by_query,
|
||||
remote_embedding_timeout_ms,
|
||||
remote_embedding_num_tries);
|
||||
|
||||
std::vector<std::string> nested_join_coll_names;
|
||||
for (auto const& coll_name: _sort_field.nested_join_collection_names) {
|
||||
auto coll = cm.get_collection(coll_name);
|
||||
if (coll == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + coll_name + "` in `sort_by` not found.");
|
||||
}
|
||||
// `CollectionManager::get_collection` accounts for collection alias being used and provides pointer to the
|
||||
// original collection.
|
||||
nested_join_coll_names.emplace_back(coll->name);
|
||||
}
|
||||
|
||||
for (auto& ref_sort_field_std: ref_sort_fields_std) {
|
||||
ref_sort_field_std.reference_collection_name = ref_collection_name;
|
||||
ref_sort_field_std.nested_join_collection_names.insert(ref_sort_field_std.nested_join_collection_names.begin(),
|
||||
nested_join_coll_names.begin(),
|
||||
nested_join_coll_names.end());
|
||||
|
||||
sort_fields_std.emplace_back(ref_sort_field_std);
|
||||
}
|
||||
|
||||
|
@ -894,9 +894,82 @@ bool parse_eval(const std::string& sort_by_str, uint32_t& index, std::vector<sor
|
||||
return true;
|
||||
}
|
||||
|
||||
bool parse_nested_join_sort_by_str(const std::string& sort_by_str, uint32_t& index, const std::string& parent_coll_name,
|
||||
std::vector<sort_by>& sort_fields) {
|
||||
if (sort_by_str[index] != '$') {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string sort_field_expr;
|
||||
char prev_non_space_char = '`';
|
||||
|
||||
auto open_paren_pos = sort_by_str.find('(', index);
|
||||
if (open_paren_pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto const& collection_name = sort_by_str.substr(index + 1, open_paren_pos - index - 1);
|
||||
index = open_paren_pos;
|
||||
int paren_count = 1;
|
||||
while (++index < sort_by_str.size() && paren_count > 0) {
|
||||
if (sort_by_str[index] == '(') {
|
||||
paren_count++;
|
||||
} else if (sort_by_str[index] == ')' && --paren_count == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (sort_by_str[index] == '$' && (prev_non_space_char == '`' || prev_non_space_char == ',')) {
|
||||
// Nested join sort_by
|
||||
|
||||
// Process the sort fields provided up until now.
|
||||
if (!sort_field_expr.empty()) {
|
||||
sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", "");
|
||||
auto& collection_names = sort_fields.back().nested_join_collection_names;
|
||||
collection_names.insert(collection_names.begin(), parent_coll_name);
|
||||
collection_names.emplace_back(collection_name);
|
||||
|
||||
sort_field_expr.clear();
|
||||
}
|
||||
|
||||
auto prev_size = sort_fields.size();
|
||||
if (!parse_nested_join_sort_by_str(sort_by_str, index, collection_name, sort_fields)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (; prev_size < sort_fields.size(); prev_size++) {
|
||||
auto& collection_names = sort_fields[prev_size].nested_join_collection_names;
|
||||
collection_names.insert(collection_names.begin(), parent_coll_name);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
sort_field_expr += sort_by_str[index];
|
||||
if (sort_by_str[index] != ' ') {
|
||||
prev_non_space_char = sort_by_str[index];
|
||||
}
|
||||
}
|
||||
if (paren_count != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sort_field_expr.empty()) {
|
||||
sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", "");
|
||||
auto& collection_names = sort_fields.back().nested_join_collection_names;
|
||||
collection_names.insert(collection_names.begin(), parent_coll_name);
|
||||
collection_names.emplace_back(collection_name);
|
||||
}
|
||||
|
||||
// Skip the space in between the sort_by expressions.
|
||||
while (index + 1 < sort_by_str.size() && (sort_by_str[index + 1] == ' ' || sort_by_str[index + 1] == ',')) {
|
||||
index++;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector<sort_by>& sort_fields) {
|
||||
std::string sort_field_expr;
|
||||
char prev_non_space_char = 'a';
|
||||
char prev_non_space_char = '`';
|
||||
|
||||
for(uint32_t i=0; i < sort_by_str.size(); i++) {
|
||||
if (sort_field_expr.empty()) {
|
||||
@ -906,27 +979,50 @@ bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector<s
|
||||
if (open_paren_pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
sort_field_expr = sort_by_str.substr(i, open_paren_pos - i + 1);
|
||||
|
||||
auto const& collection_name = sort_by_str.substr(i + 1, open_paren_pos - i - 1);
|
||||
i = open_paren_pos;
|
||||
int paren_count = 1;
|
||||
while (++i < sort_by_str.size() && paren_count > 0) {
|
||||
if (sort_by_str[i] == '(') {
|
||||
paren_count++;
|
||||
} else if (sort_by_str[i] == ')') {
|
||||
paren_count--;
|
||||
} else if (sort_by_str[i] == ')' && --paren_count == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (sort_by_str[i] == '$' && (prev_non_space_char == '`' || prev_non_space_char == ',')) {
|
||||
// Nested join sort_by
|
||||
|
||||
// Process the sort fields provided up until now. Doing this step to maintain the order of sort_by
|
||||
// as specified. Eg, `$Customers(product_price:DESC, $foo(bar:asc))` should result into
|
||||
// {`$Customers(product_price:DESC)`, `$Customers($foo(bar:asc))`} and not the other way around.
|
||||
if (!sort_field_expr.empty()) {
|
||||
sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", "");
|
||||
sort_field_expr.clear();
|
||||
}
|
||||
|
||||
if (!parse_nested_join_sort_by_str(sort_by_str, i, collection_name, sort_fields)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
sort_field_expr += sort_by_str[i];
|
||||
if (sort_by_str[i] != ' ') {
|
||||
prev_non_space_char = sort_by_str[i];
|
||||
}
|
||||
}
|
||||
if (paren_count != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
sort_fields.emplace_back(sort_field_expr, "");
|
||||
sort_field_expr = "";
|
||||
if (!sort_field_expr.empty()) {
|
||||
sort_fields.emplace_back("$" + collection_name + "(" + sort_field_expr + ")", "");
|
||||
sort_field_expr.clear();
|
||||
}
|
||||
|
||||
// Skip the space in between the sort_by expressions.
|
||||
while (i + 1 < sort_by_str.size() && sort_by_str[i + 1] == ' ') {
|
||||
while (i + 1 < sort_by_str.size() && (sort_by_str[i + 1] == ' ' || sort_by_str[i + 1] == ',')) {
|
||||
i++;
|
||||
}
|
||||
continue;
|
||||
|
449
src/index.cpp
449
src/index.cpp
@ -4675,6 +4675,192 @@ Option<bool> Index::search_across_fields(const std::vector<token_t>& query_token
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Index::ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id,
|
||||
bool& reference_found, const std::map<basic_string<char>, reference_filter_result_t>& references,
|
||||
const std::string& collection_name) const {
|
||||
auto const& ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const& multiple_references_error_message = "Multiple references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
auto const& no_references_error_message = "No references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
|
||||
if (sort_field.is_nested_join_sort_by()) {
|
||||
// Get the reference doc_id by following through all the nested join collections.
|
||||
ref_seq_id = seq_id;
|
||||
std::string prev_coll_name = collection_name;
|
||||
for (const auto &coll_name: sort_field.nested_join_collection_names) {
|
||||
// Joined on ref collection
|
||||
if (references.count(coll_name) > 0) {
|
||||
auto const& count = references.at(coll_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
break;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(coll_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(coll_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + coll_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(prev_coll_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(prev_coll_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
|
||||
auto prev_coll = cm.get_collection(prev_coll_name);
|
||||
if (prev_coll == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + prev_coll_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto sort_index_op = prev_coll->get_sort_index_value_with_lock(field_name, ref_seq_id);
|
||||
if (!sort_index_op.ok()) {
|
||||
if (sort_index_op.code() == 400) {
|
||||
return Option<bool>(400, sort_index_op.error());
|
||||
}
|
||||
reference_found = false;
|
||||
break;
|
||||
} else {
|
||||
ref_seq_id = sort_index_op.get();
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
return Option<bool>(400, no_references_error_message);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
break;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prev_coll_name = coll_name;
|
||||
}
|
||||
} else if (references.count(ref_collection_name) > 0) { // Joined on ref collection
|
||||
auto const& count = references.at(ref_collection_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(ref_collection_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(ref_collection_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + ref_collection_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(collection_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
if (sort_index.count(field_name) == 0) {
|
||||
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
|
||||
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
|
||||
reference_found = false;
|
||||
} else {
|
||||
ref_seq_id = sort_index.at(field_name)->at(seq_id);
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
return Option<bool>(400, no_references_error_message);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
@ -4744,89 +4930,10 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
|
||||
// In case of reference sort_by, we need to get the sort score of the reference doc id.
|
||||
if (!sort_fields[0].reference_collection_name.empty()) {
|
||||
auto& sort_field = sort_fields[0];
|
||||
auto const& ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const& multiple_references_error_message = "Multiple references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
auto const& no_references_error_message = "No references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
|
||||
// Joined on ref collection
|
||||
if (references.count(ref_collection_name) > 0) {
|
||||
auto const& count = references.at(ref_collection_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(ref_collection_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(ref_collection_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + ref_collection_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(collection_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
if (sort_index.count(field_name) == 0) {
|
||||
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
|
||||
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
|
||||
reference_found = false;
|
||||
} else {
|
||||
ref_seq_id = sort_index.at(field_name)->at(seq_id);
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
return Option<bool>(400, no_references_error_message);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[0], seq_id, ref_seq_id, reference_found,
|
||||
references, collection_name);
|
||||
if (!ref_compute_op.ok()) {
|
||||
return ref_compute_op;
|
||||
}
|
||||
}
|
||||
|
||||
@ -4932,89 +5039,10 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
|
||||
// In case of reference sort_by, we need to get the sort score of the reference doc id.
|
||||
if (!sort_fields[1].reference_collection_name.empty()) {
|
||||
auto& sort_field = sort_fields[1];
|
||||
auto const& ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const& multiple_references_error_message = "Multiple references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
auto const& no_references_error_message = "No references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
|
||||
// Joined on ref collection
|
||||
if (references.count(ref_collection_name) > 0) {
|
||||
auto const& count = references.at(ref_collection_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(ref_collection_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(ref_collection_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + ref_collection_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(collection_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
if (sort_index.count(field_name) == 0) {
|
||||
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
|
||||
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
|
||||
reference_found = false;
|
||||
} else {
|
||||
ref_seq_id = sort_index.at(field_name)->at(seq_id);
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
return Option<bool>(400, no_references_error_message);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[1], seq_id, ref_seq_id, reference_found,
|
||||
references, collection_name);
|
||||
if (!ref_compute_op.ok()) {
|
||||
return ref_compute_op;
|
||||
}
|
||||
}
|
||||
|
||||
@ -5118,89 +5146,10 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
|
||||
// In case of reference sort_by, we need to get the sort score of the reference doc id.
|
||||
if (!sort_fields[2].reference_collection_name.empty()) {
|
||||
auto& sort_field = sort_fields[2];
|
||||
auto const& ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const& multiple_references_error_message = "Multiple references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
auto const& no_references_error_message = "No references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
|
||||
// Joined on ref collection
|
||||
if (references.count(ref_collection_name) > 0) {
|
||||
auto const& count = references.at(ref_collection_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(ref_collection_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(ref_collection_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + ref_collection_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(collection_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
if (sort_index.count(field_name) == 0) {
|
||||
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
|
||||
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
|
||||
reference_found = false;
|
||||
} else {
|
||||
ref_seq_id = sort_index.at(field_name)->at(seq_id);
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
return Option<bool>(400, no_references_error_message);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[2], seq_id, ref_seq_id, reference_found,
|
||||
references, collection_name);
|
||||
if (!ref_compute_op.ok()) {
|
||||
return ref_compute_op;
|
||||
}
|
||||
}
|
||||
|
||||
@ -7657,7 +7606,7 @@ Option<uint32_t> Index::get_sort_index_value_with_lock(const std::string& collec
|
||||
return Option<uint32_t>(400, "Cannot sort on `" + field_name + "` in the collection, `" + collection_name +
|
||||
"` is `" + search_schema.at(field_name).type + "`.");
|
||||
} else if (sort_index.count(field_name) == 0 || sort_index.at(field_name)->count(seq_id) == 0) {
|
||||
return Option<uint32_t>(400, "Could not find `" + field_name + "` value for doc `" +
|
||||
return Option<uint32_t>(404, "Could not find `" + field_name + "` value for doc `" +
|
||||
std::to_string(seq_id) + "`.");;
|
||||
}
|
||||
|
||||
|
@ -5544,29 +5544,6 @@ TEST_F(CollectionJoinTest, SortByReference) {
|
||||
ASSERT_FALSE(search_op.ok());
|
||||
ASSERT_EQ("Multiple references found to sort by on `Customers.product_price`.", search_op.error());
|
||||
|
||||
schema_json =
|
||||
R"({
|
||||
"name": "Ads",
|
||||
"fields": [
|
||||
{"name": "id", "type": "string"}
|
||||
]
|
||||
})"_json;
|
||||
documents = {
|
||||
R"({
|
||||
"id": "ad_a"
|
||||
})"_json
|
||||
};
|
||||
collection_create_op = collectionManager.create_collection(schema_json);
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
for (auto const &json: documents) {
|
||||
auto add_op = collection_create_op.get()->add(json.dump());
|
||||
if (!add_op.ok()) {
|
||||
LOG(INFO) << add_op.error();
|
||||
}
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
schema_json =
|
||||
R"({
|
||||
"name": "Structures",
|
||||
@ -5596,6 +5573,35 @@ TEST_F(CollectionJoinTest, SortByReference) {
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
schema_json =
|
||||
R"({
|
||||
"name": "Ads",
|
||||
"fields": [
|
||||
{"name": "id", "type": "string"},
|
||||
{"name": "structure", "type": "string", "reference": "Structures.id"}
|
||||
]
|
||||
})"_json;
|
||||
documents = {
|
||||
R"({
|
||||
"id": "ad_a",
|
||||
"structure": "struct_b"
|
||||
})"_json,
|
||||
R"({
|
||||
"id": "ad_b",
|
||||
"structure": "struct_a"
|
||||
})"_json
|
||||
};
|
||||
collection_create_op = collectionManager.create_collection(schema_json);
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
for (auto const &json: documents) {
|
||||
auto add_op = collection_create_op.get()->add(json.dump());
|
||||
if (!add_op.ok()) {
|
||||
LOG(INFO) << add_op.error();
|
||||
}
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
schema_json =
|
||||
R"({
|
||||
"name": "Candidates",
|
||||
@ -5606,13 +5612,16 @@ TEST_F(CollectionJoinTest, SortByReference) {
|
||||
})"_json;
|
||||
documents = {
|
||||
R"({
|
||||
"structure": "struct_a"
|
||||
"structure": "struct_b"
|
||||
})"_json,
|
||||
R"({
|
||||
"ad": "ad_a"
|
||||
})"_json,
|
||||
R"({
|
||||
"structure": "struct_b"
|
||||
"structure": "struct_a"
|
||||
})"_json,
|
||||
R"({
|
||||
"ad": "ad_b"
|
||||
})"_json
|
||||
};
|
||||
collection_create_op = collectionManager.create_collection(schema_json);
|
||||
@ -5630,25 +5639,67 @@ TEST_F(CollectionJoinTest, SortByReference) {
|
||||
{"collection", "Candidates"},
|
||||
{"q", "*"},
|
||||
{"filter_by", "$Ads(id:*) || $Structures(id:*)"},
|
||||
{"sort_by", "$Structures(name: asc)"}
|
||||
{"sort_by", "$Structures(name: asc)"},
|
||||
{"include_fields", "$Ads($Structures(*))"}
|
||||
};
|
||||
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
|
||||
ASSERT_TRUE(search_op.ok());
|
||||
|
||||
res_obj = nlohmann::json::parse(json_res);
|
||||
ASSERT_EQ(3, res_obj["found"].get<size_t>());
|
||||
ASSERT_EQ(3, res_obj["hits"].size());
|
||||
ASSERT_EQ("2", res_obj["hits"][0]["document"].at("id"));
|
||||
ASSERT_EQ(4, res_obj["found"].get<size_t>());
|
||||
ASSERT_EQ(4, res_obj["hits"].size());
|
||||
ASSERT_EQ("0", res_obj["hits"][0]["document"].at("id"));
|
||||
ASSERT_EQ("bar", res_obj["hits"][0]["document"]["Structures"].at("name"));
|
||||
ASSERT_EQ(0, res_obj["hits"][0]["document"].count("Ads"));
|
||||
|
||||
ASSERT_EQ("0", res_obj["hits"][1]["document"].at("id"));
|
||||
ASSERT_EQ("2", res_obj["hits"][1]["document"].at("id"));
|
||||
ASSERT_EQ("foo", res_obj["hits"][1]["document"]["Structures"].at("name"));
|
||||
ASSERT_EQ(0, res_obj["hits"][1]["document"].count("Ads"));
|
||||
|
||||
ASSERT_EQ("1", res_obj["hits"][2]["document"].at("id"));
|
||||
ASSERT_EQ("3", res_obj["hits"][2]["document"].at("id"));
|
||||
ASSERT_EQ(0, res_obj["hits"][2]["document"].count("Structures"));
|
||||
ASSERT_EQ(1, res_obj["hits"][2]["document"].count("Ads"));
|
||||
ASSERT_EQ(1, res_obj["hits"][2]["document"]["Ads"].count("Structures"));
|
||||
ASSERT_EQ("foo", res_obj["hits"][2]["document"]["Ads"]["Structures"]["name"]);
|
||||
|
||||
ASSERT_EQ("1", res_obj["hits"][3]["document"].at("id"));
|
||||
ASSERT_EQ(0, res_obj["hits"][3]["document"].count("Structures"));
|
||||
ASSERT_EQ(1, res_obj["hits"][3]["document"].count("Ads"));
|
||||
ASSERT_EQ(1, res_obj["hits"][3]["document"]["Ads"].count("Structures"));
|
||||
ASSERT_EQ("bar", res_obj["hits"][3]["document"]["Ads"]["Structures"]["name"]);
|
||||
|
||||
req_params = {
|
||||
{"collection", "Candidates"},
|
||||
{"q", "*"},
|
||||
{"filter_by", "$Ads(id:*) || $Structures(id:*)"},
|
||||
{"sort_by", "$Ads($Structures(name: asc))"},
|
||||
{"include_fields", "$Ads($Structures(*))"}
|
||||
};
|
||||
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
|
||||
ASSERT_TRUE(search_op.ok());
|
||||
|
||||
res_obj = nlohmann::json::parse(json_res);
|
||||
ASSERT_EQ(4, res_obj["found"].get<size_t>());
|
||||
ASSERT_EQ(4, res_obj["hits"].size());
|
||||
ASSERT_EQ("1", res_obj["hits"][0]["document"].at("id"));
|
||||
ASSERT_EQ(0, res_obj["hits"][0]["document"].count("Structures"));
|
||||
ASSERT_EQ(1, res_obj["hits"][0]["document"].count("Ads"));
|
||||
ASSERT_EQ(1, res_obj["hits"][0]["document"]["Ads"].count("Structures"));
|
||||
ASSERT_EQ("bar", res_obj["hits"][0]["document"]["Ads"]["Structures"]["name"]);
|
||||
|
||||
ASSERT_EQ("3", res_obj["hits"][1]["document"].at("id"));
|
||||
ASSERT_EQ(0, res_obj["hits"][1]["document"].count("Structures"));
|
||||
ASSERT_EQ(1, res_obj["hits"][1]["document"].count("Ads"));
|
||||
ASSERT_EQ(1, res_obj["hits"][1]["document"]["Ads"].count("Structures"));
|
||||
ASSERT_EQ("foo", res_obj["hits"][1]["document"]["Ads"]["Structures"]["name"]);
|
||||
|
||||
ASSERT_EQ("2", res_obj["hits"][2]["document"].at("id"));
|
||||
ASSERT_EQ("foo", res_obj["hits"][2]["document"]["Structures"].at("name"));
|
||||
ASSERT_EQ(0, res_obj["hits"][2]["document"].count("Ads"));
|
||||
|
||||
ASSERT_EQ("0", res_obj["hits"][3]["document"].at("id"));
|
||||
ASSERT_EQ("bar", res_obj["hits"][3]["document"]["Structures"].at("name"));
|
||||
ASSERT_EQ(0, res_obj["hits"][3]["document"].count("Ads"));
|
||||
}
|
||||
|
||||
TEST_F(CollectionJoinTest, FilterByReferenceAlias) {
|
||||
|
@ -1309,6 +1309,45 @@ TEST_F(CollectionManagerTest, ParseSortByClause) {
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
ASSERT_EQ("$foo( _eval(brand:nike && foo:bar):DESC,points:desc)", sort_fields[0].name);
|
||||
|
||||
sort_fields.clear();
|
||||
sort_by_parsed = CollectionManager::parse_sort_by_str("$Customers(product_price:DESC, $foo(bar:asc))", sort_fields);
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
ASSERT_EQ(2, sort_fields.size());
|
||||
ASSERT_EQ("$Customers(product_price:DESC, )", sort_fields[0].name);
|
||||
ASSERT_EQ("$foo(bar:asc)", sort_fields[1].name);
|
||||
ASSERT_TRUE(sort_fields[1].is_nested_join_sort_by());
|
||||
ASSERT_EQ(2, sort_fields[1].nested_join_collection_names.size());
|
||||
ASSERT_EQ("Customers", sort_fields[1].nested_join_collection_names[0]);
|
||||
ASSERT_EQ("foo", sort_fields[1].nested_join_collection_names[1]);
|
||||
|
||||
sort_fields.clear();
|
||||
sort_by_parsed = CollectionManager::parse_sort_by_str("$foo($bar($baz(field:asc)))", sort_fields);
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
ASSERT_EQ(1, sort_fields.size());
|
||||
ASSERT_EQ("$baz(field:asc)", sort_fields[0].name);
|
||||
ASSERT_TRUE(sort_fields[0].is_nested_join_sort_by());
|
||||
ASSERT_EQ(3, sort_fields[0].nested_join_collection_names.size());
|
||||
ASSERT_EQ("foo", sort_fields[0].nested_join_collection_names[0]);
|
||||
ASSERT_EQ("bar", sort_fields[0].nested_join_collection_names[1]);
|
||||
ASSERT_EQ("baz", sort_fields[0].nested_join_collection_names[2]);
|
||||
|
||||
sort_fields.clear();
|
||||
sort_by_parsed = CollectionManager::parse_sort_by_str("$Customers(product_price:DESC, $foo($bar( _eval(brand:nike && foo:bar):DESC), baz:asc))", sort_fields);
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
ASSERT_EQ(3, sort_fields.size());
|
||||
ASSERT_EQ("$Customers(product_price:DESC, )", sort_fields[0].name);
|
||||
ASSERT_EQ("$bar( _eval(brand:nike && foo:bar):DESC)", sort_fields[1].name);
|
||||
ASSERT_TRUE(sort_fields[1].is_nested_join_sort_by());
|
||||
ASSERT_EQ(3, sort_fields[1].nested_join_collection_names.size());
|
||||
ASSERT_EQ("Customers", sort_fields[1].nested_join_collection_names[0]);
|
||||
ASSERT_EQ("foo", sort_fields[1].nested_join_collection_names[1]);
|
||||
ASSERT_EQ("bar", sort_fields[1].nested_join_collection_names[2]);
|
||||
ASSERT_EQ("$foo(baz:asc)", sort_fields[2].name);
|
||||
ASSERT_TRUE(sort_fields[2].is_nested_join_sort_by());
|
||||
ASSERT_EQ(2, sort_fields[2].nested_join_collection_names.size());
|
||||
ASSERT_EQ("Customers", sort_fields[2].nested_join_collection_names[0]);
|
||||
ASSERT_EQ("foo", sort_fields[2].nested_join_collection_names[1]);
|
||||
|
||||
sort_fields.clear();
|
||||
sort_by_parsed = CollectionManager::parse_sort_by_str("", sort_fields);
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
|
Loading…
x
Reference in New Issue
Block a user