#pragma once #include #include #include struct encoded_input_t { std::vector input_ids; std::vector token_type_ids; std::vector attention_mask; }; class TextEmbedder { public: TextEmbedder(const std::string& model_path); ~TextEmbedder(); std::vector Embed(const std::string& text); static bool is_model_valid(const std::string& model_path, unsigned int& num_dims); private: Ort::Session* session_; Ort::Env env_; encoded_input_t Encode(const std::string& text); BertTokenizer* tokenizer_; static std::vector mean_pooling(const std::vector>& input); std::string output_tensor_name; };