diff --git a/include/text_embedder.h b/include/text_embedder.h index 660e6ae4..ca64aa52 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -14,15 +14,16 @@ class TextEmbedder { // Constructor for local or public models TextEmbedder(const std::string& model_path); // Constructor for remote models - TextEmbedder(const nlohmann::json& model_config); + TextEmbedder(const nlohmann::json& model_config, size_t num_dims); ~TextEmbedder(); embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2); std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200); const std::string& get_vocab_file_name() const; + const size_t get_num_dim() const; bool is_remote() { return remote_embedder_ != nullptr; } - Option validate(size_t& num_dims); + Option validate(); private: std::unique_ptr session_; Ort::Env env_; @@ -33,5 +34,6 @@ class TextEmbedder { std::string vocab_file_name; static std::vector mean_pooling(const std::vector>& input); std::string output_tensor_name; + size_t num_dim; std::mutex mutex_; }; diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index b9158305..543e8f91 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -36,7 +36,6 @@ public: TextEmbedderManager& operator=(const TextEmbedderManager&) = delete; Option get_text_embedder(const nlohmann::json& model_config); - Option init_text_embedder(const nlohmann::json& model_config, size_t& num_dim); void delete_text_embedder(const std::string& model_path); void delete_all_text_embedders(); @@ -69,9 +68,9 @@ public: bool is_public_model(const std::string& model_name); static bool is_remote_model(const std::string& model_name); - static Option validate_and_init_remote_model(const nlohmann::json& model_config, size_t& num_dims); - static Option validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims); - static Option validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims); + Option validate_and_init_remote_model(const nlohmann::json& model_config, size_t& num_dims); + Option validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims); + Option validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims); private: TextEmbedderManager() = default; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 3610eeb5..9aec9e5c 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -80,13 +80,14 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection size_t num_dim = 0; auto& model_config = field_obj[fields::embed][fields::model_config]; - auto res = TextEmbedderManager::validate_and_init_model(model_config, num_dim); + auto res = TextEmbedderManager::get_instance().validate_and_init_model(model_config, num_dim); if(!res.ok()) { const std::string& model_name = model_config["model_name"].get(); LOG(ERROR) << "Error initializing model: " << model_name << ", error: " << res.error(); continue; } + field_obj[fields::num_dim] = num_dim; LOG(INFO) << "Model init done."; } diff --git a/src/field.cpp b/src/field.cpp index dc82a021..ca7d5149 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -1115,7 +1115,7 @@ Option field::validate_and_init_embed_fields(const std::vector(res.code(), res.error()); } diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index cdfa5eba..61244054 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -8,7 +8,7 @@ #include TextEmbedder::TextEmbedder(const std::string& model_name) { - // create environment + // create environment for local model Ort::SessionOptions session_options; auto providers = Ort::GetAvailableProviders(); for(auto& provider : providers) { @@ -50,14 +50,15 @@ TextEmbedder::TextEmbedder(const std::string& model_name) { if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) { Ort::AllocatorWithDefaultOptions allocator; output_tensor_name = std::string(session_->GetOutputNameAllocated(i, allocator).get()); + num_dim = shape[2]; break; } } } -TextEmbedder::TextEmbedder(const nlohmann::json& model_config) { +TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims) { const std::string& model_name = model_config["model_name"].get(); - LOG(INFO) << "Initializing embedding model: " << model_name; + LOG(INFO) << "Initializing remote embedding model: " << model_name; auto model_namespace = TextEmbedderManager::get_model_namespace(model_name); if(model_namespace == "openai") { @@ -78,6 +79,8 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config) { remote_embedder_ = std::make_unique(project_id, model_name, access_token, refresh_token, client_id, client_secret); } + + num_dim = num_dims; } @@ -267,7 +270,7 @@ batch_encoded_input_t TextEmbedder::batch_encode(const std::vector& return encoded_inputs; } -Option TextEmbedder::validate(size_t& num_dims) { +Option TextEmbedder::validate() { if(session_->GetInputCount() != 3 && session_->GetInputCount() != 2) { LOG(ERROR) << "Invalid model: input count is not 3 or 2"; return Option(400, "Invalid model: input count is not 3 or 2"); @@ -300,7 +303,6 @@ Option TextEmbedder::validate(size_t& num_dims) { for (size_t i = 0; i < output_tensor_count; i++) { auto shape = session_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) { - num_dims = shape[2]; found_output_tensor = true; break; } @@ -313,3 +315,7 @@ Option TextEmbedder::validate(size_t& num_dims) { return Option(true); } + +const size_t TextEmbedder::get_num_dim() const { + return num_dim; +} diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index 392fc61e..ac2c110f 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -42,7 +42,13 @@ Option TextEmbedderManager::validate_and_init_remote_model(const nlohmann: return Option(400, "Invalid model namespace"); } - return TextEmbedderManager::get_instance().init_text_embedder(model_config, num_dims); + std::unique_lock lock(text_embedders_mutex); + auto text_embedder_it = text_embedders.find(model_name); + if(text_embedder_it == text_embedders.end()) { + text_embedders.emplace(model_name, std::make_shared(model_config, num_dims)); + } + + return Option(true); } Option TextEmbedderManager::validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims) { @@ -53,9 +59,8 @@ Option TextEmbedderManager::validate_and_init_local_model(const nlohmann:: return public_model_op; } - Ort::SessionOptions session_options; - Ort::Env env; - std::string abs_path = TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::get_model_name_without_namespace(model_name)); + std::string abs_path = TextEmbedderManager::get_absolute_model_path( + TextEmbedderManager::get_model_name_without_namespace(model_name)); if(!std::filesystem::exists(abs_path)) { LOG(ERROR) << "Model file not found: " << abs_path; @@ -92,30 +97,25 @@ Option TextEmbedderManager::validate_and_init_local_model(const nlohmann:: return Option(400, "Invalid model type"); } } - - return TextEmbedderManager::get_instance().init_text_embedder(model_config, num_dims); -} -Option TextEmbedderManager::init_text_embedder(const nlohmann::json& model_config, size_t& num_dim) { std::unique_lock lock(text_embedders_mutex); - const std::string& model_name = model_config.at("model_name"); auto text_embedder_it = text_embedders.find(model_name); - if(text_embedder_it == text_embedders.end()) { - if(is_remote_model(model_name)) { - text_embedders.emplace(model_name, std::make_shared(model_config)); - } else { - const std::shared_ptr& embedder = std::make_shared( - get_model_name_without_namespace(model_name)); - auto validate_op = embedder->validate(num_dim); - if(!validate_op.ok()) { - return validate_op; - } - - text_embedders.emplace(model_name, embedder); - } + if(text_embedder_it != text_embedders.end()) { + num_dims = text_embedder_it->second->get_num_dim(); + return Option(true); } + const std::shared_ptr& embedder = std::make_shared( + get_model_name_without_namespace(model_name)); + + auto validate_op = embedder->validate(); + if(!validate_op.ok()) { + return validate_op; + } + + num_dims = embedder->get_num_dim(); + text_embedders.emplace(model_name, embedder); return Option(true); }