#pragma once #include #include #include #include #include "option.h" #include "text_embedder_tokenizer.h" #include "text_embedder_remote.h" class TextEmbedder { public: // Constructor for local or public models TextEmbedder(const std::string& model_path); // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); Option> Embed(const std::string& text); Option>> batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { return remote_embedder_ != nullptr; } static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); private: std::unique_ptr session_; Ort::Env env_; encoded_input_t Encode(const std::string& text); batch_encoded_input_t batch_encode(const std::vector& inputs); std::unique_ptr tokenizer_; std::unique_ptr remote_embedder_; std::string vocab_file_name; static std::vector mean_pooling(const std::vector>& input); std::string output_tensor_name; static Option validate_remote_model(const nlohmann::json& model_config, unsigned int& num_dims); static Option validate_local_or_public_model(const nlohmann::json& model_config, unsigned int& num_dims); std::mutex mutex_; };