Add extra check for timeouts

This commit is contained in:
ozanarmagan 2023-07-06 13:10:32 +03:00
parent 01a17a6972
commit e614895ea4
2 changed files with 41 additions and 4 deletions

View File

@ -156,6 +156,9 @@ long HttpClient::perform_curl(CURL *curl, std::map<std::string, std::string>& re
LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res);
curl_easy_cleanup(curl);
curl_slist_free_all(chunk);
if(res == CURLE_OPERATION_TIMEDOUT) {
return 408;
}
return 500;
}

View File

@ -59,6 +59,11 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
headers["Authorization"] = "Bearer " + api_key;
std::string res;
auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res;
try {
@ -98,9 +103,13 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
std::string embedding_res;
headers["Content-Type"] = "application/json";
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res;
try {
@ -139,7 +148,7 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return embedding_res_t(400, "OpenAI API error: " + res);
}
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
@ -151,6 +160,9 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
}
if(res_code == 408) {
embedding_res["error"] = "OpenAI API timeout.";
}
return embedding_res_t(res_code, embedding_res);
}
try {
@ -188,7 +200,6 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
json_res = nlohmann::json::object();
json_res["error"] = "OpenAI API error: " + res;
}
LOG(INFO) << "OpenAI API error: " << json_res.dump();
nlohmann::json embedding_res = nlohmann::json::object();
@ -201,6 +212,9 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
}
if(res_code == 408) {
embedding_res["error"] = "OpenAI API timeout.";
}
for(size_t i = 0; i < inputs.size(); i++) {
embedding_res["request"]["body"]["input"][0] = inputs[i];
@ -271,11 +285,14 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "Google API error: " + res);
}
if(res_code == 408) {
return Option<bool>(408, "Google API timeout.");
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "Google API error: " + res);
}
return Option<bool>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
@ -315,6 +332,9 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get<std::string>();
}
if(res_code == 408) {
embedding_res["error"] = "Google API timeout.";
}
return embedding_res_t(res_code, embedding_res);
}
try {
@ -389,6 +409,9 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from GCP API.");
}
if(json_res == 408) {
return Option<bool>(408, "GCP API timeout.");
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "GCP API error: " + res);
}
@ -463,6 +486,10 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) {
} else {
embedding_res["error"] = "Malformed response from GCP API.";
}
if(res_code == 408) {
embedding_res["error"] = "GCP API timeout.";
}
return embedding_res_t(res_code, embedding_res);
}
nlohmann::json res_json;
@ -542,6 +569,10 @@ std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::str
} else {
embedding_res["error"] = "Malformed response from GCP API.";
}
if(res_code == 408) {
embedding_res["error"] = "GCP API timeout.";
}
std::vector<embedding_res_t> outputs;
for(size_t i = 0; i < inputs.size(); i++) {
outputs.push_back(embedding_res_t(res_code, embedding_res));
@ -592,6 +623,9 @@ Option<std::string> GCPEmbedder::generate_access_token(const std::string& refres
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<std::string>(400, "GCP API error: " + res);
}
if(res_code == 408) {
return Option<std::string>(408, "GCP API timeout.");
}
return Option<std::string>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
nlohmann::json res_json;