diff --git a/include/collection.h b/include/collection.h index 9dd69595..2f79bc68 100644 --- a/include/collection.h +++ b/include/collection.h @@ -207,7 +207,8 @@ private: Option validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, - bool is_wildcard_query, bool is_group_by_query = false) const; + bool is_wildcard_query,const bool is_vector_query, + bool is_group_by_query = false) const; Option persist_collection_meta(); @@ -397,7 +398,7 @@ public: nlohmann::json get_summary_json() const; - size_t batch_index_in_memory(std::vector& index_records); + size_t batch_index_in_memory(std::vector& index_records, const bool generate_embeddings = true); Option add(const std::string & json_str, const index_operation_t& operation=CREATE, const std::string& id="", @@ -462,7 +463,8 @@ public: const text_match_type_t match_type = max_score, const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, - const size_t page_offset = UINT32_MAX) const; + const size_t page_offset = UINT32_MAX, + const size_t vector_query_hits = 250) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/field.h b/include/field.h index a0eca2af..5ad381ba 100644 --- a/include/field.h +++ b/include/field.h @@ -550,6 +550,8 @@ namespace sort_field_const { static const std::string precision = "precision"; static const std::string missing_values = "missing_values"; + + static const std::string vector_distance = "_vector_distance"; } struct sort_by { diff --git a/include/http_server.h b/include/http_server.h index b6465e43..1ced92e7 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -51,7 +51,10 @@ public: bool is_res_start = true; h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS; - h2o_iovec_t res_body{}; + + std::string res_body; + h2o_iovec_t res_buff; + h2o_iovec_t res_content_type{}; int status = 0; const char* reason = nullptr; @@ -64,8 +67,10 @@ public: } } - void set_response(uint32_t status_code, const std::string& content_type, const std::string& body) { - res_body = h2o_strdup(&req->pool, body.c_str(), SIZE_MAX); + void set_response(uint32_t status_code, const std::string& content_type, std::string& body) { + std::string().swap(res_body); + res_body = std::move(body); + res_buff = h2o_iovec_t{.base = res_body.data(), .len = res_body.size()}; if(is_res_start) { res_content_type = h2o_strdup(&req->pool, content_type.c_str(), SIZE_MAX); diff --git a/include/index.h b/include/index.h index 38c02238..8db2945f 100644 --- a/include/index.h +++ b/include/index.h @@ -202,6 +202,9 @@ struct index_record { nlohmann::json new_doc; // new *full* document to be stored into disk nlohmann::json del_doc; // document containing the fields that should be deleted + nlohmann::json embedding_res; // embedding result + int embedding_status_code; // embedding status code + index_operation_t operation; bool is_update; @@ -344,6 +347,7 @@ private: static spp::sparse_hash_map eval_sentinel_value; static spp::sparse_hash_map geo_sentinel_value; static spp::sparse_hash_map str_sentinel_value; + static spp::sparse_hash_map vector_distance_sentinel_value; // Internal utility functions @@ -523,7 +527,7 @@ private: void initialize_facet_indexes(const field& facet_field); - static Option batch_embed_fields(std::vector& documents, + static void batch_embed_fields(std::vector& documents, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema); @@ -661,7 +665,7 @@ public: const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation); + const bool do_validation, const bool generate_embeddings = true); static size_t batch_memory_index(Index *index, std::vector& iter_batch, @@ -671,7 +675,7 @@ public: const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation); + const bool do_validation, const bool generate_embeddings = true); void index_field_in_memory(const field& afield, std::vector& iter_batch); @@ -929,7 +933,7 @@ public: size_t filter_index, int64_t max_field_match_score, int64_t* scores, - int64_t& match_score_index) const; + int64_t& match_score_index, float vector_distance = 0) const; void process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, diff --git a/include/text_embedder.h b/include/text_embedder.h index 4bae7746..cff7c7c7 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,8 +16,8 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - Option> Embed(const std::string& text); - Option>> batch_embed(const std::vector& inputs); + embedding_res_t Embed(const std::string& text); + std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { return remote_embedder_ != nullptr; diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 99f9146f..652c0b61 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -9,6 +9,18 @@ +struct embedding_res_t { + std::vector embedding; + nlohmann::json error = nlohmann::json::object(); + int status_code; + bool success; + + embedding_res_t(const std::vector& embedding) : embedding(embedding), success(true) {} + + embedding_res_t(int status_code, const nlohmann::json& error) : error(error), success(false), status_code(status_code) {} +}; + + class RemoteEmbedder { protected: @@ -16,11 +28,12 @@ class RemoteEmbedder { static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; public: - virtual Option> Embed(const std::string& text) = 0; - virtual Option>> batch_embed(const std::vector& inputs) = 0; + virtual embedding_res_t Embed(const std::string& text) = 0; + virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; } + virtual ~RemoteEmbedder() = default; }; @@ -34,8 +47,8 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - Option> Embed(const std::string& text) override; - Option>> batch_embed(const std::vector& inputs) override; + embedding_res_t Embed(const std::string& text) override; + std::vector batch_embed(const std::vector& inputs) override; }; @@ -49,8 +62,8 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - Option> Embed(const std::string& text) override; - Option>> batch_embed(const std::vector& inputs) override; + embedding_res_t Embed(const std::string& text) override; + std::vector batch_embed(const std::vector& inputs) override; }; @@ -75,8 +88,8 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - Option> Embed(const std::string& text) override; - Option>> batch_embed(const std::vector& inputs) override; + embedding_res_t Embed(const std::string& text) override; + std::vector batch_embed(const std::vector& inputs) override; }; diff --git a/include/text_embedder_tokenizer.h b/include/text_embedder_tokenizer.h index a9b3a41b..e3fc737d 100644 --- a/include/text_embedder_tokenizer.h +++ b/include/text_embedder_tokenizer.h @@ -30,6 +30,7 @@ class TextEmbeddingTokenizer { private: public: virtual encoded_input_t Encode(const std::string& text) = 0; + virtual ~TextEmbeddingTokenizer() = default; }; class BertTokenizerWrapper : public TextEmbeddingTokenizer { diff --git a/include/topster.h b/include/topster.h index 8e10abcb..2ca2363a 100644 --- a/include/topster.h +++ b/include/topster.h @@ -15,6 +15,11 @@ struct KV { uint64_t key{}; uint64_t distinct_key{}; int64_t scores[3]{}; // match score + 2 custom attributes + + // only to be used in hybrid search + float vector_distance = 0.0f; + int64_t text_match_score = 0; + reference_filter_result_t* reference_filter_result = nullptr; // to be used only in final aggregation @@ -43,6 +48,9 @@ struct KV { query_indices = kv.query_indices; kv.query_indices = nullptr; + + vector_distance = kv.vector_distance; + text_match_score = kv.text_match_score; } KV& operator=(KV&& kv) noexcept { @@ -60,6 +68,9 @@ struct KV { delete[] query_indices; query_indices = kv.query_indices; kv.query_indices = nullptr; + + vector_distance = kv.vector_distance; + text_match_score = kv.text_match_score; } return *this; @@ -80,6 +91,9 @@ struct KV { delete[] query_indices; query_indices = kv.query_indices; kv.query_indices = nullptr; + + vector_distance = kv.vector_distance; + text_match_score = kv.text_match_score; } return *this; diff --git a/include/validator.h b/include/validator.h index a8a4a8f9..a03ebd04 100644 --- a/include/validator.h +++ b/include/validator.h @@ -16,7 +16,7 @@ public: const index_operation_t op, const bool is_update, const std::string& fallback_field_type, - const DIRTY_VALUES& dirty_values); + const DIRTY_VALUES& dirty_values, const bool validate_embedding_fields = true); static Option coerce_element(const field& a_field, nlohmann::json& document, diff --git a/src/collection.cpp b/src/collection.cpp index 83d5266f..e4c1201e 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -562,12 +562,22 @@ void Collection::batch_index(std::vector& index_records, std::vect if(!index_record.indexed.ok()) { res["document"] = json_out[index_record.position]; res["error"] = index_record.indexed.error(); + if (!index_record.embedding_res.empty()) { + res["embedding_error"] = nlohmann::json::object(); + res["embedding_error"] = index_record.embedding_res; + res["error"] = index_record.embedding_res["error"]; + } res["code"] = index_record.indexed.code(); } } else { res["success"] = false; res["document"] = json_out[index_record.position]; res["error"] = index_record.indexed.error(); + if (!index_record.embedding_res.empty()) { + res["embedding_error"] = nlohmann::json::object(); + res["error"] = index_record.embedding_res["error"]; + res["embedding_error"] = index_record.embedding_res; + } res["code"] = index_record.indexed.code(); } @@ -599,11 +609,11 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t return Option<>(200); } -size_t Collection::batch_index_in_memory(std::vector& index_records) { +size_t Collection::batch_index_in_memory(std::vector& index_records, const bool generate_embeddings) { std::unique_lock lock(mutex); size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field, search_schema, embedding_fields, fallback_field_type, - token_separators, symbols_to_index, true); + token_separators, symbols_to_index, true, generate_embeddings); num_documents += num_indexed; return num_indexed; } @@ -732,6 +742,7 @@ void Collection::curate_results(string& actual_query, const string& filter_query Option Collection::validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, const bool is_wildcard_query, + const bool is_vector_query, const bool is_group_by_query) const { size_t num_sort_expressions = 0; @@ -906,7 +917,7 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && - sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found) { + sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) { const auto field_it = search_schema.find(sort_field_std.name); if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { @@ -920,6 +931,11 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< std::string error = "group_by parameters should not be empty when using sort_by group_found"; return Option(404, error); } + + if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) { + std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search."; + return Option(404, error); + } StringUtils::toupper(sort_field_std.order); @@ -942,6 +958,10 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc); } + if(is_vector_query) { + sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc); + } + if(!default_sorting_field.empty()) { sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc); } else { @@ -950,9 +970,15 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } bool found_match_score = false; + bool found_vector_distance = false; for(const auto & sort_field : sort_fields_std) { if(sort_field.name == sort_field_const::text_match) { found_match_score = true; + } + if(sort_field.name == sort_field_const::vector_distance) { + found_vector_distance = true; + } + if(found_match_score && found_vector_distance) { break; } } @@ -961,6 +987,10 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc); } + if(!found_vector_distance && is_vector_query && sort_fields.size() < 3) { + sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc); + } + if(sort_fields_std.size() > 3) { std::string message = "Only upto 3 sort_by fields can be specified."; return Option(422, message); @@ -1077,7 +1107,8 @@ Option Collection::search(std::string raw_query, const text_match_type_t match_type, const size_t facet_sample_percent, const size_t facet_sample_threshold, - const size_t page_offset) const { + const size_t page_offset, + const size_t vector_query_hits) const { std::shared_lock lock(mutex); @@ -1132,6 +1163,7 @@ Option Collection::search(std::string raw_query, group_limit = 0; } + vector_query_t vector_query; if(!vector_query_str.empty()) { auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this); @@ -1185,8 +1217,8 @@ Option Collection::search(std::string raw_query, // } if(raw_query == "*") { - std::string error = "Wildcard query is not supported for embedding fields."; - return Option(400, error); + // ignore embedding field if query is a wildcard + continue; } TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); @@ -1206,14 +1238,18 @@ Option Collection::search(std::string raw_query, std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; auto embedding_op = embedder->Embed(embed_query); - if(!embedding_op.ok()) { - return Option(400, embedding_op.error()); + if(!embedding_op.success) { + if(!embedding_op.error["error"].get().empty()) { + return Option(400, embedding_op.error["error"].get()); + } else { + return Option(400, embedding_op.error.dump()); + } } - - std::vector embedding = embedding_op.get(); + std::vector embedding = embedding_op.embedding; vector_query._reset(); vector_query.values = embedding; vector_query.field_name = field_name; + vector_query.k = vector_query_hits; continue; } @@ -1443,10 +1479,11 @@ Option Collection::search(std::string raw_query, bool is_wildcard_query = (query == "*"); bool is_group_by_query = group_by_fields.size() > 0; + bool is_vector_query = !vector_query.field_name.empty(); if(curated_sort_by.empty()) { auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, - sort_fields_std, is_wildcard_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1458,7 +1495,7 @@ Option Collection::search(std::string raw_query, } auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, - sort_fields_std, is_wildcard_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1888,7 +1925,10 @@ Option Collection::search(std::string raw_query, populate_text_match_info(wrapper_doc["text_match_info"], field_order_kv->scores[field_order_kv->match_score_index], match_type); } else { - wrapper_doc["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]); + wrapper_doc["hybrid_search_info"] = nlohmann::json::object(); + wrapper_doc["hybrid_search_info"]["rank_fusion_score"] = Index::int64_t_to_float(field_order_kv->scores[field_order_kv->match_score_index]); + wrapper_doc["hybrid_search_info"]["text_match_score"] = field_order_kv->text_match_score; + wrapper_doc["hybrid_search_info"]["vector_distance"] = field_order_kv->vector_distance; } } @@ -1906,7 +1946,7 @@ Option Collection::search(std::string raw_query, } if(!vector_query.field_name.empty() && query == "*") { - wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]); + wrapper_doc["vector_distance"] = field_order_kv->vector_distance; } hits_array.push_back(wrapper_doc); @@ -3708,6 +3748,10 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_field_names.push_back(f.name); } + if(f.embed.count(fields::from) != 0) { + embedding_fields.emplace(f.name, f); + } + fields.push_back(f); } @@ -3762,9 +3806,9 @@ Option Collection::batch_alter_data(const std::vector& alter_fields index->remove(seq_id, rec.doc, del_fields, true); } } - + Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields, - fallback_field_type, token_separators, symbols_to_index, true); + fallback_field_type, token_separators, symbols_to_index, true, false); iter_batch.clear(); } @@ -3869,6 +3913,21 @@ Option Collection::alter(nlohmann::json& alter_payload) { } } + + // hide credentials in the alter payload return + for(auto& field_json : alter_payload["fields"]) { + if(field_json[fields::embed].count(fields::model_config) != 0) { + hide_credential(field_json[fields::embed][fields::model_config], "api_key"); + hide_credential(field_json[fields::embed][fields::model_config], "access_token"); + hide_credential(field_json[fields::embed][fields::model_config], "refresh_token"); + hide_credential(field_json[fields::embed][fields::model_config], "client_id"); + hide_credential(field_json[fields::embed][fields::model_config], "client_secret"); + hide_credential(field_json[fields::embed][fields::model_config], "project_id"); + } + } + + + return Option(true); } @@ -4212,6 +4271,65 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, } } + for(const auto& kv: schema_changes["fields"].items()) { + // validate embedding fields externally + auto& field_json = kv.value(); + if(field_json.count(fields::embed) != 0 && !field_json[fields::embed].empty()) { + if(!field_json[fields::embed].is_object()) { + return Option(400, "Property `" + fields::embed + "` must be an object."); + } + + if(field_json[fields::embed].count(fields::from) == 0) { + return Option(400, "Property `" + fields::embed + "` must contain a `" + fields::from + "` property."); + } + + if(!field_json[fields::embed][fields::from].is_array()) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` must be an array."); + } + + if(field_json[fields::embed][fields::from].empty()) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` must have at least one element."); + } + + for(auto& embed_from_field : field_json[fields::embed][fields::from]) { + if(!embed_from_field.is_string()) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` must contain only field names as strings."); + } + } + + if(field_json[fields::type] != field_types::FLOAT_ARRAY) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` is only allowed on a float array field."); + } + + for(auto& embed_from_field : field_json[fields::embed][fields::from]) { + bool flag = false; + for(const auto& field : search_schema) { + if(field.name == embed_from_field) { + if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields."); + } + flag = true; + break; + } + } + if(!flag) { + for(const auto& other_kv: schema_changes["fields"].items()) { + if(other_kv.value()["name"] == embed_from_field) { + if(other_kv.value()[fields::type] != field_types::STRING && other_kv.value()[fields::type] != field_types::STRING_ARRAY) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields."); + } + flag = true; + break; + } + } + } + if(!flag) { + return Option(400, "Property `" + fields::embed + "." + fields::from + "` can only refer to string or string array fields."); + } + } + } + } + if(num_auto_detect_fields > 1) { return Option(400, "There can be only one field named `.*`."); } @@ -4819,10 +4937,10 @@ void Collection::process_remove_field_for_embedding_fields(const field& the_fiel void Collection::hide_credential(nlohmann::json& json, const std::string& credential_name) { if(json.count(credential_name) != 0) { - // hide api key with * except first 3 chars + // hide api key with * except first 5 chars std::string credential_name_str = json[credential_name]; - if(credential_name_str.size() > 3) { - json[credential_name] = credential_name_str.replace(3, credential_name_str.size() - 3, credential_name_str.size() - 3, '*'); + if(credential_name_str.size() > 5) { + json[credential_name] = credential_name_str.replace(5, credential_name_str.size() - 5, credential_name_str.size() - 5, '*'); } else { json[credential_name] = credential_name_str.replace(0, credential_name_str.size(), credential_name_str.size(), '*'); } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index b0c76f69..5323e1b0 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -669,6 +669,7 @@ Option CollectionManager::do_search(std::map& re const char *MAX_FACET_VALUES = "max_facet_values"; const char *VECTOR_QUERY = "vector_query"; + const char *VECTOR_QUERY_HITS = "vector_query_hits"; const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -821,6 +822,7 @@ Option CollectionManager::do_search(std::map& re size_t max_extra_suffix = INT16_MAX; bool enable_highlight_v1 = true; text_match_type_t match_type = max_score; + size_t vector_query_hits = 250; size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -847,6 +849,7 @@ Option CollectionManager::do_search(std::map& re {FILTER_CURATED_HITS, &filter_curated_hits_option}, {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, + {VECTOR_QUERY_HITS, &vector_query_hits}, }; std::unordered_map str_values = { @@ -1383,7 +1386,7 @@ Option CollectionManager::load_collection(const nlohmann::json &collection // batch must match atleast the number of shards if(exceeds_batch_mem_threshold || (num_valid_docs % batch_size == 0) || last_record) { size_t num_records = index_records.size(); - size_t num_indexed = collection->batch_index_in_memory(index_records); + size_t num_indexed = collection->batch_index_in_memory(index_records, false); batch_doc_str_size = 0; if(num_indexed != num_records) { diff --git a/src/core_api.cpp b/src/core_api.cpp index 7de42f6a..5d04ad07 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -928,14 +928,42 @@ bool post_add_document(const std::shared_ptr& req, const std::shared_p const index_operation_t operation = get_index_operation(req->params[ACTION]); const auto& dirty_values = collection->parse_dirty_values_option(req->params[DIRTY_VALUES_PARAM]); - Option inserted_doc_op = collection->add(req->body, operation, "", dirty_values); + nlohmann::json document; + std::vector json_lines = {req->body}; + const nlohmann::json& inserted_doc_op = collection->add_many(json_lines, document, operation, "", dirty_values, false, false); - if(!inserted_doc_op.ok()) { - res->set(inserted_doc_op.code(), inserted_doc_op.error()); + if(!inserted_doc_op["success"].get()) { + nlohmann::json res_doc; + + try { + res_doc = nlohmann::json::parse(json_lines[0]); + } catch(const std::exception& e) { + LOG(ERROR) << "JSON error: " << e.what(); + res->set_400("Bad JSON."); + return false; + } + + res->status_code = res_doc["code"].get(); + // erase keys from res_doc except error and embedding_error + for(auto it = res_doc.begin(); it != res_doc.end(); ) { + if(it.key() != "error" && it.key() != "embedding_error") { + it = res_doc.erase(it); + } else { + ++it; + } + } + + // rename error to message if not empty and exists + if(res_doc.count("error") != 0 && !res_doc["error"].get().empty()) { + res_doc["message"] = res_doc["error"]; + res_doc.erase("error"); + } + + res->body = res_doc.dump(); return false; } - res->set_201(inserted_doc_op.get().dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore)); + res->set_201(document.dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore)); return true; } diff --git a/src/http_client.cpp b/src/http_client.cpp index ef11b57d..7b3c274b 100644 --- a/src/http_client.cpp +++ b/src/http_client.cpp @@ -279,6 +279,8 @@ size_t HttpClient::curl_write_async_done(void *context, curl_socket_t item) { if(!req_res->res->is_alive) { // underlying client request is dead, don't try to send anymore data + // also, close the socket as we've overridden the close socket handler! + close(item); return 0; } diff --git a/src/http_server.cpp b/src/http_server.cpp index 6a704153..813ee450 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -504,6 +504,9 @@ int HttpServer::catch_all_handler(h2o_handler_t *_h2o_handler, h2o_req_t *req) { ); *allocated_generator = custom_gen; + // ensures that the first response need not wait for previous chunk to be done sending + response->notify(); + //LOG(INFO) << "Init res: " << custom_gen->response << ", ref count: " << custom_gen->response.use_count(); if(root_resource == "multi_search") { @@ -544,7 +547,9 @@ bool HttpServer::is_write_request(const std::string& root_resource, const std::s return false; } - bool write_free_request = (root_resource == "multi_search" || root_resource == "operations"); + bool write_free_request = (root_resource == "multi_search" || root_resource == "proxy" || + root_resource == "operations"); + if(!write_free_request && (http_method == "POST" || http_method == "PUT" || http_method == "DELETE" || http_method == "PATCH")) { @@ -632,9 +637,6 @@ int HttpServer::async_req_cb(void *ctx, int is_end_stream) { if(request->first_chunk_aggregate) { request->first_chunk_aggregate = false; - - // ensures that the first response need not wait for previous chunk to be done sending - response->notify(); } // default value for last_chunk_aggregate is false @@ -821,7 +823,7 @@ void HttpServer::stream_response(stream_response_state_t& state) { h2o_start_response(req, state.generator); } - h2o_send(req, &state.res_body, 1, H2O_SEND_STATE_FINAL); + h2o_send(req, &state.res_buff, 1, H2O_SEND_STATE_FINAL); h2o_dispose_request(req); return ; @@ -833,13 +835,13 @@ void HttpServer::stream_response(stream_response_state_t& state) { h2o_start_response(req, state.generator); } - if(state.res_body.len == 0 && state.send_state != H2O_SEND_STATE_FINAL) { + if(state.res_buff.len == 0 && state.send_state != H2O_SEND_STATE_FINAL) { // without this guard, http streaming will break state.generator->proceed(state.generator, req); return; } - h2o_send(req, &state.res_body, 1, state.send_state); + h2o_send(req, &state.res_buff, 1, state.send_state); //LOG(INFO) << "stream_response after send"; } diff --git a/src/index.cpp b/src/index.cpp index d2fbf3f5..f7cd7235 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -43,6 +43,7 @@ spp::sparse_hash_map Index::seq_id_sentinel_value; spp::sparse_hash_map Index::eval_sentinel_value; spp::sparse_hash_map Index::geo_sentinel_value; spp::sparse_hash_map Index::str_sentinel_value; +spp::sparse_hash_map Index::vector_distance_sentinel_value; struct token_posting_t { uint32_t token_id; @@ -427,10 +428,10 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation) { + const bool do_validation, const bool generate_embeddings) { // runs in a partitioned thread - std::vector docs_to_embed; + std::vector records_to_embed; for(size_t i = 0; i < batch_size; i++) { index_record& index_rec = iter_batch[batch_start_index + i]; @@ -453,7 +454,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite index_rec.operation, index_rec.is_update, fallback_field_type, - index_rec.dirty_values); + index_rec.dirty_values, generate_embeddings); if(!validation_op.ok()) { index_rec.index_failure(validation_op.code(), validation_op.error()); @@ -467,14 +468,16 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite index_rec.new_doc, index_rec.del_doc); scrub_reindex_doc(search_schema, index_rec.doc, index_rec.del_doc, index_rec.old_doc); - for(auto& field: index_rec.doc.items()) { - for(auto& embedding_field : embedding_fields) { - if(!embedding_field.embed[fields::from].is_null()) { - auto embed_from_vector = embedding_field.embed[fields::from].get>(); - for(auto& embed_from: embed_from_vector) { - if(embed_from == field.key()) { - docs_to_embed.push_back(&index_rec.new_doc); - break; + if(generate_embeddings) { + for(auto& field: index_rec.doc.items()) { + for(auto& embedding_field : embedding_fields) { + if(!embedding_field.embed[fields::from].is_null()) { + auto embed_from_vector = embedding_field.embed[fields::from].get>(); + for(auto& embed_from: embed_from_vector) { + if(embed_from == field.key()) { + records_to_embed.push_back(&index_rec); + break; + } } } } @@ -482,7 +485,9 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite } } else { handle_doc_ops(search_schema, index_rec.doc, index_rec.old_doc); - docs_to_embed.push_back(&index_rec.doc); + if(generate_embeddings) { + records_to_embed.push_back(&index_rec); + } } compute_token_offsets_facets(index_rec, search_schema, token_separators, symbols_to_index); @@ -512,13 +517,8 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite index_rec.index_failure(400, e.what()); } } - - auto embed_op = batch_embed_fields(docs_to_embed, embedding_fields, search_schema); - if(!embed_op.ok()) { - for(size_t i = 0; i < batch_size; i++) { - index_record& index_rec = iter_batch[batch_start_index + i]; - index_rec.index_failure(embed_op.code(), embed_op.error()); - } + if(generate_embeddings) { + batch_embed_fields(records_to_embed, embedding_fields, search_schema); } } @@ -529,7 +529,7 @@ size_t Index::batch_memory_index(Index *index, std::vector& iter_b const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation) { + const bool do_validation, const bool generate_embeddings) { const size_t concurrency = 4; const size_t num_threads = std::min(concurrency, iter_batch.size()); @@ -559,7 +559,7 @@ size_t Index::batch_memory_index(Index *index, std::vector& iter_b index->thread_pool->enqueue([&, batch_index, batch_len]() { write_log_index = local_write_log_index; validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema, - embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation); + embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, generate_embeddings); std::unique_lock lock(m_process); num_processed++; @@ -2530,12 +2530,12 @@ Option Index::search(std::vector& field_query_tokens, cons } int64_t scores[3] = {0}; - scores[0] = -float_to_int64_t(vec_dist_score); int64_t match_score_index = -1; - //LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first; + compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0, 0, scores, match_score_index, vec_dist_score); KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr); + kv.vector_distance = vec_dist_score; int ret = topster->add(&kv); if(group_limit != 0 && ret < 2) { @@ -2756,7 +2756,9 @@ Option Index::search(std::vector& field_query_tokens, cons VectorFilterFunctor filterFunctor(filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; - auto k = std::max(vector_query.k, fetch_size); + // use k as 100 by default for ensuring results stability in pagination + size_t default_k = 100; + auto k = std::max(vector_query.k, default_k); if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); @@ -2792,29 +2794,62 @@ Option Index::search(std::vector& field_query_tokens, cons continue; } // (1 / rank_of_document) * WEIGHT) + result->text_match_score = result->scores[result->match_score_index]; result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT); } - for(int i = 0; i < vec_results.size(); i++) { - auto& result = vec_results[i]; - auto doc_id = result.first; + std::vector vec_search_ids; // list of IDs found only in vector search + for(size_t res_index = 0; res_index < vec_results.size(); res_index++) { + auto& vec_result = vec_results[res_index]; + auto doc_id = vec_result.first; auto result_it = topster->kv_map.find(doc_id); - if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) { + if(result_it != topster->kv_map.end()) { + if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) { + continue; + } + + // result overlaps with keyword search: we have to combine the scores + auto result = result_it->second; // old_score + (1 / rank_of_document) * WEIGHT) - result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT)); + result->vector_distance = vec_result.second; + result->scores[result->match_score_index] = float_to_int64_t( + (int64_t_to_float(result->scores[result->match_score_index])) + + ((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT)); + + for(size_t i = 0;i < 3; i++) { + if(field_values[i] == &vector_distance_sentinel_value) { + result->scores[i] = float_to_int64_t(vec_result.second); + } + + if(sort_order[i] == -1) { + result->scores[i] = -result->scores[i]; + } + } + } else { - int64_t scores[3] = {0}; + // Result has been found only in vector search: we have to add it to both KV and result_ids // (1 / rank_of_document) * WEIGHT) - scores[0] = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT); - int64_t match_score_index = 0; + int64_t scores[3] = {0}; + int64_t match_score = float_to_int64_t((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT); + int64_t match_score_index = -1; + compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second); KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores); + kv.vector_distance = vec_result.second; topster->add(&kv); - ++all_result_ids_len; + vec_search_ids.push_back(doc_id); } } + + if(!vec_search_ids.empty()) { + uint32_t* new_all_result_ids = nullptr; + all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &vec_search_ids[0], + vec_search_ids.size(), &new_all_result_ids); + delete[] all_result_ids; + all_result_ids = new_all_result_ids; + } } } @@ -3778,7 +3813,7 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i std::array*, 3> field_values, const std::vector& geopoint_indices, uint32_t seq_id, size_t filter_index, int64_t max_field_match_score, - int64_t* scores, int64_t& match_score_index) const { + int64_t* scores, int64_t& match_score_index, float vector_distance) const { int64_t geopoint_distances[3]; @@ -3873,6 +3908,8 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } scores[0] = int64_t(found); + } else if(field_values[0] == &vector_distance_sentinel_value) { + scores[0] = float_to_int64_t(vector_distance); } else { auto it = field_values[0]->find(seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; @@ -3929,6 +3966,8 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } scores[1] = int64_t(found); + } else if(field_values[1] == &vector_distance_sentinel_value) { + scores[1] = float_to_int64_t(vector_distance); } else { auto it = field_values[1]->find(seq_id); scores[1] = (it == field_values[1]->end()) ? default_score : it->second; @@ -3981,6 +4020,8 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } scores[2] = int64_t(found); + } else if(field_values[2] == &vector_distance_sentinel_value) { + scores[2] = float_to_int64_t(vector_distance); } else { auto it = field_values[2]->find(seq_id); scores[2] = (it == field_values[2]->end()) ? default_score : it->second; @@ -4679,15 +4720,14 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root); auto filter_init_op = filter_result_iterator.init_status(); if (!filter_init_op.ok()) { return; } - sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids); - + } else if(sort_fields_std[i].name == sort_field_const::vector_distance) { + field_values[i] = &vector_distance_sentinel_value; } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); @@ -6091,13 +6131,26 @@ bool Index::common_results_exist(std::vector& leaves, bool must_match } -Option Index::batch_embed_fields(std::vector& documents, +void Index::batch_embed_fields(std::vector& records, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema) { for(const auto& field : embedding_fields) { - std::vector> texts_to_embed; + std::vector> texts_to_embed; auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]); - for(auto& document : documents) { + for(auto& record : records) { + if(!record->indexed.ok()) { + continue; + } + nlohmann::json* document; + if(record->is_update) { + document = &record->new_doc; + } else { + document = &record->doc; + } + + if(document == nullptr) { + continue; + } std::string text = indexing_prefix; auto embed_from = field.embed[fields::from].get>(); for(const auto& field_name : embed_from) { @@ -6110,8 +6163,8 @@ Option Index::batch_embed_fields(std::vector& documents, } } } - if(!text.empty()) { - texts_to_embed.push_back(std::make_pair(document, text)); + if(text != indexing_prefix) { + texts_to_embed.push_back(std::make_pair(record, text)); } } @@ -6123,13 +6176,15 @@ Option Index::batch_embed_fields(std::vector& documents, auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]); if(!embedder_op.ok()) { - return Option(400, embedder_op.error()); + LOG(ERROR) << "Error while getting embedder for model: " << field.embed[fields::model_config]; + LOG(ERROR) << "Error: " << embedder_op.error(); + return; } // sort texts by length std::sort(texts_to_embed.begin(), texts_to_embed.end(), - [](const std::pair& a, - const std::pair& b) { + [](const std::pair& a, + const std::pair& b) { return a.second.size() < b.second.size(); }); @@ -6139,19 +6194,24 @@ Option Index::batch_embed_fields(std::vector& documents, texts.push_back(text_to_embed.second); } - auto embedding_op = embedder_op.get()->batch_embed(texts); - if(!embedding_op.ok()) { - return Option(400, embedding_op.error()); - } + auto embeddings = embedder_op.get()->batch_embed(texts); - auto embeddings = embedding_op.get(); for(size_t i = 0; i < embeddings.size(); i++) { - auto& embedding = embeddings[i]; - auto& document = texts_to_embed[i].first; - (*document)[field.name] = embedding; + auto& embedding_res = embeddings[i]; + if(!embedding_res.success) { + texts_to_embed[i].first->embedding_res = embedding_res.error; + texts_to_embed[i].first->index_failure(embedding_res.status_code, ""); + continue; + } + nlohmann::json* document; + if(texts_to_embed[i].first->is_update) { + document = &texts_to_embed[i].first->new_doc; + } else { + document = &texts_to_embed[i].first->doc; + } + (*document)[field.name] = embedding_res.embedding; } } - return Option(true); } /* diff --git a/src/raft_server.cpp b/src/raft_server.cpp index 920b6b34..0ca0b7de 100644 --- a/src/raft_server.cpp +++ b/src/raft_server.cpp @@ -194,7 +194,8 @@ void ReplicationState::write(const std::shared_ptr& request, const std auto resource_check = cached_resource_stat_t::get_instance().has_enough_resources(raft_dir_path, config->get_disk_used_max_percentage(), config->get_memory_used_max_percentage()); - if (resource_check != cached_resource_stat_t::OK && request->http_method != "DELETE") { + if (resource_check != cached_resource_stat_t::OK && + request->http_method != "DELETE" && request->path_without_query != "/health") { response->set_422("Rejecting write: running out of resource type: " + std::string(magic_enum::enum_name(resource_check))); response->final = true; @@ -967,6 +968,36 @@ bool ReplicationState::get_ext_snapshot_succeeded() { std::string ReplicationState::get_leader_url() const { std::shared_lock lock(node_mutex); + if(!node) { + const Option & refreshed_nodes_op = Config::fetch_nodes_config(config->get_nodes()); + + if(!refreshed_nodes_op.ok()) { + LOG(WARNING) << "Error while fetching peer configuration: " << refreshed_nodes_op.error(); + return ""; + } + + const std::string& nodes_config = ReplicationState::to_nodes_config(peering_endpoint, + Config::get_instance().get_api_port(), + + refreshed_nodes_op.get()); + std::vector peers; + braft::Configuration peer_config; + peer_config.parse_from(nodes_config); + peer_config.list_peers(&peers); + + if(peers.empty()) { + LOG(WARNING) << "No peers found in nodes config: " << nodes_config; + return ""; + } + + + const std::string protocol = api_uses_ssl ? "https" : "http"; + std::string url = get_node_url_path(peers[0].to_string(), "/", protocol); + + LOG(INFO) << "Returning first peer as leader URL: " << url; + return url; + } + if(node->leader_id().is_empty()) { LOG(ERROR) << "Could not get leader status, as node does not have a leader!"; return ""; diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 9e168424..b2a67607 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) { TextEmbedder::TextEmbedder(const nlohmann::json& model_config) { auto model_name = model_config["model_name"].get(); - LOG(INFO) << "Loading model from remote: " << model_name; + LOG(INFO) << "Initializing remote embedding model: " << model_name; auto model_namespace = TextEmbedderManager::get_model_namespace(model_name); if(model_namespace == "openai") { @@ -83,7 +83,7 @@ std::vector TextEmbedder::mean_pooling(const std::vector> TextEmbedder::Embed(const std::string& text) { +embedding_res_t TextEmbedder::Embed(const std::string& text) { if(is_remote()) { return remote_embedder_->Embed(text); } else { @@ -129,12 +129,12 @@ Option> TextEmbedder::Embed(const std::string& text) { } auto pooled_output = mean_pooling(output); - return Option>(pooled_output); + return embedding_res_t(pooled_output); } } -Option>> TextEmbedder::batch_embed(const std::vector& inputs) { - std::vector> outputs; +std::vector TextEmbedder::batch_embed(const std::vector& inputs) { + std::vector outputs; if(!is_remote()) { for(int i = 0; i < inputs.size(); i += 8) { auto input_batch = std::vector(inputs.begin() + i, inputs.begin() + std::min(i + 8, static_cast(inputs.size()))); @@ -193,7 +193,7 @@ Option>> TextEmbedder::batch_embed(const std::vec // if seq length is 0, return empty vector if(input_shapes[0][1] == 0) { for(int i = 0; i < input_batch.size(); i++) { - outputs.push_back(std::vector()); + outputs.push_back(embedding_res_t(400, nlohmann::json({{"error", "Invalid input: empty sequence"}}))); } continue; } @@ -211,17 +211,14 @@ Option>> TextEmbedder::batch_embed(const std::vec } output.push_back(output_row); } - outputs.push_back(mean_pooling(output)); + outputs.push_back(embedding_res_t(mean_pooling(output))); } } } else { - auto embed_op = remote_embedder_->batch_embed(inputs); - if(!embed_op.ok()) { - return Option>>(embed_op.code(), embed_op.error()); - } - outputs = embed_op.get(); + outputs = std::move(remote_embedder_->batch_embed(inputs)); } - return Option>>(outputs); + + return outputs; } TextEmbedder::~TextEmbedder() { diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index de68640d..e7e0cf51 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -94,7 +94,6 @@ const std::string& TextEmbedderManager::get_model_dir() { } TextEmbedderManager::~TextEmbedderManager() { - delete_all_text_embedders(); } const std::string TextEmbedderManager::get_absolute_model_path(const std::string& model_name) { @@ -117,8 +116,9 @@ const bool TextEmbedderManager::check_md5(const std::string& file_path, const st std::stringstream ss,res; ss << stream.rdbuf(); MD5((unsigned char*)ss.str().c_str(), ss.str().length(), md5); - for (int i = 0; i < MD5_DIGEST_LENGTH; i++) { - res << std::hex << (int)md5[i]; + // convert md5 to hex string with leading zeros + for(int i = 0; i < MD5_DIGEST_LENGTH; i++) { + res << std::hex << std::setfill('0') << std::setw(2) << (int)md5[i]; } return res.str() == target_md5; } diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 083aa2e1..2e12f3b9 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -13,15 +13,16 @@ Option RemoteEmbedder::validate_string_properties(const nlohmann::json& mo long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers) { - if(raft_server == nullptr) { + if(raft_server == nullptr || raft_server->get_leader_url().empty()) { if(method == "GET") { - return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 10000, true); + return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 100000, true); } else if(method == "POST") { - return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 10000, true); + return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 100000, true); } else { return 400; } } + auto leader_url = raft_server->get_leader_url(); leader_url += "proxy"; nlohmann::json req_body; @@ -103,7 +104,7 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -Option> OpenAIEmbedder::Embed(const std::string& text) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; @@ -116,15 +117,23 @@ Option> OpenAIEmbedder::Embed(const std::string& text) { auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if (res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); - if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { - return Option>(400, "OpenAI API error: " + res); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["response"] = json_res; + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); } - return Option>(400, "OpenAI API error: " + res); + return embedding_res_t(res_code, embedding_res); } - return Option>(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); + + return embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); } -Option>> OpenAIEmbedder::batch_embed(const std::vector& inputs) { +std::vector OpenAIEmbedder::batch_embed(const std::vector& inputs) { nlohmann::json req_body; req_body["input"] = inputs; // remove "openai/" prefix @@ -137,20 +146,35 @@ Option>> OpenAIEmbedder::batch_embed(const std::v auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if(res_code != 200) { + std::vector outputs; + nlohmann::json json_res = nlohmann::json::parse(res); - if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { - return Option>>(400, "OpenAI API error: " + res); + LOG(INFO) << "OpenAI API error: " << json_res.dump(); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["response"] = json_res; + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + embedding_res["request"]["body"]["input"] = std::vector{inputs[0]}; + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); } - return Option>>(400, res); + + for(size_t i = 0; i < inputs.size(); i++) { + embedding_res["request"]["body"]["input"][0] = inputs[i]; + outputs.push_back(embedding_res_t(res_code, embedding_res)); + } + return outputs; } nlohmann::json res_json = nlohmann::json::parse(res); - std::vector> outputs; + std::vector outputs; for(auto& data : res_json["data"]) { - outputs.push_back(data["embedding"].get>()); + outputs.push_back(embedding_res_t(data["embedding"].get>())); } - return Option>>(outputs); + return outputs; } @@ -198,7 +222,7 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -Option> GoogleEmbedder::Embed(const std::string& text) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; @@ -210,27 +234,30 @@ Option> GoogleEmbedder::Embed(const std::string& text) { if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); - if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { - return Option>(400, "Google API error: " + res); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["response"] = json_res; + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get(); } - return Option>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); + return embedding_res_t(res_code, embedding_res); } - return Option>(nlohmann::json::parse(res)["embedding"]["value"].get>()); + return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); } -Option>> GoogleEmbedder::batch_embed(const std::vector& inputs) { - std::vector> outputs; +std::vector GoogleEmbedder::batch_embed(const std::vector& inputs) { + std::vector outputs; for(auto& input : inputs) { auto res = Embed(input); - if(!res.ok()) { - return Option>>(res.code(), res.error()); - } - outputs.push_back(res.get()); + outputs.push_back(res); } - return Option>>(outputs); + return outputs; } @@ -298,7 +325,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -Option> GCPEmbedder::Embed(const std::string& text) { +embedding_res_t GCPEmbedder::Embed(const std::string& text) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -316,7 +343,9 @@ Option> GCPEmbedder::Embed(const std::string& text) { if(res_code == 401) { auto refresh_op = generate_access_token(refresh_token, client_id, client_secret); if(!refresh_op.ok()) { - return Option>(refresh_op.code(), refresh_op.error()); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["error"] = refresh_op.error(); + return embedding_res_t(refresh_op.code(), embedding_res); } access_token = refresh_op.get(); // retry @@ -327,32 +356,32 @@ Option> GCPEmbedder::Embed(const std::string& text) { if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); - if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { - return Option>(400, "GCP API error: " + res); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["response"] = json_res; + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get(); } - return Option>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); + return embedding_res_t(res_code, embedding_res); } nlohmann::json res_json = nlohmann::json::parse(res); - return Option>(res_json["predictions"][0]["embeddings"]["values"].get>()); + return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get>()); } -Option>> GCPEmbedder::batch_embed(const std::vector& inputs) { +std::vector GCPEmbedder::batch_embed(const std::vector& inputs) { // GCP API has a limit of 5 instances per request if(inputs.size() > 5) { - std::vector> res; + std::vector res; for(size_t i = 0; i < inputs.size(); i += 5) { auto batch_res = batch_embed(std::vector(inputs.begin() + i, inputs.begin() + std::min(i + 5, inputs.size()))); - if(!batch_res.ok()) { - LOG(INFO) << "Batch embedding failed: " << batch_res.error(); - return Option>>(batch_res.code(), batch_res.error()); - } - auto batch = batch_res.get(); - res.insert(res.end(), batch.begin(), batch.end()); + res.insert(res.end(), batch_res.begin(), batch_res.end()); } - auto opt = Option>>(res); - return opt; + return res; } nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); @@ -371,7 +400,13 @@ Option>> GCPEmbedder::batch_embed(const std::vect if(res_code == 401) { auto refresh_op = generate_access_token(refresh_token, client_id, client_secret); if(!refresh_op.ok()) { - return Option>>(refresh_op.code(), refresh_op.error()); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["error"] = refresh_op.error(); + std::vector outputs; + for(size_t i = 0; i < inputs.size(); i++) { + outputs.push_back(embedding_res_t(refresh_op.code(), embedding_res)); + } + return outputs; } access_token = refresh_op.get(); // retry @@ -382,19 +417,29 @@ Option>> GCPEmbedder::batch_embed(const std::vect if(res_code != 200) { nlohmann::json json_res = nlohmann::json::parse(res); - if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { - return Option>>(400, "GCP API error: " + res); + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["response"] = json_res; + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get(); } - return Option>>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); + std::vector outputs; + for(size_t i = 0; i < inputs.size(); i++) { + outputs.push_back(embedding_res_t(res_code, embedding_res)); + } + return outputs; } nlohmann::json res_json = nlohmann::json::parse(res); - std::vector> outputs; + std::vector outputs; for(const auto& prediction : res_json["predictions"]) { - outputs.push_back(prediction["embeddings"]["values"].get>()); + outputs.push_back(embedding_res_t(prediction["embeddings"]["values"].get>())); } - return Option>>(outputs); + return outputs; } Option GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) { diff --git a/src/validator.cpp b/src/validator.cpp index 0caa1302..d2396511 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -606,7 +606,7 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, const index_operation_t op, const bool is_update, const std::string& fallback_field_type, - const DIRTY_VALUES& dirty_values) { + const DIRTY_VALUES& dirty_values, const bool validate_embedding_fields) { bool missing_default_sort_field = (!default_sorting_field.empty() && document.count(default_sorting_field) == 0); @@ -653,10 +653,12 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, } } - // validate embedding fields - auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update); - if(!validate_embed_op.ok()) { - return Option<>(validate_embed_op.code(), validate_embed_op.error()); + if(validate_embedding_fields) { + // validate embedding fields + auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update); + if(!validate_embed_op.ok()) { + return Option<>(validate_embed_op.code(), validate_embed_op.error()); + } } return Option<>(200); diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index f021be45..f081a3f1 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1629,4 +1629,46 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { embedding_fields_map = coll->get_embedding_fields(); ASSERT_EQ(0, embedding_fields_map.size()); +} + +TEST_F(CollectionSchemaChangeTest, DropAndReindexEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + + ASSERT_TRUE(op.ok()); + + // drop the embedding field and reindex + nlohmann::json schema_without_embedding = R"({ + "fields": [ + {"name": "embedding", "drop": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + auto update_op = op.get()->alter(schema_without_embedding); + + ASSERT_TRUE(update_op.ok()); + + auto embedding_fields_map = op.get()->get_embedding_fields(); + + ASSERT_EQ(1, embedding_fields_map.size()); + + // try adding a document + nlohmann::json doc; + doc["name"] = "hello"; + auto add_op = op.get()->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); + auto added_doc = add_op.get(); + + ASSERT_EQ(384, added_doc["embedding"].get>().size()); } \ No newline at end of file diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 7598a589..95c255e9 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -2246,3 +2246,145 @@ TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) { ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); } } + + +TEST_F(CollectionSortingTest, AscendingVectorDistance) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + std::vector> points = { + {3.0, 4.0}, + {9.0, 21.0}, + {8.0, 15.0}, + {1.0, 1.0}, + {5.0, 7.0} + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector sort_fields = { + sort_by("_vector_distance", "asc"), + }; + + auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "points:([8.0, 15.0])").get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"2", "1", "4", "0", "3"}; + for(size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } +} + + +TEST_F(CollectionSortingTest, DescendingVectorDistance) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + std::vector> points = { + {3.0, 4.0}, + {9.0, 21.0}, + {8.0, 15.0}, + {1.0, 1.0}, + {5.0, 7.0} + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector sort_fields = { + sort_by("_vector_distance", "DESC"), + }; + + auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "points:([8.0, 15.0])").get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"3", "0", "4", "1", "2"}; + + for(size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } +} + + +TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + std::vector> points = { + {1.0, 1.0}, + {2.0, 2.0}, + {3.0, 3.0}, + {4.0, 4.0}, + {5.0, 5.0}, + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector sort_fields = { + sort_by("_vector_distance", "desc"), + }; + + + + auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + + ASSERT_FALSE(results.ok()); + + ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error()); +} \ No newline at end of file diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 930c8176..9bc20f83 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4790,9 +4790,9 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) { ASSERT_EQ("butterfly", search_res["hits"][1]["document"]["name"].get()); ASSERT_EQ("butterball", search_res["hits"][2]["document"]["name"].get()); - ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["rank_fusion_score"].get()); - ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["rank_fusion_score"].get()); - ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); + ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); } TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { @@ -4813,8 +4813,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { spp::sparse_hash_set dummy_include_exclude; auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); - ASSERT_FALSE(search_res_op.ok()); - ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error()); + ASSERT_TRUE(search_res_op.ok()); } TEST_F(CollectionTest, CreateModelDirIfNotExists) { @@ -5060,7 +5059,7 @@ TEST_F(CollectionTest, HideOpenAIApiKey) { ASSERT_TRUE(op.ok()); auto summary = op.get()->get_summary_json(); // hide api key with * after first 3 characters - ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*')); + ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get(), api_key.replace(5, api_key.size() - 5, api_key.size() - 5, '*')); } TEST_F(CollectionTest, PrefixSearchDisabledForOpenAI) { @@ -5134,3 +5133,48 @@ TEST_F(CollectionTest, MoreThanOneEmbeddingField) { ASSERT_EQ("Only one embedding field is allowed in the query.", search_res_op.error()); } + +TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["names"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + + auto coll = op.get(); + + nlohmann::json doc; + doc["names"] = nlohmann::json::array(); + + // try adding + auto add_op = coll->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); + + ASSERT_TRUE(add_op.get()["embedding"].is_null()); + + // try updating + auto id = add_op.get()["id"]; + doc["names"].push_back("butter"); + std::string dirty_values; + + + auto update_op = coll->update_matching_filter("id:=" + id.get(), doc.dump(), dirty_values); + ASSERT_TRUE(update_op.ok()); + ASSERT_EQ(1, update_op.get()["num_updated"]); + + + auto get_op = coll->get(id); + ASSERT_TRUE(get_op.ok()); + + ASSERT_FALSE(get_op.get()["embedding"].is_null()); + + ASSERT_EQ(384, get_op.get()["embedding"].size()); +} \ No newline at end of file diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 995bddd7..b6b598dc 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -711,12 +711,42 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { 4, {off}, 32767, 32767, 2, false, true, "vec:(" + dummy_vec_string +")"); ASSERT_EQ(true, results_op.ok()); - - ASSERT_EQ(1, results_op.get()["found"].get()); ASSERT_EQ(1, results_op.get()["hits"].size()); } +TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "name", "type": "string", "facet": true}, + {"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + doc["name"] = "john doe"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results_op = coll1->search("zzz", {"name", "vec"}, "", {"name"}, {}, {0}, 20, 1, FREQUENCY, {true}, + Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2); + ASSERT_EQ(true, results_op.ok()); + ASSERT_EQ(1, results_op.get()["found"].get()); + ASSERT_EQ(1, results_op.get()["hits"].size()); + ASSERT_EQ(1, results_op.get()["facet_counts"].size()); + ASSERT_EQ(4, results_op.get()["facet_counts"][0].size()); + ASSERT_EQ("name", results_op.get()["facet_counts"][0]["field_name"]); +} + TEST_F(CollectionVectorTest, DistanceThresholdTest) { nlohmann::json schema = R"({ "name": "test",