From 31938490aeb930766ea48c335f99e121e138dd8b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 24 Oct 2023 11:21:57 +0530 Subject: [PATCH] Refactor `Collection::validate_and_standardize_sort_fields_with_lock`. --- include/collection.h | 19 ++-- src/collection.cpp | 249 +++++-------------------------------------- 2 files changed, 37 insertions(+), 231 deletions(-) diff --git a/include/collection.h b/include/collection.h index 044c2438..030cd319 100644 --- a/include/collection.h +++ b/include/collection.h @@ -215,18 +215,21 @@ private: static std::vector to_char_array(const std::vector& strs); - Option validate_and_standardize_sort_fields(const std::vector & sort_fields, - std::vector& sort_fields_std, - bool is_wildcard_query, const bool is_vector_query, - const std::string& query, bool is_group_by_query = false, - const size_t remote_embedding_timeout_ms = 30000, - const size_t remote_embedding_num_tries = 2) const; - Option validate_and_standardize_sort_fields_with_lock(const std::vector & sort_fields, std::vector& sort_fields_std, bool is_wildcard_query,const bool is_vector_query, - bool is_group_by_query = false) const; + const std::string& query, bool is_group_by_query = false, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_tries = 2) const; + Option validate_and_standardize_sort_fields(const std::vector & sort_fields, + std::vector& sort_fields_std, + const bool is_wildcard_query, + const bool is_vector_query, + const std::string& query, bool is_group_by_query = false, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_tries = 2, + const bool is_reference_sort = false) const; Option persist_collection_meta(); diff --git a/src/collection.cpp b/src/collection.cpp index 34532af0..05ed1149 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -842,13 +842,26 @@ void Collection::curate_results(string& actual_query, const string& filter_query } } +Option Collection::validate_and_standardize_sort_fields_with_lock(const std::vector & sort_fields, + std::vector& sort_fields_std, + const bool is_wildcard_query, + const bool is_vector_query, + const std::string& query, const bool is_group_by_query, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_tries) const { + std::shared_lock lock(mutex); + return validate_and_standardize_sort_fields(sort_fields, sort_fields_std, is_wildcard_query, is_vector_query, + query, is_group_by_query, remote_embedding_timeout_ms, true); +} + Option Collection::validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, const bool is_wildcard_query, const bool is_vector_query, const std::string& query, const bool is_group_by_query, const size_t remote_embedding_timeout_ms, - const size_t remote_embedding_num_tries) const { + const size_t remote_embedding_num_tries, + const bool is_reference_sort) const { uint32_t eval_sort_count = 0; size_t num_sort_expressions = 0; @@ -879,7 +892,10 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< ref_sort_fields_std, is_wildcard_query, is_vector_query, - is_group_by_query); + query, + is_group_by_query, + remote_embedding_timeout_ms, + remote_embedding_num_tries); if (!sort_validation_op.ok()) { return Option(sort_validation_op.code(), "Referenced collection `" + ref_collection_name + "`: " + sort_validation_op.error()); @@ -1192,6 +1208,14 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_fields_std.emplace_back(sort_field_std); } + if (is_reference_sort) { + if (eval_sort_count > 1) { + std::string message = "Only one sorting eval expression is allowed."; + return Option(422, message); + } + return Option(true); + } + /* 1. Empty: [match_score, dsf] upstream 2. ONE : [usf, match_score] @@ -1249,227 +1273,6 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< return Option(true); } -Option Collection::validate_and_standardize_sort_fields_with_lock(const std::vector & sort_fields, - std::vector& sort_fields_std, - const bool is_wildcard_query, - const bool is_vector_query, - const bool is_group_by_query) const { - std::shared_lock lock(mutex); - - for(const auto & sort_field : sort_fields) { - sort_by sort_field_std(sort_field.name, sort_field.order); - - if(sort_field_std.name.back() == ')') { - // check if this is a geo field or text match field - size_t paran_start = 0; - while(paran_start < sort_field_std.name.size() && sort_field_std.name[paran_start] != '(') { - paran_start++; - } - - const std::string& actual_field_name = sort_field_std.name.substr(0, paran_start); - const auto field_it = search_schema.find(actual_field_name); - - if(actual_field_name == sort_field_const::text_match) { - std::vector match_parts; - const std::string& match_config = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2); - StringUtils::split(match_config, match_parts, ":"); - if(match_parts.size() != 2 || match_parts[0] != "buckets") { - return Option(400, "Invalid sorting parameter passed for _text_match."); - } - - if(!StringUtils::is_uint32_t(match_parts[1])) { - return Option(400, "Invalid value passed for _text_match `buckets` configuration."); - } - - sort_field_std.name = actual_field_name; - sort_field_std.text_match_buckets = std::stoll(match_parts[1]); - - } else { - if(field_it == search_schema.end()) { - std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting."; - return Option(404, error); - } - - std::string error = "Bad syntax for sorting field `" + actual_field_name + "`"; - - if(!field_it.value().is_geopoint()) { - // check for null value order - const std::string& sort_params_str = sort_field_std.name.substr(paran_start + 1, - sort_field_std.name.size() - - paran_start - 2); - - std::vector param_parts; - StringUtils::split(sort_params_str, param_parts, ":"); - - if(param_parts.size() != 2) { - return Option(400, error); - } - - if(param_parts[0] != sort_field_const::missing_values) { - return Option(400, error); - } - - auto missing_values_op = magic_enum::enum_cast(param_parts[1]); - if(missing_values_op.has_value()) { - sort_field_std.missing_values = missing_values_op.value(); - } else { - return Option(400, error); - } - } - - else { - const std::string& geo_coordstr = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2); - - // e.g. geopoint_field(lat1, lng1, exclude_radius: 10 miles) - - std::vector geo_parts; - StringUtils::split(geo_coordstr, geo_parts, ","); - - if(geo_parts.size() != 2 && geo_parts.size() != 3) { - return Option(400, error); - } - - if(!StringUtils::is_float(geo_parts[0]) || !StringUtils::is_float(geo_parts[1])) { - return Option(400, error); - } - - if(geo_parts.size() == 3) { - // try to parse the exclude radius option - bool is_exclude_option = false; - - if(StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) { - is_exclude_option = true; - } else if(StringUtils::begins_with(geo_parts[2], sort_field_const::precision)) { - is_exclude_option = false; - } else { - return Option(400, error); - } - - std::vector param_parts; - StringUtils::split(geo_parts[2], param_parts, ":"); - - if(param_parts.size() != 2) { - return Option(400, error); - } - - // param_parts[1] is the value, in either "20km" or "20 km" format - - if(param_parts[1].size() < 2) { - return Option(400, error); - } - - std::string unit = param_parts[1].substr(param_parts[1].size()-2, 2); - - if(unit != "km" && unit != "mi") { - return Option(400, "Sort field's parameter unit must be either `km` or `mi`."); - } - - std::vector dist_values; - StringUtils::split(param_parts[1], dist_values, unit); - - if(dist_values.size() != 1) { - return Option(400, error); - } - - if(!StringUtils::is_float(dist_values[0])) { - return Option(400, error); - } - - int32_t value_meters; - - if(unit == "km") { - value_meters = std::stof(dist_values[0]) * 1000; - } else if(unit == "mi") { - value_meters = std::stof(dist_values[0]) * 1609.34; - } else { - return Option(400, "Sort field's parameter " - "unit must be either `km` or `mi`."); - } - - if(value_meters <= 0) { - return Option(400, "Sort field's parameter must be a positive number."); - } - - if(is_exclude_option) { - sort_field_std.exclude_radius = value_meters; - } else { - sort_field_std.geo_precision = value_meters; - } - } - - double lat = std::stod(geo_parts[0]); - double lng = std::stod(geo_parts[1]); - int64_t lat_lng = GeoPoint::pack_lat_lng(lat, lng); - sort_field_std.geopoint = lat_lng; - } - - sort_field_std.name = actual_field_name; - } - } else if (sort_field.name == sort_field_const::eval) { - auto const& count = sort_field.eval_expressions.size(); - sort_field_std.eval.filter_trees = new filter_node_t[count]; - std::unique_ptr filter_trees_guard(sort_field_std.eval.filter_trees); - - for (uint32_t j = 0; j < count; j++) { - auto const& filter_exp = sort_field.eval_expressions[j]; - if (filter_exp.empty()) { - return Option(400, "The eval expression in sort_by is empty."); - } - - filter_node_t* filter_tree_root = nullptr; - Option parse_filter_op = filter::parse_filter_query(filter_exp, search_schema, - store, "", filter_tree_root); - std::unique_ptr filter_tree_root_guard(filter_tree_root); - - if (!parse_filter_op.ok()) { - return Option(parse_filter_op.code(), "Error parsing eval expression in sort_by clause."); - } - - sort_field_std.eval.filter_trees[j] = std::move(*filter_tree_root); - } - - sort_field_std.name = sort_field.name; - sort_field_std.eval_expressions = sort_field.eval_expressions; - sort_field_std.eval.scores = sort_field.eval.scores; - sort_fields_std.emplace_back(sort_field_std); - filter_trees_guard.release(); - continue; - } - - if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && - sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) { - - const auto field_it = search_schema.find(sort_field_std.name); - if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { - std::string error = "Could not find a field named `" + sort_field_std.name + - "` in the schema for sorting."; - return Option(404, error); - } - } - - if(sort_field_std.name == sort_field_const::group_found && is_group_by_query == false) { - std::string error = "group_by parameters should not be empty when using sort_by group_found"; - return Option(404, error); - } - - if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) { - std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search."; - return Option(404, error); - } - - StringUtils::toupper(sort_field_std.order); - - if(sort_field_std.order != sort_field_const::asc && sort_field_std.order != sort_field_const::desc) { - std::string error = "Order for field` " + sort_field_std.name + "` should be either ASC or DESC."; - return Option(400, error); - } - - sort_fields_std.emplace_back(sort_field_std); - } - - return Option(true); -} - Option Collection::extract_field_name(const std::string& field_name, const tsl::htrie_map& search_schema, std::vector& processed_search_fields,