#pragma once #include #include #include #include #include "option.h" #include "text_embedder_tokenizer.h" #include "text_embedder_remote.h" class TextEmbedder { public: // Constructor for local or public models TextEmbedder(const std::string& model_path); // Constructor for remote models TextEmbedder(const nlohmann::json& model_config, size_t num_dims); ~TextEmbedder(); embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2); std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200); const std::string& get_vocab_file_name() const; const size_t get_num_dim() const; bool is_remote() { return remote_embedder_ != nullptr; } Option validate(); private: std::unique_ptr session_; Ort::Env env_; encoded_input_t Encode(const std::string& text); batch_encoded_input_t batch_encode(const std::vector& inputs); std::unique_ptr tokenizer_; std::unique_ptr remote_embedder_; std::string vocab_file_name; static std::vector mean_pooling(const std::vector>& input); std::string output_tensor_name; size_t num_dim; std::mutex mutex_; };