From d569cbd75e31843b2226ccafacbb8876a39d75a7 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 2 Jul 2023 20:43:31 +0530 Subject: [PATCH 01/11] Fix bug with offset/pagination affecting vector search. # Conflicts: # include/collection.h # src/index.cpp --- include/collection.h | 2 +- include/index.h | 2 +- src/collection.cpp | 12 +++++++++++- src/collection_manager.cpp | 16 ++++------------ test/collection_test.cpp | 20 ++++++++++---------- test/core_api_utils_test.cpp | 27 ++++++++++++++++++++++++--- 6 files changed, 51 insertions(+), 28 deletions(-) diff --git a/include/collection.h b/include/collection.h index 2f79bc68..65d51c0c 100644 --- a/include/collection.h +++ b/include/collection.h @@ -463,7 +463,7 @@ public: const text_match_type_t match_type = max_score, const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, - const size_t page_offset = UINT32_MAX, + const size_t page_offset = 0, const size_t vector_query_hits = 250) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/index.h b/include/index.h index 5c1b77e0..2d87fa45 100644 --- a/include/index.h +++ b/include/index.h @@ -150,7 +150,7 @@ struct search_args { filter_node_t* filter_tree_root, std::vector& facets, std::vector>& included_ids, std::vector excluded_ids, std::vector& sort_fields_std, facet_query_t facet_query, const std::vector& num_typos, - size_t max_facet_values, size_t max_hits, size_t per_page, size_t page, token_ordering token_order, + size_t max_facet_values, size_t max_hits, size_t per_page, size_t offset, token_ordering token_order, const std::vector& prefixes, size_t drop_tokens_threshold, size_t typo_tokens_threshold, const std::vector& group_by_fields, size_t group_limit, const string& default_sorting_field, bool prioritize_exact_match, diff --git a/src/collection.cpp b/src/collection.cpp index c55bee7d..05bfb824 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1392,7 +1392,17 @@ Option Collection::search(std::string raw_query, return Option(422, message); } - size_t offset = (page != 0) ? (per_page * (page - 1)) : page_offset; + size_t offset = 0; + + if(page == 0 && page_offset != 0) { + // if only offset is set, use that + offset = page_offset; + } else { + // if both are set or none set, use page value (default is 1) + size_t actual_page = (page == 0) ? 1 : page; + offset = (per_page * (actual_page - 1)); + } + size_t fetch_size = offset + per_page; if(fetch_size > limit_hits) { diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 5323e1b0..f8d32b4e 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -783,7 +783,7 @@ Option CollectionManager::do_search(std::map& re std::vector sort_fields; size_t per_page = 10; size_t page = 0; - size_t offset = UINT32_MAX; + size_t offset = 0; token_ordering token_order = NOT_SET; std::string vector_query; @@ -978,14 +978,6 @@ Option CollectionManager::do_search(std::map& re per_page = 0; } - if(!req_params[PAGE].empty() && page == 0 && offset == UINT32_MAX) { - return Option(422, "Parameter `page` must be an integer of value greater than 0."); - } - - if(req_params[PAGE].empty() && req_params[OFFSET].empty()) { - page = 1; - } - include_fields.insert(include_fields_vec.begin(), include_fields_vec.end()); exclude_fields.insert(exclude_fields_vec.begin(), exclude_fields_vec.end()); @@ -1097,10 +1089,10 @@ Option CollectionManager::do_search(std::map& re result["search_time_ms"] = timeMillis; } - if(page != 0) { - result["page"] = page; - } else { + if(page == 0 && offset != 0) { result["offset"] = offset; + } else { + result["page"] = (page == 0) ? 1 : page; } results_json_str = result.dump(-1, ' ', false, nlohmann::detail::error_handler_t::ignore); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index e49e89f6..13d1bdab 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -1054,11 +1054,11 @@ TEST_F(CollectionTest, KeywordQueryReturnsResultsBasedOnPerPageParam) { ASSERT_EQ(422, res_op.code()); ASSERT_STREQ("Only upto 250 hits can be fetched per page.", res_op.error().c_str()); - // when page number is not valid - res_op = coll_mul_fields->search("w", query_fields, "", facets, sort_fields, {0}, 10, 0, - FREQUENCY, {true}, 1000, empty, empty, 10); - ASSERT_FALSE(res_op.ok()); - ASSERT_EQ(422, res_op.code()); + // when page number is zero, use the first page + results = coll_mul_fields->search("w", query_fields, "", facets, sort_fields, {0}, 3, 0, + FREQUENCY, {true}, 1000, empty, empty, 10).get(); + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ(6, results["found"].get()); // do pagination @@ -3026,11 +3026,11 @@ TEST_F(CollectionTest, WildcardQueryReturnsResultsBasedOnPerPageParam) { ASSERT_EQ(422, res_op.code()); ASSERT_STREQ("Only upto 250 hits can be fetched per page.", res_op.error().c_str()); - // when page number is not valid - res_op = collection->search("*", query_fields, "", facets, sort_fields, {0}, 10, 0, - FREQUENCY, {false}, 1000, empty, empty, 10); - ASSERT_FALSE(res_op.ok()); - ASSERT_EQ(422, res_op.code()); + // when page number is 0, just fetch first page + results = collection->search("*", query_fields, "", facets, sort_fields, {0}, 10, 0, + FREQUENCY, {false}, 1000, empty, empty, 10).get(); + ASSERT_EQ(10, results["hits"].size()); + ASSERT_EQ(25, results["found"].get()); // do pagination diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index 95af7546..d24b44bf 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -768,7 +768,7 @@ TEST_F(CoreAPIUtilsTest, SearchPagination) { ASSERT_EQ(400, results["code"].get()); ASSERT_EQ("Parameter `offset` must be an unsigned integer.", results["error"].get()); - // when page is 0 and no offset is sent + // when page is 0 and offset is NOT sent, we will treat as page=1 search.clear(); req->params.clear(); body["searches"] = nlohmann::json::array(); @@ -782,8 +782,29 @@ TEST_F(CoreAPIUtilsTest, SearchPagination) { post_multi_search(req, res); results = nlohmann::json::parse(res->body)["results"][0]; - ASSERT_EQ(422, results["code"].get()); - ASSERT_EQ("Parameter `page` must be an integer of value greater than 0.", results["error"].get()); + ASSERT_EQ(10, results["hits"].size()); + ASSERT_EQ(1, results["page"].get()); + ASSERT_EQ(0, results.count("offset")); + + // when both page and offset are sent, use page + search.clear(); + req->params.clear(); + body["searches"] = nlohmann::json::array(); + search["collection"] = "coll1"; + search["q"] = "title"; + search["page"] = "2"; + search["offset"] = "30"; + search["query_by"] = "name"; + search["sort_by"] = "points:desc"; + body["searches"].push_back(search); + req->body = body.dump(); + + post_multi_search(req, res); + results = nlohmann::json::parse(res->body)["results"][0]; + ASSERT_EQ(10, results["hits"].size()); + ASSERT_EQ(2, results["page"].get()); + ASSERT_EQ(0, results.count("offset")); + } TEST_F(CoreAPIUtilsTest, ExportWithFilter) { From 974280a4d580906472d3eccea42d9a83c5ce1f31 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 4 Jul 2023 15:55:12 +0530 Subject: [PATCH 02/11] Fix response streaming regression. Introduced during export memory usage fix. --- include/http_data.h | 46 +++++++++++++++++++++++++++++++++ include/http_server.h | 60 +++++-------------------------------------- src/http_server.cpp | 10 +++++--- 3 files changed, 59 insertions(+), 57 deletions(-) diff --git a/include/http_data.h b/include/http_data.h index c435e8d7..d96dee69 100644 --- a/include/http_data.h +++ b/include/http_data.h @@ -205,6 +205,49 @@ public: virtual ~req_state_t() = default; }; +struct stream_response_state_t { +private: + + h2o_req_t* req = nullptr; + +public: + + bool is_req_early_exit = false; + + bool is_res_start = true; + h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS; + + std::string res_body; + h2o_iovec_t res_buff; + + std::string res_content_type; + int status = 0; + const char* reason = nullptr; + + h2o_generator_t* generator = nullptr; + + void set_response(uint32_t status_code, const std::string& content_type, std::string& body) { + std::string().swap(res_body); + res_body = std::move(body); + res_buff = h2o_iovec_t{.base = res_body.data(), .len = res_body.size()}; + + if(is_res_start) { + res_content_type = std::move(content_type); + status = (int)status_code; + reason = http_res::get_status_reason(status_code); + is_res_start = false; + } + } + + void set_req(h2o_req_t* _req) { + req = _req; + } + + h2o_req_t* get_req() { + return req; + } +}; + struct http_req { static constexpr const char* AUTH_HEADER = "x-typesense-api-key"; static constexpr const char* USER_HEADER = "x-typesense-user-id"; @@ -248,6 +291,9 @@ struct http_req { std::atomic is_diposed; std::string client_ip = "0.0.0.0"; + // stores http lib related datastructures to avoid race conditions between indexing and http write threads + stream_response_state_t res_state; + http_req(): _req(nullptr), route_hash(1), first_chunk_aggregate(true), last_chunk_aggregate(false), chunk_len(0), body_index(0), data(nullptr), ready(false), log_index(0), diff --git a/include/http_server.h b/include/http_server.h index 1ced92e7..b83ae775 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -40,50 +40,6 @@ struct h2o_custom_generator_t { } }; -struct stream_response_state_t { -private: - - h2o_req_t* req = nullptr; - -public: - - bool is_req_early_exit = false; - - bool is_res_start = true; - h2o_send_state_t send_state = H2O_SEND_STATE_IN_PROGRESS; - - std::string res_body; - h2o_iovec_t res_buff; - - h2o_iovec_t res_content_type{}; - int status = 0; - const char* reason = nullptr; - - h2o_generator_t* generator = nullptr; - - explicit stream_response_state_t(h2o_req_t* _req): req(_req) { - if(req != nullptr) { - is_res_start = (req->res.status == 0); - } - } - - void set_response(uint32_t status_code, const std::string& content_type, std::string& body) { - std::string().swap(res_body); - res_body = std::move(body); - res_buff = h2o_iovec_t{.base = res_body.data(), .len = res_body.size()}; - - if(is_res_start) { - res_content_type = h2o_strdup(&req->pool, content_type.c_str(), SIZE_MAX); - status = status_code; - reason = http_res::get_status_reason(status_code); - } - } - - h2o_req_t* get_req() { - return req; - } -}; - struct deferred_req_res_t { const std::shared_ptr req; const std::shared_ptr res; @@ -110,13 +66,9 @@ public: // used to manage lifecycle of async actions const bool destroy_after_use; - // stores http lib related datastructures to avoid race conditions between indexing and http write threads - stream_response_state_t res_state; - async_req_res_t(const std::shared_ptr& h_req, const std::shared_ptr& h_res, const bool destroy_after_use) : - req(h_req), res(h_res), destroy_after_use(destroy_after_use), - res_state((std::shared_lock(res->mres), h_req->is_diposed ? nullptr : h_req->_req)) { + req(h_req), res(h_res), destroy_after_use(destroy_after_use) { std::shared_lock lk(res->mres); @@ -124,12 +76,10 @@ public: return; } - // ***IMPORTANT*** - // We limit writing to fields of `res_state.req` to prevent race conditions with http thread - // Check `HttpServer::stream_response()` for overlapping writes. - h2o_custom_generator_t* res_generator = static_cast(res->generator.load()); + auto& res_state = req->res_state; + res_state.set_req(h_req->is_diposed ? nullptr : h_req->_req); res_state.is_req_early_exit = (res_generator->rpath->async_req && res->final && !req->last_chunk_aggregate); res_state.send_state = res->final ? H2O_SEND_STATE_FINAL : H2O_SEND_STATE_IN_PROGRESS; res_state.generator = (res_generator == nullptr) ? nullptr : &res_generator->h2o_generator; @@ -147,6 +97,10 @@ public: void res_notify() { return res->notify(); } + + stream_response_state_t& get_res_state() { + return req->res_state; + } }; struct defer_processing_t { diff --git a/src/http_server.cpp b/src/http_server.cpp index 813ee450..661a83ba 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -808,9 +808,11 @@ void HttpServer::stream_response(stream_response_state_t& state) { h2o_req_t* req = state.get_req(); - if(state.is_res_start) { + bool start_of_res = (req->res.status == 0); + + if(start_of_res) { h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, NULL, - state.res_content_type.base, state.res_content_type.len); + state.res_content_type.data(), state.res_content_type.size()); req->res.status = state.status; req->res.reason = state.reason; } @@ -829,7 +831,7 @@ void HttpServer::stream_response(stream_response_state_t& state) { return ; } - if (state.is_res_start) { + if (start_of_res) { /*LOG(INFO) << "h2o_start_response, content_type=" << state.res_content_type << ",response.status_code=" << state.res_status_code;*/ h2o_start_response(req, state.generator); @@ -969,7 +971,7 @@ bool HttpServer::on_stream_response_message(void *data) { // NOTE: access to `req` and `res` objects must be synchronized and wrapped by `req_res` if(req_res->is_alive()) { - stream_response(req_res->res_state); + stream_response(req_res->get_res_state()); } else { // serialized request or generator has been disposed (underlying request is probably dead) req_res->req_notify(); From 437732c89d6ff00ad90062db7d26077c6c079970 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 6 Jul 2023 10:31:41 +0530 Subject: [PATCH 03/11] Remove stray logging in test. --- test/tokenizer_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tokenizer_test.cpp b/test/tokenizer_test.cpp index ff416f7d..64a34066 100644 --- a/test/tokenizer_test.cpp +++ b/test/tokenizer_test.cpp @@ -238,7 +238,7 @@ TEST(TokenizerTest, ShouldTokenizeLocaleText) { // window used to locate the starting offset for snippet on the text while(tokenizer.next(raw_token, raw_token_index, tok_start, tok_end)) { - LOG(INFO) << "tok_start: " << tok_start; + //LOG(INFO) << "tok_start: " << tok_start; } return ; From 2d4221c1f7f09b9d2e0e952ec8a9dbce7afcbca7 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 6 Jul 2023 20:33:48 +0530 Subject: [PATCH 04/11] Check for presence of key before getting from old doc during update. --- src/index.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/index.cpp b/src/index.cpp index 4e6ac914..155958ce 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6268,7 +6268,9 @@ void Index::get_doc_changes(const index_operation_t op, const tsl::htrie_map Date: Thu, 6 Jul 2023 21:06:18 +0530 Subject: [PATCH 05/11] Add test for null value + update of missing optional field. --- test/collection_specific_more_test.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index f53ce1a7..e3f82f69 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -1311,6 +1311,21 @@ TEST_F(CollectionSpecificMoreTest, UpdateArrayWithNullValue) { auto results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(0, results["found"].get()); + // update document with no value (optional field) with a null value + auto doc3 = R"({ + "id": "2" + })"_json; + + ASSERT_TRUE(coll1->add(doc3.dump(), CREATE).ok()); + results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["found"].get()); + + doc_update = R"({ + "id": "2", + "tags": null + })"_json; + ASSERT_TRUE(coll1->add(doc_update.dump(), UPDATE).ok()); + // via upsert doc_update = R"({ From 8e1f6caaf15b6a596a21928b03c66cf5094b3ad3 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 7 Jul 2023 12:26:54 +0530 Subject: [PATCH 06/11] Add upsert + get APIs for analytics rules. --- include/analytics_manager.h | 15 ++++-- include/core_api.h | 4 ++ src/analytics_manager.cpp | 43 +++++++++++---- src/collection_manager.cpp | 22 ++++---- src/core_api.cpp | 52 ++++++++++++++---- src/main/typesense_server.cpp | 2 + test/analytics_manager_test.cpp | 95 ++++++++++++++++++++++++++++++++- 7 files changed, 196 insertions(+), 37 deletions(-) diff --git a/include/analytics_manager.h b/include/analytics_manager.h index 353085b6..4207f7cf 100644 --- a/include/analytics_manager.h +++ b/include/analytics_manager.h @@ -24,10 +24,11 @@ private: void to_json(nlohmann::json& obj) const { obj["name"] = name; + obj["type"] = POPULAR_QUERIES_TYPE; obj["params"] = nlohmann::json::object(); - obj["params"]["suggestion_collection"] = suggestion_collection; - obj["params"]["query_collections"] = query_collections; obj["params"]["limit"] = limit; + obj["params"]["source"]["collections"] = query_collections; + obj["params"]["destination"]["collection"] = suggestion_collection; } }; @@ -48,7 +49,9 @@ private: Option remove_popular_queries_index(const std::string& name); - Option create_popular_queries_index(nlohmann::json &payload, bool write_to_disk); + Option create_popular_queries_index(nlohmann::json &payload, + bool upsert, + bool write_to_disk); public: @@ -69,12 +72,14 @@ public: Option list_rules(); - Option create_rule(nlohmann::json& payload, bool write_to_disk = true); + Option get_rule(const std::string& name); + + Option create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk); Option remove_rule(const std::string& name); void add_suggestion(const std::string& query_collection, - std::string& query, const bool live_query, const std::string& user_id); + std::string& query, bool live_query, const std::string& user_id); void stop(); diff --git a/include/core_api.h b/include/core_api.h index 780d3d1b..a13b1db8 100644 --- a/include/core_api.h +++ b/include/core_api.h @@ -147,8 +147,12 @@ bool post_create_event(const std::shared_ptr& req, const std::shared_p bool get_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); +bool get_analytics_rule(const std::shared_ptr& req, const std::shared_ptr& res); + bool post_create_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); +bool put_upsert_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); + bool del_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res); // Misc helpers diff --git a/src/analytics_manager.cpp b/src/analytics_manager.cpp index 4851dd46..dcb6a64a 100644 --- a/src/analytics_manager.cpp +++ b/src/analytics_manager.cpp @@ -5,7 +5,7 @@ #include "http_client.h" #include "collection_manager.h" -Option AnalyticsManager::create_rule(nlohmann::json& payload, bool write_to_disk) { +Option AnalyticsManager::create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk) { /* Sample payload: @@ -37,16 +37,23 @@ Option AnalyticsManager::create_rule(nlohmann::json& payload, bool write_t } if(payload["type"] == POPULAR_QUERIES_TYPE) { - return create_popular_queries_index(payload, write_to_disk); + return create_popular_queries_index(payload, upsert, write_to_disk); } return Option(400, "Invalid type."); } -Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool write_to_disk) { +Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool upsert, bool write_to_disk) { // params and name are validated upstream - const auto& params = payload["params"]; const std::string& suggestion_config_name = payload["name"].get(); + bool already_exists = suggestion_configs.find(suggestion_config_name) != suggestion_configs.end(); + + if(!upsert && already_exists) { + return Option(400, "There's already another configuration with the name `" + + suggestion_config_name + "`."); + } + + const auto& params = payload["params"]; if(!params.contains("source") || !params["source"].is_object()) { return Option(400, "Bad or missing source."); @@ -56,18 +63,12 @@ Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payl return Option(400, "Bad or missing destination."); } - size_t limit = 1000; if(params.contains("limit") && params["limit"].is_number_integer()) { limit = params["limit"].get(); } - if(suggestion_configs.find(suggestion_config_name) != suggestion_configs.end()) { - return Option(400, "There's already another configuration with the name `" + - suggestion_config_name + "`."); - } - if(!params["source"].contains("collections") || !params["source"]["collections"].is_array()) { return Option(400, "Must contain a valid list of source collections."); } @@ -93,6 +94,14 @@ Option AnalyticsManager::create_popular_queries_index(nlohmann::json &payl std::unique_lock lock(mutex); + if(already_exists) { + // remove the previous configuration with same name (upsert) + Option remove_op = remove_popular_queries_index(suggestion_config_name); + if(!remove_op.ok()) { + return Option(500, "Error erasing the existing configuration.");; + } + } + suggestion_configs.emplace(suggestion_config_name, suggestion_config); for(const auto& query_coll: suggestion_config.query_collections) { @@ -130,13 +139,25 @@ Option AnalyticsManager::list_rules() { for(const auto& suggestion_config: suggestion_configs) { nlohmann::json rule; suggestion_config.second.to_json(rule); - rule["type"] = POPULAR_QUERIES_TYPE; rules["rules"].push_back(rule); } return Option(rules); } +Option AnalyticsManager::get_rule(const string& name) { + nlohmann::json rule; + std::unique_lock lock(mutex); + + auto suggestion_config_it = suggestion_configs.find(name); + if(suggestion_config_it == suggestion_configs.end()) { + return Option(404, "Rule not found."); + } + + suggestion_config_it->second.to_json(rule); + return Option(rule); +} + Option AnalyticsManager::remove_rule(const string &name) { std::unique_lock lock(mutex); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f8d32b4e..6c281595 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -285,6 +285,17 @@ Option CollectionManager::load(const size_t collection_batch_size, const s iter->Next(); } + // restore query suggestions configs + std::vector analytics_config_jsons; + store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX, + std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`", + analytics_config_jsons); + + for(const auto& analytics_config_json: analytics_config_jsons) { + nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json); + AnalyticsManager::get_instance().create_rule(analytics_config, false, false); + } + delete iter; LOG(INFO) << "Loaded " << num_collections << " collection(s)."; @@ -1312,17 +1323,6 @@ Option CollectionManager::load_collection(const nlohmann::json &collection collection->add_synonym(collection_synonym, false); } - // restore query suggestions configs - std::vector analytics_config_jsons; - cm.store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX, - std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`", - analytics_config_jsons); - - for(const auto& analytics_config_json: analytics_config_jsons) { - nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json); - AnalyticsManager::get_instance().create_rule(analytics_config, false); - } - // Fetch records from the store and re-create memory index const std::string seq_id_prefix = collection->get_seq_id_collection_prefix(); std::string upper_bound_key = collection->get_seq_id_collection_prefix() + "`"; // cannot inline this diff --git a/src/core_api.cpp b/src/core_api.cpp index 5d04ad07..0c9cd0e5 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -2078,13 +2078,25 @@ bool post_create_event(const std::shared_ptr& req, const std::shared_p bool get_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { auto rules_op = AnalyticsManager::get_instance().list_rules(); - if(rules_op.ok()) { - res->set_200(rules_op.get().dump()); - return true; + if(!rules_op.ok()) { + res->set(rules_op.code(), rules_op.error()); + return false; } - res->set(rules_op.code(), rules_op.error()); - return false; + res->set_200(rules_op.get().dump()); + return true; +} + +bool get_analytics_rule(const std::shared_ptr& req, const std::shared_ptr& res) { + auto rules_op = AnalyticsManager::get_instance().get_rule(req->params["name"]); + + if(!rules_op.ok()) { + res->set(rules_op.code(), rules_op.error()); + return false; + } + + res->set_200(rules_op.get().dump()); + return true; } bool post_create_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { @@ -2098,7 +2110,7 @@ bool post_create_analytics_rules(const std::shared_ptr& req, const std return false; } - auto op = AnalyticsManager::get_instance().create_rule(req_json); + auto op = AnalyticsManager::get_instance().create_rule(req_json, false, true); if(!op.ok()) { res->set(op.code(), op.error()); @@ -2109,6 +2121,29 @@ bool post_create_analytics_rules(const std::shared_ptr& req, const std return true; } +bool put_upsert_analytics_rules(const std::shared_ptr &req, const std::shared_ptr &res) { + nlohmann::json req_json; + + try { + req_json = nlohmann::json::parse(req->body); + } catch(const std::exception& e) { + LOG(ERROR) << "JSON error: " << e.what(); + res->set_400("Bad JSON."); + return false; + } + + req_json["name"] = req->params["name"]; + auto op = AnalyticsManager::get_instance().create_rule(req_json, true, true); + + if(!op.ok()) { + res->set(op.code(), op.error()); + return false; + } + + res->set_200(req_json.dump()); + return true; +} + bool del_analytics_rules(const std::shared_ptr& req, const std::shared_ptr& res) { auto op = AnalyticsManager::get_instance().remove_rule(req->params["name"]); if(!op.ok()) { @@ -2116,11 +2151,10 @@ bool del_analytics_rules(const std::shared_ptr& req, const std::shared return false; } - res->set_200(R"({"ok": true)"); + res->set_200(R"({"ok": true})"); return true; } - bool post_proxy(const std::shared_ptr& req, const std::shared_ptr& res) { HttpProxy& proxy = HttpProxy::get_instance(); @@ -2180,4 +2214,4 @@ bool post_proxy(const std::shared_ptr& req, const std::shared_ptrset_200(response.body); return true; -} \ No newline at end of file +} diff --git a/src/main/typesense_server.cpp b/src/main/typesense_server.cpp index d4df4f2a..8a47e1bb 100644 --- a/src/main/typesense_server.cpp +++ b/src/main/typesense_server.cpp @@ -70,7 +70,9 @@ void master_server_routes() { // analytics server->get("/analytics/rules", get_analytics_rules); + server->get("/analytics/rules/:name", get_analytics_rule); server->post("/analytics/rules", post_create_analytics_rules); + server->put("/analytics/rules/:name", put_upsert_analytics_rules); server->del("/analytics/rules/:name", del_analytics_rules); server->post("/analytics/events", post_create_event); diff --git a/test/analytics_manager_test.cpp b/test/analytics_manager_test.cpp index 0030f221..a86bf809 100644 --- a/test/analytics_manager_test.cpp +++ b/test/analytics_manager_test.cpp @@ -78,7 +78,7 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) { } })"_json; - auto create_op = analyticsManager.create_rule(analytics_rule); + auto create_op = analyticsManager.create_rule(analytics_rule, false, true); ASSERT_TRUE(create_op.ok()); std::string q = "foobar"; @@ -88,4 +88,97 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) { auto userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"]; ASSERT_EQ(1, userQueries.size()); ASSERT_EQ("foobar", userQueries[0].query); + + // add another query which is more popular + q = "buzzfoo"; + analyticsManager.add_suggestion("titles", q, true, "1"); + analyticsManager.add_suggestion("titles", q, true, "2"); + analyticsManager.add_suggestion("titles", q, true, "3"); + + popularQueries = analyticsManager.get_popular_queries(); + userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"]; + ASSERT_EQ(2, userQueries.size()); + ASSERT_EQ("foobar", userQueries[0].query); + ASSERT_EQ("buzzfoo", userQueries[1].query); } + +TEST_F(AnalyticsManagerTest, GetAndDeleteSuggestions) { + nlohmann::json analytics_rule = R"({ + "name": "top_search_queries", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queries" + } + } + })"_json; + + auto create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_TRUE(create_op.ok()); + + analytics_rule = R"({ + "name": "top_search_queries2", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queries" + } + } + })"_json; + + create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_TRUE(create_op.ok()); + + auto rules = analyticsManager.list_rules().get()["rules"]; + ASSERT_EQ(2, rules.size()); + + ASSERT_TRUE(analyticsManager.get_rule("top_search_queries").ok()); + ASSERT_TRUE(analyticsManager.get_rule("top_search_queries2").ok()); + + auto missing_rule_op = analyticsManager.get_rule("top_search_queriesX"); + ASSERT_FALSE(missing_rule_op.ok()); + ASSERT_EQ(404, missing_rule_op.code()); + ASSERT_EQ("Rule not found.", missing_rule_op.error()); + + // upsert rule that already exists + analytics_rule = R"({ + "name": "top_search_queries2", + "type": "popular_queries", + "params": { + "limit": 100, + "source": { + "collections": ["titles"] + }, + "destination": { + "collection": "top_queriesUpdated" + } + } + })"_json; + create_op = analyticsManager.create_rule(analytics_rule, true, true); + ASSERT_TRUE(create_op.ok()); + auto existing_rule = analyticsManager.get_rule("top_search_queries2").get(); + ASSERT_EQ("top_queriesUpdated", existing_rule["params"]["destination"]["collection"].get()); + + // reject when upsert is not enabled + create_op = analyticsManager.create_rule(analytics_rule, false, true); + ASSERT_FALSE(create_op.ok()); + ASSERT_EQ("There's already another configuration with the name `top_search_queries2`.", create_op.error()); + + // try deleting both rules + analyticsManager.remove_rule("top_search_queries"); + analyticsManager.remove_rule("top_search_queries2"); + + missing_rule_op = analyticsManager.get_rule("top_search_queries"); + ASSERT_FALSE(missing_rule_op.ok()); + missing_rule_op = analyticsManager.get_rule("top_search_queries2"); + ASSERT_FALSE(missing_rule_op.ok()); +} + From 29830e1d58825175fab4548e2faeb93495241e86 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 13:41:38 +0300 Subject: [PATCH 07/11] Create virtual functions to generate error JSONs from responses --- include/text_embedder_remote.h | 4 + src/http_proxy.cpp | 4 +- src/text_embedder_remote.cpp | 242 +++++++++++++-------------------- 3 files changed, 100 insertions(+), 150 deletions(-) diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 652c0b61..46c32778 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -27,6 +27,7 @@ class RemoteEmbedder { static Option validate_string_properties(const nlohmann::json& model_config, const std::vector& properties); static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; + virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; public: virtual embedding_res_t Embed(const std::string& text) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; @@ -44,6 +45,7 @@ class OpenAIEmbedder : public RemoteEmbedder { std::string openai_model_path; static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); @@ -59,6 +61,7 @@ class GoogleEmbedder : public RemoteEmbedder { inline static constexpr short GOOGLE_EMBEDDING_DIM = 768; inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key="; std::string google_api_key; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); @@ -84,6 +87,7 @@ class GCPEmbedder : public RemoteEmbedder { static std::string get_gcp_embedding_url(const std::string& project_id, const std::string& model_name) { return GCP_EMBEDDING_BASE_URL + project_id + GCP_EMBEDDING_PATH + model_name + GCP_EMBEDDING_PREDICT; } + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index 4ed4db5a..518cd45b 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -44,7 +44,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth auto res = call(url, method, body, headers); - if(res.status_code == 408){ + if(res.status_code >= 500 || res.status_code == 408){ // retry res = call(url, method, body, headers); } @@ -57,7 +57,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth // add to cache - if(res.status_code != 408){ + if(res.status_code == 200){ cache.insert(key, res); } diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 4e844871..a54315e1 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -144,41 +144,13 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); if (res_code != 200) { - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from OpenAI API."; - - } - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); - } - if(res_code == 408) { - embedding_res["error"] = "OpenAI API timeout."; - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } try { embedding_res_t embedding_res = embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); return embedding_res; } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["error"] = "Malformed response from OpenAI API."; - - return embedding_res_t(500, embedding_res); + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } } @@ -196,46 +168,19 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector outputs; - - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from OpenAI API."; - } - LOG(INFO) << "OpenAI API error: " << json_res.dump(); - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["request"]["body"]["input"] = std::vector{inputs[0]}; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); - } - if(res_code == 408) { - embedding_res["error"] = "OpenAI API timeout."; - } - + nlohmann::json embedding_res = get_error_json(req_body, res_code, res); for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; outputs.push_back(embedding_res_t(res_code, embedding_res)); } return outputs; } + nlohmann::json res_json; try { res_json = nlohmann::json::parse(res); } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["request"]["body"]["input"] = std::vector{inputs[0]}; - embedding_res["error"] = "Malformed response from OpenAI API"; + nlohmann::json embedding_res = get_error_json(req_body, res_code, res); std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { embedding_res["request"]["body"]["input"][0] = inputs[i]; @@ -252,6 +197,36 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector 0 && embedding_res["request"]["body"]["input"].get>().size() > 1) { + auto vec = embedding_res["request"]["body"]["input"].get>(); + vec.resize(1); + embedding_res["request"]["body"]["input"] = vec; + } + if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { + embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get(); + } + if(res_code == 408) { + embedding_res["error"] = "OpenAI API timeout."; + } + + return embedding_res; +} + + GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) { } @@ -321,37 +296,13 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) { auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); if(res_code != 200) { - nlohmann::json json_res; - try { - nlohmann::json json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from Google API." - } - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get(); - } - if(res_code == 408) { - embedding_res["error"] = "Google API timeout."; - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } + try { return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get>()); } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING; - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["error"] = "Malformed response from Google API."; - return embedding_res_t(500, embedding_res); + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } } @@ -366,6 +317,30 @@ std::vector GoogleEmbedder::batch_embed(const std::vector(); + } + if(res_code == 408) { + embedding_res["error"] = "Google API timeout."; + } + + return embedding_res; +} + GCPEmbedder::GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) : @@ -473,46 +448,17 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { } if(res_code != 200) { - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - json_res = nlohmann::json::object(); - json_res["error"] = "Malformed response from GCP API."; - } - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["response"] = json_res; - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) { - embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get(); - } else { - embedding_res["error"] = "Malformed response from GCP API."; - } - - if(res_code == 408) { - embedding_res["error"] = "GCP API timeout."; - } - return embedding_res_t(res_code, embedding_res); + return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } nlohmann::json res_json; try { res_json = nlohmann::json::parse(res); } catch (const std::exception& e) { - nlohmann::json embedding_res = nlohmann::json::object(); - embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name); - embedding_res["request"]["method"] = "POST"; - embedding_res["request"]["body"] = req_body; - embedding_res["error"] = "Malformed response from GCP API."; - return embedding_res_t(500, embedding_res); + return embedding_res_t(500, get_error_json(req_body, res_code, res)); } return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get>()); } - std::vector GCPEmbedder::batch_embed(const std::vector& inputs) { // GCP API has a limit of 5 instances per request if(inputs.size() > 5) { @@ -556,28 +502,7 @@ std::vector GCPEmbedder::batch_embed(const std::vector(); - } else { - embedding_res["error"] = "Malformed response from GCP API."; - } - - if(res_code == 408) { - embedding_res["error"] = "GCP API timeout."; - } + auto embedding_res = get_error_json(req_body, res_code, res); std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(res_code, embedding_res)); @@ -588,12 +513,7 @@ std::vector GCPEmbedder::batch_embed(const std::vector outputs; for(size_t i = 0; i < inputs.size(); i++) { outputs.push_back(embedding_res_t(400, embedding_res)); @@ -608,6 +528,34 @@ std::vector GCPEmbedder::batch_embed(const std::vector(); + } else { + embedding_res["error"] = "Malformed response from GCP API."; + } + + if(res_code == 408) { + embedding_res["error"] = "GCP API timeout."; + } + + return embedding_res; +} + Option GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) { std::unordered_map headers; headers["Content-Type"] = "application/x-www-form-urlencoded"; @@ -643,5 +591,3 @@ Option GCPEmbedder::generate_access_token(const std::string& refres return Option(access_token); } - - From 10c0070fecd635b86f362aabc726540e2538bbc7 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 14:45:53 +0300 Subject: [PATCH 08/11] Add test for catching partial JSON --- include/text_embedder_remote.h | 9 ++++---- test/collection_test.cpp | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 46c32778..18892afa 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -27,8 +27,8 @@ class RemoteEmbedder { static Option validate_string_properties(const nlohmann::json& model_config, const std::vector& properties); static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map& headers, const std::unordered_map& req_headers); static inline ReplicationState* raft_server = nullptr; - virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; public: + virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; virtual embedding_res_t Embed(const std::string& text) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { @@ -45,12 +45,12 @@ class OpenAIEmbedder : public RemoteEmbedder { std::string openai_model_path; static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; - nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -61,12 +61,13 @@ class GoogleEmbedder : public RemoteEmbedder { inline static constexpr short GOOGLE_EMBEDDING_DIM = 768; inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key="; std::string google_api_key; - nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; + public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -87,13 +88,13 @@ class GCPEmbedder : public RemoteEmbedder { static std::string get_gcp_embedding_url(const std::string& project_id, const std::string& model_name) { return GCP_EMBEDDING_BASE_URL + project_id + GCP_EMBEDDING_PATH + model_name + GCP_EMBEDDING_PREDICT; } - nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; public: GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); embedding_res_t Embed(const std::string& text) override; std::vector batch_embed(const std::vector& inputs) override; + nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index e49e89f6..9e054d0a 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5176,4 +5176,45 @@ TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) { ASSERT_FALSE(get_op.get()["embedding"].is_null()); ASSERT_EQ(384, get_op.get()["embedding"].size()); +} + + +TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) { + std::string partial_json = R"({ + "results": [ + { + "embedding": [ + 0.0, + 0.0, + 0.0 + ], + "text": "butter" + }, + { + "embedding": [ + 0.0, + 0.0, + 0.0 + ], + "text": "butterball" + }, + { + "embedding": [ + 0.0, + 0.0)"; + + nlohmann::json req_body = R"({ + "inputs": [ + "butter", + "butterball", + "butterfly" + ] + })"_json; + + OpenAIEmbedder embedder("", ""); + + auto res = embedder.get_error_json(req_body, 200, partial_json); + + ASSERT_EQ(res["response"]["error"], "Malformed response from OpenAI API."); + ASSERT_EQ(res["request"]["body"], req_body); } \ No newline at end of file From 6f7efa5d729914c03a0e9c6e61ed465c651a8e00 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 17:01:29 +0300 Subject: [PATCH 09/11] Add timeout and num retries as search parameter for remote embedding --- include/collection.h | 4 +++- include/http_proxy.h | 6 ++++-- include/text_embedder.h | 2 +- include/text_embedder_remote.h | 8 +++---- src/collection.cpp | 6 ++++-- src/collection_manager.cpp | 13 +++++++++++- src/http_proxy.cpp | 39 ++++++++++++++++++++++++---------- src/text_embedder.cpp | 4 ++-- src/text_embedder_remote.cpp | 14 ++++++++---- test/core_api_utils_test.cpp | 23 ++++++++++++++++++++ 10 files changed, 91 insertions(+), 28 deletions(-) diff --git a/include/collection.h b/include/collection.h index 2f79bc68..ac987606 100644 --- a/include/collection.h +++ b/include/collection.h @@ -464,7 +464,9 @@ public: const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, const size_t page_offset = UINT32_MAX, - const size_t vector_query_hits = 250) const; + const size_t vector_query_hits = 250, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_retry = 2) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/http_proxy.h b/include/http_proxy.h index 851daa91..13519ef2 100644 --- a/include/http_proxy.h +++ b/include/http_proxy.h @@ -32,11 +32,13 @@ class HttpProxy { void operator=(const HttpProxy&) = delete; HttpProxy(HttpProxy&&) = delete; void operator=(HttpProxy&&) = delete; - http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers); + http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers); private: HttpProxy(); ~HttpProxy() = default; - http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map& headers = {}); + http_proxy_res_t call(const std::string& url, const std::string& method, + const std::string& body = "", const std::unordered_map& headers = {}, + const size_t timeout_ms = 30000); // lru cache for http requests diff --git a/include/text_embedder.h b/include/text_embedder.h index cff7c7c7..e1f73287 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_max_retries = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 18892afa..ffe8d6a0 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -29,7 +29,7 @@ class RemoteEmbedder { static inline ReplicationState* raft_server = nullptr; public: virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; - virtual embedding_res_t Embed(const std::string& text) = 0; + virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -65,7 +65,7 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -92,7 +92,7 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/src/collection.cpp b/src/collection.cpp index c55bee7d..b9bb6a3b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1107,7 +1107,9 @@ Option Collection::search(std::string raw_query, const size_t facet_sample_percent, const size_t facet_sample_threshold, const size_t page_offset, - const size_t vector_query_hits) const { + const size_t vector_query_hits, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_retry) const { std::shared_lock lock(mutex); @@ -1236,7 +1238,7 @@ Option Collection::search(std::string raw_query, } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query); + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retry); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 5323e1b0..91a348b7 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -671,6 +671,9 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY = "vector_query"; const char *VECTOR_QUERY_HITS = "vector_query_hits"; + const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; + const char* REMOTE_EMBEDDING_NUM_RETRY = "remote_embedding_num_retry"; + const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -824,6 +827,9 @@ Option CollectionManager::do_search(std::map& re text_match_type_t match_type = max_score; size_t vector_query_hits = 250; + size_t remote_embedding_timeout_ms = 30000; + size_t remote_embedding_num_retry = 2; + size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -850,6 +856,8 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, + {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, + {REMOTE_EMBEDDING_NUM_RETRY, &remote_embedding_num_retry}, }; std::unordered_map str_values = { @@ -1070,7 +1078,10 @@ Option CollectionManager::do_search(std::map& re match_type, facet_sample_percent, facet_sample_threshold, - offset + offset, + vector_query_hits, + remote_embedding_timeout_ms, + remote_embedding_num_retry ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index 518cd45b..23acdef5 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -8,17 +8,19 @@ HttpProxy::HttpProxy() : cache(30s){ } -http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, + const std::string& body, const std::unordered_map& headers, + const size_t timeout_ms) { HttpClient& client = HttpClient::get_instance(); http_proxy_res_t res; if(method == "GET") { - res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms); } else if(method == "POST") { - res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000); + res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms); } else if(method == "PUT") { - res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000); + res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms); } else if(method == "DELETE") { - res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000); + res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms); } else { res.status_code = 400; nlohmann::json j; @@ -29,11 +31,25 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth } -http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map& headers) { +http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map& headers) { // check if url is in cache uint64_t key = StringUtils::hash_wy(url.c_str(), url.size()); key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size())); key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); + + size_t timeout_ms = 30000; + size_t num_retry = 2; + + if(headers.find("timeout_ms") != headers.end()){ + timeout_ms = std::stoul(headers.at("timeout_ms")); + headers.erase("timeout_ms"); + } + + if(headers.find("num_retry") != headers.end()){ + num_retry = std::stoul(headers.at("num_retry")); + headers.erase("num_retry"); + } + for(auto& header : headers){ key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size())); key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size())); @@ -42,11 +58,13 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth return cache[key]; } - auto res = call(url, method, body, headers); + http_proxy_res_t res; + for(size_t i = 0; i < num_retry; i++){ + res = call(url, method, body, headers, timeout_ms); - if(res.status_code >= 500 || res.status_code == 408){ - // retry - res = call(url, method, body, headers); + if(res.status_code != 408 && res.status_code < 500){ + break; + } } if(res.status_code == 408){ @@ -54,7 +72,6 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth j["message"] = "Server error on remote server. Please try again later."; res.body = j.dump(); } - // add to cache if(res.status_code == 200){ diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index b2a67607..0e1e9385 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text); + return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedder_max_retries); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index a54315e1..0155a1ac 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -132,14 +132,16 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_retry"] = std::to_string(remote_embedder_num_retry); std::string res; nlohmann::json req_body; - req_body["input"] = text; + req_body["input"] = std::vector{text}; // remove "openai/" prefix req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path); auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); @@ -285,10 +287,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_retry"] = std::to_string(remote_embedder_num_retry); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -418,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -427,6 +431,8 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) { std::unordered_map headers; headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; + headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); + headers["num_retry"] = std::to_string(remote_embedder_num_retry); std::map res_headers; std::string res; diff --git a/test/core_api_utils_test.cpp b/test/core_api_utils_test.cpp index 95af7546..57e88db3 100644 --- a/test/core_api_utils_test.cpp +++ b/test/core_api_utils_test.cpp @@ -1116,4 +1116,27 @@ TEST_F(CoreAPIUtilsTest, TestProxyInvalid) { ASSERT_EQ(400, resp->status_code); ASSERT_EQ("Headers must be a JSON object.", nlohmann::json::parse(resp->body)["message"]); +} + + + +TEST_F(CoreAPIUtilsTest, TestProxyTimeout) { + nlohmann::json body; + + auto req = std::make_shared(); + auto resp = std::make_shared(nullptr); + + // test with url as empty string + body["url"] = "https://typesense.org/docs/"; + body["method"] = "GET"; + body["headers"] = nlohmann::json::object(); + body["headers"]["timeout_ms"] = "1"; + body["headers"]["num_retry"] = "1"; + + req->body = body.dump(); + + post_proxy(req, resp); + + ASSERT_EQ(408, resp->status_code); + ASSERT_EQ("Server error on remote server. Please try again later.", nlohmann::json::parse(resp->body)["message"]); } \ No newline at end of file From 8b1aa13ffe9de938786a28e9dfe8042829cf3105 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 17:11:25 +0300 Subject: [PATCH 10/11] Change param name to 'remote_embedding_num_try' --- include/collection.h | 2 +- include/text_embedder.h | 2 +- src/collection.cpp | 9 +++++++-- src/collection_manager.cpp | 8 ++++---- src/http_proxy.cpp | 10 +++++----- src/text_embedder_remote.cpp | 12 ++++++------ 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/include/collection.h b/include/collection.h index ac987606..dcc3a50b 100644 --- a/include/collection.h +++ b/include/collection.h @@ -466,7 +466,7 @@ public: const size_t page_offset = UINT32_MAX, const size_t vector_query_hits = 250, const size_t remote_embedding_timeout_ms = 30000, - const size_t remote_embedding_num_retry = 2) const; + const size_t remote_embedding_num_try = 2) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/text_embedder.h b/include/text_embedder.h index e1f73287..520c7c1d 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_max_retries = 2); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_try = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/src/collection.cpp b/src/collection.cpp index b9bb6a3b..c650aaea 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1109,7 +1109,7 @@ Option Collection::search(std::string raw_query, const size_t page_offset, const size_t vector_query_hits, const size_t remote_embedding_timeout_ms, - const size_t remote_embedding_num_retry) const { + const size_t remote_embedding_num_try) const { std::shared_lock lock(mutex); @@ -1235,10 +1235,15 @@ Option Collection::search(std::string raw_query, std::string error = "Prefix search is not supported for remote embedders. Please set `prefix=false` as an additional search parameter to disable prefix searching."; return Option(400, error); } + + if(remote_embedding_num_try == 0) { + std::string error = "`remote-embedding-num-try` must be greater than 0."; + return Option(400, error); + } } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; - auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_retry); + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_try); if(!embedding_op.success) { if(!embedding_op.error["error"].get().empty()) { return Option(400, embedding_op.error["error"].get()); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index 91a348b7..aa2c26dc 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -672,7 +672,7 @@ Option CollectionManager::do_search(std::map& re const char *VECTOR_QUERY_HITS = "vector_query_hits"; const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; - const char* REMOTE_EMBEDDING_NUM_RETRY = "remote_embedding_num_retry"; + const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try"; const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -828,7 +828,7 @@ Option CollectionManager::do_search(std::map& re size_t vector_query_hits = 250; size_t remote_embedding_timeout_ms = 30000; - size_t remote_embedding_num_retry = 2; + size_t remote_embedding_num_try = 2; size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -857,7 +857,7 @@ Option CollectionManager::do_search(std::map& re {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, {VECTOR_QUERY_HITS, &vector_query_hits}, {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, - {REMOTE_EMBEDDING_NUM_RETRY, &remote_embedding_num_retry}, + {REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try}, }; std::unordered_map str_values = { @@ -1081,7 +1081,7 @@ Option CollectionManager::do_search(std::map& re offset, vector_query_hits, remote_embedding_timeout_ms, - remote_embedding_num_retry + remote_embedding_num_try ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/src/http_proxy.cpp b/src/http_proxy.cpp index 23acdef5..3134841a 100644 --- a/src/http_proxy.cpp +++ b/src/http_proxy.cpp @@ -38,16 +38,16 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size())); size_t timeout_ms = 30000; - size_t num_retry = 2; + size_t num_try = 2; if(headers.find("timeout_ms") != headers.end()){ timeout_ms = std::stoul(headers.at("timeout_ms")); headers.erase("timeout_ms"); } - if(headers.find("num_retry") != headers.end()){ - num_retry = std::stoul(headers.at("num_retry")); - headers.erase("num_retry"); + if(headers.find("num_try") != headers.end()){ + num_try = std::stoul(headers.at("num_try")); + headers.erase("num_try"); } for(auto& header : headers){ @@ -59,7 +59,7 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth } http_proxy_res_t res; - for(size_t i = 0; i < num_retry; i++){ + for(size_t i = 0; i < num_try; i++){ res = call(url, method, body, headers, timeout_ms); if(res.status_code != 408 && res.status_code < 500){ diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 0155a1ac..13a73575 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -132,13 +132,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_retry"] = std::to_string(remote_embedder_num_retry); + headers["num_try"] = std::to_string(remote_embedder_num_try); std::string res; nlohmann::json req_body; req_body["input"] = std::vector{text}; @@ -287,12 +287,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_retry"] = std::to_string(remote_embedder_num_retry); + headers["num_try"] = std::to_string(remote_embedder_num_try); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -422,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_retry) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -432,7 +432,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_ headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_retry"] = std::to_string(remote_embedder_num_retry); + headers["num_try"] = std::to_string(remote_embedder_num_try); std::map res_headers; std::string res; From a9de01f16184bf2a42ed1ec018daa7e3c5931647 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Jul 2023 18:33:53 +0300 Subject: [PATCH 11/11] Standardize variable names --- include/text_embedder.h | 2 +- include/text_embedder_remote.h | 8 ++++---- src/text_embedder.cpp | 4 ++-- src/text_embedder_remote.cpp | 12 ++++++------ 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/text_embedder.h b/include/text_embedder.h index 520c7c1d..e8f1de57 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -16,7 +16,7 @@ class TextEmbedder { // Constructor for remote models TextEmbedder(const nlohmann::json& model_config); ~TextEmbedder(); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_try = 2); + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2); std::vector batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; bool is_remote() { diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index ffe8d6a0..b72b26d5 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -29,7 +29,7 @@ class RemoteEmbedder { static inline ReplicationState* raft_server = nullptr; public: virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0; - virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) = 0; + virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) = 0; virtual std::vector batch_embed(const std::vector& inputs) = 0; static void init(ReplicationState* rs) { raft_server = rs; @@ -48,7 +48,7 @@ class OpenAIEmbedder : public RemoteEmbedder { public: OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -65,7 +65,7 @@ class GoogleEmbedder : public RemoteEmbedder { public: GoogleEmbedder(const std::string& google_api_key); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; @@ -92,7 +92,7 @@ class GCPEmbedder : public RemoteEmbedder { GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token, const std::string& refresh_token, const std::string& client_id, const std::string& client_secret); static Option is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims); - embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedder_num_retry = 2) override; + embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override; std::vector batch_embed(const std::vector& inputs) override; nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override; }; diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 0e1e9385..2055ea77 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -83,9 +83,9 @@ std::vector TextEmbedder::mean_pooling(const std::vectorEmbed(text, remote_embedder_timeout_ms, remote_embedder_max_retries); + return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedding_num_try); } else { // Cannot run same model in parallel, so lock the mutex std::lock_guard lock(mutex_); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 13a73575..f93c6713 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -132,13 +132,13 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { +embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { std::unordered_map headers; std::map res_headers; headers["Authorization"] = "Bearer " + api_key; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_try"] = std::to_string(remote_embedder_num_try); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::string res; nlohmann::json req_body; req_body["input"] = std::vector{text}; @@ -287,12 +287,12 @@ Option GoogleEmbedder::is_model_valid(const nlohmann::json& model_config, return Option(true); } -embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { +embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { std::unordered_map headers; std::map res_headers; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_try"] = std::to_string(remote_embedder_num_try); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::string res; nlohmann::json req_body; req_body["text"] = text; @@ -422,7 +422,7 @@ Option GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns return Option(true); } -embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedder_num_try) { +embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) { nlohmann::json req_body; req_body["instances"] = nlohmann::json::array(); nlohmann::json instance; @@ -432,7 +432,7 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_ headers["Authorization"] = "Bearer " + access_token; headers["Content-Type"] = "application/json"; headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms); - headers["num_try"] = std::to_string(remote_embedder_num_try); + headers["num_try"] = std::to_string(remote_embedding_num_try); std::map res_headers; std::string res;