#pragma once #include #include #include #include #include "option.h" #include "text_embedder_tokenizer.h" #include "text_embedder_remote.h" class TextEmbedder { public: // Constructor for local or public models TextEmbedder(const std::string& model_path); // Constructor for remote models TextEmbedder(const nlohmann::json& model_config, size_t num_dims, const bool has_custom_dims = false); ~TextEmbedder(); embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2); std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200, const size_t remote_embedding_timeout_ms = 60000, const size_t remote_embedding_num_tries = 2); const std::string& get_vocab_file_name() const; const size_t get_num_dim() const; bool is_remote() { return remote_embedder_ != nullptr; } Option validate(); std::shared_ptr get_session() { return session_; } std::shared_ptr get_env() { return env_; } const TokenizerType get_tokenizer_type() { return tokenizer_->get_tokenizer_type(); } bool update_remote_embedder_apikey(const std::string& api_key) { return remote_embedder_->update_api_key(api_key); } private: std::shared_ptr session_; std::shared_ptr env_; encoded_input_t Encode(const std::string& text); batch_encoded_input_t batch_encode(const std::vector& inputs); std::unique_ptr tokenizer_; std::unique_ptr remote_embedder_; std::string vocab_file_name; static std::vector mean_pooling(const std::vector>& input, const std::vector& attention_mask); std::string output_tensor_name; size_t num_dim; std::mutex mutex_; };