From e614895ea4ec62f66ed87baa412140cba34562fd Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 6 Jul 2023 13:10:32 +0300 Subject: [PATCH] Add extra check for timeouts --- src/http_client.cpp | 3 +++ src/text_embedder_remote.cpp | 42 ++++++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/http_client.cpp b/src/http_client.cpp index c6e4b7fd..9b9a3239 100644 --- a/src/http_client.cpp +++ b/src/http_client.cpp @@ -156,6 +156,9 @@ long HttpClient::perform_curl(CURL *curl, std::map& re LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res); curl_easy_cleanup(curl); curl_slist_free_all(chunk); + if(res == CURLE_OPERATION_TIMEDOUT) { + return 408; + } return 500; } diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 931da129..f13f86ed 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -59,6 +59,11 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, headers["Authorization"] = "Bearer " + api_key; std::string res; auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers); + + if(res_code == 408) { + return Option(408, "OpenAI API timeout."); + } + if (res_code != 200) { nlohmann::json json_res; try { @@ -98,9 +103,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::string embedding_res; headers["Content-Type"] = "application/json"; - res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + if(res_code == 408) { + return Option(408, "OpenAI API timeout."); + } + if (res_code != 200) { nlohmann::json json_res; try { @@ -139,7 +148,7 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { try { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { - return embedding_res_t(400, "OpenAI API error: " + res); + } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; @@ -151,6 +160,9 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } return embedding_res_t(res_code, embedding_res); } try { @@ -188,7 +200,6 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector OpenAIEmbedder::batch_embed(const std::vector(); } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; @@ -271,11 +285,14 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, try { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { - return Option(400, "Google API error: " + res); + } + if(res_code == 408) { + return Option(408, "Google API timeout."); } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "Google API error: " + res); } + return Option(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } @@ -315,6 +332,9 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get(); } + if(res_code == 408) { + embedding_res["error"] = "Google API timeout."; + } return embedding_res_t(res_code, embedding_res); } try { @@ -389,6 +409,9 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns } catch (const std::exception& e) { return Option(400, "Got malformed response from GCP API."); } + if(json_res == 408) { + return Option(408, "GCP API timeout."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } @@ -463,6 +486,10 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } else { embedding_res["error"] = "Malformed response from GCP API."; } + + if(res_code == 408) { + embedding_res["error"] = "GCP API timeout."; + } return embedding_res_t(res_code, embedding_res); } nlohmann::json res_json; @@ -542,6 +569,10 @@ std::vector GCPEmbedder::batch_embed(const std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(res_code, embedding_res)); @@ -592,6 +623,9 @@ Option GCPEmbedder::generate_access_token(const std::string& refres if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } + if(res_code == 408) { + return Option(408, "GCP API timeout."); + } return Option(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } nlohmann::json res_json;