From 65dbeb6e75ea97a795df4a4a96327bd5b6c1ddd7 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Wed, 5 Jul 2023 01:28:22 +0300 Subject: [PATCH 01/15] Add timeout and num retries as search params --- include/collection.h | 4 +++- include/http_proxy.h | 4 ++-- include/text_embedder.h | 2 +- include/text_embedder_remote.h | 8 +++---- src/collection.cpp | 6 +++-- src/collection_manager.cpp | 13 +++++++++- src/http_proxy.cpp | 44 ++++++++++++++++++++++++++-------- src/text_embedder.cpp | 4 ++-- src/text_embedder_remote.cpp | 14 ++++++++--- 9 files changed, 73 insertions(+), 26 deletions(-) diff --git a/include/collection.h b/include/collection.h index 2f79bc68..f4c930f4 100644 --- a/include/collection.h +++ b/include/collection.h @@ -464,7 +464,9 @@ public: const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, const size_t page_offset = UINT32_MAX, - const size_t vector_query_hits = 250) const; + const size_t vector_query_hits = 250, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_retries = 2) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/http_proxy.h b/include/http_proxy.h index 851daa91..232576f6 100644 --- a/include/http_proxy.h +++ b/include/http_proxy.h @@ -32,11 +32,11 @@ class HttpProxy { void operator=(const HttpProxy&) = delete; HttpProxy(HttpProxy&&) = delete; void operator=(HttpProxy&&) = delete; - http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers); + http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers); private: HttpProxy(); ~HttpProxy() = default; - http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}); + http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}, const size_t timeout_ms = 30000); // lru cache for http requests diff --git a/include/text_embedder.h b/include/text_embedder.h index cff7c7c7..4d74292e 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text); + embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 652c0b61..3f9937c3 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -28,7 +28,7 @@ class RemoteEmbedder { static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; public: - virtual embedding_res_t Embed(const std::string& text) = 0; + virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -47,7 +47,7 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs) override; }; @@ -62,7 +62,7 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs) override; }; @@ -88,7 +88,7 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs) override; }; diff --git a/src/collection.cpp b/src/collection.cpp index c55bee7d..8c193fc0 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1107,7 +1107,9 @@ Option Collection::search(std::string raw_query, const size_t facet_sample_percent, const size_t facet_sample_threshold, const size_t page_offset, - const size_t vector_query_hits) const { + const size_t vector_query_hits, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_retries) const { std::shared_lock lock(mutex); @@ -1236,7 +1238,7 @@ Option Collection::search(std::string raw_query, } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query); + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retries); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 5323e1b0..d86ec2a1 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -671,6 +671,9 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY = "vector_query"; const char *VECTOR_QUERY_HITS = "vector_query_hits"; + const char *REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; + const char *REMOTE_EMBEDDING_NUM_RETRIES = "remote_embedding_num_retries"; + const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -824,6 +827,9 @@ Option CollectionManager::do_search(std::map& re text_match_type_t match_type = max_score; size_t vector_query_hits = 250; + size_t remote_embedding_timeout_ms = 30000; + size_t remote_embedding_num_retries = 2; + size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -850,6 +856,8 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, + {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, + {REMOTE_EMBEDDING_NUM_RETRIES, &remote_embedding_num_retries} }; std::unordered_map str_values = { @@ -1070,7 +1078,10 @@ Option CollectionManager::do_search(std::map& re match_type, facet_sample_percent, facet_sample_threshold, - offset + offset, + vector_query_hits, + remote_embedding_timeout_ms, + remote_embedding_num_retries ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index eb562849..d44ee06a 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -8,17 +8,19 @@ HttpProxy::HttpProxy() : cache(30s){ } -http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, + const std::unordered_map& headers, const size_t timeout_ms) { HttpClient& client = HttpClient::get_instance(); http_proxy_res_t res; + if(method == "GET") { - res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms); } else if(method == "POST") { - res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms); } else if(method == "PUT") { - res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000); + res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms); } else if(method == "DELETE") { - res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000); + res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms); } else { res.status_code = 400; nlohmann::json j; @@ -29,7 +31,7 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth } -http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers) { // check if url is in cache uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); @@ -42,13 +44,34 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth return cache[key]; } - auto res = call(url, method, body, headers); + size_t timeout_ms = 30000; + size_t num_retries = 2; - if(res.status_code == 500){ - // retry - res = call(url, method, body, headers); + if(headers.find("remote_embedding_timeout_ms") != headers.end()) { + std::stringstream ss(headers["remote_embedding_timeout_ms"]); + ss >> timeout_ms; + + headers.erase("remote_embedding_timeout_ms"); } + + if(headers.find("remote_embedding_num_retries") != headers.end()) { + std::stringstream ss(headers["remote_embedding_num_retries"]); + ss >> num_retries; + + headers.erase("remote_embedding_num_retries"); + } + + http_proxy_res_t res; + for(size_t i = 0;i < num_retries;i++) { + res = call(url, method, body, headers, timeout_ms); + + if(res.status_code != 500) { + break; + } + } + + if(res.status_code == 500){ nlohmann::json j; j["message"] = "Server error on remote server. Please try again later."; @@ -61,5 +84,6 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth cache.insert(key, res); } + return res; } \ No newline at end of file diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b2a67607..c235ecf8 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text); + return remote_embedder_->Embed(text, remote_embedding_timeout_ms, remote_embedding_num_tries); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index bb984e86..336c7a74 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -104,11 +104,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedding_timeout_ms, const int remote_embedding_num_tries) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; + headers["remote_embedding_timeout_ms"] = std::to_string(remote_embedding_timeout_ms); + headers["remote_embedding_num_tries"] = std::to_string(remote_embedding_num_tries); std::string res; nlohmann::json req_body; req_body["input"] = text; @@ -170,10 +172,12 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; + for(auto& data : res_json["data"]) { outputs.push_back(embedding_res_t(data["embedding"].get>())); } + return outputs; } @@ -222,10 +226,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedding_timeout_ms, const int remote_embedding_num_tries) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; + headers["remote_embedding_timeout_ms"] = std::to_string(remote_embedding_timeout_ms); + headers["remote_embedding_num_tries"] = std::to_string(remote_embedding_num_tries); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -325,7 +331,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedding_timeout_ms, const int remote_embedding_num_tries) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -334,6 +340,8 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { std::unordered_map headers; headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; + headers["remote_embedding_timeout_ms"] = std::to_string(remote_embedding_timeout_ms); + headers["remote_embedding_num_tries"] = std::to_string(remote_embedding_num_tries); std::map res_headers; std::string res; From 2ae06c5824c7dddf3436263b42e109514ca5fafd Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 6 Jul 2023 10:49:26 +0300 Subject: [PATCH 02/15] Revert "Add timeout and num retries as search params" This reverts commit 65dbeb6e75ea97a795df4a4a96327bd5b6c1ddd7. --- include/collection.h | 4 +--- include/http_proxy.h | 4 ++-- include/text_embedder.h | 2 +- include/text_embedder_remote.h | 8 +++---- src/collection.cpp | 6 ++--- src/collection_manager.cpp | 13 +--------- src/http_proxy.cpp | 44 ++++++++-------------------------- src/text_embedder.cpp | 4 ++-- src/text_embedder_remote.cpp | 14 +++-------- 9 files changed, 26 insertions(+), 73 deletions(-) diff --git a/include/collection.h b/include/collection.h index f4c930f4..2f79bc68 100644 --- a/include/collection.h +++ b/include/collection.h @@ -464,9 +464,7 @@ public: const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, const size_t page_offset = UINT32_MAX, - const size_t vector_query_hits = 250, - const size_t remote_embedding_timeout_ms = 30000, - const size_t remote_embedding_num_retries = 2) const; + const size_t vector_query_hits = 250) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/http_proxy.h b/include/http_proxy.h index 232576f6..851daa91 100644 --- a/include/http_proxy.h +++ b/include/http_proxy.h @@ -32,11 +32,11 @@ class HttpProxy { void operator=(const HttpProxy&) = delete; HttpProxy(HttpProxy&&) = delete; void operator=(HttpProxy&&) = delete; - http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers); + http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers); private: HttpProxy(); ~HttpProxy() = default; - http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}, const size_t timeout_ms = 30000); + http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}); // lru cache for http requests diff --git a/include/text_embedder.h b/include/text_embedder.h index 4d74292e..cff7c7c7 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2); + embedding_res_t Embed(const std::string& text); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 3f9937c3..652c0b61 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -28,7 +28,7 @@ class RemoteEmbedder { static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; public: - virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) = 0; + virtual embedding_res_t Embed(const std::string& text) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -47,7 +47,7 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) override; + embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; }; @@ -62,7 +62,7 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) override; + embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; }; @@ -88,7 +88,7 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedding_timeout_ms = 30000, const int remote_embedding_num_tries = 2) override; + embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; }; diff --git a/src/collection.cpp b/src/collection.cpp index 8c193fc0..c55bee7d 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1107,9 +1107,7 @@ Option Collection::search(std::string raw_query, const size_t facet_sample_percent, const size_t facet_sample_threshold, const size_t page_offset, - const size_t vector_query_hits, - const size_t remote_embedding_timeout_ms, - const size_t remote_embedding_num_retries) const { + const size_t vector_query_hits) const { std::shared_lock lock(mutex); @@ -1238,7 +1236,7 @@ Option Collection::search(std::string raw_query, } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retries); + auto embedding_op = embedder->Embed(embed_query); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index d86ec2a1..5323e1b0 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -671,9 +671,6 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY = "vector_query"; const char *VECTOR_QUERY_HITS = "vector_query_hits"; - const char *REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; - const char *REMOTE_EMBEDDING_NUM_RETRIES = "remote_embedding_num_retries"; - const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -827,9 +824,6 @@ Option CollectionManager::do_search(std::map& re text_match_type_t match_type = max_score; size_t vector_query_hits = 250; - size_t remote_embedding_timeout_ms = 30000; - size_t remote_embedding_num_retries = 2; - size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -856,8 +850,6 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, - {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, - {REMOTE_EMBEDDING_NUM_RETRIES, &remote_embedding_num_retries} }; std::unordered_map str_values = { @@ -1078,10 +1070,7 @@ Option CollectionManager::do_search(std::map& re match_type, facet_sample_percent, facet_sample_threshold, - offset, - vector_query_hits, - remote_embedding_timeout_ms, - remote_embedding_num_retries + offset ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index d44ee06a..eb562849 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -8,19 +8,17 @@ HttpProxy::HttpProxy() : cache(30s){ } -http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, - const std::unordered_map& headers, const size_t timeout_ms) { +http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { HttpClient& client = HttpClient::get_instance(); http_proxy_res_t res; - if(method == "GET") { - res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms); + res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000); } else if(method == "POST") { - res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms); + res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000); } else if(method == "PUT") { - res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms); + res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000); } else if(method == "DELETE") { - res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms); + res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000); } else { res.status_code = 400; nlohmann::json j; @@ -31,7 +29,7 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth } -http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers) { +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { // check if url is in cache uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); @@ -44,34 +42,13 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth return cache[key]; } - size_t timeout_ms = 30000; - size_t num_retries = 2; + auto res = call(url, method, body, headers); - if(headers.find("remote_embedding_timeout_ms") != headers.end()) { - std::stringstream ss(headers["remote_embedding_timeout_ms"]); - ss >> timeout_ms; - - headers.erase("remote_embedding_timeout_ms"); + if(res.status_code == 500){ + // retry + res = call(url, method, body, headers); } - - if(headers.find("remote_embedding_num_retries") != headers.end()) { - std::stringstream ss(headers["remote_embedding_num_retries"]); - ss >> num_retries; - - headers.erase("remote_embedding_num_retries"); - } - - http_proxy_res_t res; - for(size_t i = 0;i < num_retries;i++) { - res = call(url, method, body, headers, timeout_ms); - - if(res.status_code != 500) { - break; - } - } - - if(res.status_code == 500){ nlohmann::json j; j["message"] = "Server error on remote server. Please try again later."; @@ -84,6 +61,5 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth cache.insert(key, res); } - return res; } \ No newline at end of file diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index c235ecf8..b2a67607 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text, remote_embedding_timeout_ms, remote_embedding_num_tries); + return remote_embedder_->Embed(text); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 336c7a74..bb984e86 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -104,13 +104,11 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedding_timeout_ms, const int remote_embedding_num_tries) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; - headers["remote_embedding_timeout_ms"] = std::to_string(remote_embedding_timeout_ms); - headers["remote_embedding_num_tries"] = std::to_string(remote_embedding_num_tries); std::string res; nlohmann::json req_body; req_body["input"] = text; @@ -172,12 +170,10 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; - for(auto& data : res_json["data"]) { outputs.push_back(embedding_res_t(data["embedding"].get>())); } - return outputs; } @@ -226,12 +222,10 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedding_timeout_ms, const int remote_embedding_num_tries) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; - headers["remote_embedding_timeout_ms"] = std::to_string(remote_embedding_timeout_ms); - headers["remote_embedding_num_tries"] = std::to_string(remote_embedding_num_tries); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -331,7 +325,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedding_timeout_ms, const int remote_embedding_num_tries) { +embedding_res_t GCPEmbedder::Embed(const std::string& text) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -340,8 +334,6 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_ std::unordered_map headers; headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; - headers["remote_embedding_timeout_ms"] = std::to_string(remote_embedding_timeout_ms); - headers["remote_embedding_num_tries"] = std::to_string(remote_embedding_num_tries); std::map res_headers; std::string res; From 01a17a697271e5f6c4d77a4ebbdbe75c973394c7 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 6 Jul 2023 11:48:59 +0300 Subject: [PATCH 03/15] Fix JSON parsing errors --- src/text_embedder_remote.cpp | 193 ++++++++++++++++++++++++++++++----- 1 file changed, 166 insertions(+), 27 deletions(-) diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index bb984e86..931da129 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -60,14 +60,24 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::string res; auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers); if (res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "OpenAI API error: " + res); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "OpenAI API error: " + res); } return Option(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - auto models_json = nlohmann::json::parse(res); + nlohmann::json models_json; + try { + models_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from OpenAI API."); + } bool found = false; // extract model name by removing "openai/" prefix auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name); @@ -92,14 +102,23 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, if (res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(embedding_res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(embedding_res); + } catch (const std::exception& e) { + return Option(400, "OpenAI API error: " + embedding_res); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "OpenAI API error: " + embedding_res); } return Option(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - - auto embedding = nlohmann::json::parse(embedding_res)["data"][0]["embedding"].get>(); + std::vector embedding; + try { + embedding = nlohmann::json::parse(embedding_res)["data"][0]["embedding"].get>(); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from OpenAI API."); + } num_dims = embedding.size(); return Option(true); } @@ -116,7 +135,12 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if (res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return embedding_res_t(400, "OpenAI API error: " + res); + } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; embedding_res["request"] = nlohmann::json::object(); @@ -129,8 +153,19 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { } return embedding_res_t(res_code, embedding_res); } + try { + embedding_res_t embedding_res = embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); + return embedding_res; + } catch (const std::exception& e) { + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + embedding_res["error"] = "Malformed response from OpenAI API."; - return embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); + return embedding_res_t(500, embedding_res); + } } std::vector OpenAIEmbedder::batch_embed(const std::vector& inputs) { @@ -148,7 +183,13 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + json_res = nlohmann::json::object(); + json_res["error"] = "OpenAI API error: " + res; + } LOG(INFO) << "OpenAI API error: " << json_res.dump(); nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; @@ -167,8 +208,24 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector{inputs[0]}; + embedding_res["error"] = "Malformed response from OpenAI API"; + std::vector outputs; + for(size_t i = 0; i < inputs.size(); i++) { + embedding_res["request"]["body"]["input"][0] = inputs[i]; + outputs.push_back(embedding_res_t(500, embedding_res)); + } + return outputs; + } std::vector outputs; for(auto& data : res_json["data"]) { outputs.push_back(embedding_res_t(data["embedding"].get>())); @@ -210,14 +267,23 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Google API error: " + res); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "Google API error: " + res); } return Option(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - num_dims = nlohmann::json::parse(res)["embedding"]["value"].get>().size(); + try { + num_dims = nlohmann::json::parse(res)["embedding"]["value"].get>().size(); + } catch (const std::exception& e) { + return Option(500, "Got malformed response from Google API."); + } return Option(true); } @@ -233,7 +299,13 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + nlohmann::json json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + json_res = nlohmann::json::object(); + json_res["error"] = res; + } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; embedding_res["request"] = nlohmann::json::object(); @@ -245,8 +317,17 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { } return embedding_res_t(res_code, embedding_res); } - - return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); + try { + return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); + } catch (const std::exception& e) { + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + embedding_res["error"] = "Malformed response from Google API."; + return embedding_res_t(500, embedding_res); + } } @@ -302,14 +383,23 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } return Option(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - - auto res_json = nlohmann::json::parse(res); + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } if(res_json.count("predictions") == 0 || res_json["predictions"].size() == 0 || res_json["predictions"][0].count("embeddings") == 0) { LOG(INFO) << "Invalid response from GCP API: " << res_json.dump(); return Option(400, "GCP API error: Invalid response"); @@ -355,7 +445,13 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + json_res = nlohmann::json::object(); + json_res["error"] = "Got malformed response from GCP API."; + } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; embedding_res["request"] = nlohmann::json::object(); @@ -364,11 +460,23 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { embedding_res["request"]["body"] = req_body; if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get(); + } else { + embedding_res["error"] = "Malformed response from GCP API."; } return embedding_res_t(res_code, embedding_res); } - - nlohmann::json res_json = nlohmann::json::parse(res); + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + nlohmann::json embedding_res = nlohmann::json::object(); + embedding_res["request"] = nlohmann::json::object(); + embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); + embedding_res["request"]["method"] = "POST"; + embedding_res["request"]["body"] = req_body; + embedding_res["error"] = "Malformed response from GCP API."; + return embedding_res_t(500, embedding_res); + } return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get>()); } @@ -416,7 +524,13 @@ std::vector GCPEmbedder::batch_embed(const std::vector GCPEmbedder::batch_embed(const std::vector(); + } else { + embedding_res["error"] = "Malformed response from GCP API."; } std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { @@ -432,8 +548,22 @@ std::vector GCPEmbedder::batch_embed(const std::vector outputs; + for(size_t i = 0; i < inputs.size(); i++) { + outputs.push_back(embedding_res_t(400, embedding_res)); + } + return outputs; + } std::vector outputs; for(const auto& prediction : res_json["predictions"]) { outputs.push_back(embedding_res_t(prediction["embeddings"]["values"].get>())); @@ -453,14 +583,23 @@ Option GCPEmbedder::generate_access_token(const std::string& refres auto res_code = call_remote_api("POST", GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } return Option(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - - nlohmann::json res_json = nlohmann::json::parse(res); + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } std::string access_token = res_json["access_token"].get(); return Option(access_token); From e614895ea4ec62f66ed87baa412140cba34562fd Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 6 Jul 2023 13:10:32 +0300 Subject: [PATCH 04/15] Add extra check for timeouts --- src/http_client.cpp | 3 +++ src/text_embedder_remote.cpp | 42 ++++++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/http_client.cpp b/src/http_client.cpp index c6e4b7fd..9b9a3239 100644 --- a/src/http_client.cpp +++ b/src/http_client.cpp @@ -156,6 +156,9 @@ long HttpClient::perform_curl(CURL *curl, std::map& re LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res); curl_easy_cleanup(curl); curl_slist_free_all(chunk); + if(res == CURLE_OPERATION_TIMEDOUT) { + return 408; + } return 500; } diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 931da129..f13f86ed 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -59,6 +59,11 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, headers["Authorization"] = "Bearer " + api_key; std::string res; auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers); + + if(res_code == 408) { + return Option(408, "OpenAI API timeout."); + } + if (res_code != 200) { nlohmann::json json_res; try { @@ -98,9 +103,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::string embedding_res; headers["Content-Type"] = "application/json"; - res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + if(res_code == 408) { + return Option(408, "OpenAI API timeout."); + } + if (res_code != 200) { nlohmann::json json_res; try { @@ -139,7 +148,7 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { try { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { - return embedding_res_t(400, "OpenAI API error: " + res); + } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; @@ -151,6 +160,9 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } return embedding_res_t(res_code, embedding_res); } try { @@ -188,7 +200,6 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector OpenAIEmbedder::batch_embed(const std::vector(); } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; @@ -271,11 +285,14 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, try { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { - return Option(400, "Google API error: " + res); + } + if(res_code == 408) { + return Option(408, "Google API timeout."); } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "Google API error: " + res); } + return Option(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } @@ -315,6 +332,9 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get(); } + if(res_code == 408) { + embedding_res["error"] = "Google API timeout."; + } return embedding_res_t(res_code, embedding_res); } try { @@ -389,6 +409,9 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns } catch (const std::exception& e) { return Option(400, "Got malformed response from GCP API."); } + if(json_res == 408) { + return Option(408, "GCP API timeout."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } @@ -463,6 +486,10 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } else { embedding_res["error"] = "Malformed response from GCP API."; } + + if(res_code == 408) { + embedding_res["error"] = "GCP API timeout."; + } return embedding_res_t(res_code, embedding_res); } nlohmann::json res_json; @@ -542,6 +569,10 @@ std::vector GCPEmbedder::batch_embed(const std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(res_code, embedding_res)); @@ -592,6 +623,9 @@ Option GCPEmbedder::generate_access_token(const std::string& refres if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } + if(res_code == 408) { + return Option(408, "GCP API timeout."); + } return Option(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } nlohmann::json res_json; From 2d4221c1f7f09b9d2e0e952ec8a9dbce7afcbca7 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 6 Jul 2023 20:33:48 +0530 Subject: [PATCH 05/15] Check for presence of key before getting from old doc during update. --- src/index.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/index.cpp b/src/index.cpp index 4e6ac914..155958ce 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6268,7 +6268,9 @@ void Index::get_doc_changes(const index_operation_t op, const tsl::htrie_map Date: Thu, 6 Jul 2023 21:06:18 +0530 Subject: [PATCH 06/15] Add test for null value + update of missing optional field. --- test/collection_specific_more_test.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index f53ce1a7..e3f82f69 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -1311,6 +1311,21 @@ TEST_F(CollectionSpecificMoreTest, UpdateArrayWithNullValue) { auto results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(0, results["found"].get()); + // update document with no value (optional field) with a null value + auto doc3 = R"({ + "id": "2" + })"_json; + + ASSERT_TRUE(coll1->add(doc3.dump(), CREATE).ok()); + results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["found"].get()); + + doc_update = R"({ + "id": "2", + "tags": null + })"_json; + ASSERT_TRUE(coll1->add(doc_update.dump(), UPDATE).ok()); + // via upsert doc_update = R"({ From e7887c5efa7879b020eaf9b173f85f91064ef372 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 6 Jul 2023 21:53:03 +0300 Subject: [PATCH 07/15] Update timeout response code to 408 in proxy --- src/http_proxy.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index eb562849..4ed4db5a 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -44,12 +44,12 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth auto res = call(url, method, body, headers); - if(res.status_code == 500){ + if(res.status_code == 408){ // retry res = call(url, method, body, headers); } - if(res.status_code == 500){ + if(res.status_code == 408){ nlohmann::json j; j["message"] = "Server error on remote server. Please try again later."; res.body = j.dump(); @@ -57,7 +57,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth // add to cache - if(res.status_code != 500){ + if(res.status_code != 408){ cache.insert(key, res); } From a769eeb0a758a6c29b553654bf86b327c951ed82 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 6 Jul 2023 22:02:02 +0300 Subject: [PATCH 08/15] Add error messages for JSON parse errors --- src/text_embedder_remote.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index f13f86ed..4e844871 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -148,6 +148,8 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { try { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { + json_res = nlohmann::json::object(); + json_res["error"] = "Malformed response from OpenAI API."; } nlohmann::json embedding_res = nlohmann::json::object(); @@ -200,6 +202,7 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, try { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { + json_res = nlohmann::json::object(); + json_res["error"] = "Malformed response from Google API."; } if(res_code == 408) { return Option(408, "Google API timeout."); @@ -321,7 +326,7 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { nlohmann::json json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { json_res = nlohmann::json::object(); - json_res["error"] = res; + json_res["error"] = "Malformed response from Google API." } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; @@ -473,7 +478,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { json_res = nlohmann::json::parse(res); } catch (const std::exception& e) { json_res = nlohmann::json::object(); - json_res["error"] = "Got malformed response from GCP API."; + json_res["error"] = "Malformed response from GCP API."; } nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; @@ -556,7 +561,7 @@ std::vector GCPEmbedder::batch_embed(const std::vector Date: Fri, 7 Jul 2023 12:26:54 +0530 Subject: [PATCH 09/15] Add upsert + get APIs for analytics rules. --- include/analytics_manager.h | 15 ++++-- include/core_api.h | 4 ++ src/analytics_manager.cpp | 43 +++++++++++---- src/collection_manager.cpp | 22 ++++---- src/core_api.cpp | 52 ++++++++++++++---- src/main/typesense_server.cpp | 2 + test/analytics_manager_test.cpp | 95 ++++++++++++++++++++++++++++++++- 7 files changed, 196 insertions(+), 37 deletions(-) diff --git a/include/analytics_manager.h b/include/analytics_manager.h index 353085b6..4207f7cf 100644 --- a/include/analytics_manager.h +++ b/include/analytics_manager.h @@ -24,10 +24,11 @@ private: void to_json(nlohmann::json& obj) const { obj["name"] = name; + obj["type"] = POPULAR_QUERIES_TYPE; obj["params"] = nlohmann::json::object(); - obj["params"]["suggestion_collection"] = suggestion_collection; - obj["params"]["query_collections"] = query_collections; obj["params"]["limit"] = limit; + obj["params"]["source"]["collections"] = query_collections; + obj["params"]["destination"]["collection"] = suggestion_collection; } }; @@ -48,7 +49,9 @@ private: Option remove_popular_queries_index(const std::string& name); - Option create_popular_queries_index(nlohmann::json &payload, bool write_to_disk); + Option create_popular_queries_index(nlohmann::json &payload, + bool upsert, + bool write_to_disk); public: @@ -69,12 +72,14 @@ public: Option list_rules(); - Option create_rule(nlohmann::json& payload, bool write_to_disk = true); + Option get_rule(const std::string& name); + + Option create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk); Option remove_rule(const std::string& name); void add_suggestion(const std::string& query_collection, - std::string& query, const bool live_query, const std::string& user_id); + std::string& query, bool live_query, const std::string& user_id); void stop(); diff --git a/include/core_api.h b/include/core_api.h index 780d3d1b..a13b1db8 100644 --- a/include/core_api.h +++ b/include/core_api.h @@ -147,8 +147,12 @@ bool post_create_event(const std::shared_ptr& req, const std::shared_p bool get_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); +bool get_analytics_rule(const std::shared_ptr& req, const std::shared_ptr& res); + bool post_create_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); +bool put_upsert_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); + bool del_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); // Misc helpers diff --git a/src/analytics_manager.cpp b/src/analytics_manager.cpp index 4851dd46..dcb6a64a 100644 --- a/src/analytics_manager.cpp +++ b/src/analytics_manager.cpp @@ -5,7 +5,7 @@ #include "http_client.h" #include "collection_manager.h" -Option AnalyticsManager::create_rule(nlohmann::json& payload, bool write_to_disk) { +Option AnalyticsManager::create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk) { /* Sample payload: @@ -37,16 +37,23 @@ Option AnalyticsManager::create_rule(nlohmann::json& payload, bool write_t } if(payload["type"] == POPULAR_QUERIES_TYPE) { - return create_popular_queries_index(payload, write_to_disk); + return create_popular_queries_index(payload, upsert, write_to_disk); } return Option(400, "Invalid type."); } -Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool write_to_disk) { +Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool upsert, bool write_to_disk) { // params and name are validated upstream - const auto& params = payload["params"]; const std::string& suggestion_config_name = payload["name"].get(); + bool already_exists = suggestion_configs.find(suggestion_config_name) != suggestion_configs.end(); + + if(!upsert && already_exists) { + return Option(400, "There's already another configuration with the name `" + + suggestion_config_name + "`."); + } + + const auto& params = payload["params"]; if(!params.contains("source") || !params["source"].is_object()) { return Option(400, "Bad or missing source."); @@ -56,18 +63,12 @@ Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payl return Option(400, "Bad or missing destination."); } - size_t limit = 1000; if(params.contains("limit") && params["limit"].is_number_integer()) { limit = params["limit"].get(); } - if(suggestion_configs.find(suggestion_config_name) != suggestion_configs.end()) { - return Option(400, "There's already another configuration with the name `" + - suggestion_config_name + "`."); - } - if(!params["source"].contains("collections") || !params["source"]["collections"].is_array()) { return Option(400, "Must contain a valid list of source collections."); } @@ -93,6 +94,14 @@ Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payl std::unique_lock lock(mutex); + if(already_exists) { + // remove the previous configuration with same name (upsert) + Option remove_op = remove_popular_queries_index(suggestion_config_name); + if(!remove_op.ok()) { + return Option(500, "Error erasing the existing configuration.");; + } + } + suggestion_configs.emplace(suggestion_config_name, suggestion_config); for(const auto& query_coll: suggestion_config.query_collections) { @@ -130,13 +139,25 @@ Option AnalyticsManager::list_rules() { for(const auto& suggestion_config: suggestion_configs) { nlohmann::json rule; suggestion_config.second.to_json(rule); - rule["type"] = POPULAR_QUERIES_TYPE; rules["rules"].push_back(rule); } return Option(rules); } +Option AnalyticsManager::get_rule(const string& name) { + nlohmann::json rule; + std::unique_lock lock(mutex); + + auto suggestion_config_it = suggestion_configs.find(name); + if(suggestion_config_it == suggestion_configs.end()) { + return Option(404, "Rule not found."); + } + + suggestion_config_it->second.to_json(rule); + return Option(rule); +} + Option AnalyticsManager::remove_rule(const string &name) { std::unique_lock lock(mutex); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f8d32b4e..6c281595 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -285,6 +285,17 @@ Option CollectionManager::load(const size_t collection_batch_size, const s iter->Next(); } + // restore query suggestions configs + std::vector analytics_config_jsons; + store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX, + std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`", + analytics_config_jsons); + + for(const auto& analytics_config_json: analytics_config_jsons) { + nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json); + AnalyticsManager::get_instance().create_rule(analytics_config, false, false); + } + delete iter; LOG(INFO) << "Loaded " << num_collections << " collection(s)."; @@ -1312,17 +1323,6 @@ Option CollectionManager::load_collection(const nlohmann::json &collection collection->add_synonym(collection_synonym, false); } - // restore query suggestions configs - std::vector analytics_config_jsons; - cm.store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX, - std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`", - analytics_config_jsons); - - for(const auto& analytics_config_json: analytics_config_jsons) { - nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json); - AnalyticsManager::get_instance().create_rule(analytics_config, false); - } - // Fetch records from the store and re-create memory index const std::string seq_id_prefix = collection->get_seq_id_collection_prefix(); std::string upper_bound_key = collection->get_seq_id_collection_prefix() + "`"; // cannot inline this diff --git a/src/core_api.cpp b/src/core_api.cpp index 5d04ad07..0c9cd0e5 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -2078,13 +2078,25 @@ bool post_create_event(const std::shared_ptr& req, const std::shared_p bool get_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { auto rules_op = AnalyticsManager::get_instance().list_rules(); - if(rules_op.ok()) { - res->set_200(rules_op.get().dump()); - return true; + if(!rules_op.ok()) { + res->set(rules_op.code(), rules_op.error()); + return false; } - res->set(rules_op.code(), rules_op.error()); - return false; + res->set_200(rules_op.get().dump()); + return true; +} + +bool get_analytics_rule(const std::shared_ptr& req, const std::shared_ptr& res) { + auto rules_op = AnalyticsManager::get_instance().get_rule(req->params["name"]); + + if(!rules_op.ok()) { + res->set(rules_op.code(), rules_op.error()); + return false; + } + + res->set_200(rules_op.get().dump()); + return true; } bool post_create_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { @@ -2098,7 +2110,7 @@ bool post_create_analytics_rules(const std::shared_ptr& req, const std return false; } - auto op = AnalyticsManager::get_instance().create_rule(req_json); + auto op = AnalyticsManager::get_instance().create_rule(req_json, false, true); if(!op.ok()) { res->set(op.code(), op.error()); @@ -2109,6 +2121,29 @@ bool post_create_analytics_rules(const std::shared_ptr& req, const std return true; } +bool put_upsert_analytics_rules(const std::shared_ptr &req, const std::shared_ptr &res) { + nlohmann::json req_json; + + try { + req_json = nlohmann::json::parse(req->body); + } catch(const std::exception& e) { + LOG(ERROR) << "JSON error: " << e.what(); + res->set_400("Bad JSON."); + return false; + } + + req_json["name"] = req->params["name"]; + auto op = AnalyticsManager::get_instance().create_rule(req_json, true, true); + + if(!op.ok()) { + res->set(op.code(), op.error()); + return false; + } + + res->set_200(req_json.dump()); + return true; +} + bool del_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { auto op = AnalyticsManager::get_instance().remove_rule(req->params["name"]); if(!op.ok()) { @@ -2116,11 +2151,10 @@ bool del_analytics_rules(const std::shared_ptr& req, const std::shared return false; } - res->set_200(R"({"ok": true)"); + res->set_200(R"({"ok": true})"); return true; } - bool post_proxy(const std::shared_ptr& req, const std::shared_ptr& res) { HttpProxy& proxy = HttpProxy::get_instance(); @@ -2180,4 +2214,4 @@ bool post_proxy(const std::shared_ptr& req, const std::shared_ptrset_200(response.body); return true; -} \ No newline at end of file +} diff --git a/src/main/typesense_server.cpp b/src/main/typesense_server.cpp index d4df4f2a..8a47e1bb 100644 --- a/src/main/typesense_server.cpp +++ b/src/main/typesense_server.cpp @@ -70,7 +70,9 @@ void master_server_routes() { // analytics server->get("/analytics/rules", get_analytics_rules); + server->get("/analytics/rules/:name", get_analytics_rule); server->post("/analytics/rules", post_create_analytics_rules); + server->put("/analytics/rules/:name", put_upsert_analytics_rules); server->del("/analytics/rules/:name", del_analytics_rules); server->post("/analytics/events", post_create_event); diff --git a/test/analytics_manager_test.cpp b/test/analytics_manager_test.cpp index 0030f221..a86bf809 100644 --- a/test/analytics_manager_test.cpp +++ b/test/analytics_manager_test.cpp @@ -78,7 +78,7 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) { } })"_json; - auto create_op = analyticsManager.create_rule(analytics_rule); + auto create_op = analyticsManager.create_rule(analytics_rule, false, true); ASSERT_TRUE(create_op.ok()); std::string q = "foobar"; @@ -88,4 +88,97 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) { auto userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"]; ASSERT_EQ(1, userQueries.size()); ASSERT_EQ("foobar", userQueries[0].query); + + // add another query which is more popular + q = "buzzfoo"; + analyticsManager.add_suggestion("titles", q, true, "1"); + analyticsManager.add_suggestion("titles", q, true, "2"); + analyticsManager.add_suggestion("titles", q, true, "3"); + + popularQueries = analyticsManager.get_popular_queries(); + userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"]; + ASSERT_EQ(2, userQueries.size()); + ASSERT_EQ("foobar", userQueries[0].query); + ASSERT_EQ("buzzfoo", userQueries[1].query); } + +TEST_F(AnalyticsManagerTest, GetAndDeleteSuggestions) { + nlohmann::json analytics_rule = R"({ + "name": "top_search_queries", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queries" + } + } + })"_json; + + auto create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_TRUE(create_op.ok()); + + analytics_rule = R"({ + "name": "top_search_queries2", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queries" + } + } + })"_json; + + create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_TRUE(create_op.ok()); + + auto rules = analyticsManager.list_rules().get()["rules"]; + ASSERT_EQ(2, rules.size()); + + ASSERT_TRUE(analyticsManager.get_rule("top_search_queries").ok()); + ASSERT_TRUE(analyticsManager.get_rule("top_search_queries2").ok()); + + auto missing_rule_op = analyticsManager.get_rule("top_search_queriesX"); + ASSERT_FALSE(missing_rule_op.ok()); + ASSERT_EQ(404, missing_rule_op.code()); + ASSERT_EQ("Rule not found.", missing_rule_op.error()); + + // upsert rule that already exists + analytics_rule = R"({ + "name": "top_search_queries2", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queriesUpdated" + } + } + })"_json; + create_op = analyticsManager.create_rule(analytics_rule, true, true); + ASSERT_TRUE(create_op.ok()); + auto existing_rule = analyticsManager.get_rule("top_search_queries2").get(); + ASSERT_EQ("top_queriesUpdated", existing_rule["params"]["destination"]["collection"].get()); + + // reject when upsert is not enabled + create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_FALSE(create_op.ok()); + ASSERT_EQ("There's already another configuration with the name `top_search_queries2`.", create_op.error()); + + // try deleting both rules + analyticsManager.remove_rule("top_search_queries"); + analyticsManager.remove_rule("top_search_queries2"); + + missing_rule_op = analyticsManager.get_rule("top_search_queries"); + ASSERT_FALSE(missing_rule_op.ok()); + missing_rule_op = analyticsManager.get_rule("top_search_queries2"); + ASSERT_FALSE(missing_rule_op.ok()); +} + From 29830e1d58825175fab4548e2faeb93495241e86 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 13:41:38 +0300 Subject: [PATCH 10/15] Create virtual functions to generate error JSONs from responses --- include/text_embedder_remote.h | 4 + src/http_proxy.cpp | 4 +- src/text_embedder_remote.cpp | 242 +++++++++++++-------------------- 3 files changed, 100 insertions(+), 150 deletions(-) diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 652c0b61..46c32778 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -27,6 +27,7 @@ class RemoteEmbedder { static Option validate_string_properties(const nlohmann::json& model_config, const std::vector& properties); static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; + virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; public: virtual embedding_res_t Embed(const std::string& text) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; @@ -44,6 +45,7 @@ class OpenAIEmbedder : public RemoteEmbedder { std::string openai_model_path; static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); @@ -59,6 +61,7 @@ class GoogleEmbedder : public RemoteEmbedder { inline static constexpr short GOOGLE_EMBEDDING_DIM = 768; inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key="; std::string google_api_key; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); @@ -84,6 +87,7 @@ class GCPEmbedder : public RemoteEmbedder { static std::string get_gcp_embedding_url(const std::string& project_id, const std::string& model_name) { return GCP_EMBEDDING_BASE_URL + project_id + GCP_EMBEDDING_PATH + model_name + GCP_EMBEDDING_PREDICT; } + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index 4ed4db5a..518cd45b 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -44,7 +44,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth auto res = call(url, method, body, headers); - if(res.status_code == 408){ + if(res.status_code >= 500 || res.status_code == 408){ // retry res = call(url, method, body, headers); } @@ -57,7 +57,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth // add to cache - if(res.status_code != 408){ + if(res.status_code == 200){ cache.insert(key, res); } diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 4e844871..a54315e1 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -144,41 +144,13 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if (res_code != 200) { - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from OpenAI API."; - - } - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); - } - if(res_code == 408) { - embedding_res["error"] = "OpenAI API timeout."; - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } try { embedding_res_t embedding_res = embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); return embedding_res; } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["error"] = "Malformed response from OpenAI API."; - - return embedding_res_t(500, embedding_res); + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } } @@ -196,46 +168,19 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; - - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from OpenAI API."; - } - LOG(INFO) << "OpenAI API error: " << json_res.dump(); - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["request"]["body"]["input"] = std::vector{inputs[0]}; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); - } - if(res_code == 408) { - embedding_res["error"] = "OpenAI API timeout."; - } - + nlohmann::json embedding_res = get_error_json(req_body, res_code, res); for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; outputs.push_back(embedding_res_t(res_code, embedding_res)); } return outputs; } + nlohmann::json res_json; try { res_json = nlohmann::json::parse(res); } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["request"]["body"]["input"] = std::vector{inputs[0]}; - embedding_res["error"] = "Malformed response from OpenAI API"; + nlohmann::json embedding_res = get_error_json(req_body, res_code, res); std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; @@ -252,6 +197,36 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector 0 && embedding_res["request"]["body"]["input"].get>().size() > 1) { + auto vec = embedding_res["request"]["body"]["input"].get>(); + vec.resize(1); + embedding_res["request"]["body"]["input"] = vec; + } + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); + } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } + + return embedding_res; +} + + GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) { } @@ -321,37 +296,13 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res; - try { - nlohmann::json json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from Google API." - } - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get(); - } - if(res_code == 408) { - embedding_res["error"] = "Google API timeout."; - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } + try { return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["error"] = "Malformed response from Google API."; - return embedding_res_t(500, embedding_res); + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } } @@ -366,6 +317,30 @@ std::vector GoogleEmbedder::batch_embed(const std::vector(); + } + if(res_code == 408) { + embedding_res["error"] = "Google API timeout."; + } + + return embedding_res; +} + GCPEmbedder::GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) : @@ -473,46 +448,17 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } if(res_code != 200) { - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from GCP API."; - } - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get(); - } else { - embedding_res["error"] = "Malformed response from GCP API."; - } - - if(res_code == 408) { - embedding_res["error"] = "GCP API timeout."; - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } nlohmann::json res_json; try { res_json = nlohmann::json::parse(res); } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["error"] = "Malformed response from GCP API."; - return embedding_res_t(500, embedding_res); + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get>()); } - std::vector GCPEmbedder::batch_embed(const std::vector& inputs) { // GCP API has a limit of 5 instances per request if(inputs.size() > 5) { @@ -556,28 +502,7 @@ std::vector GCPEmbedder::batch_embed(const std::vector(); - } else { - embedding_res["error"] = "Malformed response from GCP API."; - } - - if(res_code == 408) { - embedding_res["error"] = "GCP API timeout."; - } + auto embedding_res = get_error_json(req_body, res_code, res); std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(res_code, embedding_res)); @@ -588,12 +513,7 @@ std::vector GCPEmbedder::batch_embed(const std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(400, embedding_res)); @@ -608,6 +528,34 @@ std::vector GCPEmbedder::batch_embed(const std::vector(); + } else { + embedding_res["error"] = "Malformed response from GCP API."; + } + + if(res_code == 408) { + embedding_res["error"] = "GCP API timeout."; + } + + return embedding_res; +} + Option GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) { std::unordered_map headers; headers["Content-Type"] = "application/x-www-form-urlencoded"; @@ -643,5 +591,3 @@ Option GCPEmbedder::generate_access_token(const std::string& refres return Option(access_token); } - - From 10c0070fecd635b86f362aabc726540e2538bbc7 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 14:45:53 +0300 Subject: [PATCH 11/15] Add test for catching partial JSON --- include/text_embedder_remote.h | 9 ++++---- test/collection_test.cpp | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 46c32778..18892afa 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -27,8 +27,8 @@ class RemoteEmbedder { static Option validate_string_properties(const nlohmann::json& model_config, const std::vector& properties); static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; - virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; public: + virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; virtual embedding_res_t Embed(const std::string& text) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { @@ -45,12 +45,12 @@ class OpenAIEmbedder : public RemoteEmbedder { std::string openai_model_path; static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; - nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -61,12 +61,13 @@ class GoogleEmbedder : public RemoteEmbedder { inline static constexpr short GOOGLE_EMBEDDING_DIM = 768; inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key="; std::string google_api_key; - nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; + public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -87,13 +88,13 @@ class GCPEmbedder : public RemoteEmbedder { static std::string get_gcp_embedding_url(const std::string& project_id, const std::string& model_name) { return GCP_EMBEDDING_BASE_URL + project_id + GCP_EMBEDDING_PATH + model_name + GCP_EMBEDDING_PREDICT; } - nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index e49e89f6..9e054d0a 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5176,4 +5176,45 @@ TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) { ASSERT_FALSE(get_op.get()["embedding"].is_null()); ASSERT_EQ(384, get_op.get()["embedding"].size()); +} + + +TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) { + std::string partial_json = R"({ + "results": [ + { + "embedding": [ + 0.0, + 0.0, + 0.0 + ], + "text": "butter" + }, + { + "embedding": [ + 0.0, + 0.0, + 0.0 + ], + "text": "butterball" + }, + { + "embedding": [ + 0.0, + 0.0)"; + + nlohmann::json req_body = R"({ + "inputs": [ + "butter", + "butterball", + "butterfly" + ] + })"_json; + + OpenAIEmbedder embedder("", ""); + + auto res = embedder.get_error_json(req_body, 200, partial_json); + + ASSERT_EQ(res["response"]["error"], "Malformed response from OpenAI API."); + ASSERT_EQ(res["request"]["body"], req_body); } \ No newline at end of file From 6f7efa5d729914c03a0e9c6e61ed465c651a8e00 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 17:01:29 +0300 Subject: [PATCH 12/15] Add timeout and num retries as search parameter for remote embedding --- include/collection.h | 4 +++- include/http_proxy.h | 6 ++++-- include/text_embedder.h | 2 +- include/text_embedder_remote.h | 8 +++---- src/collection.cpp | 6 ++++-- src/collection_manager.cpp | 13 +++++++++++- src/http_proxy.cpp | 39 ++++++++++++++++++++++++---------- src/text_embedder.cpp | 4 ++-- src/text_embedder_remote.cpp | 14 ++++++++---- test/core_api_utils_test.cpp | 23 ++++++++++++++++++++ 10 files changed, 91 insertions(+), 28 deletions(-) diff --git a/include/collection.h b/include/collection.h index 2f79bc68..ac987606 100644 --- a/include/collection.h +++ b/include/collection.h @@ -464,7 +464,9 @@ public: const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, const size_t page_offset = UINT32_MAX, - const size_t vector_query_hits = 250) const; + const size_t vector_query_hits = 250, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_retry = 2) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/http_proxy.h b/include/http_proxy.h index 851daa91..13519ef2 100644 --- a/include/http_proxy.h +++ b/include/http_proxy.h @@ -32,11 +32,13 @@ class HttpProxy { void operator=(const HttpProxy&) = delete; HttpProxy(HttpProxy&&) = delete; void operator=(HttpProxy&&) = delete; - http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers); + http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers); private: HttpProxy(); ~HttpProxy() = default; - http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}); + http_proxy_res_t call(const std::string& url, const std::string& method, + const std::string& body = "", const std::unordered_map& headers = {}, + const size_t timeout_ms = 30000); // lru cache for http requests diff --git a/include/text_embedder.h b/include/text_embedder.h index cff7c7c7..e1f73287 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_max_retries = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 18892afa..ffe8d6a0 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -29,7 +29,7 @@ class RemoteEmbedder { static inline ReplicationState* raft_server = nullptr; public: virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; - virtual embedding_res_t Embed(const std::string& text) = 0; + virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -65,7 +65,7 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -92,7 +92,7 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/src/collection.cpp b/src/collection.cpp index c55bee7d..b9bb6a3b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1107,7 +1107,9 @@ Option Collection::search(std::string raw_query, const size_t facet_sample_percent, const size_t facet_sample_threshold, const size_t page_offset, - const size_t vector_query_hits) const { + const size_t vector_query_hits, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_retry) const { std::shared_lock lock(mutex); @@ -1236,7 +1238,7 @@ Option Collection::search(std::string raw_query, } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query); + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retry); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 5323e1b0..91a348b7 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -671,6 +671,9 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY = "vector_query"; const char *VECTOR_QUERY_HITS = "vector_query_hits"; + const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; + const char* REMOTE_EMBEDDING_NUM_RETRY = "remote_embedding_num_retry"; + const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -824,6 +827,9 @@ Option CollectionManager::do_search(std::map& re text_match_type_t match_type = max_score; size_t vector_query_hits = 250; + size_t remote_embedding_timeout_ms = 30000; + size_t remote_embedding_num_retry = 2; + size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -850,6 +856,8 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, + {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, + {REMOTE_EMBEDDING_NUM_RETRY, &remote_embedding_num_retry}, }; std::unordered_map str_values = { @@ -1070,7 +1078,10 @@ Option CollectionManager::do_search(std::map& re match_type, facet_sample_percent, facet_sample_threshold, - offset + offset, + vector_query_hits, + remote_embedding_timeout_ms, + remote_embedding_num_retry ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index 518cd45b..23acdef5 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -8,17 +8,19 @@ HttpProxy::HttpProxy() : cache(30s){ } -http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, + const std::string& body, const std::unordered_map& headers, + const size_t timeout_ms) { HttpClient& client = HttpClient::get_instance(); http_proxy_res_t res; if(method == "GET") { - res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms); } else if(method == "POST") { - res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms); } else if(method == "PUT") { - res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000); + res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms); } else if(method == "DELETE") { - res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000); + res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms); } else { res.status_code = 400; nlohmann::json j; @@ -29,11 +31,25 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth } -http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers) { // check if url is in cache uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); + + size_t timeout_ms = 30000; + size_t num_retry = 2; + + if(headers.find("timeout_ms") != headers.end()){ + timeout_ms = std::stoul(headers.at("timeout_ms")); + headers.erase("timeout_ms"); + } + + if(headers.find("num_retry") != headers.end()){ + num_retry = std::stoul(headers.at("num_retry")); + headers.erase("num_retry"); + } + for(auto& header : headers){ key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size())); key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size())); @@ -42,11 +58,13 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth return cache[key]; } - auto res = call(url, method, body, headers); + http_proxy_res_t res; + for(size_t i = 0; i < num_retry; i++){ + res = call(url, method, body, headers, timeout_ms); - if(res.status_code >= 500 || res.status_code == 408){ - // retry - res = call(url, method, body, headers); + if(res.status_code != 408 && res.status_code < 500){ + break; + } } if(res.status_code == 408){ @@ -54,7 +72,6 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth j["message"] = "Server error on remote server. Please try again later."; res.body = j.dump(); } - // add to cache if(res.status_code == 200){ diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b2a67607..0e1e9385 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text); + return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedder_max_retries); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index a54315e1..0155a1ac 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -132,14 +132,16 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_retry"] = std::to_string(remote_embedder_num_retry); std::string res; nlohmann::json req_body; - req_body["input"] = text; + req_body["input"] = std::vector{text}; // remove "openai/" prefix req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); @@ -285,10 +287,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_retry"] = std::to_string(remote_embedder_num_retry); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -418,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -427,6 +431,8 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { std::unordered_map headers; headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_retry"] = std::to_string(remote_embedder_num_retry); std::map res_headers; std::string res; diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index 95af7546..57e88db3 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -1116,4 +1116,27 @@ TEST_F(CoreAPIUtilsTest, TestProxyInvalid) { ASSERT_EQ(400, resp->status_code); ASSERT_EQ("Headers must be a JSON object.", nlohmann::json::parse(resp->body)["message"]); +} + + + +TEST_F(CoreAPIUtilsTest, TestProxyTimeout) { + nlohmann::json body; + + auto req = std::make_shared(); + auto resp = std::make_shared(nullptr); + + // test with url as empty string + body["url"] = "https://typesense.org/docs/"; + body["method"] = "GET"; + body["headers"] = nlohmann::json::object(); + body["headers"]["timeout_ms"] = "1"; + body["headers"]["num_retry"] = "1"; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(408, resp->status_code); + ASSERT_EQ("Server error on remote server. Please try again later.", nlohmann::json::parse(resp->body)["message"]); } \ No newline at end of file From 8b1aa13ffe9de938786a28e9dfe8042829cf3105 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 17:11:25 +0300 Subject: [PATCH 13/15] Change param name to 'remote_embedding_num_try' --- include/collection.h | 2 +- include/text_embedder.h | 2 +- src/collection.cpp | 9 +++++++-- src/collection_manager.cpp | 8 ++++---- src/http_proxy.cpp | 10 +++++----- src/text_embedder_remote.cpp | 12 ++++++------ 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/include/collection.h b/include/collection.h index ac987606..dcc3a50b 100644 --- a/include/collection.h +++ b/include/collection.h @@ -466,7 +466,7 @@ public: const size_t page_offset = UINT32_MAX, const size_t vector_query_hits = 250, const size_t remote_embedding_timeout_ms = 30000, - const size_t remote_embedding_num_retry = 2) const; + const size_t remote_embedding_num_try = 2) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/text_embedder.h b/include/text_embedder.h index e1f73287..520c7c1d 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_max_retries = 2); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_try = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/src/collection.cpp b/src/collection.cpp index b9bb6a3b..c650aaea 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1109,7 +1109,7 @@ Option Collection::search(std::string raw_query, const size_t page_offset, const size_t vector_query_hits, const size_t remote_embedding_timeout_ms, - const size_t remote_embedding_num_retry) const { + const size_t remote_embedding_num_try) const { std::shared_lock lock(mutex); @@ -1235,10 +1235,15 @@ Option Collection::search(std::string raw_query, std::string error = "Prefix search is not supported for remote embedders. Please set `prefix=false` as an additional search parameter to disable prefix searching."; return Option(400, error); } + + if(remote_embedding_num_try == 0) { + std::string error = "`remote-embedding-num-try` must be greater than 0."; + return Option(400, error); + } } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retry); + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_try); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 91a348b7..aa2c26dc 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -672,7 +672,7 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY_HITS = "vector_query_hits"; const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; - const char* REMOTE_EMBEDDING_NUM_RETRY = "remote_embedding_num_retry"; + const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try"; const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -828,7 +828,7 @@ Option CollectionManager::do_search(std::map& re size_t vector_query_hits = 250; size_t remote_embedding_timeout_ms = 30000; - size_t remote_embedding_num_retry = 2; + size_t remote_embedding_num_try = 2; size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -857,7 +857,7 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, - {REMOTE_EMBEDDING_NUM_RETRY, &remote_embedding_num_retry}, + {REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try}, }; std::unordered_map str_values = { @@ -1081,7 +1081,7 @@ Option CollectionManager::do_search(std::map& re offset, vector_query_hits, remote_embedding_timeout_ms, - remote_embedding_num_retry + remote_embedding_num_try ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index 23acdef5..3134841a 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -38,16 +38,16 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); size_t timeout_ms = 30000; - size_t num_retry = 2; + size_t num_try = 2; if(headers.find("timeout_ms") != headers.end()){ timeout_ms = std::stoul(headers.at("timeout_ms")); headers.erase("timeout_ms"); } - if(headers.find("num_retry") != headers.end()){ - num_retry = std::stoul(headers.at("num_retry")); - headers.erase("num_retry"); + if(headers.find("num_try") != headers.end()){ + num_try = std::stoul(headers.at("num_try")); + headers.erase("num_try"); } for(auto& header : headers){ @@ -59,7 +59,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth } http_proxy_res_t res; - for(size_t i = 0; i < num_retry; i++){ + for(size_t i = 0; i < num_try; i++){ res = call(url, method, body, headers, timeout_ms); if(res.status_code != 408 && res.status_code < 500){ diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 0155a1ac..13a73575 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -132,13 +132,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_retry"] = std::to_string(remote_embedder_num_retry); + headers["num_try"] = std::to_string(remote_embedder_num_try); std::string res; nlohmann::json req_body; req_body["input"] = std::vector{text}; @@ -287,12 +287,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_retry"] = std::to_string(remote_embedder_num_retry); + headers["num_try"] = std::to_string(remote_embedder_num_try); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -422,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -432,7 +432,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_ headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_retry"] = std::to_string(remote_embedder_num_retry); + headers["num_try"] = std::to_string(remote_embedder_num_try); std::map res_headers; std::string res; From a9de01f16184bf2a42ed1ec018daa7e3c5931647 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 18:33:53 +0300 Subject: [PATCH 14/15] Standardize variable names --- include/text_embedder.h | 2 +- include/text_embedder_remote.h | 8 ++++---- src/text_embedder.cpp | 4 ++-- src/text_embedder_remote.cpp | 12 ++++++------ 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/text_embedder.h b/include/text_embedder.h index 520c7c1d..e8f1de57 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_try = 2); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index ffe8d6a0..b72b26d5 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -29,7 +29,7 @@ class RemoteEmbedder { static inline ReplicationState* raft_server = nullptr; public: virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; - virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) = 0; + virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -65,7 +65,7 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -92,7 +92,7 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 0e1e9385..2055ea77 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text, remote_embedder_timeout_ms, remote_embedder_max_retries); + return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedding_num_try); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 13a73575..f93c6713 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -132,13 +132,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_try"] = std::to_string(remote_embedder_num_try); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::string res; nlohmann::json req_body; req_body["input"] = std::vector{text}; @@ -287,12 +287,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_try"] = std::to_string(remote_embedder_num_try); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -422,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -432,7 +432,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_ headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_try"] = std::to_string(remote_embedder_num_try); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::map res_headers; std::string res; From 9b8754b13d12d461c647986a34bf5d53cb8a4d0f Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sat, 8 Jul 2023 06:41:00 +0530 Subject: [PATCH 15/15] Return name of rule as deletion response. --- src/core_api.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/core_api.cpp b/src/core_api.cpp index 0c9cd0e5..b2b993c8 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -2151,7 +2151,10 @@ bool del_analytics_rules(const std::shared_ptr& req, const std::shared return false; } - res->set_200(R"({"ok": true})"); + nlohmann::json res_json; + res_json["name"] = req->params["name"]; + + res->set_200(res_json.dump()); return true; }