mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 21:52:23 +08:00
Refactor Collection::validate_and_standardize_sort_fields_with_lock
.
This commit is contained in:
parent
e6876fa147
commit
31938490ae
@ -215,18 +215,21 @@ private:
|
||||
|
||||
static std::vector<char> to_char_array(const std::vector<std::string>& strs);
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
bool is_wildcard_query, const bool is_vector_query,
|
||||
const std::string& query, bool is_group_by_query = false,
|
||||
const size_t remote_embedding_timeout_ms = 30000,
|
||||
const size_t remote_embedding_num_tries = 2) const;
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields_with_lock(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
bool is_wildcard_query,const bool is_vector_query,
|
||||
bool is_group_by_query = false) const;
|
||||
const std::string& query, bool is_group_by_query = false,
|
||||
const size_t remote_embedding_timeout_ms = 30000,
|
||||
const size_t remote_embedding_num_tries = 2) const;
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const std::string& query, bool is_group_by_query = false,
|
||||
const size_t remote_embedding_timeout_ms = 30000,
|
||||
const size_t remote_embedding_num_tries = 2,
|
||||
const bool is_reference_sort = false) const;
|
||||
|
||||
Option<bool> persist_collection_meta();
|
||||
|
||||
|
@ -842,13 +842,26 @@ void Collection::curate_results(string& actual_query, const string& filter_query
|
||||
}
|
||||
}
|
||||
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields_with_lock(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const std::string& query, const bool is_group_by_query,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_tries) const {
|
||||
std::shared_lock lock(mutex);
|
||||
return validate_and_standardize_sort_fields(sort_fields, sort_fields_std, is_wildcard_query, is_vector_query,
|
||||
query, is_group_by_query, remote_embedding_timeout_ms, true);
|
||||
}
|
||||
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const std::string& query, const bool is_group_by_query,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_tries) const {
|
||||
const size_t remote_embedding_num_tries,
|
||||
const bool is_reference_sort) const {
|
||||
|
||||
uint32_t eval_sort_count = 0;
|
||||
size_t num_sort_expressions = 0;
|
||||
@ -879,7 +892,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
ref_sort_fields_std,
|
||||
is_wildcard_query,
|
||||
is_vector_query,
|
||||
is_group_by_query);
|
||||
query,
|
||||
is_group_by_query,
|
||||
remote_embedding_timeout_ms,
|
||||
remote_embedding_num_tries);
|
||||
if (!sort_validation_op.ok()) {
|
||||
return Option<bool>(sort_validation_op.code(), "Referenced collection `" + ref_collection_name + "`: " +
|
||||
sort_validation_op.error());
|
||||
@ -1192,6 +1208,14 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
sort_fields_std.emplace_back(sort_field_std);
|
||||
}
|
||||
|
||||
if (is_reference_sort) {
|
||||
if (eval_sort_count > 1) {
|
||||
std::string message = "Only one sorting eval expression is allowed.";
|
||||
return Option<bool>(422, message);
|
||||
}
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
/*
|
||||
1. Empty: [match_score, dsf] upstream
|
||||
2. ONE : [usf, match_score]
|
||||
@ -1249,227 +1273,6 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields_with_lock(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const bool is_group_by_query) const {
|
||||
std::shared_lock lock(mutex);
|
||||
|
||||
for(const auto & sort_field : sort_fields) {
|
||||
sort_by sort_field_std(sort_field.name, sort_field.order);
|
||||
|
||||
if(sort_field_std.name.back() == ')') {
|
||||
// check if this is a geo field or text match field
|
||||
size_t paran_start = 0;
|
||||
while(paran_start < sort_field_std.name.size() && sort_field_std.name[paran_start] != '(') {
|
||||
paran_start++;
|
||||
}
|
||||
|
||||
const std::string& actual_field_name = sort_field_std.name.substr(0, paran_start);
|
||||
const auto field_it = search_schema.find(actual_field_name);
|
||||
|
||||
if(actual_field_name == sort_field_const::text_match) {
|
||||
std::vector<std::string> match_parts;
|
||||
const std::string& match_config = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2);
|
||||
StringUtils::split(match_config, match_parts, ":");
|
||||
if(match_parts.size() != 2 || match_parts[0] != "buckets") {
|
||||
return Option<bool>(400, "Invalid sorting parameter passed for _text_match.");
|
||||
}
|
||||
|
||||
if(!StringUtils::is_uint32_t(match_parts[1])) {
|
||||
return Option<bool>(400, "Invalid value passed for _text_match `buckets` configuration.");
|
||||
}
|
||||
|
||||
sort_field_std.name = actual_field_name;
|
||||
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
|
||||
|
||||
} else {
|
||||
if(field_it == search_schema.end()) {
|
||||
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
std::string error = "Bad syntax for sorting field `" + actual_field_name + "`";
|
||||
|
||||
if(!field_it.value().is_geopoint()) {
|
||||
// check for null value order
|
||||
const std::string& sort_params_str = sort_field_std.name.substr(paran_start + 1,
|
||||
sort_field_std.name.size() -
|
||||
paran_start - 2);
|
||||
|
||||
std::vector<std::string> param_parts;
|
||||
StringUtils::split(sort_params_str, param_parts, ":");
|
||||
|
||||
if(param_parts.size() != 2) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
if(param_parts[0] != sort_field_const::missing_values) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
auto missing_values_op = magic_enum::enum_cast<sort_by::missing_values_t>(param_parts[1]);
|
||||
if(missing_values_op.has_value()) {
|
||||
sort_field_std.missing_values = missing_values_op.value();
|
||||
} else {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
const std::string& geo_coordstr = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2);
|
||||
|
||||
// e.g. geopoint_field(lat1, lng1, exclude_radius: 10 miles)
|
||||
|
||||
std::vector<std::string> geo_parts;
|
||||
StringUtils::split(geo_coordstr, geo_parts, ",");
|
||||
|
||||
if(geo_parts.size() != 2 && geo_parts.size() != 3) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
if(!StringUtils::is_float(geo_parts[0]) || !StringUtils::is_float(geo_parts[1])) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
if(geo_parts.size() == 3) {
|
||||
// try to parse the exclude radius option
|
||||
bool is_exclude_option = false;
|
||||
|
||||
if(StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) {
|
||||
is_exclude_option = true;
|
||||
} else if(StringUtils::begins_with(geo_parts[2], sort_field_const::precision)) {
|
||||
is_exclude_option = false;
|
||||
} else {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
std::vector<std::string> param_parts;
|
||||
StringUtils::split(geo_parts[2], param_parts, ":");
|
||||
|
||||
if(param_parts.size() != 2) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
// param_parts[1] is the value, in either "20km" or "20 km" format
|
||||
|
||||
if(param_parts[1].size() < 2) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
std::string unit = param_parts[1].substr(param_parts[1].size()-2, 2);
|
||||
|
||||
if(unit != "km" && unit != "mi") {
|
||||
return Option<bool>(400, "Sort field's parameter unit must be either `km` or `mi`.");
|
||||
}
|
||||
|
||||
std::vector<std::string> dist_values;
|
||||
StringUtils::split(param_parts[1], dist_values, unit);
|
||||
|
||||
if(dist_values.size() != 1) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
if(!StringUtils::is_float(dist_values[0])) {
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
int32_t value_meters;
|
||||
|
||||
if(unit == "km") {
|
||||
value_meters = std::stof(dist_values[0]) * 1000;
|
||||
} else if(unit == "mi") {
|
||||
value_meters = std::stof(dist_values[0]) * 1609.34;
|
||||
} else {
|
||||
return Option<bool>(400, "Sort field's parameter "
|
||||
"unit must be either `km` or `mi`.");
|
||||
}
|
||||
|
||||
if(value_meters <= 0) {
|
||||
return Option<bool>(400, "Sort field's parameter must be a positive number.");
|
||||
}
|
||||
|
||||
if(is_exclude_option) {
|
||||
sort_field_std.exclude_radius = value_meters;
|
||||
} else {
|
||||
sort_field_std.geo_precision = value_meters;
|
||||
}
|
||||
}
|
||||
|
||||
double lat = std::stod(geo_parts[0]);
|
||||
double lng = std::stod(geo_parts[1]);
|
||||
int64_t lat_lng = GeoPoint::pack_lat_lng(lat, lng);
|
||||
sort_field_std.geopoint = lat_lng;
|
||||
}
|
||||
|
||||
sort_field_std.name = actual_field_name;
|
||||
}
|
||||
} else if (sort_field.name == sort_field_const::eval) {
|
||||
auto const& count = sort_field.eval_expressions.size();
|
||||
sort_field_std.eval.filter_trees = new filter_node_t[count];
|
||||
std::unique_ptr<filter_node_t []> filter_trees_guard(sort_field_std.eval.filter_trees);
|
||||
|
||||
for (uint32_t j = 0; j < count; j++) {
|
||||
auto const& filter_exp = sort_field.eval_expressions[j];
|
||||
if (filter_exp.empty()) {
|
||||
return Option<bool>(400, "The eval expression in sort_by is empty.");
|
||||
}
|
||||
|
||||
filter_node_t* filter_tree_root = nullptr;
|
||||
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
|
||||
store, "", filter_tree_root);
|
||||
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);
|
||||
|
||||
if (!parse_filter_op.ok()) {
|
||||
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
|
||||
}
|
||||
|
||||
sort_field_std.eval.filter_trees[j] = std::move(*filter_tree_root);
|
||||
}
|
||||
|
||||
sort_field_std.name = sort_field.name;
|
||||
sort_field_std.eval_expressions = sort_field.eval_expressions;
|
||||
sort_field_std.eval.scores = sort_field.eval.scores;
|
||||
sort_fields_std.emplace_back(sort_field_std);
|
||||
filter_trees_guard.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
|
||||
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
std::string error = "Could not find a field named `" + sort_field_std.name +
|
||||
"` in the schema for sorting.";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_field_std.name == sort_field_const::group_found && is_group_by_query == false) {
|
||||
std::string error = "group_by parameters should not be empty when using sort_by group_found";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) {
|
||||
std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
StringUtils::toupper(sort_field_std.order);
|
||||
|
||||
if(sort_field_std.order != sort_field_const::asc && sort_field_std.order != sort_field_const::desc) {
|
||||
std::string error = "Order for field` " + sort_field_std.name + "` should be either ASC or DESC.";
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
sort_fields_std.emplace_back(sort_field_std);
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Collection::extract_field_name(const std::string& field_name,
|
||||
const tsl::htrie_map<char, field>& search_schema,
|
||||
std::vector<std::string>& processed_search_fields,
|
||||
|
Loading…
x
Reference in New Issue
Block a user