mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 13:42:26 +08:00
Adding Google remote emedding platform
This commit is contained in:
parent
0ca0198f1d
commit
9cbb98c1b9
@ -6,33 +6,29 @@
|
||||
#include <vector>
|
||||
#include "option.h"
|
||||
#include "text_embedder_tokenizer.h"
|
||||
#include "text_embedder_remote.h"
|
||||
|
||||
|
||||
class TextEmbedder {
|
||||
public:
|
||||
TextEmbedder(const std::string& model_path);
|
||||
TextEmbedder(const std::string& openai_model_path, const std::string& api_key);
|
||||
TextEmbedder(const std::string& model_name, const std::string& api_key);
|
||||
~TextEmbedder();
|
||||
Option<std::vector<float>> Embed(const std::string& text);
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs);
|
||||
const std::string& get_vocab_file_name() const;
|
||||
|
||||
bool is_openai() {
|
||||
return !api_key.empty();
|
||||
bool is_remote() {
|
||||
return remote_embedder_ != nullptr;
|
||||
}
|
||||
|
||||
static bool is_model_valid(const std::string& model_path, unsigned int& num_dims);
|
||||
static Option<bool> is_model_valid(const std::string openai_model_path, const std::string api_key, unsigned int& num_dims);
|
||||
static Option<bool> is_model_valid(const std::string model_name, const std::string api_key, unsigned int& num_dims);
|
||||
private:
|
||||
std::unique_ptr<Ort::Session> session_;
|
||||
Ort::Env env_;
|
||||
encoded_input_t Encode(const std::string& text);
|
||||
std::unique_ptr<TextEmbeddingTokenizer> tokenizer_;
|
||||
std::unique_ptr<RemoteEmbedder> remote_embedder_;
|
||||
std::string vocab_file_name;
|
||||
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
|
||||
std::string output_tensor_name;
|
||||
std::string api_key;
|
||||
std::string openai_model_path;
|
||||
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
|
||||
static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings";
|
||||
};
|
||||
|
@ -58,6 +58,7 @@ public:
|
||||
static const std::string get_vocab_url(const text_embedding_model& model);
|
||||
static Option<nlohmann::json> get_public_model_config(const std::string& model_name);
|
||||
static const std::string get_model_name_without_namespace(const std::string& model_name);
|
||||
static const std::string get_model_namespace(const std::string& model_name);
|
||||
static const std::string get_model_subdir(const std::string& model_name);
|
||||
static const bool check_md5(const std::string& file_path, const std::string& target_md5);
|
||||
Option<bool> download_public_model(const std::string& model_name);
|
||||
|
46
include/text_embedder_remote.h
Normal file
46
include/text_embedder_remote.h
Normal file
@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "http_client.h"
|
||||
#include "option.h"
|
||||
|
||||
|
||||
|
||||
|
||||
class RemoteEmbedder {
|
||||
public:
|
||||
virtual Option<std::vector<float>> Embed(const std::string& text) = 0;
|
||||
virtual Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) = 0;
|
||||
};
|
||||
|
||||
|
||||
class OpenAIEmbedder : public RemoteEmbedder {
|
||||
private:
|
||||
std::string api_key;
|
||||
std::string openai_model_path;
|
||||
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
|
||||
static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings";
|
||||
public:
|
||||
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
|
||||
static Option<bool> is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
||||
class GoogleEmbedder : public RemoteEmbedder {
|
||||
private:
|
||||
// only support this model for now
|
||||
inline static const char* SUPPORTED_MODEL = "embedding-gecko-001";
|
||||
inline static constexpr short GOOGLE_EMBEDDING_DIM = 768;
|
||||
inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=";
|
||||
std::string google_api_key;
|
||||
public:
|
||||
GoogleEmbedder(const std::string& google_api_key);
|
||||
static Option<bool> is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims);
|
||||
Option<std::vector<float>> Embed(const std::string& text) override;
|
||||
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
|
||||
};
|
||||
|
||||
|
@ -1195,10 +1195,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
auto embedder = embedder_op.get();
|
||||
|
||||
if(embedder->is_openai()) {
|
||||
if(embedder->is_remote()) {
|
||||
// return error if prefix search is used with openai embedder
|
||||
if((prefixes.size() == 1 && prefixes[0] == true) || (prefixes.size() > 1 && prefixes[i] == true)) {
|
||||
std::string error = "Prefix search is not supported for OpenAI embedder.";
|
||||
std::string error = "Prefix search is not supported for remote embedders.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
}
|
||||
|
@ -36,8 +36,15 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
|
||||
}
|
||||
}
|
||||
|
||||
TextEmbedder::TextEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
|
||||
TextEmbedder::TextEmbedder(const std::string& model_name, const std::string& api_key) {
|
||||
LOG(INFO) << "Loading model from remote: " << model_name;
|
||||
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
|
||||
|
||||
if (model_namespace == "openai") {
|
||||
remote_embedder_ = std::make_unique<OpenAIEmbedder>(model_name, api_key);
|
||||
} else if (model_namespace == "google") {
|
||||
remote_embedder_ = std::make_unique<GoogleEmbedder>(api_key);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -55,23 +62,8 @@ std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<floa
|
||||
}
|
||||
|
||||
Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
if(is_openai()) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::string res;
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = text;
|
||||
// remove "openai/" prefix
|
||||
req_body["model"] = openai_model_path.substr(7);
|
||||
auto res_code = client.post_response(TextEmbedder::OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
LOG(ERROR) << "OpenAI API error: " << res;
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
if(is_remote()) {
|
||||
return remote_embedder_->Embed(text);
|
||||
} else {
|
||||
auto encoded_input = tokenizer_->Encode(text);
|
||||
// create input tensor object from data values
|
||||
@ -119,34 +111,13 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
|
||||
|
||||
Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<float>> outputs;
|
||||
if(!is_openai()) {
|
||||
if(!is_remote()) {
|
||||
// for now only openai is supported for batch embedding
|
||||
for(const auto& input : inputs) {
|
||||
outputs.push_back(Embed(input).get());
|
||||
}
|
||||
} else {
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = inputs;
|
||||
// remove "openai/" prefix
|
||||
req_body["model"] = openai_model_path.substr(7);
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
std::string res;
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
|
||||
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
LOG(ERROR) << "OpenAI API error: " << res;
|
||||
return Option<std::vector<std::vector<float>>>(400, res);
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
for(auto& data : res_json["data"]) {
|
||||
outputs.push_back(data["embedding"].get<std::vector<float>>());
|
||||
}
|
||||
outputs = std::move(remote_embedder_->batch_embed(inputs).get());
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
}
|
||||
@ -252,55 +223,14 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n
|
||||
}
|
||||
|
||||
|
||||
Option<bool> TextEmbedder::is_model_valid(const std::string openai_model_path, const std::string api_key, unsigned int& num_dims) {
|
||||
if (openai_model_path.empty() || api_key.empty() || openai_model_path.length() < 7) {
|
||||
return Option<bool>(400, "Invalid OpenAI model path or API key");
|
||||
}
|
||||
Option<bool> TextEmbedder::is_model_valid(const std::string model_name, const std::string api_key, unsigned int& num_dims) {
|
||||
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
|
||||
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
std::string res;
|
||||
auto res_code = client.get_response(TextEmbedder::OPENAI_LIST_MODELS, res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
LOG(ERROR) << "OpenAI API error: " << res;
|
||||
return Option<bool>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
|
||||
auto models_json = nlohmann::json::parse(res);
|
||||
bool found = false;
|
||||
// extract model name by removing "openai/" prefix
|
||||
auto model_name = openai_model_path.substr(7);
|
||||
for (auto& model : models_json["data"]) {
|
||||
if (model["id"] == model_name) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
return Option<bool>(400, "OpenAI model not found");
|
||||
}
|
||||
|
||||
// This part is hard coded for now. Because OpenAI API does not provide a way to get the output dimensions of the model.
|
||||
if(model_name.find("-ada-") != std::string::npos) {
|
||||
if(model_name.substr(model_name.length() - 3) == "002") {
|
||||
num_dims = 1536;
|
||||
} else {
|
||||
num_dims = 1024;
|
||||
}
|
||||
}
|
||||
else if(model_name.find("-davinci-") != std::string::npos) {
|
||||
num_dims = 12288;
|
||||
} else if(model_name.find("-curie-") != std::string::npos) {
|
||||
num_dims = 4096;
|
||||
} else if(model_name.find("-babbage-") != std::string::npos) {
|
||||
num_dims = 2048;
|
||||
if(model_namespace == "openai") {
|
||||
return OpenAIEmbedder::is_model_valid(model_name, api_key, num_dims);
|
||||
} else if(model_namespace == "google") {
|
||||
return GoogleEmbedder::is_model_valid(model_name, api_key, num_dims);
|
||||
} else {
|
||||
num_dims = 768;
|
||||
return Option<bool>(400, "Invalid model namespace");
|
||||
}
|
||||
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
@ -254,4 +254,12 @@ const std::string TextEmbedderManager::get_model_url(const text_embedding_model&
|
||||
|
||||
const std::string TextEmbedderManager::get_vocab_url(const text_embedding_model& model) {
|
||||
return MODELS_REPO_URL + model.model_name + "/" + model.vocab_file_name;
|
||||
}
|
||||
|
||||
const std::string TextEmbedderManager::get_model_namespace(const std::string& model_name) {
|
||||
if(model_name.find("/") != std::string::npos) {
|
||||
return model_name.substr(0, model_name.find("/"));
|
||||
} else {
|
||||
return "ts";
|
||||
}
|
||||
}
|
198
src/text_embedder_remote.cpp
Normal file
198
src/text_embedder_remote.cpp
Normal file
@ -0,0 +1,198 @@
|
||||
#include "text_embedder_remote.h"
|
||||
#include "text_embedder_manager.h"
|
||||
|
||||
|
||||
|
||||
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
Option<bool> OpenAIEmbedder::is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims) {
|
||||
if (model_name.empty() || api_key.empty()) {
|
||||
return Option<bool>(400, "Invalid OpenAI model name or API key");
|
||||
}
|
||||
|
||||
if(TextEmbedderManager::get_model_namespace(model_name) != "openai") {
|
||||
return Option<bool>(400, "Invalid OpenAI model name");
|
||||
}
|
||||
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
std::string res;
|
||||
auto res_code = client.get_response(OPENAI_LIST_MODELS, res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<bool>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
auto models_json = nlohmann::json::parse(res);
|
||||
bool found = false;
|
||||
// extract model name by removing "openai/" prefix
|
||||
auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name);
|
||||
for (auto& model : models_json["data"]) {
|
||||
if (model["id"] == model_name_without_namespace) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
return Option<bool>(400, "OpenAI model not found");
|
||||
}
|
||||
|
||||
// This part is hard coded for now. Because OpenAI API does not provide a way to get the output dimensions of the model.
|
||||
if(model_name.find("-ada-") != std::string::npos) {
|
||||
if(model_name.substr(model_name.length() - 3) == "002") {
|
||||
num_dims = 1536;
|
||||
} else {
|
||||
num_dims = 1024;
|
||||
}
|
||||
}
|
||||
else if(model_name.find("-davinci-") != std::string::npos) {
|
||||
num_dims = 12288;
|
||||
} else if(model_name.find("-curie-") != std::string::npos) {
|
||||
num_dims = 4096;
|
||||
} else if(model_name.find("-babbage-") != std::string::npos) {
|
||||
num_dims = 2048;
|
||||
} else {
|
||||
num_dims = 768;
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::string res;
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = text;
|
||||
// remove "openai/" prefix
|
||||
req_body["model"] = openai_model_path.substr(7);
|
||||
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = inputs;
|
||||
// remove "openai/" prefix
|
||||
req_body["model"] = openai_model_path.substr(7);
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::map<std::string, std::string> res_headers;
|
||||
std::string res;
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
|
||||
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<std::vector<float>>>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<std::vector<std::vector<float>>>(400, res);
|
||||
}
|
||||
|
||||
nlohmann::json res_json = nlohmann::json::parse(res);
|
||||
std::vector<std::vector<float>> outputs;
|
||||
for(auto& data : res_json["data"]) {
|
||||
outputs.push_back(data["embedding"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
}
|
||||
|
||||
|
||||
GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) {
|
||||
|
||||
}
|
||||
|
||||
Option<bool> GoogleEmbedder::is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims) {
|
||||
if(model_name.empty() || api_key.empty()) {
|
||||
return Option<bool>(400, "Invalid Google model name or API key");
|
||||
}
|
||||
|
||||
if(TextEmbedderManager::get_model_namespace(model_name) != "google") {
|
||||
return Option<bool>(400, "Invalid Google model name");
|
||||
}
|
||||
|
||||
if(TextEmbedderManager::get_model_name_without_namespace(model_name) != std::string(SUPPORTED_MODEL)) {
|
||||
return Option<bool>(400, "Invalid Google model name");
|
||||
}
|
||||
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::string res;
|
||||
nlohmann::json req_body;
|
||||
req_body["text"] = "test";
|
||||
|
||||
auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<bool>(400, "Google API error: " + res);
|
||||
}
|
||||
return Option<bool>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
num_dims = GOOGLE_EMBEDDING_DIM;
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
|
||||
HttpClient& client = HttpClient::get_instance();
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Content-Type"] = "application/json";
|
||||
std::string res;
|
||||
nlohmann::json req_body;
|
||||
req_body["text"] = text;
|
||||
|
||||
auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
nlohmann::json json_res = nlohmann::json::parse(res);
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<std::vector<float>>(400, "Google API error: " + res);
|
||||
}
|
||||
return Option<std::vector<float>>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
return Option<std::vector<float>>(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
|
||||
}
|
||||
|
||||
|
||||
Option<std::vector<std::vector<float>>> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
|
||||
std::vector<std::vector<float>> outputs;
|
||||
for(auto& input : inputs) {
|
||||
auto res = Embed(input);
|
||||
if(!res.ok()) {
|
||||
return Option<std::vector<std::vector<float>>>(res.code(), res.error());
|
||||
}
|
||||
outputs.push_back(res.get());
|
||||
}
|
||||
|
||||
return Option<std::vector<std::vector<float>>>(outputs);
|
||||
}
|
@ -5001,7 +5001,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) {
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
TEST_F(CollectionTest, CreateOpenAIEmbeddingField) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
@ -5032,7 +5032,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
|
||||
ASSERT_EQ(1536, add_op.get()["embedding"].size());
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) {
|
||||
TEST_F(CollectionTest, HideOpenAIApiKey) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
@ -5056,7 +5056,7 @@ TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) {
|
||||
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'));
|
||||
}
|
||||
|
||||
TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) {
|
||||
TEST_F(CollectionTest, PrefixSearchDisabledForOpenAI) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
@ -5086,10 +5086,10 @@ TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) {
|
||||
auto search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
ASSERT_EQ("Prefix search is not supported for OpenAI embedder.", search_res_op.error());
|
||||
ASSERT_EQ("Prefix search is not supported for remote embedders.", search_res_op.error());
|
||||
|
||||
search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user