Refactor remoted embedding code.

This commit is contained in:
Kishore Nallan 2023-07-11 12:42:06 +05:30
parent 04fd68b1a1
commit 01801f4ea9
8 changed files with 83 additions and 74 deletions

View File

@ -407,7 +407,7 @@ public:
nlohmann::json add_many(std::vector<std::string>& json_lines, nlohmann::json& document,
const index_operation_t& operation=CREATE, const std::string& id="",
const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT,
const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=100);
const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=200);
Option<nlohmann::json> update_matching_filter(const std::string& filter_query,
const std::string & json_str,

View File

@ -24,6 +24,9 @@ struct http_proxy_res_t {
class HttpProxy {
// singleton class for http proxy
public:
static const size_t default_timeout_ms = 30000*4;
static const size_t default_num_try = 2;
static HttpProxy& get_instance() {
static HttpProxy instance;
return instance;
@ -32,13 +35,13 @@ class HttpProxy {
void operator=(const HttpProxy&) = delete;
HttpProxy(HttpProxy&&) = delete;
void operator=(HttpProxy&&) = delete;
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map<std::string, std::string>& headers);
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& req_body, std::unordered_map<std::string, std::string>& req_headers);
private:
HttpProxy();
~HttpProxy() = default;
http_proxy_res_t call(const std::string& url, const std::string& method,
const std::string& body = "", const std::unordered_map<std::string, std::string>& headers = {},
const size_t timeout_ms = 30000);
http_proxy_res_t call(const std::string& url, const std::string& method,
const std::string& req_body = "", const std::unordered_map<std::string, std::string>& req_headers = {},
const size_t timeout_ms = default_timeout_ms);
// lru cache for http requests

View File

@ -25,7 +25,7 @@ struct embedding_res_t {
class RemoteEmbedder {
protected:
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers);
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
static inline ReplicationState* raft_server = nullptr;
public:
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;

View File

@ -3825,7 +3825,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
}
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
fallback_field_type, token_separators, symbols_to_index, true, 100, false);
fallback_field_type, token_separators, symbols_to_index, true, 200, false);
iter_batch.clear();
}

View File

@ -153,13 +153,25 @@ long HttpClient::perform_curl(CURL *curl, std::map<std::string, std::string>& re
if (res != CURLE_OK) {
char* url = nullptr;
curl_easy_getinfo(curl, CURLINFO_EFFECTIVE_URL, &url);
LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res);
long status_code = 0;
if(res == CURLE_OPERATION_TIMEDOUT) {
double total_time;
char* http_method;
curl_easy_getinfo(curl, CURLINFO_TOTAL_TIME, &total_time);
curl_easy_getinfo(curl, CURLINFO_EFFECTIVE_METHOD, http_method);
LOG(ERROR) << "CURL timeout. Time taken: " << total_time << ", URL: " << http_method << " " << url;
status_code = 408;
} else {
LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res);
status_code = 500;
}
curl_easy_cleanup(curl);
curl_slist_free_all(chunk);
if(res == CURLE_OPERATION_TIMEDOUT) {
return 408;
}
return 500;
return status_code;
}
long http_code = 500;

View File

@ -8,17 +8,18 @@ HttpProxy::HttpProxy() : cache(30s){
}
http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method,
const std::string& body, const std::unordered_map<std::string, std::string>& headers,
http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method,
const std::string& req_body,
const std::unordered_map<std::string, std::string>& req_headers,
const size_t timeout_ms) {
HttpClient& client = HttpClient::get_instance();
http_proxy_res_t res;
if(method == "GET") {
res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms);
res.status_code = client.get_response(url, res.body, res.headers, req_headers, timeout_ms);
} else if(method == "POST") {
res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms);
res.status_code = client.post_response(url, req_body, res.body, res.headers, req_headers, timeout_ms);
} else if(method == "PUT") {
res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms);
res.status_code = client.put_response(url, req_body, res.body, res.headers, timeout_ms);
} else if(method == "DELETE") {
res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms);
} else {
@ -31,26 +32,27 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth
}
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map<std::string, std::string>& headers) {
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& req_body,
std::unordered_map<std::string, std::string>& req_headers) {
// check if url is in cache
uint64_t key = StringUtils::hash_wy(url.c_str(), url.size());
key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size()));
key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size()));
key = StringUtils::hash_combine(key, StringUtils::hash_wy(req_body.c_str(), req_body.size()));
size_t timeout_ms = 30000;
size_t num_try = 2;
size_t timeout_ms = default_timeout_ms;
size_t num_try = default_num_try;
if(headers.find("timeout_ms") != headers.end()){
timeout_ms = std::stoul(headers.at("timeout_ms"));
headers.erase("timeout_ms");
if(req_headers.find("timeout_ms") != req_headers.end()){
timeout_ms = std::stoul(req_headers.at("timeout_ms"));
req_headers.erase("timeout_ms");
}
if(headers.find("num_try") != headers.end()){
num_try = std::stoul(headers.at("num_try"));
headers.erase("num_try");
if(req_headers.find("num_try") != req_headers.end()){
num_try = std::stoul(req_headers.at("num_try"));
req_headers.erase("num_try");
}
for(auto& header : headers){
for(auto& header : req_headers){
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size()));
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size()));
}
@ -60,11 +62,13 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth
http_proxy_res_t res;
for(size_t i = 0; i < num_try; i++){
res = call(url, method, body, headers, timeout_ms);
res = call(url, method, req_body, req_headers, timeout_ms);
if(res.status_code != 408 && res.status_code < 500){
break;
}
LOG(ERROR) << "Proxy call failed, status_code: " << res.status_code << ", num_try: " << num_try;
}
if(res.status_code == 408){

View File

@ -970,37 +970,12 @@ std::string ReplicationState::get_leader_url() const {
std::shared_lock lock(node_mutex);
if(!node) {
const Option<std::string> & refreshed_nodes_op = Config::fetch_nodes_config(config->get_nodes());
if(!refreshed_nodes_op.ok()) {
LOG(WARNING) << "Error while fetching peer configuration: " << refreshed_nodes_op.error();
return "";
}
const std::string& nodes_config = ReplicationState::to_nodes_config(peering_endpoint,
Config::get_instance().get_api_port(),
refreshed_nodes_op.get());
std::vector<braft::PeerId> peers;
braft::Configuration peer_config;
peer_config.parse_from(nodes_config);
peer_config.list_peers(&peers);
if(peers.empty()) {
LOG(WARNING) << "No peers found in nodes config: " << nodes_config;
return "";
}
const std::string protocol = api_uses_ssl ? "https" : "http";
std::string url = get_node_url_path(peers[0].to_string(), "/", protocol);
LOG(INFO) << "Returning first peer as leader URL: " << url;
return url;
LOG(ERROR) << "Could not get leader url as node is not initialized!";
return "";
}
if(node->leader_id().is_empty()) {
LOG(ERROR) << "Could not get leader status, as node does not have a leader!";
LOG(ERROR) << "Could not get leader url, as node does not have a leader!";
return "";
}

View File

@ -1,3 +1,4 @@
#include <http_proxy.h>
#include "text_embedder_remote.h"
#include "text_embedder_manager.h"
@ -11,35 +12,49 @@ Option<bool> RemoteEmbedder::validate_string_properties(const nlohmann::json& mo
return Option<bool>(true);
}
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body,
std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers) {
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body,
std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers) {
if(raft_server == nullptr || raft_server->get_leader_url().empty()) {
if(method == "GET") {
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 45000, true);
} else if(method == "POST") {
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 45000, true);
// call proxy's internal send() directly
if(method == "GET" || method == "POST") {
auto proxy_res = HttpProxy::get_instance().send(url, method, req_body, req_headers);
res_body = std::move(proxy_res.body);
res_headers = std::move(proxy_res.headers);
return proxy_res.status_code;
} else {
return 400;
}
}
auto leader_url = raft_server->get_leader_url();
leader_url += "proxy";
nlohmann::json req_body;
req_body["method"] = method;
req_body["url"] = url;
req_body["body"] = body;
req_body["headers"] = req_headers;
return HttpClient::get_instance().post_response(leader_url, req_body.dump(), res_body, headers, {}, 45000, true);
auto proxy_url = raft_server->get_leader_url() + "proxy";
nlohmann::json proxy_req_body;
proxy_req_body["method"] = method;
proxy_req_body["url"] = url;
proxy_req_body["body"] = req_body;
proxy_req_body["headers"] = req_headers;
size_t per_call_timeout_ms = HttpProxy::default_timeout_ms;
size_t num_try = HttpProxy::default_num_try;
if(res_headers.find("timeout_ms") != res_headers.end()){
per_call_timeout_ms = std::stoul(res_headers.at("timeout_ms"));
}
if(res_headers.find("num_try") != res_headers.end()){
num_try = std::stoul(res_headers.at("num_try"));
}
size_t proxy_call_timeout_ms = (per_call_timeout_ms * num_try) + 1000;
return HttpClient::get_instance().post_response(proxy_url, proxy_req_body.dump(), res_body, res_headers, {},
proxy_call_timeout_ms, true);
}
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
}
Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims) {
auto validate_properties = validate_string_properties(model_config, {"model_name", "api_key"});