diff --git a/include/analytics_manager.h b/include/analytics_manager.h index 353085b6..4207f7cf 100644 --- a/include/analytics_manager.h +++ b/include/analytics_manager.h @@ -24,10 +24,11 @@ private: void to_json(nlohmann::json& obj) const { obj["name"] = name; + obj["type"] = POPULAR_QUERIES_TYPE; obj["params"] = nlohmann::json::object(); - obj["params"]["suggestion_collection"] = suggestion_collection; - obj["params"]["query_collections"] = query_collections; obj["params"]["limit"] = limit; + obj["params"]["source"]["collections"] = query_collections; + obj["params"]["destination"]["collection"] = suggestion_collection; } }; @@ -48,7 +49,9 @@ private: Option remove_popular_queries_index(const std::string& name); - Option create_popular_queries_index(nlohmann::json &payload, bool write_to_disk); + Option create_popular_queries_index(nlohmann::json &payload, + bool upsert, + bool write_to_disk); public: @@ -69,12 +72,14 @@ public: Option list_rules(); - Option create_rule(nlohmann::json& payload, bool write_to_disk = true); + Option get_rule(const std::string& name); + + Option create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk); Option remove_rule(const std::string& name); void add_suggestion(const std::string& query_collection, - std::string& query, const bool live_query, const std::string& user_id); + std::string& query, bool live_query, const std::string& user_id); void stop(); diff --git a/include/collection.h b/include/collection.h index 9fc304fc..33d97850 100644 --- a/include/collection.h +++ b/include/collection.h @@ -465,7 +465,9 @@ public: const size_t facet_sample_threshold = 0, const size_t page_offset = 0, facet_index_type_t facet_index_type = HASH, - const size_t vector_query_hits = 250) const; + const size_t vector_query_hits = 250, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_try = 2) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/core_api.h b/include/core_api.h index 780d3d1b..a13b1db8 100644 --- a/include/core_api.h +++ b/include/core_api.h @@ -147,8 +147,12 @@ bool post_create_event(const std::shared_ptr& req, const std::shared_p bool get_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); +bool get_analytics_rule(const std::shared_ptr& req, const std::shared_ptr& res); + bool post_create_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); +bool put_upsert_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); + bool del_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); // Misc helpers diff --git a/include/http_proxy.h b/include/http_proxy.h index 851daa91..13519ef2 100644 --- a/include/http_proxy.h +++ b/include/http_proxy.h @@ -32,11 +32,13 @@ class HttpProxy { void operator=(const HttpProxy&) = delete; HttpProxy(HttpProxy&&) = delete; void operator=(HttpProxy&&) = delete; - http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers); + http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers); private: HttpProxy(); ~HttpProxy() = default; - http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}); + http_proxy_res_t call(const std::string& url, const std::string& method, + const std::string& body = "", const std::unordered_map& headers = {}, + const size_t timeout_ms = 30000); // lru cache for http requests diff --git a/include/text_embedder.h b/include/text_embedder.h index cff7c7c7..e8f1de57 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 652c0b61..b72b26d5 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -28,7 +28,8 @@ class RemoteEmbedder { static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; public: - virtual embedding_res_t Embed(const std::string& text) = 0; + virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; + virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -47,8 +48,9 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -59,11 +61,13 @@ class GoogleEmbedder : public RemoteEmbedder { inline static constexpr short GOOGLE_EMBEDDING_DIM = 768; inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key="; std::string google_api_key; + public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -88,8 +92,9 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/src/analytics_manager.cpp b/src/analytics_manager.cpp index 4851dd46..dcb6a64a 100644 --- a/src/analytics_manager.cpp +++ b/src/analytics_manager.cpp @@ -5,7 +5,7 @@ #include "http_client.h" #include "collection_manager.h" -Option AnalyticsManager::create_rule(nlohmann::json& payload, bool write_to_disk) { +Option AnalyticsManager::create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk) { /* Sample payload: @@ -37,16 +37,23 @@ Option AnalyticsManager::create_rule(nlohmann::json& payload, bool write_t } if(payload["type"] == POPULAR_QUERIES_TYPE) { - return create_popular_queries_index(payload, write_to_disk); + return create_popular_queries_index(payload, upsert, write_to_disk); } return Option(400, "Invalid type."); } -Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool write_to_disk) { +Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool upsert, bool write_to_disk) { // params and name are validated upstream - const auto& params = payload["params"]; const std::string& suggestion_config_name = payload["name"].get(); + bool already_exists = suggestion_configs.find(suggestion_config_name) != suggestion_configs.end(); + + if(!upsert && already_exists) { + return Option(400, "There's already another configuration with the name `" + + suggestion_config_name + "`."); + } + + const auto& params = payload["params"]; if(!params.contains("source") || !params["source"].is_object()) { return Option(400, "Bad or missing source."); @@ -56,18 +63,12 @@ Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payl return Option(400, "Bad or missing destination."); } - size_t limit = 1000; if(params.contains("limit") && params["limit"].is_number_integer()) { limit = params["limit"].get(); } - if(suggestion_configs.find(suggestion_config_name) != suggestion_configs.end()) { - return Option(400, "There's already another configuration with the name `" + - suggestion_config_name + "`."); - } - if(!params["source"].contains("collections") || !params["source"]["collections"].is_array()) { return Option(400, "Must contain a valid list of source collections."); } @@ -93,6 +94,14 @@ Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payl std::unique_lock lock(mutex); + if(already_exists) { + // remove the previous configuration with same name (upsert) + Option remove_op = remove_popular_queries_index(suggestion_config_name); + if(!remove_op.ok()) { + return Option(500, "Error erasing the existing configuration.");; + } + } + suggestion_configs.emplace(suggestion_config_name, suggestion_config); for(const auto& query_coll: suggestion_config.query_collections) { @@ -130,13 +139,25 @@ Option AnalyticsManager::list_rules() { for(const auto& suggestion_config: suggestion_configs) { nlohmann::json rule; suggestion_config.second.to_json(rule); - rule["type"] = POPULAR_QUERIES_TYPE; rules["rules"].push_back(rule); } return Option(rules); } +Option AnalyticsManager::get_rule(const string& name) { + nlohmann::json rule; + std::unique_lock lock(mutex); + + auto suggestion_config_it = suggestion_configs.find(name); + if(suggestion_config_it == suggestion_configs.end()) { + return Option(404, "Rule not found."); + } + + suggestion_config_it->second.to_json(rule); + return Option(rule); +} + Option AnalyticsManager::remove_rule(const string &name) { std::unique_lock lock(mutex); diff --git a/src/collection.cpp b/src/collection.cpp index 6d644b31..a2be1b06 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1108,7 +1108,9 @@ Option Collection::search(std::string raw_query, const size_t facet_sample_threshold, const size_t page_offset, facet_index_type_t facet_index_type, - const size_t vector_query_hits) const { + const size_t vector_query_hits, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_try) const { std::shared_lock lock(mutex); @@ -1234,10 +1236,15 @@ Option Collection::search(std::string raw_query, std::string error = "Prefix search is not supported for remote embedders. Please set `prefix=false` as an additional search parameter to disable prefix searching."; return Option(400, error); } + + if(remote_embedding_num_try == 0) { + std::string error = "`remote-embedding-num-try` must be greater than 0."; + return Option(400, error); + } } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query); + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_try); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f8d32b4e..21a1d803 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -285,6 +285,17 @@ Option CollectionManager::load(const size_t collection_batch_size, const s iter->Next(); } + // restore query suggestions configs + std::vector analytics_config_jsons; + store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX, + std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`", + analytics_config_jsons); + + for(const auto& analytics_config_json: analytics_config_jsons) { + nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json); + AnalyticsManager::get_instance().create_rule(analytics_config, false, false); + } + delete iter; LOG(INFO) << "Loaded " << num_collections << " collection(s)."; @@ -671,6 +682,9 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY = "vector_query"; const char *VECTOR_QUERY_HITS = "vector_query_hits"; + const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; + const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try"; + const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -824,6 +838,9 @@ Option CollectionManager::do_search(std::map& re text_match_type_t match_type = max_score; size_t vector_query_hits = 250; + size_t remote_embedding_timeout_ms = 30000; + size_t remote_embedding_num_try = 2; + size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -850,6 +867,8 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, + {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, + {REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try}, }; std::unordered_map str_values = { @@ -1062,7 +1081,11 @@ Option CollectionManager::do_search(std::map& re match_type, facet_sample_percent, facet_sample_threshold, - offset + offset, + HASH, + vector_query_hits, + remote_embedding_timeout_ms, + remote_embedding_num_try ); uint64_t timeMillis = std::chrono::duration_cast( @@ -1312,17 +1335,6 @@ Option CollectionManager::load_collection(const nlohmann::json &collection collection->add_synonym(collection_synonym, false); } - // restore query suggestions configs - std::vector analytics_config_jsons; - cm.store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX, - std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`", - analytics_config_jsons); - - for(const auto& analytics_config_json: analytics_config_jsons) { - nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json); - AnalyticsManager::get_instance().create_rule(analytics_config, false); - } - // Fetch records from the store and re-create memory index const std::string seq_id_prefix = collection->get_seq_id_collection_prefix(); std::string upper_bound_key = collection->get_seq_id_collection_prefix() + "`"; // cannot inline this diff --git a/src/core_api.cpp b/src/core_api.cpp index 5d04ad07..b2b993c8 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -2078,13 +2078,25 @@ bool post_create_event(const std::shared_ptr& req, const std::shared_p bool get_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { auto rules_op = AnalyticsManager::get_instance().list_rules(); - if(rules_op.ok()) { - res->set_200(rules_op.get().dump()); - return true; + if(!rules_op.ok()) { + res->set(rules_op.code(), rules_op.error()); + return false; } - res->set(rules_op.code(), rules_op.error()); - return false; + res->set_200(rules_op.get().dump()); + return true; +} + +bool get_analytics_rule(const std::shared_ptr& req, const std::shared_ptr& res) { + auto rules_op = AnalyticsManager::get_instance().get_rule(req->params["name"]); + + if(!rules_op.ok()) { + res->set(rules_op.code(), rules_op.error()); + return false; + } + + res->set_200(rules_op.get().dump()); + return true; } bool post_create_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { @@ -2098,7 +2110,7 @@ bool post_create_analytics_rules(const std::shared_ptr& req, const std return false; } - auto op = AnalyticsManager::get_instance().create_rule(req_json); + auto op = AnalyticsManager::get_instance().create_rule(req_json, false, true); if(!op.ok()) { res->set(op.code(), op.error()); @@ -2109,6 +2121,29 @@ bool post_create_analytics_rules(const std::shared_ptr& req, const std return true; } +bool put_upsert_analytics_rules(const std::shared_ptr &req, const std::shared_ptr &res) { + nlohmann::json req_json; + + try { + req_json = nlohmann::json::parse(req->body); + } catch(const std::exception& e) { + LOG(ERROR) << "JSON error: " << e.what(); + res->set_400("Bad JSON."); + return false; + } + + req_json["name"] = req->params["name"]; + auto op = AnalyticsManager::get_instance().create_rule(req_json, true, true); + + if(!op.ok()) { + res->set(op.code(), op.error()); + return false; + } + + res->set_200(req_json.dump()); + return true; +} + bool del_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { auto op = AnalyticsManager::get_instance().remove_rule(req->params["name"]); if(!op.ok()) { @@ -2116,11 +2151,13 @@ bool del_analytics_rules(const std::shared_ptr& req, const std::shared return false; } - res->set_200(R"({"ok": true)"); + nlohmann::json res_json; + res_json["name"] = req->params["name"]; + + res->set_200(res_json.dump()); return true; } - bool post_proxy(const std::shared_ptr& req, const std::shared_ptr& res) { HttpProxy& proxy = HttpProxy::get_instance(); @@ -2180,4 +2217,4 @@ bool post_proxy(const std::shared_ptr& req, const std::shared_ptrset_200(response.body); return true; -} \ No newline at end of file +} diff --git a/src/http_client.cpp b/src/http_client.cpp index c6e4b7fd..9b9a3239 100644 --- a/src/http_client.cpp +++ b/src/http_client.cpp @@ -156,6 +156,9 @@ long HttpClient::perform_curl(CURL *curl, std::map& re LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res); curl_easy_cleanup(curl); curl_slist_free_all(chunk); + if(res == CURLE_OPERATION_TIMEDOUT) { + return 408; + } return 500; } diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index eb562849..3134841a 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -8,17 +8,19 @@ HttpProxy::HttpProxy() : cache(30s){ } -http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, + const std::string& body, const std::unordered_map& headers, + const size_t timeout_ms) { HttpClient& client = HttpClient::get_instance(); http_proxy_res_t res; if(method == "GET") { - res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms); } else if(method == "POST") { - res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms); } else if(method == "PUT") { - res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000); + res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms); } else if(method == "DELETE") { - res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000); + res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms); } else { res.status_code = 400; nlohmann::json j; @@ -29,11 +31,25 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth } -http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers) { // check if url is in cache uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); + + size_t timeout_ms = 30000; + size_t num_try = 2; + + if(headers.find("timeout_ms") != headers.end()){ + timeout_ms = std::stoul(headers.at("timeout_ms")); + headers.erase("timeout_ms"); + } + + if(headers.find("num_try") != headers.end()){ + num_try = std::stoul(headers.at("num_try")); + headers.erase("num_try"); + } + for(auto& header : headers){ key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size())); key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size())); @@ -42,22 +58,23 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth return cache[key]; } - auto res = call(url, method, body, headers); + http_proxy_res_t res; + for(size_t i = 0; i < num_try; i++){ + res = call(url, method, body, headers, timeout_ms); - if(res.status_code == 500){ - // retry - res = call(url, method, body, headers); + if(res.status_code != 408 && res.status_code < 500){ + break; + } } - if(res.status_code == 500){ + if(res.status_code == 408){ nlohmann::json j; j["message"] = "Server error on remote server. Please try again later."; res.body = j.dump(); } - // add to cache - if(res.status_code != 500){ + if(res.status_code == 200){ cache.insert(key, res); } diff --git a/src/index.cpp b/src/index.cpp index ffbbbfea..98bb2025 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -5781,7 +5781,9 @@ void Index::get_doc_changes(const index_operation_t op, const tsl::htrie_mapget("/analytics/rules", get_analytics_rules); + server->get("/analytics/rules/:name", get_analytics_rule); server->post("/analytics/rules", post_create_analytics_rules); + server->put("/analytics/rules/:name", put_upsert_analytics_rules); server->del("/analytics/rules/:name", del_analytics_rules); server->post("/analytics/events", post_create_event); diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b2a67607..2055ea77 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text); + return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedding_num_try); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index bb984e86..f93c6713 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -59,15 +59,30 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, headers["Authorization"] = "Bearer " + api_key; std::string res; auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers); + + if(res_code == 408) { + return Option(408, "OpenAI API timeout."); + } + if (res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "OpenAI API error: " + res); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "OpenAI API error: " + res); } return Option(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - auto models_json = nlohmann::json::parse(res); + nlohmann::json models_json; + try { + models_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from OpenAI API."); + } bool found = false; // extract model name by removing "openai/" prefix auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name); @@ -88,49 +103,57 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::string embedding_res; headers["Content-Type"] = "application/json"; - res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + if(res_code == 408) { + return Option(408, "OpenAI API timeout."); + } + if (res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(embedding_res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(embedding_res); + } catch (const std::exception& e) { + return Option(400, "OpenAI API error: " + embedding_res); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "OpenAI API error: " + embedding_res); } return Option(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - - auto embedding = nlohmann::json::parse(embedding_res)["data"][0]["embedding"].get>(); + std::vector embedding; + try { + embedding = nlohmann::json::parse(embedding_res)["data"][0]["embedding"].get>(); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from OpenAI API."); + } num_dims = embedding.size(); return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::string res; nlohmann::json req_body; - req_body["input"] = text; + req_body["input"] = std::vector{text}; // remove "openai/" prefix req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if (res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); + } + try { + embedding_res_t embedding_res = embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); + return embedding_res; + } catch (const std::exception& e) { + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } - - return embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); } std::vector OpenAIEmbedder::batch_embed(const std::vector& inputs) { @@ -147,20 +170,7 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; - - nlohmann::json json_res = nlohmann::json::parse(res); - LOG(INFO) << "OpenAI API error: " << json_res.dump(); - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["request"]["body"]["input"] = std::vector{inputs[0]}; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); - } - + nlohmann::json embedding_res = get_error_json(req_body, res_code, res); for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; outputs.push_back(embedding_res_t(res_code, embedding_res)); @@ -168,7 +178,18 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; + for(size_t i = 0; i < inputs.size(); i++) { + embedding_res["request"]["body"]["input"][0] = inputs[i]; + outputs.push_back(embedding_res_t(500, embedding_res)); + } + return outputs; + } std::vector outputs; for(auto& data : res_json["data"]) { outputs.push_back(embedding_res_t(data["embedding"].get>())); @@ -178,6 +199,36 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector 0 && embedding_res["request"]["body"]["input"].get>().size() > 1) { + auto vec = embedding_res["request"]["body"]["input"].get>(); + vec.resize(1); + embedding_res["request"]["body"]["input"] = vec; + } + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); + } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } + + return embedding_res; +} + + GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) { } @@ -210,22 +261,38 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + json_res = nlohmann::json::object(); + json_res["error"] = "Malformed response from Google API."; + } + if(res_code == 408) { + return Option(408, "Google API timeout."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "Google API error: " + res); } + return Option(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - num_dims = nlohmann::json::parse(res)["embedding"]["value"].get>().size(); + try { + num_dims = nlohmann::json::parse(res)["embedding"]["value"].get>().size(); + } catch (const std::exception& e) { + return Option(500, "Got malformed response from Google API."); + } return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -233,20 +300,14 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get(); - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } - return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); + try { + return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); + } catch (const std::exception& e) { + return embedding_res_t(500, get_error_json(req_body, res_code, res)); + } } @@ -260,6 +321,30 @@ std::vector GoogleEmbedder::batch_embed(const std::vector(); + } + if(res_code == 408) { + embedding_res["error"] = "Google API timeout."; + } + + return embedding_res; +} + GCPEmbedder::GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) : @@ -302,14 +387,26 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } + if(json_res == 408) { + return Option(408, "GCP API timeout."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } return Option(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - - auto res_json = nlohmann::json::parse(res); + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } if(res_json.count("predictions") == 0 || res_json["predictions"].size() == 0 || res_json["predictions"][0].count("embeddings") == 0) { LOG(INFO) << "Invalid response from GCP API: " << res_json.dump(); return Option(400, "GCP API error: Invalid response"); @@ -325,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -334,6 +431,8 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { std::unordered_map headers; headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::map res_headers; std::string res; @@ -355,24 +454,17 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get(); - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); + } + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } - - nlohmann::json res_json = nlohmann::json::parse(res); return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get>()); } - std::vector GCPEmbedder::batch_embed(const std::vector& inputs) { // GCP API has a limit of 5 instances per request if(inputs.size() > 5) { @@ -416,24 +508,24 @@ std::vector GCPEmbedder::batch_embed(const std::vector(); - } + auto embedding_res = get_error_json(req_body, res_code, res); std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(res_code, embedding_res)); } return outputs; } - - nlohmann::json res_json = nlohmann::json::parse(res); + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + nlohmann::json embedding_res = get_error_json(req_body, res_code, res); + std::vector outputs; + for(size_t i = 0; i < inputs.size(); i++) { + outputs.push_back(embedding_res_t(400, embedding_res)); + } + return outputs; + } std::vector outputs; for(const auto& prediction : res_json["predictions"]) { outputs.push_back(embedding_res_t(prediction["embeddings"]["values"].get>())); @@ -442,6 +534,34 @@ std::vector GCPEmbedder::batch_embed(const std::vector(); + } else { + embedding_res["error"] = "Malformed response from GCP API."; + } + + if(res_code == 408) { + embedding_res["error"] = "GCP API timeout."; + } + + return embedding_res; +} + Option GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) { std::unordered_map headers; headers["Content-Type"] = "application/x-www-form-urlencoded"; @@ -453,17 +573,27 @@ Option GCPEmbedder::generate_access_token(const std::string& refres auto res_code = call_remote_api("POST", GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res = nlohmann::json::parse(res); + nlohmann::json json_res; + try { + json_res = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { return Option(400, "GCP API error: " + res); } + if(res_code == 408) { + return Option(408, "GCP API timeout."); + } return Option(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get()); } - - nlohmann::json res_json = nlohmann::json::parse(res); + nlohmann::json res_json; + try { + res_json = nlohmann::json::parse(res); + } catch (const std::exception& e) { + return Option(400, "Got malformed response from GCP API."); + } std::string access_token = res_json["access_token"].get(); return Option(access_token); } - - diff --git a/test/analytics_manager_test.cpp b/test/analytics_manager_test.cpp index 0030f221..a86bf809 100644 --- a/test/analytics_manager_test.cpp +++ b/test/analytics_manager_test.cpp @@ -78,7 +78,7 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) { } })"_json; - auto create_op = analyticsManager.create_rule(analytics_rule); + auto create_op = analyticsManager.create_rule(analytics_rule, false, true); ASSERT_TRUE(create_op.ok()); std::string q = "foobar"; @@ -88,4 +88,97 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) { auto userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"]; ASSERT_EQ(1, userQueries.size()); ASSERT_EQ("foobar", userQueries[0].query); + + // add another query which is more popular + q = "buzzfoo"; + analyticsManager.add_suggestion("titles", q, true, "1"); + analyticsManager.add_suggestion("titles", q, true, "2"); + analyticsManager.add_suggestion("titles", q, true, "3"); + + popularQueries = analyticsManager.get_popular_queries(); + userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"]; + ASSERT_EQ(2, userQueries.size()); + ASSERT_EQ("foobar", userQueries[0].query); + ASSERT_EQ("buzzfoo", userQueries[1].query); } + +TEST_F(AnalyticsManagerTest, GetAndDeleteSuggestions) { + nlohmann::json analytics_rule = R"({ + "name": "top_search_queries", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queries" + } + } + })"_json; + + auto create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_TRUE(create_op.ok()); + + analytics_rule = R"({ + "name": "top_search_queries2", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queries" + } + } + })"_json; + + create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_TRUE(create_op.ok()); + + auto rules = analyticsManager.list_rules().get()["rules"]; + ASSERT_EQ(2, rules.size()); + + ASSERT_TRUE(analyticsManager.get_rule("top_search_queries").ok()); + ASSERT_TRUE(analyticsManager.get_rule("top_search_queries2").ok()); + + auto missing_rule_op = analyticsManager.get_rule("top_search_queriesX"); + ASSERT_FALSE(missing_rule_op.ok()); + ASSERT_EQ(404, missing_rule_op.code()); + ASSERT_EQ("Rule not found.", missing_rule_op.error()); + + // upsert rule that already exists + analytics_rule = R"({ + "name": "top_search_queries2", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queriesUpdated" + } + } + })"_json; + create_op = analyticsManager.create_rule(analytics_rule, true, true); + ASSERT_TRUE(create_op.ok()); + auto existing_rule = analyticsManager.get_rule("top_search_queries2").get(); + ASSERT_EQ("top_queriesUpdated", existing_rule["params"]["destination"]["collection"].get()); + + // reject when upsert is not enabled + create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_FALSE(create_op.ok()); + ASSERT_EQ("There's already another configuration with the name `top_search_queries2`.", create_op.error()); + + // try deleting both rules + analyticsManager.remove_rule("top_search_queries"); + analyticsManager.remove_rule("top_search_queries2"); + + missing_rule_op = analyticsManager.get_rule("top_search_queries"); + ASSERT_FALSE(missing_rule_op.ok()); + missing_rule_op = analyticsManager.get_rule("top_search_queries2"); + ASSERT_FALSE(missing_rule_op.ok()); +} + diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 1a0f0267..68980d21 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -1321,6 +1321,21 @@ TEST_F(CollectionSpecificMoreTest, UpdateArrayWithNullValue) { auto results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(0, results["found"].get()); + // update document with no value (optional field) with a null value + auto doc3 = R"({ + "id": "2" + })"_json; + + ASSERT_TRUE(coll1->add(doc3.dump(), CREATE).ok()); + results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["found"].get()); + + doc_update = R"({ + "id": "2", + "tags": null + })"_json; + ASSERT_TRUE(coll1->add(doc_update.dump(), UPDATE).ok()); + // via upsert doc_update = R"({ diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 4da66c4a..1c90367e 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5177,4 +5177,45 @@ TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) { ASSERT_FALSE(get_op.get()["embedding"].is_null()); ASSERT_EQ(384, get_op.get()["embedding"].size()); +} + + +TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) { + std::string partial_json = R"({ + "results": [ + { + "embedding": [ + 0.0, + 0.0, + 0.0 + ], + "text": "butter" + }, + { + "embedding": [ + 0.0, + 0.0, + 0.0 + ], + "text": "butterball" + }, + { + "embedding": [ + 0.0, + 0.0)"; + + nlohmann::json req_body = R"({ + "inputs": [ + "butter", + "butterball", + "butterfly" + ] + })"_json; + + OpenAIEmbedder embedder("", ""); + + auto res = embedder.get_error_json(req_body, 200, partial_json); + + ASSERT_EQ(res["response"]["error"], "Malformed response from OpenAI API."); + ASSERT_EQ(res["request"]["body"], req_body); } \ No newline at end of file diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index d24b44bf..80790fd2 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -1137,4 +1137,27 @@ TEST_F(CoreAPIUtilsTest, TestProxyInvalid) { ASSERT_EQ(400, resp->status_code); ASSERT_EQ("Headers must be a JSON object.", nlohmann::json::parse(resp->body)["message"]); +} + + + +TEST_F(CoreAPIUtilsTest, TestProxyTimeout) { + nlohmann::json body; + + auto req = std::make_shared(); + auto resp = std::make_shared(nullptr); + + // test with url as empty string + body["url"] = "https://typesense.org/docs/"; + body["method"] = "GET"; + body["headers"] = nlohmann::json::object(); + body["headers"]["timeout_ms"] = "1"; + body["headers"]["num_retry"] = "1"; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(408, resp->status_code); + ASSERT_EQ("Server error on remote server. Please try again later.", nlohmann::json::parse(resp->body)["message"]); } \ No newline at end of file