diff --git a/include/collection.h b/include/collection.h index 2f79bc68..ff44506b 100644 --- a/include/collection.h +++ b/include/collection.h @@ -385,7 +385,7 @@ public: const std::vector& sort_by_fields); void batch_index(std::vector& index_records, std::vector& json_out, size_t &num_indexed, - const bool& return_doc, const bool& return_id); + const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size = 200); bool is_exceeding_memory_threshold() const; @@ -398,7 +398,7 @@ public: nlohmann::json get_summary_json() const; - size_t batch_index_in_memory(std::vector& index_records, const bool generate_embeddings = true); + size_t batch_index_in_memory(std::vector& index_records, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true); Option add(const std::string & json_str, const index_operation_t& operation=CREATE, const std::string& id="", @@ -407,7 +407,7 @@ public: nlohmann::json add_many(std::vector& json_lines, nlohmann::json& document, const index_operation_t& operation=CREATE, const std::string& id="", const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT, - const bool& return_doc=false, const bool& return_id=false); + const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=100); Option update_matching_filter(const std::string& filter_query, const std::string & json_str, diff --git a/include/index.h b/include/index.h index 5c1b77e0..410952aa 100644 --- a/include/index.h +++ b/include/index.h @@ -545,7 +545,7 @@ private: static void batch_embed_fields(std::vector& documents, const tsl::htrie_map& embedding_fields, - const tsl::htrie_map & search_schema); + const tsl::htrie_map & search_schema, const size_t remote_embedding_batch_size = 200); public: // for limiting number of results on multiple candidates / query rewrites @@ -678,7 +678,7 @@ public: const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation, const bool generate_embeddings = true); + const bool do_validation, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true); static size_t batch_memory_index(Index *index, std::vector& iter_batch, @@ -688,7 +688,7 @@ public: const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation, const bool generate_embeddings = true); + const bool do_validation, const size_t remote_embedding_batch_size = 200, const bool generate_embeddings = true); void index_field_in_memory(const field& afield, std::vector& iter_batch); diff --git a/include/text_embedder.h b/include/text_embedder.h index cff7c7c7..f5419de9 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -17,7 +17,7 @@ class TextEmbedder { TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); embedding_res_t Embed(const std::string& text); - std::vector batch_embed(const std::vector& inputs); + std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 0); const std::string& get_vocab_file_name() const; bool is_remote() { return remote_embedder_ != nullptr; diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 652c0b61..d3a9bd3d 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -29,7 +29,7 @@ class RemoteEmbedder { static inline ReplicationState* raft_server = nullptr; public: virtual embedding_res_t Embed(const std::string& text) = 0; - virtual std::vector batch_embed(const std::vector& inputs) = 0; + virtual std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) = 0; static void init(ReplicationState* rs) { raft_server = rs; } @@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder { OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; - std::vector batch_embed(const std::vector& inputs) override; + std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) override; }; @@ -63,7 +63,7 @@ class GoogleEmbedder : public RemoteEmbedder { GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; - std::vector batch_embed(const std::vector& inputs) override; + std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) override; }; @@ -89,7 +89,7 @@ class GCPEmbedder : public RemoteEmbedder { const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; - std::vector batch_embed(const std::vector& inputs) override; + std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200) override; }; diff --git a/src/collection.cpp b/src/collection.cpp index c55bee7d..a617ad96 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -295,7 +295,8 @@ Option Collection::add(const std::string & json_str, nlohmann::json Collection::add_many(std::vector& json_lines, nlohmann::json& document, const index_operation_t& operation, const std::string& id, - const DIRTY_VALUES& dirty_values, const bool& return_doc, const bool& return_id) { + const DIRTY_VALUES& dirty_values, const bool& return_doc, const bool& return_id, + const size_t remote_embedding_batch_size) { //LOG(INFO) << "Memory ratio. Max = " << max_memory_ratio << ", Used = " << SystemMetrics::used_memory_ratio(); std::vector index_records; @@ -385,7 +386,7 @@ nlohmann::json Collection::add_many(std::vector& json_lines, nlohma if((i+1) % index_batch_size == 0 || i == json_lines.size()-1 || repeated_doc) { - batch_index(index_records, json_lines, num_indexed, return_doc, return_id); + batch_index(index_records, json_lines, num_indexed, return_doc, return_id, remote_embedding_batch_size); // to return the document for the single doc add cases if(index_records.size() == 1) { @@ -502,9 +503,9 @@ bool Collection::is_exceeding_memory_threshold() const { } void Collection::batch_index(std::vector& index_records, std::vector& json_out, - size_t &num_indexed, const bool& return_doc, const bool& return_id) { + size_t &num_indexed, const bool& return_doc, const bool& return_id, const size_t remote_embedding_batch_size) { - batch_index_in_memory(index_records); + batch_index_in_memory(index_records, remote_embedding_batch_size); // store only documents that were indexed in-memory successfully for(auto& index_record: index_records) { @@ -608,11 +609,11 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t return Option<>(200); } -size_t Collection::batch_index_in_memory(std::vector& index_records, const bool generate_embeddings) { +size_t Collection::batch_index_in_memory(std::vector& index_records, const size_t remote_embedding_batch_size, const bool generate_embeddings) { std::unique_lock lock(mutex); size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field, search_schema, embedding_fields, fallback_field_type, - token_separators, symbols_to_index, true, generate_embeddings); + token_separators, symbols_to_index, true, remote_embedding_batch_size, generate_embeddings); num_documents += num_indexed; return num_indexed; } @@ -3807,7 +3808,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields } Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields, - fallback_field_type, token_separators, symbols_to_index, true, false); + fallback_field_type, token_separators, symbols_to_index, true, 100, false); iter_batch.clear(); } diff --git a/src/core_api.cpp b/src/core_api.cpp index 5d04ad07..59ee5dec 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -745,11 +745,16 @@ bool post_import_documents(const std::shared_ptr& req, const std::shar const char *DIRTY_VALUES = "dirty_values"; const char *RETURN_DOC = "return_doc"; const char *RETURN_ID = "return_id"; + const char *REMOTE_EMBEDDING_BATCH_SIZE = "remote_embedding_batch_size"; if(req->params.count(BATCH_SIZE) == 0) { req->params[BATCH_SIZE] = "40"; } + if(req->params.count(REMOTE_EMBEDDING_BATCH_SIZE) == 0) { + req->params[REMOTE_EMBEDDING_BATCH_SIZE] = "200"; + } + if(req->params.count(ACTION) == 0) { req->params[ACTION] = "create"; } @@ -796,6 +801,7 @@ bool post_import_documents(const std::shared_ptr& req, const std::shar } const size_t IMPORT_BATCH_SIZE = std::stoi(req->params[BATCH_SIZE]); + const size_t REMOTE_EMBEDDING_BATCH_SIZE_VAL = std::stoi(req->params[REMOTE_EMBEDDING_BATCH_SIZE]); if(IMPORT_BATCH_SIZE == 0) { res->final = true; @@ -804,6 +810,13 @@ bool post_import_documents(const std::shared_ptr& req, const std::shar return false; } + if(REMOTE_EMBEDDING_BATCH_SIZE_VAL == 0) { + res->final = true; + res->set_400("Parameter `" + std::string(REMOTE_EMBEDDING_BATCH_SIZE) + "` must be a positive integer."); + stream_response(req, res); + return false; + } + if(req->body_index == 0) { // will log for every major chunk of request body //LOG(INFO) << "Import, req->body.size=" << req->body.size() << ", batch_size=" << IMPORT_BATCH_SIZE; @@ -873,7 +886,7 @@ bool post_import_documents(const std::shared_ptr& req, const std::shar const bool& return_doc = req->params[RETURN_DOC] == "true"; const bool& return_id = req->params[RETURN_ID] == "true"; nlohmann::json json_res = collection->add_many(json_lines, document, operation, "", - dirty_values, return_doc, return_id); + dirty_values, return_doc, return_id, REMOTE_EMBEDDING_BATCH_SIZE_VAL); //const std::string& import_summary_json = json_res->dump(); //response_stream << import_summary_json << "\n"; diff --git a/src/index.cpp b/src/index.cpp index 4e6ac914..1b87578f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -416,7 +416,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation, const bool generate_embeddings) { + const bool do_validation, const size_t remote_embedding_batch_size, const bool generate_embeddings) { // runs in a partitioned thread std::vector records_to_embed; @@ -505,7 +505,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite } } if(generate_embeddings) { - batch_embed_fields(records_to_embed, embedding_fields, search_schema); + batch_embed_fields(records_to_embed, embedding_fields, search_schema, remote_embedding_batch_size); } } @@ -516,7 +516,7 @@ size_t Index::batch_memory_index(Index *index, std::vector& iter_b const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, - const bool do_validation, const bool generate_embeddings) { + const bool do_validation, const size_t remote_embedding_batch_size, const bool generate_embeddings) { const size_t concurrency = 4; const size_t num_threads = std::min(concurrency, iter_batch.size()); @@ -546,7 +546,7 @@ size_t Index::batch_memory_index(Index *index, std::vector& iter_b index->thread_pool->enqueue([&, batch_index, batch_len]() { write_log_index = local_write_log_index; validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema, - embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, generate_embeddings); + embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation, remote_embedding_batch_size, generate_embeddings); std::unique_lock lock(m_process); num_processed++; @@ -6468,7 +6468,7 @@ bool Index::common_results_exist(std::vector& leaves, bool must_match void Index::batch_embed_fields(std::vector& records, const tsl::htrie_map& embedding_fields, - const tsl::htrie_map & search_schema) { + const tsl::htrie_map & search_schema, const size_t remote_embedding_batch_size) { for(const auto& field : embedding_fields) { std::vector> texts_to_embed; auto indexing_prefix = TextEmbedderManager::get_instance().get_indexing_prefix(field.embed[fields::model_config]); @@ -6529,7 +6529,7 @@ void Index::batch_embed_fields(std::vector& records, texts.push_back(text_to_embed.second); } - auto embeddings = embedder_op.get()->batch_embed(texts); + auto embeddings = embedder_op.get()->batch_embed(texts, remote_embedding_batch_size); for(size_t i = 0; i < embeddings.size(); i++) { auto& embedding_res = embeddings[i]; diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b2a67607..e36d3683 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -133,7 +133,7 @@ embedding_res_t TextEmbedder::Embed(const std::string& text) { } } -std::vector TextEmbedder::batch_embed(const std::vector& inputs) { +std::vector TextEmbedder::batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size) { std::vector outputs; if(!is_remote()) { for(int i = 0; i < inputs.size(); i += 8) { @@ -215,7 +215,7 @@ std::vector TextEmbedder::batch_embed(const std::vectorbatch_embed(inputs)); + outputs = std::move(remote_embedder_->batch_embed(inputs, remote_embedding_batch_size)); } return outputs; diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 4e844871..edc7e3b5 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -182,7 +182,17 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { } } -std::vector OpenAIEmbedder::batch_embed(const std::vector& inputs) { +std::vector OpenAIEmbedder::batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size) { + // call recursively if inputs larger than remote_embedding_batch_size + if(inputs.size() > remote_embedding_batch_size) { + std::vector outputs; + for(size_t i = 0; i < inputs.size(); i += remote_embedding_batch_size) { + auto batch = std::vector(inputs.begin() + i, inputs.begin() + std::min(i + remote_embedding_batch_size, inputs.size())); + auto batch_outputs = batch_embed(batch, remote_embedding_batch_size); + outputs.insert(outputs.end(), batch_outputs.begin(), batch_outputs.end()); + } + return outputs; + } nlohmann::json req_body; req_body["input"] = inputs; // remove "openai/" prefix @@ -326,7 +336,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { nlohmann::json json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from Google API." + json_res["error"] = "Malformed response from Google API."; } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; @@ -356,7 +366,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { } -std::vector GoogleEmbedder::batch_embed(const std::vector& inputs) { +std::vector GoogleEmbedder::batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size) { std::vector outputs; for(auto& input : inputs) { auto res = Embed(input); @@ -513,7 +523,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } -std::vector GCPEmbedder::batch_embed(const std::vector& inputs) { +std::vector GCPEmbedder::batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size) { // GCP API has a limit of 5 instances per request if(inputs.size() > 5) { std::vector res;