mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 20:52:50 +08:00
Fix crash when reference expression is passed in _eval
. (#1863)
This commit is contained in:
parent
05b0faa955
commit
65151cd51b
@ -701,9 +701,9 @@ public:
|
||||
|
||||
Option<bool> truncate_after_top_k(const std::string& field_name, size_t k);
|
||||
|
||||
void reference_populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
|
||||
Option<bool> reference_populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
|
||||
|
||||
int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const;
|
||||
|
||||
|
@ -796,13 +796,13 @@ public:
|
||||
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length, const std::vector<uint32_t>& curated_ids_sorted) const;
|
||||
|
||||
void populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
|
||||
Option<bool> populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
|
||||
|
||||
void populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
|
||||
Option<bool> populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
|
||||
|
||||
int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const;
|
||||
|
||||
|
@ -6828,12 +6828,12 @@ Option<bool> Collection::truncate_after_top_k(const string &field_name, size_t k
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
void Collection::reference_populate_sort_mapping(int *sort_order, std::vector<size_t> &geopoint_indices,
|
||||
std::vector<sort_by> &sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32> *, 3> &field_values)
|
||||
const {
|
||||
Option<bool> Collection::reference_populate_sort_mapping(int *sort_order, std::vector<size_t> &geopoint_indices,
|
||||
std::vector<sort_by> &sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32> *, 3> &field_values)
|
||||
const {
|
||||
std::shared_lock lock(mutex);
|
||||
index->populate_sort_mapping_with_lock(sort_order, geopoint_indices, sort_fields_std, field_values);
|
||||
return index->populate_sort_mapping_with_lock(sort_order, geopoint_indices, sort_fields_std, field_values);
|
||||
}
|
||||
|
||||
int64_t Collection::reference_string_sort_score(const string &field_name, const uint32_t& seq_id) const {
|
||||
|
@ -1036,6 +1036,11 @@ bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector<s
|
||||
i = open_paren_pos;
|
||||
while(sort_by_str[++i] == ' ');
|
||||
|
||||
if (sort_by_str[i] == '$' && sort_by_str.find('(', i) != std::string::npos) {
|
||||
// Reference expression inside `_eval()`
|
||||
return false;
|
||||
}
|
||||
|
||||
auto result = sort_by_str[i] == '[' ? parse_multi_eval(sort_by_str, i, sort_fields) :
|
||||
parse_eval(sort_by_str, --i, sort_fields);
|
||||
if (!result) {
|
||||
|
@ -793,6 +793,12 @@ void filter_result_iterator_t::init() {
|
||||
ref_collection_name = ref_collection->name;
|
||||
|
||||
auto coll = cm.get_collection(collection_name);
|
||||
if (coll == nullptr) {
|
||||
status = Option<bool>(400, "Collection `" + collection_name + "` not found.");
|
||||
validity = invalid;
|
||||
return;
|
||||
}
|
||||
|
||||
bool is_referenced = coll->referenced_in.count(ref_collection_name) > 0,
|
||||
has_reference = ref_collection->is_referenced_in(collection_name);
|
||||
if (!is_referenced && !has_reference) {
|
||||
|
@ -2807,7 +2807,10 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
int sort_order[3]; // 1 or -1 based on DESC or ASC respectively
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values;
|
||||
std::vector<size_t> geopoint_indices;
|
||||
populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
|
||||
auto populate_op = populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
|
||||
if (!populate_op.ok()) {
|
||||
return populate_op;
|
||||
}
|
||||
|
||||
// Prepare excluded document IDs that we can later remove from the result set
|
||||
uint32_t* excluded_result_ids = nullptr;
|
||||
@ -5006,6 +5009,10 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
bool found = false;
|
||||
uint32_t index = 0;
|
||||
auto const& eval = sort_fields[0].eval;
|
||||
if (eval.eval_ids_vec.size() != count || eval.eval_ids_count_vec.size() != count) {
|
||||
return Option<bool>(400, "Eval expressions count does not match the ids count.");
|
||||
}
|
||||
|
||||
for (; index < count; index++) {
|
||||
// ref_seq_id(s) can be unordered.
|
||||
uint32_t ref_filter_index = 0;
|
||||
@ -5140,7 +5147,6 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
}
|
||||
|
||||
scores[1] = found ? eval.scores[index] : 0;
|
||||
LOG(INFO) << "seq_id: " << seq_id << " ref_seq_id: " << ref_seq_id << " score: " << scores[1] << " index: " << index;
|
||||
} else if(field_values[1] == &vector_distance_sentinel_value) {
|
||||
scores[1] = float_to_int64_t(vector_distance);
|
||||
} else if(field_values[1] == &vector_query_sentinel_value) {
|
||||
@ -6124,9 +6130,9 @@ Option<bool> Index::search_wildcard(filter_node_t const* const& filter_tree_root
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
|
||||
Option<bool> Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
|
||||
for (size_t i = 0; i < sort_fields_std.size(); i++) {
|
||||
if (!sort_fields_std[i].reference_collection_name.empty()) {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
@ -6138,8 +6144,11 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
ref_sort_fields_std.emplace_back(sort_fields_std[i]);
|
||||
ref_sort_fields_std.front().reference_collection_name.clear();
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> ref_field_values;
|
||||
ref_collection->reference_populate_sort_mapping(ref_sort_order, ref_geopoint_indices,
|
||||
ref_sort_fields_std, ref_field_values);
|
||||
auto populate_op = ref_collection->reference_populate_sort_mapping(ref_sort_order, ref_geopoint_indices,
|
||||
ref_sort_fields_std, ref_field_values);
|
||||
if (!populate_op.ok()) {
|
||||
return populate_op;
|
||||
}
|
||||
|
||||
sort_order[i] = ref_sort_order[0];
|
||||
if (!ref_geopoint_indices.empty()) {
|
||||
@ -6171,7 +6180,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
search_begin_us, search_stop_us);
|
||||
auto filter_init_op = filter_result_iterator.init_status();
|
||||
if (!filter_init_op.ok()) {
|
||||
return;
|
||||
return filter_init_op;
|
||||
}
|
||||
|
||||
filter_result_iterator.compute_iterators();
|
||||
@ -6200,13 +6209,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
void Index::populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
|
||||
Option<bool> Index::populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
|
||||
std::shared_lock lock(mutex);
|
||||
populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
|
||||
return populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
|
||||
}
|
||||
|
||||
int Index::get_bounded_typo_cost(const size_t max_cost, const std::string& token, const size_t token_len,
|
||||
|
@ -5904,6 +5904,17 @@ TEST_F(CollectionJoinTest, SortByReference) {
|
||||
ASSERT_EQ("2", res_obj["hits"][3]["document"].at("id"));
|
||||
ASSERT_EQ("6", res_obj["hits"][4]["document"].at("id"));
|
||||
ASSERT_EQ("1", res_obj["hits"][5]["document"].at("id"));
|
||||
|
||||
req_params = {
|
||||
{"collection", "product"},
|
||||
{"q", "tablet"},
|
||||
{"query_by", "name"},
|
||||
{"filter_by", "$stock(id: *)"},
|
||||
{"sort_by", "_eval($stock(store_1:true || store_2:true)):desc"}
|
||||
};
|
||||
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
|
||||
ASSERT_FALSE(search_op.ok());
|
||||
ASSERT_EQ("Parameter `sort_by` is malformed.", search_op.error());
|
||||
}
|
||||
|
||||
TEST_F(CollectionJoinTest, FilterByReferenceAlias) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user