Refactor Collection::validate_and_standardize_sort_fields_with_lock.

This commit is contained in:
Harpreet Sangar 2023-10-24 11:21:57 +05:30
parent e6876fa147
commit 31938490ae
2 changed files with 37 additions and 231 deletions

View File

@ -215,18 +215,21 @@ private:
static std::vector<char> to_char_array(const std::vector<std::string>& strs);
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
bool is_wildcard_query, const bool is_vector_query,
const std::string& query, bool is_group_by_query = false,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_tries = 2) const;
Option<bool> validate_and_standardize_sort_fields_with_lock(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
bool is_wildcard_query,const bool is_vector_query,
bool is_group_by_query = false) const;
const std::string& query, bool is_group_by_query = false,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_tries = 2) const;
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query,
const bool is_vector_query,
const std::string& query, bool is_group_by_query = false,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_tries = 2,
const bool is_reference_sort = false) const;
Option<bool> persist_collection_meta();

View File

@ -842,13 +842,26 @@ void Collection::curate_results(string& actual_query, const string& filter_query
}
}
Option<bool> Collection::validate_and_standardize_sort_fields_with_lock(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query,
const bool is_vector_query,
const std::string& query, const bool is_group_by_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries) const {
std::shared_lock lock(mutex);
return validate_and_standardize_sort_fields(sort_fields, sort_fields_std, is_wildcard_query, is_vector_query,
query, is_group_by_query, remote_embedding_timeout_ms, true);
}
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query,
const bool is_vector_query,
const std::string& query, const bool is_group_by_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries) const {
const size_t remote_embedding_num_tries,
const bool is_reference_sort) const {
uint32_t eval_sort_count = 0;
size_t num_sort_expressions = 0;
@ -879,7 +892,10 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
ref_sort_fields_std,
is_wildcard_query,
is_vector_query,
is_group_by_query);
query,
is_group_by_query,
remote_embedding_timeout_ms,
remote_embedding_num_tries);
if (!sort_validation_op.ok()) {
return Option<bool>(sort_validation_op.code(), "Referenced collection `" + ref_collection_name + "`: " +
sort_validation_op.error());
@ -1192,6 +1208,14 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_fields_std.emplace_back(sort_field_std);
}
if (is_reference_sort) {
if (eval_sort_count > 1) {
std::string message = "Only one sorting eval expression is allowed.";
return Option<bool>(422, message);
}
return Option<bool>(true);
}
/*
1. Empty: [match_score, dsf] upstream
2. ONE : [usf, match_score]
@ -1249,227 +1273,6 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
return Option<bool>(true);
}
Option<bool> Collection::validate_and_standardize_sort_fields_with_lock(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query,
const bool is_vector_query,
const bool is_group_by_query) const {
std::shared_lock lock(mutex);
for(const auto & sort_field : sort_fields) {
sort_by sort_field_std(sort_field.name, sort_field.order);
if(sort_field_std.name.back() == ')') {
// check if this is a geo field or text match field
size_t paran_start = 0;
while(paran_start < sort_field_std.name.size() && sort_field_std.name[paran_start] != '(') {
paran_start++;
}
const std::string& actual_field_name = sort_field_std.name.substr(0, paran_start);
const auto field_it = search_schema.find(actual_field_name);
if(actual_field_name == sort_field_const::text_match) {
std::vector<std::string> match_parts;
const std::string& match_config = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2);
StringUtils::split(match_config, match_parts, ":");
if(match_parts.size() != 2 || match_parts[0] != "buckets") {
return Option<bool>(400, "Invalid sorting parameter passed for _text_match.");
}
if(!StringUtils::is_uint32_t(match_parts[1])) {
return Option<bool>(400, "Invalid value passed for _text_match `buckets` configuration.");
}
sort_field_std.name = actual_field_name;
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
} else {
if(field_it == search_schema.end()) {
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
return Option<bool>(404, error);
}
std::string error = "Bad syntax for sorting field `" + actual_field_name + "`";
if(!field_it.value().is_geopoint()) {
// check for null value order
const std::string& sort_params_str = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() -
paran_start - 2);
std::vector<std::string> param_parts;
StringUtils::split(sort_params_str, param_parts, ":");
if(param_parts.size() != 2) {
return Option<bool>(400, error);
}
if(param_parts[0] != sort_field_const::missing_values) {
return Option<bool>(400, error);
}
auto missing_values_op = magic_enum::enum_cast<sort_by::missing_values_t>(param_parts[1]);
if(missing_values_op.has_value()) {
sort_field_std.missing_values = missing_values_op.value();
} else {
return Option<bool>(400, error);
}
}
else {
const std::string& geo_coordstr = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2);
// e.g. geopoint_field(lat1, lng1, exclude_radius: 10 miles)
std::vector<std::string> geo_parts;
StringUtils::split(geo_coordstr, geo_parts, ",");
if(geo_parts.size() != 2 && geo_parts.size() != 3) {
return Option<bool>(400, error);
}
if(!StringUtils::is_float(geo_parts[0]) || !StringUtils::is_float(geo_parts[1])) {
return Option<bool>(400, error);
}
if(geo_parts.size() == 3) {
// try to parse the exclude radius option
bool is_exclude_option = false;
if(StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) {
is_exclude_option = true;
} else if(StringUtils::begins_with(geo_parts[2], sort_field_const::precision)) {
is_exclude_option = false;
} else {
return Option<bool>(400, error);
}
std::vector<std::string> param_parts;
StringUtils::split(geo_parts[2], param_parts, ":");
if(param_parts.size() != 2) {
return Option<bool>(400, error);
}
// param_parts[1] is the value, in either "20km" or "20 km" format
if(param_parts[1].size() < 2) {
return Option<bool>(400, error);
}
std::string unit = param_parts[1].substr(param_parts[1].size()-2, 2);
if(unit != "km" && unit != "mi") {
return Option<bool>(400, "Sort field's parameter unit must be either `km` or `mi`.");
}
std::vector<std::string> dist_values;
StringUtils::split(param_parts[1], dist_values, unit);
if(dist_values.size() != 1) {
return Option<bool>(400, error);
}
if(!StringUtils::is_float(dist_values[0])) {
return Option<bool>(400, error);
}
int32_t value_meters;
if(unit == "km") {
value_meters = std::stof(dist_values[0]) * 1000;
} else if(unit == "mi") {
value_meters = std::stof(dist_values[0]) * 1609.34;
} else {
return Option<bool>(400, "Sort field's parameter "
"unit must be either `km` or `mi`.");
}
if(value_meters <= 0) {
return Option<bool>(400, "Sort field's parameter must be a positive number.");
}
if(is_exclude_option) {
sort_field_std.exclude_radius = value_meters;
} else {
sort_field_std.geo_precision = value_meters;
}
}
double lat = std::stod(geo_parts[0]);
double lng = std::stod(geo_parts[1]);
int64_t lat_lng = GeoPoint::pack_lat_lng(lat, lng);
sort_field_std.geopoint = lat_lng;
}
sort_field_std.name = actual_field_name;
}
} else if (sort_field.name == sort_field_const::eval) {
auto const& count = sort_field.eval_expressions.size();
sort_field_std.eval.filter_trees = new filter_node_t[count];
std::unique_ptr<filter_node_t []> filter_trees_guard(sort_field_std.eval.filter_trees);
for (uint32_t j = 0; j < count; j++) {
auto const& filter_exp = sort_field.eval_expressions[j];
if (filter_exp.empty()) {
return Option<bool>(400, "The eval expression in sort_by is empty.");
}
filter_node_t* filter_tree_root = nullptr;
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", filter_tree_root);
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);
if (!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}
sort_field_std.eval.filter_trees[j] = std::move(*filter_tree_root);
}
sort_field_std.name = sort_field.name;
sort_field_std.eval_expressions = sort_field.eval_expressions;
sort_field_std.eval.scores = sort_field.eval.scores;
sort_fields_std.emplace_back(sort_field_std);
filter_trees_guard.release();
continue;
}
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
std::string error = "Could not find a field named `" + sort_field_std.name +
"` in the schema for sorting.";
return Option<bool>(404, error);
}
}
if(sort_field_std.name == sort_field_const::group_found && is_group_by_query == false) {
std::string error = "group_by parameters should not be empty when using sort_by group_found";
return Option<bool>(404, error);
}
if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) {
std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.";
return Option<bool>(404, error);
}
StringUtils::toupper(sort_field_std.order);
if(sort_field_std.order != sort_field_const::asc && sort_field_std.order != sort_field_const::desc) {
std::string error = "Order for field` " + sort_field_std.name + "` should be either ASC or DESC.";
return Option<bool>(400, error);
}
sort_fields_std.emplace_back(sort_field_std);
}
return Option<bool>(true);
}
Option<bool> Collection::extract_field_name(const std::string& field_name,
const tsl::htrie_map<char, field>& search_schema,
std::vector<std::string>& processed_search_fields,