diff --git a/src/conversation_model.cpp b/src/conversation_model.cpp index 1bef50eb..9ab87ebc 100644 --- a/src/conversation_model.cpp +++ b/src/conversation_model.cpp @@ -612,10 +612,19 @@ Option vLLMConversationModel::validate_model(const nlohmann::json& model_c if(!model_config["vllm_url"].is_string()) { return Option(400, "vLLM URL is not a string"); } + std::unordered_map headers; std::map res_headers; std::string res; + + if(model_config.count("api_key") != 0) { + if(!model_config["api_key"].is_string()) { + return Option(400, "API key is not a string"); + } + headers["Authorization"] = "Bearer " + model_config["api_key"].get(); + } + auto res_code = RemoteEmbedder::call_remote_api("GET", get_list_models_url(model_config["vllm_url"]), "", res, res_headers, headers); if(res_code == 408) { @@ -713,6 +722,11 @@ Option vLLMConversationModel::get_answer(const std::string& context req_body["messages"].push_back(message); std::string res; + + if(model_config.count("api_key") != 0) { + headers["Authorization"] = "Bearer " + model_config["api_key"].get(); + } + auto res_code = RemoteEmbedder::call_remote_api("POST", get_chat_completion_url(vllm_url), req_body.dump(), res, res_headers, headers); if(res_code == 408) { @@ -795,6 +809,10 @@ Option vLLMConversationModel::get_standalone_question(const nlohman req_body["messages"].push_back(message); + if(model_config.count("api_key") != 0) { + headers["Authorization"] = "Bearer " + model_config["api_key"].get(); + } + auto res_code = RemoteEmbedder::call_remote_api("POST", get_chat_completion_url(vllm_url), req_body.dump(), res, res_headers, headers); if(res_code == 408) {