Fix to check model type for custom models

This commit is contained in:
ozanarmagan 2023-05-04 13:04:10 +03:00
parent 5bcbb2832e
commit 85ed9090b2

View File

@ -186,10 +186,20 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n
return false;
}
if(!config["model_type"].is_string() || !config["vocab_file_name"].is_string()) {
LOG(ERROR) << "Invalid config file: " << TextEmbedderManager::get_absolute_config_path(model_name);
return false;
}
if(!std::filesystem::exists(TextEmbedderManager::get_model_subdir(model_name) + "/" + config["vocab_file_name"].get<std::string>())) {
LOG(ERROR) << "Vocab file not found: " << TextEmbedderManager::get_model_subdir(model_name) + "/" + config["vocab_file_name"].get<std::string>();
return false;
}
if(config["model_type"].get<std::string>() != "bert" && config["model_type"].get<std::string>() != "xlm_roberta" && config["model_type"].get<std::string>() != "distilbert") {
LOG(ERROR) << "Invalid model type: " << config["model_type"].get<std::string>();
return false;
}
}
Ort::Session session(env, abs_path.c_str(), session_options);