Add support for larger embedding models

This commit is contained in:
ozanarmagan 2023-10-21 00:28:51 +03:00
parent dfa2872bdf
commit 3da18ea6d5
2 changed files with 24 additions and 0 deletions

View File

@ -16,6 +16,7 @@ struct text_embedding_model {
std::string model_md5;
std::string vocab_file_name;
std::string vocab_md5;
std::string data_file_md5;
TokenizerType tokenizer_type;
std::string indexing_prefix = "";
std::string query_prefix = "";
@ -56,6 +57,7 @@ public:
static const std::string get_absolute_vocab_path(const std::string& model_name, const std::string& vocab_file_name);
static const std::string get_absolute_config_path(const std::string& model_name);
static const std::string get_model_url(const text_embedding_model& model);
static const std::string get_model_data_url(const text_embedding_model& model);
static const std::string get_vocab_url(const text_embedding_model& model);
static Option<nlohmann::json> get_public_model_config(const std::string& model_name);
static const std::string get_model_name_without_namespace(const std::string& model_name);

View File

@ -148,6 +148,10 @@ void TextEmbedderManager::delete_text_embedder(const std::string& model_path) {
if (text_embedders.find(model_path) != text_embedders.end()) {
text_embedders.erase(model_path);
}
if (public_models.find(model_path) != public_models.end()) {
public_models.erase(model_path);
}
}
void TextEmbedderManager::delete_all_text_embedders() {
@ -252,6 +256,16 @@ Option<bool> TextEmbedderManager::download_public_model(const text_embedding_mod
return Option<bool>(400, "Failed to download model file");
}
}
if(!model.data_file_md5.empty()) {
if(!check_md5(get_absolute_model_path(actual_model_name) + "_data", model.data_file_md5)) {
long res = httpClient.download_file(get_model_data_url(model), get_absolute_model_path(actual_model_name) + "_data");
if(res != 200) {
LOG(INFO) << "Failed to download public model data file: " << model.model_name;
return Option<bool>(400, "Failed to download model data file");
}
}
}
if(!check_md5(get_absolute_vocab_path(actual_model_name, model.vocab_file_name), model.vocab_md5)) {
long res = httpClient.download_file(get_vocab_url(model), get_absolute_vocab_path(actual_model_name, model.vocab_file_name));
@ -350,6 +364,10 @@ text_embedding_model::text_embedding_model(const nlohmann::json& json) {
if(json.count("query_prefix") != 0) {
query_prefix = json.at("query_prefix").get<std::string>();
}
if(json.count("data_md5") != 0) {
data_file_md5 = json.at("data_md5").get<std::string>();
}
}
@ -388,6 +406,10 @@ const std::string TextEmbedderManager::get_model_url(const text_embedding_model&
return MODELS_REPO_URL + model.model_name + "/model.onnx";
}
const std::string TextEmbedderManager::get_model_data_url(const text_embedding_model& model) {
return MODELS_REPO_URL + model.model_name + "/model.onnx_data";
}
const std::string TextEmbedderManager::get_vocab_url(const text_embedding_model& model) {
return MODELS_REPO_URL + model.model_name + "/" + model.vocab_file_name;
}