Merge branch 'v0.25-join' into v0.26-facets

# Conflicts:
#	include/collection.h
#	src/collection.cpp
This commit is contained in:
Kishore Nallan 2023-07-08 09:27:45 +05:30
commit 7a23770a47
20 changed files with 572 additions and 151 deletions

View File

@ -24,10 +24,11 @@ private:
void to_json(nlohmann::json& obj) const {
obj["name"] = name;
obj["type"] = POPULAR_QUERIES_TYPE;
obj["params"] = nlohmann::json::object();
obj["params"]["suggestion_collection"] = suggestion_collection;
obj["params"]["query_collections"] = query_collections;
obj["params"]["limit"] = limit;
obj["params"]["source"]["collections"] = query_collections;
obj["params"]["destination"]["collection"] = suggestion_collection;
}
};
@ -48,7 +49,9 @@ private:
Option<bool> remove_popular_queries_index(const std::string& name);
Option<bool> create_popular_queries_index(nlohmann::json &payload, bool write_to_disk);
Option<bool> create_popular_queries_index(nlohmann::json &payload,
bool upsert,
bool write_to_disk);
public:
@ -69,12 +72,14 @@ public:
Option<nlohmann::json> list_rules();
Option<bool> create_rule(nlohmann::json& payload, bool write_to_disk = true);
Option<nlohmann::json> get_rule(const std::string& name);
Option<bool> create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk);
Option<bool> remove_rule(const std::string& name);
void add_suggestion(const std::string& query_collection,
std::string& query, const bool live_query, const std::string& user_id);
std::string& query, bool live_query, const std::string& user_id);
void stop();

View File

@ -465,7 +465,9 @@ public:
const size_t facet_sample_threshold = 0,
const size_t page_offset = 0,
facet_index_type_t facet_index_type = HASH,
const size_t vector_query_hits = 250) const;
const size_t vector_query_hits = 250,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_try = 2) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const;

View File

@ -147,8 +147,12 @@ bool post_create_event(const std::shared_ptr<http_req>& req, const std::shared_p
bool get_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
bool get_analytics_rule(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
bool post_create_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
bool put_upsert_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
bool del_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res);
// Misc helpers

View File

@ -32,11 +32,13 @@ class HttpProxy {
void operator=(const HttpProxy&) = delete;
HttpProxy(HttpProxy&&) = delete;
void operator=(HttpProxy&&) = delete;
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map<std::string, std::string>& headers);
http_proxy_res_t send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map<std::string, std::string>& headers);
private:
HttpProxy();
~HttpProxy() = default;
http_proxy_res_t call(const std::string& url, const std::string& method, const std::string& body = "", const std::unordered_map<std::string, std::string>& headers = {});
http_proxy_res_t call(const std::string& url, const std::string& method,
const std::string& body = "", const std::unordered_map<std::string, std::string>& headers = {},
const size_t timeout_ms = 30000);
// lru cache for http requests

View File

@ -16,7 +16,7 @@ class TextEmbedder {
// Constructor for remote models
TextEmbedder(const nlohmann::json& model_config);
~TextEmbedder();
embedding_res_t Embed(const std::string& text);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2);
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs);
const std::string& get_vocab_file_name() const;
bool is_remote() {

View File

@ -28,7 +28,8 @@ class RemoteEmbedder {
static long call_remote_api(const std::string& method, const std::string& url, const std::string& body, std::string& res_body, std::map<std::string, std::string>& headers, const std::unordered_map<std::string, std::string>& req_headers);
static inline ReplicationState* raft_server = nullptr;
public:
virtual embedding_res_t Embed(const std::string& text) = 0;
virtual nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) = 0;
virtual embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) = 0;
virtual std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) = 0;
static void init(ReplicationState* rs) {
raft_server = rs;
@ -47,8 +48,9 @@ class OpenAIEmbedder : public RemoteEmbedder {
public:
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
embedding_res_t Embed(const std::string& text) override;
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
};
@ -59,11 +61,13 @@ class GoogleEmbedder : public RemoteEmbedder {
inline static constexpr short GOOGLE_EMBEDDING_DIM = 768;
inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=";
std::string google_api_key;
public:
GoogleEmbedder(const std::string& google_api_key);
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
embedding_res_t Embed(const std::string& text) override;
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
};
@ -88,8 +92,9 @@ class GCPEmbedder : public RemoteEmbedder {
GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token,
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret);
static Option<bool> is_model_valid(const nlohmann::json& model_config, unsigned int& num_dims);
embedding_res_t Embed(const std::string& text) override;
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs) override;
nlohmann::json get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) override;
};

View File

@ -5,7 +5,7 @@
#include "http_client.h"
#include "collection_manager.h"
Option<bool> AnalyticsManager::create_rule(nlohmann::json& payload, bool write_to_disk) {
Option<bool> AnalyticsManager::create_rule(nlohmann::json& payload, bool upsert, bool write_to_disk) {
/*
Sample payload:
@ -37,16 +37,23 @@ Option<bool> AnalyticsManager::create_rule(nlohmann::json& payload, bool write_t
}
if(payload["type"] == POPULAR_QUERIES_TYPE) {
return create_popular_queries_index(payload, write_to_disk);
return create_popular_queries_index(payload, upsert, write_to_disk);
}
return Option<bool>(400, "Invalid type.");
}
Option<bool> AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool write_to_disk) {
Option<bool> AnalyticsManager::create_popular_queries_index(nlohmann::json &payload, bool upsert, bool write_to_disk) {
// params and name are validated upstream
const auto& params = payload["params"];
const std::string& suggestion_config_name = payload["name"].get<std::string>();
bool already_exists = suggestion_configs.find(suggestion_config_name) != suggestion_configs.end();
if(!upsert && already_exists) {
return Option<bool>(400, "There's already another configuration with the name `" +
suggestion_config_name + "`.");
}
const auto& params = payload["params"];
if(!params.contains("source") || !params["source"].is_object()) {
return Option<bool>(400, "Bad or missing source.");
@ -56,18 +63,12 @@ Option<bool> AnalyticsManager::create_popular_queries_index(nlohmann::json &payl
return Option<bool>(400, "Bad or missing destination.");
}
size_t limit = 1000;
if(params.contains("limit") && params["limit"].is_number_integer()) {
limit = params["limit"].get<size_t>();
}
if(suggestion_configs.find(suggestion_config_name) != suggestion_configs.end()) {
return Option<bool>(400, "There's already another configuration with the name `" +
suggestion_config_name + "`.");
}
if(!params["source"].contains("collections") || !params["source"]["collections"].is_array()) {
return Option<bool>(400, "Must contain a valid list of source collections.");
}
@ -93,6 +94,14 @@ Option<bool> AnalyticsManager::create_popular_queries_index(nlohmann::json &payl
std::unique_lock lock(mutex);
if(already_exists) {
// remove the previous configuration with same name (upsert)
Option<bool> remove_op = remove_popular_queries_index(suggestion_config_name);
if(!remove_op.ok()) {
return Option<bool>(500, "Error erasing the existing configuration.");;
}
}
suggestion_configs.emplace(suggestion_config_name, suggestion_config);
for(const auto& query_coll: suggestion_config.query_collections) {
@ -130,13 +139,25 @@ Option<nlohmann::json> AnalyticsManager::list_rules() {
for(const auto& suggestion_config: suggestion_configs) {
nlohmann::json rule;
suggestion_config.second.to_json(rule);
rule["type"] = POPULAR_QUERIES_TYPE;
rules["rules"].push_back(rule);
}
return Option<nlohmann::json>(rules);
}
Option<nlohmann::json> AnalyticsManager::get_rule(const string& name) {
nlohmann::json rule;
std::unique_lock lock(mutex);
auto suggestion_config_it = suggestion_configs.find(name);
if(suggestion_config_it == suggestion_configs.end()) {
return Option<nlohmann::json>(404, "Rule not found.");
}
suggestion_config_it->second.to_json(rule);
return Option<nlohmann::json>(rule);
}
Option<bool> AnalyticsManager::remove_rule(const string &name) {
std::unique_lock lock(mutex);

View File

@ -1108,7 +1108,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t facet_sample_threshold,
const size_t page_offset,
facet_index_type_t facet_index_type,
const size_t vector_query_hits) const {
const size_t vector_query_hits,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_try) const {
std::shared_lock lock(mutex);
@ -1234,10 +1236,15 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
std::string error = "Prefix search is not supported for remote embedders. Please set `prefix=false` as an additional search parameter to disable prefix searching.";
return Option<nlohmann::json>(400, error);
}
if(remote_embedding_num_try == 0) {
std::string error = "`remote-embedding-num-try` must be greater than 0.";
return Option<nlohmann::json>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query;
auto embedding_op = embedder->Embed(embed_query);
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_try);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());

View File

@ -285,6 +285,17 @@ Option<bool> CollectionManager::load(const size_t collection_batch_size, const s
iter->Next();
}
// restore query suggestions configs
std::vector<std::string> analytics_config_jsons;
store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX,
std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`",
analytics_config_jsons);
for(const auto& analytics_config_json: analytics_config_jsons) {
nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json);
AnalyticsManager::get_instance().create_rule(analytics_config, false, false);
}
delete iter;
LOG(INFO) << "Loaded " << num_collections << " collection(s).";
@ -671,6 +682,9 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *VECTOR_QUERY = "vector_query";
const char *VECTOR_QUERY_HITS = "vector_query_hits";
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try";
const char *GROUP_BY = "group_by";
const char *GROUP_LIMIT = "group_limit";
@ -824,6 +838,9 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
text_match_type_t match_type = max_score;
size_t vector_query_hits = 250;
size_t remote_embedding_timeout_ms = 30000;
size_t remote_embedding_num_try = 2;
size_t facet_sample_percent = 100;
size_t facet_sample_threshold = 0;
@ -850,6 +867,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
{VECTOR_QUERY_HITS, &vector_query_hits},
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
{REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try},
};
std::unordered_map<std::string, std::string*> str_values = {
@ -1062,7 +1081,11 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
match_type,
facet_sample_percent,
facet_sample_threshold,
offset
offset,
HASH,
vector_query_hits,
remote_embedding_timeout_ms,
remote_embedding_num_try
);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
@ -1312,17 +1335,6 @@ Option<bool> CollectionManager::load_collection(const nlohmann::json &collection
collection->add_synonym(collection_synonym, false);
}
// restore query suggestions configs
std::vector<std::string> analytics_config_jsons;
cm.store->scan_fill(AnalyticsManager::ANALYTICS_RULE_PREFIX,
std::string(AnalyticsManager::ANALYTICS_RULE_PREFIX) + "`",
analytics_config_jsons);
for(const auto& analytics_config_json: analytics_config_jsons) {
nlohmann::json analytics_config = nlohmann::json::parse(analytics_config_json);
AnalyticsManager::get_instance().create_rule(analytics_config, false);
}
// Fetch records from the store and re-create memory index
const std::string seq_id_prefix = collection->get_seq_id_collection_prefix();
std::string upper_bound_key = collection->get_seq_id_collection_prefix() + "`"; // cannot inline this

View File

@ -2078,13 +2078,25 @@ bool post_create_event(const std::shared_ptr<http_req>& req, const std::shared_p
bool get_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
auto rules_op = AnalyticsManager::get_instance().list_rules();
if(rules_op.ok()) {
res->set_200(rules_op.get().dump());
return true;
if(!rules_op.ok()) {
res->set(rules_op.code(), rules_op.error());
return false;
}
res->set(rules_op.code(), rules_op.error());
return false;
res->set_200(rules_op.get().dump());
return true;
}
bool get_analytics_rule(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
auto rules_op = AnalyticsManager::get_instance().get_rule(req->params["name"]);
if(!rules_op.ok()) {
res->set(rules_op.code(), rules_op.error());
return false;
}
res->set_200(rules_op.get().dump());
return true;
}
bool post_create_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
@ -2098,7 +2110,7 @@ bool post_create_analytics_rules(const std::shared_ptr<http_req>& req, const std
return false;
}
auto op = AnalyticsManager::get_instance().create_rule(req_json);
auto op = AnalyticsManager::get_instance().create_rule(req_json, false, true);
if(!op.ok()) {
res->set(op.code(), op.error());
@ -2109,6 +2121,29 @@ bool post_create_analytics_rules(const std::shared_ptr<http_req>& req, const std
return true;
}
bool put_upsert_analytics_rules(const std::shared_ptr<http_req> &req, const std::shared_ptr<http_res> &res) {
nlohmann::json req_json;
try {
req_json = nlohmann::json::parse(req->body);
} catch(const std::exception& e) {
LOG(ERROR) << "JSON error: " << e.what();
res->set_400("Bad JSON.");
return false;
}
req_json["name"] = req->params["name"];
auto op = AnalyticsManager::get_instance().create_rule(req_json, true, true);
if(!op.ok()) {
res->set(op.code(), op.error());
return false;
}
res->set_200(req_json.dump());
return true;
}
bool del_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
auto op = AnalyticsManager::get_instance().remove_rule(req->params["name"]);
if(!op.ok()) {
@ -2116,11 +2151,13 @@ bool del_analytics_rules(const std::shared_ptr<http_req>& req, const std::shared
return false;
}
res->set_200(R"({"ok": true)");
nlohmann::json res_json;
res_json["name"] = req->params["name"];
res->set_200(res_json.dump());
return true;
}
bool post_proxy(const std::shared_ptr<http_req>& req, const std::shared_ptr<http_res>& res) {
HttpProxy& proxy = HttpProxy::get_instance();
@ -2180,4 +2217,4 @@ bool post_proxy(const std::shared_ptr<http_req>& req, const std::shared_ptr<http
res->set_200(response.body);
return true;
}
}

View File

@ -156,6 +156,9 @@ long HttpClient::perform_curl(CURL *curl, std::map<std::string, std::string>& re
LOG(ERROR) << "CURL failed. URL: " << url << ", Code: " << res << ", strerror: " << curl_easy_strerror(res);
curl_easy_cleanup(curl);
curl_slist_free_all(chunk);
if(res == CURLE_OPERATION_TIMEDOUT) {
return 408;
}
return 500;
}

View File

@ -8,17 +8,19 @@ HttpProxy::HttpProxy() : cache(30s){
}
http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map<std::string, std::string>& headers) {
http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& method,
const std::string& body, const std::unordered_map<std::string, std::string>& headers,
const size_t timeout_ms) {
HttpClient& client = HttpClient::get_instance();
http_proxy_res_t res;
if(method == "GET") {
res.status_code = client.get_response(url, res.body, res.headers, headers, 20 * 1000);
res.status_code = client.get_response(url, res.body, res.headers, headers, timeout_ms);
} else if(method == "POST") {
res.status_code = client.post_response(url, body, res.body, res.headers, headers, 20 * 1000);
res.status_code = client.post_response(url, body, res.body, res.headers, headers, timeout_ms);
} else if(method == "PUT") {
res.status_code = client.put_response(url, body, res.body, res.headers, 20 * 1000);
res.status_code = client.put_response(url, body, res.body, res.headers, timeout_ms);
} else if(method == "DELETE") {
res.status_code = client.delete_response(url, res.body, res.headers, 20 * 1000);
res.status_code = client.delete_response(url, res.body, res.headers, timeout_ms);
} else {
res.status_code = 400;
nlohmann::json j;
@ -29,11 +31,25 @@ http_proxy_res_t HttpProxy::call(const std::string& url, const std::string& meth
}
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, const std::unordered_map<std::string, std::string>& headers) {
http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& method, const std::string& body, std::unordered_map<std::string, std::string>& headers) {
// check if url is in cache
uint64_t key = StringUtils::hash_wy(url.c_str(), url.size());
key = StringUtils::hash_combine(key, StringUtils::hash_wy(method.c_str(), method.size()));
key = StringUtils::hash_combine(key, StringUtils::hash_wy(body.c_str(), body.size()));
size_t timeout_ms = 30000;
size_t num_try = 2;
if(headers.find("timeout_ms") != headers.end()){
timeout_ms = std::stoul(headers.at("timeout_ms"));
headers.erase("timeout_ms");
}
if(headers.find("num_try") != headers.end()){
num_try = std::stoul(headers.at("num_try"));
headers.erase("num_try");
}
for(auto& header : headers){
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.first.c_str(), header.first.size()));
key = StringUtils::hash_combine(key, StringUtils::hash_wy(header.second.c_str(), header.second.size()));
@ -42,22 +58,23 @@ http_proxy_res_t HttpProxy::send(const std::string& url, const std::string& meth
return cache[key];
}
auto res = call(url, method, body, headers);
http_proxy_res_t res;
for(size_t i = 0; i < num_try; i++){
res = call(url, method, body, headers, timeout_ms);
if(res.status_code == 500){
// retry
res = call(url, method, body, headers);
if(res.status_code != 408 && res.status_code < 500){
break;
}
}
if(res.status_code == 500){
if(res.status_code == 408){
nlohmann::json j;
j["message"] = "Server error on remote server. Please try again later.";
res.body = j.dump();
}
// add to cache
if(res.status_code != 500){
if(res.status_code == 200){
cache.insert(key, res);
}

View File

@ -5781,7 +5781,9 @@ void Index::get_doc_changes(const index_operation_t op, const tsl::htrie_map<cha
if(it.value().is_null()) {
// null values should not be indexed
new_doc.erase(it.key());
del_doc[it.key()] = old_doc[it.key()];
if(old_doc.contains(it.key())) {
del_doc[it.key()] = old_doc[it.key()];
}
it = update_doc.erase(it);
continue;
}

View File

@ -70,7 +70,9 @@ void master_server_routes() {
// analytics
server->get("/analytics/rules", get_analytics_rules);
server->get("/analytics/rules/:name", get_analytics_rule);
server->post("/analytics/rules", post_create_analytics_rules);
server->put("/analytics/rules/:name", put_upsert_analytics_rules);
server->del("/analytics/rules/:name", del_analytics_rules);
server->post("/analytics/events", post_create_event);

View File

@ -83,9 +83,9 @@ std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<floa
return pooled_output;
}
embedding_res_t TextEmbedder::Embed(const std::string& text) {
embedding_res_t TextEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
if(is_remote()) {
return remote_embedder_->Embed(text);
return remote_embedder_->Embed(text, remote_embedder_timeout_ms, remote_embedding_num_try);
} else {
// Cannot run same model in parallel, so lock the mutex
std::lock_guard<std::mutex> lock(mutex_);

View File

@ -59,15 +59,30 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
headers["Authorization"] = "Bearer " + api_key;
std::string res;
auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "OpenAI API error: " + res);
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "OpenAI API error: " + res);
}
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
auto models_json = nlohmann::json::parse(res);
nlohmann::json models_json;
try {
models_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from OpenAI API.");
}
bool found = false;
// extract model name by removing "openai/" prefix
auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name);
@ -88,49 +103,57 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
std::string embedding_res;
headers["Content-Type"] = "application/json";
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(embedding_res);
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(embedding_res);
} catch (const std::exception& e) {
return Option<bool>(400, "OpenAI API error: " + embedding_res);
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "OpenAI API error: " + embedding_res);
}
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
auto embedding = nlohmann::json::parse(embedding_res)["data"][0]["embedding"].get<std::vector<float>>();
std::vector<float> embedding;
try {
embedding = nlohmann::json::parse(embedding_res)["data"][0]["embedding"].get<std::vector<float>>();
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from OpenAI API.");
}
num_dims = embedding.size();
return Option<bool>(true);
}
embedding_res_t OpenAIEmbedder::Embed(const std::string& text) {
embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_try"] = std::to_string(remote_embedding_num_try);
std::string res;
nlohmann::json req_body;
req_body["input"] = text;
req_body["input"] = std::vector<std::string>{text};
// remove "openai/" prefix
req_body["model"] = TextEmbedderManager::get_model_name_without_namespace(openai_model_path);
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
if (res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
}
return embedding_res_t(res_code, embedding_res);
return embedding_res_t(res_code, get_error_json(req_body, res_code, res));
}
try {
embedding_res_t embedding_res = embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
return embedding_res;
} catch (const std::exception& e) {
return embedding_res_t(500, get_error_json(req_body, res_code, res));
}
return embedding_res_t(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
}
std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
@ -147,20 +170,7 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
if(res_code != 200) {
std::vector<embedding_res_t> outputs;
nlohmann::json json_res = nlohmann::json::parse(res);
LOG(INFO) << "OpenAI API error: " << json_res.dump();
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
embedding_res["request"]["body"]["input"] = std::vector<std::string>{inputs[0]};
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
}
nlohmann::json embedding_res = get_error_json(req_body, res_code, res);
for(size_t i = 0; i < inputs.size(); i++) {
embedding_res["request"]["body"]["input"][0] = inputs[i];
outputs.push_back(embedding_res_t(res_code, embedding_res));
@ -168,7 +178,18 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
return outputs;
}
nlohmann::json res_json = nlohmann::json::parse(res);
nlohmann::json res_json;
try {
res_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
nlohmann::json embedding_res = get_error_json(req_body, res_code, res);
std::vector<embedding_res_t> outputs;
for(size_t i = 0; i < inputs.size(); i++) {
embedding_res["request"]["body"]["input"][0] = inputs[i];
outputs.push_back(embedding_res_t(500, embedding_res));
}
return outputs;
}
std::vector<embedding_res_t> outputs;
for(auto& data : res_json["data"]) {
outputs.push_back(embedding_res_t(data["embedding"].get<std::vector<float>>()));
@ -178,6 +199,36 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
}
nlohmann::json OpenAIEmbedder::get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) {
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res_body);
} catch (const std::exception& e) {
json_res = nlohmann::json::object();
json_res["error"] = "Malformed response from OpenAI API.";
}
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(embedding_res["request"]["body"].count("input") > 0 && embedding_res["request"]["body"]["input"].get<std::vector<std::string>>().size() > 1) {
auto vec = embedding_res["request"]["body"]["input"].get<std::vector<std::string>>();
vec.resize(1);
embedding_res["request"]["body"]["input"] = vec;
}
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "OpenAI API error: " + json_res["error"]["message"].get<std::string>();
}
if(res_code == 408) {
embedding_res["error"] = "OpenAI API timeout.";
}
return embedding_res;
}
GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) {
}
@ -210,22 +261,38 @@ Option<bool> GoogleEmbedder::is_model_valid(const nlohmann::json& model_config,
auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
json_res = nlohmann::json::object();
json_res["error"] = "Malformed response from Google API.";
}
if(res_code == 408) {
return Option<bool>(408, "Google API timeout.");
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "Google API error: " + res);
}
return Option<bool>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
num_dims = nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>().size();
try {
num_dims = nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>().size();
} catch (const std::exception& e) {
return Option<bool>(500, "Got malformed response from Google API.");
}
return Option<bool>(true);
}
embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
embedding_res_t GoogleEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_try"] = std::to_string(remote_embedding_num_try);
std::string res;
nlohmann::json req_body;
req_body["text"] = text;
@ -233,20 +300,14 @@ embedding_res_t GoogleEmbedder::Embed(const std::string& text) {
auto res_code = call_remote_api("POST", std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING;
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get<std::string>();
}
return embedding_res_t(res_code, embedding_res);
return embedding_res_t(res_code, get_error_json(req_body, res_code, res));
}
return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
try {
return embedding_res_t(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
} catch (const std::exception& e) {
return embedding_res_t(500, get_error_json(req_body, res_code, res));
}
}
@ -260,6 +321,30 @@ std::vector<embedding_res_t> GoogleEmbedder::batch_embed(const std::vector<std::
return outputs;
}
nlohmann::json GoogleEmbedder::get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) {
nlohmann::json json_res;
try {
nlohmann::json json_res = nlohmann::json::parse(res_body);
} catch (const std::exception& e) {
json_res = nlohmann::json::object();
json_res["error"] = "Malformed response from Google API.";
}
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = GOOGLE_CREATE_EMBEDDING;
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "Google API error: " + json_res["error"]["message"].get<std::string>();
}
if(res_code == 408) {
embedding_res["error"] = "Google API timeout.";
}
return embedding_res;
}
GCPEmbedder::GCPEmbedder(const std::string& project_id, const std::string& model_name, const std::string& access_token,
const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) :
@ -302,14 +387,26 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
auto res_code = call_remote_api("POST", get_gcp_embedding_url(project_id, model_name_without_namespace), req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from GCP API.");
}
if(json_res == 408) {
return Option<bool>(408, "GCP API timeout.");
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "GCP API error: " + res);
}
return Option<bool>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
auto res_json = nlohmann::json::parse(res);
nlohmann::json res_json;
try {
res_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from GCP API.");
}
if(res_json.count("predictions") == 0 || res_json["predictions"].size() == 0 || res_json["predictions"][0].count("embeddings") == 0) {
LOG(INFO) << "Invalid response from GCP API: " << res_json.dump();
return Option<bool>(400, "GCP API error: Invalid response");
@ -325,7 +422,7 @@ Option<bool> GCPEmbedder::is_model_valid(const nlohmann::json& model_config, uns
return Option<bool>(true);
}
embedding_res_t GCPEmbedder::Embed(const std::string& text) {
embedding_res_t GCPEmbedder::Embed(const std::string& text, const size_t remote_embedder_timeout_ms, const size_t remote_embedding_num_try) {
nlohmann::json req_body;
req_body["instances"] = nlohmann::json::array();
nlohmann::json instance;
@ -334,6 +431,8 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) {
std::unordered_map<std::string, std::string> headers;
headers["Authorization"] = "Bearer " + access_token;
headers["Content-Type"] = "application/json";
headers["timeout_ms"] = std::to_string(remote_embedder_timeout_ms);
headers["num_try"] = std::to_string(remote_embedding_num_try);
std::map<std::string, std::string> res_headers;
std::string res;
@ -355,24 +454,17 @@ embedding_res_t GCPEmbedder::Embed(const std::string& text) {
}
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
}
return embedding_res_t(res_code, embedding_res);
return embedding_res_t(res_code, get_error_json(req_body, res_code, res));
}
nlohmann::json res_json;
try {
res_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return embedding_res_t(500, get_error_json(req_body, res_code, res));
}
nlohmann::json res_json = nlohmann::json::parse(res);
return embedding_res_t(res_json["predictions"][0]["embeddings"]["values"].get<std::vector<float>>());
}
std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::string>& inputs) {
// GCP API has a limit of 5 instances per request
if(inputs.size() > 5) {
@ -416,24 +508,24 @@ std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::str
}
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
}
auto embedding_res = get_error_json(req_body, res_code, res);
std::vector<embedding_res_t> outputs;
for(size_t i = 0; i < inputs.size(); i++) {
outputs.push_back(embedding_res_t(res_code, embedding_res));
}
return outputs;
}
nlohmann::json res_json = nlohmann::json::parse(res);
nlohmann::json res_json;
try {
res_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
nlohmann::json embedding_res = get_error_json(req_body, res_code, res);
std::vector<embedding_res_t> outputs;
for(size_t i = 0; i < inputs.size(); i++) {
outputs.push_back(embedding_res_t(400, embedding_res));
}
return outputs;
}
std::vector<embedding_res_t> outputs;
for(const auto& prediction : res_json["predictions"]) {
outputs.push_back(embedding_res_t(prediction["embeddings"]["values"].get<std::vector<float>>()));
@ -442,6 +534,34 @@ std::vector<embedding_res_t> GCPEmbedder::batch_embed(const std::vector<std::str
return outputs;
}
nlohmann::json GCPEmbedder::get_error_json(const nlohmann::json& req_body, long res_code, const std::string& res_body) {
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res_body);
} catch (const std::exception& e) {
json_res = nlohmann::json::object();
json_res["error"] = "Malformed response from GCP API.";
}
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = get_gcp_embedding_url(project_id, model_name);
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(json_res.count("error") != 0 && json_res["error"].count("message") != 0) {
embedding_res["error"] = "GCP API error: " + json_res["error"]["message"].get<std::string>();
} else {
embedding_res["error"] = "Malformed response from GCP API.";
}
if(res_code == 408) {
embedding_res["error"] = "GCP API timeout.";
}
return embedding_res;
}
Option<std::string> GCPEmbedder::generate_access_token(const std::string& refresh_token, const std::string& client_id, const std::string& client_secret) {
std::unordered_map<std::string, std::string> headers;
headers["Content-Type"] = "application/x-www-form-urlencoded";
@ -453,17 +573,27 @@ Option<std::string> GCPEmbedder::generate_access_token(const std::string& refres
auto res_code = call_remote_api("POST", GCP_AUTH_TOKEN_URL, req_body, res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<std::string>(400, "Got malformed response from GCP API.");
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<std::string>(400, "GCP API error: " + res);
}
if(res_code == 408) {
return Option<std::string>(408, "GCP API timeout.");
}
return Option<std::string>(400, "GCP API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
nlohmann::json res_json = nlohmann::json::parse(res);
nlohmann::json res_json;
try {
res_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<std::string>(400, "Got malformed response from GCP API.");
}
std::string access_token = res_json["access_token"].get<std::string>();
return Option<std::string>(access_token);
}

View File

@ -78,7 +78,7 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) {
}
})"_json;
auto create_op = analyticsManager.create_rule(analytics_rule);
auto create_op = analyticsManager.create_rule(analytics_rule, false, true);
ASSERT_TRUE(create_op.ok());
std::string q = "foobar";
@ -88,4 +88,97 @@ TEST_F(AnalyticsManagerTest, AddSuggestion) {
auto userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"];
ASSERT_EQ(1, userQueries.size());
ASSERT_EQ("foobar", userQueries[0].query);
// add another query which is more popular
q = "buzzfoo";
analyticsManager.add_suggestion("titles", q, true, "1");
analyticsManager.add_suggestion("titles", q, true, "2");
analyticsManager.add_suggestion("titles", q, true, "3");
popularQueries = analyticsManager.get_popular_queries();
userQueries = popularQueries["top_queries"]->get_user_prefix_queries()["1"];
ASSERT_EQ(2, userQueries.size());
ASSERT_EQ("foobar", userQueries[0].query);
ASSERT_EQ("buzzfoo", userQueries[1].query);
}
TEST_F(AnalyticsManagerTest, GetAndDeleteSuggestions) {
nlohmann::json analytics_rule = R"({
"name": "top_search_queries",
"type": "popular_queries",
"params": {
"limit": 100,
"source": {
"collections": ["titles"]
},
"destination": {
"collection": "top_queries"
}
}
})"_json;
auto create_op = analyticsManager.create_rule(analytics_rule, false, true);
ASSERT_TRUE(create_op.ok());
analytics_rule = R"({
"name": "top_search_queries2",
"type": "popular_queries",
"params": {
"limit": 100,
"source": {
"collections": ["titles"]
},
"destination": {
"collection": "top_queries"
}
}
})"_json;
create_op = analyticsManager.create_rule(analytics_rule, false, true);
ASSERT_TRUE(create_op.ok());
auto rules = analyticsManager.list_rules().get()["rules"];
ASSERT_EQ(2, rules.size());
ASSERT_TRUE(analyticsManager.get_rule("top_search_queries").ok());
ASSERT_TRUE(analyticsManager.get_rule("top_search_queries2").ok());
auto missing_rule_op = analyticsManager.get_rule("top_search_queriesX");
ASSERT_FALSE(missing_rule_op.ok());
ASSERT_EQ(404, missing_rule_op.code());
ASSERT_EQ("Rule not found.", missing_rule_op.error());
// upsert rule that already exists
analytics_rule = R"({
"name": "top_search_queries2",
"type": "popular_queries",
"params": {
"limit": 100,
"source": {
"collections": ["titles"]
},
"destination": {
"collection": "top_queriesUpdated"
}
}
})"_json;
create_op = analyticsManager.create_rule(analytics_rule, true, true);
ASSERT_TRUE(create_op.ok());
auto existing_rule = analyticsManager.get_rule("top_search_queries2").get();
ASSERT_EQ("top_queriesUpdated", existing_rule["params"]["destination"]["collection"].get<std::string>());
// reject when upsert is not enabled
create_op = analyticsManager.create_rule(analytics_rule, false, true);
ASSERT_FALSE(create_op.ok());
ASSERT_EQ("There's already another configuration with the name `top_search_queries2`.", create_op.error());
// try deleting both rules
analyticsManager.remove_rule("top_search_queries");
analyticsManager.remove_rule("top_search_queries2");
missing_rule_op = analyticsManager.get_rule("top_search_queries");
ASSERT_FALSE(missing_rule_op.ok());
missing_rule_op = analyticsManager.get_rule("top_search_queries2");
ASSERT_FALSE(missing_rule_op.ok());
}

View File

@ -1321,6 +1321,21 @@ TEST_F(CollectionSpecificMoreTest, UpdateArrayWithNullValue) {
auto results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(0, results["found"].get<size_t>());
// update document with no value (optional field) with a null value
auto doc3 = R"({
"id": "2"
})"_json;
ASSERT_TRUE(coll1->add(doc3.dump(), CREATE).ok());
results = coll1->search("alpha", {"tags"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(0, results["found"].get<size_t>());
doc_update = R"({
"id": "2",
"tags": null
})"_json;
ASSERT_TRUE(coll1->add(doc_update.dump(), UPDATE).ok());
// via upsert
doc_update = R"({

View File

@ -5177,4 +5177,45 @@ TEST_F(CollectionTest, EmbeddingFieldEmptyArrayInDocument) {
ASSERT_FALSE(get_op.get()["embedding"].is_null());
ASSERT_EQ(384, get_op.get()["embedding"].size());
}
TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) {
std::string partial_json = R"({
"results": [
{
"embedding": [
0.0,
0.0,
0.0
],
"text": "butter"
},
{
"embedding": [
0.0,
0.0,
0.0
],
"text": "butterball"
},
{
"embedding": [
0.0,
0.0)";
nlohmann::json req_body = R"({
"inputs": [
"butter",
"butterball",
"butterfly"
]
})"_json;
OpenAIEmbedder embedder("", "");
auto res = embedder.get_error_json(req_body, 200, partial_json);
ASSERT_EQ(res["response"]["error"], "Malformed response from OpenAI API.");
ASSERT_EQ(res["request"]["body"], req_body);
}

View File

@ -1137,4 +1137,27 @@ TEST_F(CoreAPIUtilsTest, TestProxyInvalid) {
ASSERT_EQ(400, resp->status_code);
ASSERT_EQ("Headers must be a JSON object.", nlohmann::json::parse(resp->body)["message"]);
}
TEST_F(CoreAPIUtilsTest, TestProxyTimeout) {
nlohmann::json body;
auto req = std::make_shared<http_req>();
auto resp = std::make_shared<http_res>(nullptr);
// test with url as empty string
body["url"] = "https://typesense.org/docs/";
body["method"] = "GET";
body["headers"] = nlohmann::json::object();
body["headers"]["timeout_ms"] = "1";
body["headers"]["num_retry"] = "1";
req->body = body.dump();
post_proxy(req, resp);
ASSERT_EQ(408, resp->status_code);
ASSERT_EQ("Server error on remote server. Please try again later.", nlohmann::json::parse(resp->body)["message"]);
}