mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 05:32:30 +08:00
Fix to check model type for custom models
This commit is contained in:
parent
5bcbb2832e
commit
85ed9090b2
@ -186,10 +186,20 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!config["model_type"].is_string() || !config["vocab_file_name"].is_string()) {
|
||||
LOG(ERROR) << "Invalid config file: " << TextEmbedderManager::get_absolute_config_path(model_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!std::filesystem::exists(TextEmbedderManager::get_model_subdir(model_name) + "/" + config["vocab_file_name"].get<std::string>())) {
|
||||
LOG(ERROR) << "Vocab file not found: " << TextEmbedderManager::get_model_subdir(model_name) + "/" + config["vocab_file_name"].get<std::string>();
|
||||
return false;
|
||||
}
|
||||
|
||||
if(config["model_type"].get<std::string>() != "bert" && config["model_type"].get<std::string>() != "xlm_roberta" && config["model_type"].get<std::string>() != "distilbert") {
|
||||
LOG(ERROR) << "Invalid model type: " << config["model_type"].get<std::string>();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Session session(env, abs_path.c_str(), session_options);
|
||||
|
Loading…
x
Reference in New Issue
Block a user