mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 13:42:26 +08:00
Refactor remoted embedding code.
This commit is contained in:
parent
04fd68b1a1
commit
01801f4ea9
@ -407,7 +407,7 @@ public:
|
||||
nlohmann::json add_many(std::vector<std::string>& json_lines, nlohmann::json& document,
|
||||
const index_operation_t& operation=CREATE, const std::string& id="",
|
||||
const DIRTY_VALUES& dirty_values=DIRTY_VALUES::COERCE_OR_REJECT,
|
||||
const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=100);
|
||||
const bool& return_doc=false, const bool& return_id=false, const size_t remote_embedding_batch_size=200);
|
||||
|
||||
Option<nlohmann::json> update_matching_filter(const std::string& filter_query,
|
||||
const std::string & json_str,
|
||||
|
@ -24,6 +24,9 @@ struct http_proxy_res_t {
|
||||
class HttpProxy {
|
||||
// singleton class for http proxy
|
||||
public:
|
||||
static const size_t default_timeout_ms = 30000*4;
|
||||
static const size_t default_num_try = 2;
|
||||
|
||||
static HttpProxy& get_instance() {
|
||||
static HttpProxy instance;
|
||||
return instance;
|
||||
@ -32,13 +35,13 @@ class HttpProxy {
|
||||
void operator=(const HttpProxy&) = delete;
|
||||
HttpProxy(HttpProxy&&) = delete;
|
||||
void operator=(HttpProxy&&) = delete;
|
||||
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map<std::string, std::string>& headers);
|
||||
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& req_body, std::unordered_map<std::string, std::string>& req_headers);
|
||||
private:
|
||||
HttpProxy();
|
||||
~HttpProxy() = default;
|
||||
http_proxy_res_t call(const std::string& url, const std::string& method,
|
||||
const std::string& body = "", const std::unordered_map<std::string, std::string>& headers = {},
|
||||
const size_t timeout_ms = 30000);
|
||||
http_proxy_res_t call(const std::string& url, const std::string& method,
|
||||
const std::string& req_body = "", const std::unordered_map<std::string, std::string>& req_headers = {},
|
||||
const size_t timeout_ms = default_timeout_ms);
|
||||
|
||||
|
||||
// lru cache for http requests
|
||||
|
@ -25,7 +25,7 @@ struct embedding_res_t {
|
||||
class RemoteEmbedder {
|
||||
protected:
|
||||
static Option<bool> validate_string_properties(const nlohmann::json& model_config, const std::vector<std::string>& properties);
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers);
|
||||
static long call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body, std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers);
|
||||
static inline ReplicationState* raft_server = nullptr;
|
||||
public:
|
||||
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
|
||||
|
@ -3825,7 +3825,7 @@ Option<bool> Collection::batch_alter_data(const std::vector<field>& alter_fields
|
||||
}
|
||||
|
||||
Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields,
|
||||
fallback_field_type, token_separators, symbols_to_index, true, 100, false);
|
||||
fallback_field_type, token_separators, symbols_to_index, true, 200, false);
|
||||
|
||||
iter_batch.clear();
|
||||
}
|
||||
|
@ -153,13 +153,25 @@ long HttpClient::perform_curl(CURL *curl, std::map<std::string, std::string>& re
|
||||
if (res != CURLE_OK) {
|
||||
char* url = nullptr;
|
||||
curl_easy_getinfo(curl, CURLINFO_EFFECTIVE_URL, &url);
|
||||
LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res);
|
||||
|
||||
long status_code = 0;
|
||||
|
||||
if(res == CURLE_OPERATION_TIMEDOUT) {
|
||||
double total_time;
|
||||
char* http_method;
|
||||
curl_easy_getinfo(curl, CURLINFO_TOTAL_TIME, &total_time);
|
||||
curl_easy_getinfo(curl, CURLINFO_EFFECTIVE_METHOD, http_method);
|
||||
LOG(ERROR) << "CURL timeout. Time taken: " << total_time << ", URL: " << http_method << " " << url;
|
||||
status_code = 408;
|
||||
} else {
|
||||
LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res);
|
||||
status_code = 500;
|
||||
}
|
||||
|
||||
curl_easy_cleanup(curl);
|
||||
curl_slist_free_all(chunk);
|
||||
if(res == CURLE_OPERATION_TIMEDOUT) {
|
||||
return 408;
|
||||
}
|
||||
return 500;
|
||||
|
||||
return status_code;
|
||||
}
|
||||
|
||||
long http_code = 500;
|
||||
|
@ -8,17 +8,18 @@ HttpProxy::HttpProxy() : cache(30s){
|
||||
}
|
||||
|
||||
|
||||
http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method,
|
||||
const std::string& body, const std::unordered_map<std::string, std::string>& headers,
|
||||
http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method,
|
||||
const std::string& req_body,
|
||||
const std::unordered_map<std::string, std::string>& req_headers,
|
||||
const size_t timeout_ms) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
http_proxy_res_t res;
|
||||
if(method == "GET") {
|
||||
res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms);
|
||||
res.status_code = client.get_response(url, res.body, res.headers, req_headers, timeout_ms);
|
||||
} else if(method == "POST") {
|
||||
res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms);
|
||||
res.status_code = client.post_response(url, req_body, res.body, res.headers, req_headers, timeout_ms);
|
||||
} else if(method == "PUT") {
|
||||
res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms);
|
||||
res.status_code = client.put_response(url, req_body, res.body, res.headers, timeout_ms);
|
||||
} else if(method == "DELETE") {
|
||||
res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms);
|
||||
} else {
|
||||
@ -31,26 +32,27 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth
|
||||
}
|
||||
|
||||
|
||||
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map<std::string, std::string>& headers) {
|
||||
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& req_body,
|
||||
std::unordered_map<std::string, std::string>& req_headers) {
|
||||
// check if url is in cache
|
||||
uint64_t key = StringUtils::hash_wy(url.c_str(), url.size());
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size()));
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size()));
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(req_body.c_str(), req_body.size()));
|
||||
|
||||
size_t timeout_ms = 30000;
|
||||
size_t num_try = 2;
|
||||
size_t timeout_ms = default_timeout_ms;
|
||||
size_t num_try = default_num_try;
|
||||
|
||||
if(headers.find("timeout_ms") != headers.end()){
|
||||
timeout_ms = std::stoul(headers.at("timeout_ms"));
|
||||
headers.erase("timeout_ms");
|
||||
if(req_headers.find("timeout_ms") != req_headers.end()){
|
||||
timeout_ms = std::stoul(req_headers.at("timeout_ms"));
|
||||
req_headers.erase("timeout_ms");
|
||||
}
|
||||
|
||||
if(headers.find("num_try") != headers.end()){
|
||||
num_try = std::stoul(headers.at("num_try"));
|
||||
headers.erase("num_try");
|
||||
if(req_headers.find("num_try") != req_headers.end()){
|
||||
num_try = std::stoul(req_headers.at("num_try"));
|
||||
req_headers.erase("num_try");
|
||||
}
|
||||
|
||||
for(auto& header : headers){
|
||||
for(auto& header : req_headers){
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size()));
|
||||
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size()));
|
||||
}
|
||||
@ -60,11 +62,13 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth
|
||||
|
||||
http_proxy_res_t res;
|
||||
for(size_t i = 0; i < num_try; i++){
|
||||
res = call(url, method, body, headers, timeout_ms);
|
||||
res = call(url, method, req_body, req_headers, timeout_ms);
|
||||
|
||||
if(res.status_code != 408 && res.status_code < 500){
|
||||
break;
|
||||
}
|
||||
|
||||
LOG(ERROR) << "Proxy call failed, status_code: " << res.status_code << ", num_try: " << num_try;
|
||||
}
|
||||
|
||||
if(res.status_code == 408){
|
||||
|
@ -970,37 +970,12 @@ std::string ReplicationState::get_leader_url() const {
|
||||
std::shared_lock lock(node_mutex);
|
||||
|
||||
if(!node) {
|
||||
const Option<std::string> & refreshed_nodes_op = Config::fetch_nodes_config(config->get_nodes());
|
||||
|
||||
if(!refreshed_nodes_op.ok()) {
|
||||
LOG(WARNING) << "Error while fetching peer configuration: " << refreshed_nodes_op.error();
|
||||
return "";
|
||||
}
|
||||
|
||||
const std::string& nodes_config = ReplicationState::to_nodes_config(peering_endpoint,
|
||||
Config::get_instance().get_api_port(),
|
||||
|
||||
refreshed_nodes_op.get());
|
||||
std::vector<braft::PeerId> peers;
|
||||
braft::Configuration peer_config;
|
||||
peer_config.parse_from(nodes_config);
|
||||
peer_config.list_peers(&peers);
|
||||
|
||||
if(peers.empty()) {
|
||||
LOG(WARNING) << "No peers found in nodes config: " << nodes_config;
|
||||
return "";
|
||||
}
|
||||
|
||||
|
||||
const std::string protocol = api_uses_ssl ? "https" : "http";
|
||||
std::string url = get_node_url_path(peers[0].to_string(), "/", protocol);
|
||||
|
||||
LOG(INFO) << "Returning first peer as leader URL: " << url;
|
||||
return url;
|
||||
LOG(ERROR) << "Could not get leader url as node is not initialized!";
|
||||
return "";
|
||||
}
|
||||
|
||||
if(node->leader_id().is_empty()) {
|
||||
LOG(ERROR) << "Could not get leader status, as node does not have a leader!";
|
||||
LOG(ERROR) << "Could not get leader url, as node does not have a leader!";
|
||||
return "";
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <http_proxy.h>
|
||||
#include "text_embedder_remote.h"
|
||||
#include "text_embedder_manager.h"
|
||||
|
||||
@ -11,35 +12,49 @@ Option<bool> RemoteEmbedder::validate_string_properties(const nlohmann::json& mo
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body,
|
||||
std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers) {
|
||||
long RemoteEmbedder::call_remote_api(const std::string& method, const std::string& url, const std::string& req_body, std::string& res_body,
|
||||
std::map<std::string, std::string>& res_headers, std::unordered_map<std::string, std::string>& req_headers) {
|
||||
|
||||
if(raft_server == nullptr || raft_server->get_leader_url().empty()) {
|
||||
if(method == "GET") {
|
||||
return HttpClient::get_instance().get_response(url, res_body, headers, req_headers, 45000, true);
|
||||
} else if(method == "POST") {
|
||||
return HttpClient::get_instance().post_response(url, body, res_body, headers, req_headers, 45000, true);
|
||||
// call proxy's internal send() directly
|
||||
if(method == "GET" || method == "POST") {
|
||||
auto proxy_res = HttpProxy::get_instance().send(url, method, req_body, req_headers);
|
||||
res_body = std::move(proxy_res.body);
|
||||
res_headers = std::move(proxy_res.headers);
|
||||
return proxy_res.status_code;
|
||||
} else {
|
||||
return 400;
|
||||
}
|
||||
}
|
||||
|
||||
auto leader_url = raft_server->get_leader_url();
|
||||
leader_url += "proxy";
|
||||
nlohmann::json req_body;
|
||||
req_body["method"] = method;
|
||||
req_body["url"] = url;
|
||||
req_body["body"] = body;
|
||||
req_body["headers"] = req_headers;
|
||||
return HttpClient::get_instance().post_response(leader_url, req_body.dump(), res_body, headers, {}, 45000, true);
|
||||
auto proxy_url = raft_server->get_leader_url() + "proxy";
|
||||
nlohmann::json proxy_req_body;
|
||||
proxy_req_body["method"] = method;
|
||||
proxy_req_body["url"] = url;
|
||||
proxy_req_body["body"] = req_body;
|
||||
proxy_req_body["headers"] = req_headers;
|
||||
|
||||
size_t per_call_timeout_ms = HttpProxy::default_timeout_ms;
|
||||
size_t num_try = HttpProxy::default_num_try;
|
||||
|
||||
if(res_headers.find("timeout_ms") != res_headers.end()){
|
||||
per_call_timeout_ms = std::stoul(res_headers.at("timeout_ms"));
|
||||
}
|
||||
|
||||
if(res_headers.find("num_try") != res_headers.end()){
|
||||
num_try = std::stoul(res_headers.at("num_try"));
|
||||
}
|
||||
|
||||
size_t proxy_call_timeout_ms = (per_call_timeout_ms * num_try) + 1000;
|
||||
|
||||
return HttpClient::get_instance().post_response(proxy_url, proxy_req_body.dump(), res_body, res_headers, {},
|
||||
proxy_call_timeout_ms, true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims) {
|
||||
auto validate_properties = validate_string_properties(model_config, {"model_name", "api_key"});
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user