diff --git a/include/collection.h b/include/collection.h index 051af42c..8501c3d1 100644 --- a/include/collection.h +++ b/include/collection.h @@ -701,9 +701,9 @@ public: Option truncate_after_top_k(const std::string& field_name, size_t k); - void reference_populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, - std::vector& sort_fields_std, - std::array*, 3>& field_values) const; + Option reference_populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, + std::vector& sort_fields_std, + std::array*, 3>& field_values) const; int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const; diff --git a/include/index.h b/include/index.h index f738b521..c18a0127 100644 --- a/include/index.h +++ b/include/index.h @@ -796,13 +796,13 @@ public: const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, uint32_t*& filter_ids, uint32_t& filter_ids_length, const std::vector& curated_ids_sorted) const; - void populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, - std::vector& sort_fields_std, - std::array*, 3>& field_values) const; + Option populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, + std::vector& sort_fields_std, + std::array*, 3>& field_values) const; - void populate_sort_mapping_with_lock(int* sort_order, std::vector& geopoint_indices, - std::vector& sort_fields_std, - std::array*, 3>& field_values) const; + Option populate_sort_mapping_with_lock(int* sort_order, std::vector& geopoint_indices, + std::vector& sort_fields_std, + std::array*, 3>& field_values) const; int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const; diff --git a/src/collection.cpp b/src/collection.cpp index 7ea00e61..a988cf3c 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -6828,12 +6828,12 @@ Option Collection::truncate_after_top_k(const string &field_name, size_t k return Option(true); } -void Collection::reference_populate_sort_mapping(int *sort_order, std::vector &geopoint_indices, - std::vector &sort_fields_std, - std::array *, 3> &field_values) - const { +Option Collection::reference_populate_sort_mapping(int *sort_order, std::vector &geopoint_indices, + std::vector &sort_fields_std, + std::array *, 3> &field_values) + const { std::shared_lock lock(mutex); - index->populate_sort_mapping_with_lock(sort_order, geopoint_indices, sort_fields_std, field_values); + return index->populate_sort_mapping_with_lock(sort_order, geopoint_indices, sort_fields_std, field_values); } int64_t Collection::reference_string_sort_score(const string &field_name, const uint32_t& seq_id) const { diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 798a1f2d..79d1c8ae 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -1036,6 +1036,11 @@ bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vectorname; auto coll = cm.get_collection(collection_name); + if (coll == nullptr) { + status = Option(400, "Collection `" + collection_name + "` not found."); + validity = invalid; + return; + } + bool is_referenced = coll->referenced_in.count(ref_collection_name) > 0, has_reference = ref_collection->is_referenced_in(collection_name); if (!is_referenced && !has_reference) { diff --git a/src/index.cpp b/src/index.cpp index 6e240e9d..7af87362 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2807,7 +2807,10 @@ Option Index::search(std::vector& field_query_tokens, cons int sort_order[3]; // 1 or -1 based on DESC or ASC respectively std::array*, 3> field_values; std::vector geopoint_indices; - populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values); + auto populate_op = populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values); + if (!populate_op.ok()) { + return populate_op; + } // Prepare excluded document IDs that we can later remove from the result set uint32_t* excluded_result_ids = nullptr; @@ -5006,6 +5009,10 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, bool found = false; uint32_t index = 0; auto const& eval = sort_fields[0].eval; + if (eval.eval_ids_vec.size() != count || eval.eval_ids_count_vec.size() != count) { + return Option(400, "Eval expressions count does not match the ids count."); + } + for (; index < count; index++) { // ref_seq_id(s) can be unordered. uint32_t ref_filter_index = 0; @@ -5140,7 +5147,6 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, } scores[1] = found ? eval.scores[index] : 0; - LOG(INFO) << "seq_id: " << seq_id << " ref_seq_id: " << ref_seq_id << " score: " << scores[1] << " index: " << index; } else if(field_values[1] == &vector_distance_sentinel_value) { scores[1] = float_to_int64_t(vector_distance); } else if(field_values[1] == &vector_query_sentinel_value) { @@ -6124,9 +6130,9 @@ Option Index::search_wildcard(filter_node_t const* const& filter_tree_root return Option(true); } -void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, - std::vector& sort_fields_std, - std::array*, 3>& field_values) const { +Option Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, + std::vector& sort_fields_std, + std::array*, 3>& field_values) const { for (size_t i = 0; i < sort_fields_std.size(); i++) { if (!sort_fields_std[i].reference_collection_name.empty()) { auto& cm = CollectionManager::get_instance(); @@ -6138,8 +6144,11 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint ref_sort_fields_std.emplace_back(sort_fields_std[i]); ref_sort_fields_std.front().reference_collection_name.clear(); std::array*, 3> ref_field_values; - ref_collection->reference_populate_sort_mapping(ref_sort_order, ref_geopoint_indices, - ref_sort_fields_std, ref_field_values); + auto populate_op = ref_collection->reference_populate_sort_mapping(ref_sort_order, ref_geopoint_indices, + ref_sort_fields_std, ref_field_values); + if (!populate_op.ok()) { + return populate_op; + } sort_order[i] = ref_sort_order[0]; if (!ref_geopoint_indices.empty()) { @@ -6171,7 +6180,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint search_begin_us, search_stop_us); auto filter_init_op = filter_result_iterator.init_status(); if (!filter_init_op.ok()) { - return; + return filter_init_op; } filter_result_iterator.compute_iterators(); @@ -6200,13 +6209,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint } } } + + return Option(true); } -void Index::populate_sort_mapping_with_lock(int* sort_order, std::vector& geopoint_indices, - std::vector& sort_fields_std, - std::array*, 3>& field_values) const { +Option Index::populate_sort_mapping_with_lock(int* sort_order, std::vector& geopoint_indices, + std::vector& sort_fields_std, + std::array*, 3>& field_values) const { std::shared_lock lock(mutex); - populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values); + return populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values); } int Index::get_bounded_typo_cost(const size_t max_cost, const std::string& token, const size_t token_len, diff --git a/test/collection_join_test.cpp b/test/collection_join_test.cpp index 6e797cdf..ec72c958 100644 --- a/test/collection_join_test.cpp +++ b/test/collection_join_test.cpp @@ -5904,6 +5904,17 @@ TEST_F(CollectionJoinTest, SortByReference) { ASSERT_EQ("2", res_obj["hits"][3]["document"].at("id")); ASSERT_EQ("6", res_obj["hits"][4]["document"].at("id")); ASSERT_EQ("1", res_obj["hits"][5]["document"].at("id")); + + req_params = { + {"collection", "product"}, + {"q", "tablet"}, + {"query_by", "name"}, + {"filter_by", "$stock(id: *)"}, + {"sort_by", "_eval($stock(store_1:true || store_2:true)):desc"} + }; + search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts); + ASSERT_FALSE(search_op.ok()); + ASSERT_EQ("Parameter `sort_by` is malformed.", search_op.error()); } TEST_F(CollectionJoinTest, FilterByReferenceAlias) {