From 3da18ea6d51b94091e76039b65c9aaa963c6063f Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sat, 21 Oct 2023 00:28:51 +0300 Subject: [PATCH] Add support for larger embedding models --- include/text_embedder_manager.h | 2 ++ src/text_embedder_manager.cpp | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index 2681e0b9..68b79df9 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -16,6 +16,7 @@ struct text_embedding_model { std::string model_md5; std::string vocab_file_name; std::string vocab_md5; + std::string data_file_md5; TokenizerType tokenizer_type; std::string indexing_prefix = ""; std::string query_prefix = ""; @@ -56,6 +57,7 @@ public: static const std::string get_absolute_vocab_path(const std::string& model_name, const std::string& vocab_file_name); static const std::string get_absolute_config_path(const std::string& model_name); static const std::string get_model_url(const text_embedding_model& model); + static const std::string get_model_data_url(const text_embedding_model& model); static const std::string get_vocab_url(const text_embedding_model& model); static Option get_public_model_config(const std::string& model_name); static const std::string get_model_name_without_namespace(const std::string& model_name); diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index baa99ebf..3adbf82a 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -148,6 +148,10 @@ void TextEmbedderManager::delete_text_embedder(const std::string& model_path) { if (text_embedders.find(model_path) != text_embedders.end()) { text_embedders.erase(model_path); } + + if (public_models.find(model_path) != public_models.end()) { + public_models.erase(model_path); + } } void TextEmbedderManager::delete_all_text_embedders() { @@ -252,6 +256,16 @@ Option TextEmbedderManager::download_public_model(const text_embedding_mod return Option(400, "Failed to download model file"); } } + + if(!model.data_file_md5.empty()) { + if(!check_md5(get_absolute_model_path(actual_model_name) + "_data", model.data_file_md5)) { + long res = httpClient.download_file(get_model_data_url(model), get_absolute_model_path(actual_model_name) + "_data"); + if(res != 200) { + LOG(INFO) << "Failed to download public model data file: " << model.model_name; + return Option(400, "Failed to download model data file"); + } + } + } if(!check_md5(get_absolute_vocab_path(actual_model_name, model.vocab_file_name), model.vocab_md5)) { long res = httpClient.download_file(get_vocab_url(model), get_absolute_vocab_path(actual_model_name, model.vocab_file_name)); @@ -350,6 +364,10 @@ text_embedding_model::text_embedding_model(const nlohmann::json& json) { if(json.count("query_prefix") != 0) { query_prefix = json.at("query_prefix").get(); } + + if(json.count("data_md5") != 0) { + data_file_md5 = json.at("data_md5").get(); + } } @@ -388,6 +406,10 @@ const std::string TextEmbedderManager::get_model_url(const text_embedding_model& return MODELS_REPO_URL + model.model_name + "/model.onnx"; } +const std::string TextEmbedderManager::get_model_data_url(const text_embedding_model& model) { + return MODELS_REPO_URL + model.model_name + "/model.onnx_data"; +} + const std::string TextEmbedderManager::get_vocab_url(const text_embedding_model& model) { return MODELS_REPO_URL + model.model_name + "/" + model.vocab_file_name; }