Refactor remote/local text embedder initialization.

This commit is contained in:
Kishore Nallan 2023-08-03 15:24:59 +05:30
parent 956d596e43
commit cc9af18d9c
6 changed files with 43 additions and 35 deletions

View File

@ -14,15 +14,16 @@ class TextEmbedder {
// Constructor for local or public models
TextEmbedder(const std::string& model_path);
// Constructor for remote models
TextEmbedder(const nlohmann::json& model_config);
TextEmbedder(const nlohmann::json& model_config, size_t num_dims);
~TextEmbedder();
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200);
const std::string& get_vocab_file_name() const;
const size_t get_num_dim() const;
bool is_remote() {
return remote_embedder_ != nullptr;
}
Option<bool> validate(size_t& num_dims);
Option<bool> validate();
private:
std::unique_ptr<Ort::Session> session_;
Ort::Env env_;
@ -33,5 +34,6 @@ class TextEmbedder {
std::string vocab_file_name;
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
std::string output_tensor_name;
size_t num_dim;
std::mutex mutex_;
};

View File

@ -36,7 +36,6 @@ public:
TextEmbedderManager& operator=(const TextEmbedderManager&) = delete;
Option<TextEmbedder*> get_text_embedder(const nlohmann::json& model_config);
Option<bool> init_text_embedder(const nlohmann::json& model_config, size_t& num_dim);
void delete_text_embedder(const std::string& model_path);
void delete_all_text_embedders();
@ -69,9 +68,9 @@ public:
bool is_public_model(const std::string& model_name);
static bool is_remote_model(const std::string& model_name);
static Option<bool> validate_and_init_remote_model(const nlohmann::json& model_config, size_t& num_dims);
static Option<bool> validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims);
static Option<bool> validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims);
Option<bool> validate_and_init_remote_model(const nlohmann::json& model_config, size_t& num_dims);
Option<bool> validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims);
Option<bool> validate_and_init_model(const nlohmann::json& model_config, size_t& num_dims);
private:
TextEmbedderManager() = default;

View File

@ -80,13 +80,14 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection
size_t num_dim = 0;
auto& model_config = field_obj[fields::embed][fields::model_config];
auto res = TextEmbedderManager::validate_and_init_model(model_config, num_dim);
auto res = TextEmbedderManager::get_instance().validate_and_init_model(model_config, num_dim);
if(!res.ok()) {
const std::string& model_name = model_config["model_name"].get<std::string>();
LOG(ERROR) << "Error initializing model: " << model_name << ", error: " << res.error();
continue;
}
field_obj[fields::num_dim] = num_dim;
LOG(INFO) << "Model init done.";
}

View File

@ -1115,7 +1115,7 @@ Option<bool> field::validate_and_init_embed_fields(const std::vector<std::pair<s
const auto& model_config = field_json[fields::embed][fields::model_config];
size_t num_dim = 0;
auto res = TextEmbedderManager::validate_and_init_model(model_config, num_dim);
auto res = TextEmbedderManager::get_instance().validate_and_init_model(model_config, num_dim);
if(!res.ok()) {
return Option<bool>(res.code(), res.error());
}

View File

@ -8,7 +8,7 @@
#include <dlfcn.h>
TextEmbedder::TextEmbedder(const std::string& model_name) {
// create environment
// create environment for local model
Ort::SessionOptions session_options;
auto providers = Ort::GetAvailableProviders();
for(auto& provider : providers) {
@ -50,14 +50,15 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) {
Ort::AllocatorWithDefaultOptions allocator;
output_tensor_name = std::string(session_->GetOutputNameAllocated(i, allocator).get());
num_dim = shape[2];
break;
}
}
}
TextEmbedder::TextEmbedder(const nlohmann::json& model_config) {
TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims) {
const std::string& model_name = model_config["model_name"].get<std::string>();
LOG(INFO) << "Initializing embedding model: " << model_name;
LOG(INFO) << "Initializing remote embedding model: " << model_name;
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
if(model_namespace == "openai") {
@ -78,6 +79,8 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config) {
remote_embedder_ = std::make_unique<GCPEmbedder>(project_id, model_name, access_token, refresh_token, client_id, client_secret);
}
num_dim = num_dims;
}
@ -267,7 +270,7 @@ batch_encoded_input_t TextEmbedder::batch_encode(const std::vector<std::string>&
return encoded_inputs;
}
Option<bool> TextEmbedder::validate(size_t& num_dims) {
Option<bool> TextEmbedder::validate() {
if(session_->GetInputCount() != 3 && session_->GetInputCount() != 2) {
LOG(ERROR) << "Invalid model: input count is not 3 or 2";
return Option<bool>(400, "Invalid model: input count is not 3 or 2");
@ -300,7 +303,6 @@ Option<bool> TextEmbedder::validate(size_t& num_dims) {
for (size_t i = 0; i < output_tensor_count; i++) {
auto shape = session_->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
if (shape.size() == 3 && shape[0] == -1 && shape[1] == -1 && shape[2] > 0) {
num_dims = shape[2];
found_output_tensor = true;
break;
}
@ -313,3 +315,7 @@ Option<bool> TextEmbedder::validate(size_t& num_dims) {
return Option<bool>(true);
}
const size_t TextEmbedder::get_num_dim() const {
return num_dim;
}

View File

@ -42,7 +42,13 @@ Option<bool> TextEmbedderManager::validate_and_init_remote_model(const nlohmann:
return Option<bool>(400, "Invalid model namespace");
}
return TextEmbedderManager::get_instance().init_text_embedder(model_config, num_dims);
std::unique_lock<std::mutex> lock(text_embedders_mutex);
auto text_embedder_it = text_embedders.find(model_name);
if(text_embedder_it == text_embedders.end()) {
text_embedders.emplace(model_name, std::make_shared<TextEmbedder>(model_config, num_dims));
}
return Option<bool>(true);
}
Option<bool> TextEmbedderManager::validate_and_init_local_model(const nlohmann::json& model_config, size_t& num_dims) {
@ -53,9 +59,8 @@ Option<bool> TextEmbedderManager::validate_and_init_local_model(const nlohmann::
return public_model_op;
}
Ort::SessionOptions session_options;
Ort::Env env;
std::string abs_path = TextEmbedderManager::get_absolute_model_path(TextEmbedderManager::get_model_name_without_namespace(model_name));
std::string abs_path = TextEmbedderManager::get_absolute_model_path(
TextEmbedderManager::get_model_name_without_namespace(model_name));
if(!std::filesystem::exists(abs_path)) {
LOG(ERROR) << "Model file not found: " << abs_path;
@ -92,30 +97,25 @@ Option<bool> TextEmbedderManager::validate_and_init_local_model(const nlohmann::
return Option<bool>(400, "Invalid model type");
}
}
return TextEmbedderManager::get_instance().init_text_embedder(model_config, num_dims);
}
Option<bool> TextEmbedderManager::init_text_embedder(const nlohmann::json& model_config, size_t& num_dim) {
std::unique_lock<std::mutex> lock(text_embedders_mutex);
const std::string& model_name = model_config.at("model_name");
auto text_embedder_it = text_embedders.find(model_name);
if(text_embedder_it == text_embedders.end()) {
if(is_remote_model(model_name)) {
text_embedders.emplace(model_name, std::make_shared<TextEmbedder>(model_config));
} else {
const std::shared_ptr<TextEmbedder>& embedder = std::make_shared<TextEmbedder>(
get_model_name_without_namespace(model_name));
auto validate_op = embedder->validate(num_dim);
if(!validate_op.ok()) {
return validate_op;
}
text_embedders.emplace(model_name, embedder);
}
if(text_embedder_it != text_embedders.end()) {
num_dims = text_embedder_it->second->get_num_dim();
return Option<bool>(true);
}
const std::shared_ptr<TextEmbedder>& embedder = std::make_shared<TextEmbedder>(
get_model_name_without_namespace(model_name));
auto validate_op = embedder->validate();
if(!validate_op.ok()) {
return validate_op;
}
num_dims = embedder->get_num_dim();
text_embedders.emplace(model_name, embedder);
return Option<bool>(true);
}