Adding Google remote emedding platform

This commit is contained in:
ozanarmagan 2023-05-11 13:40:12 +03:00
parent 0ca0198f1d
commit 9cbb98c1b9
8 changed files with 285 additions and 106 deletions

View File

@ -6,33 +6,29 @@
#include <vector>
#include "option.h"
#include "text_embedder_tokenizer.h"
#include "text_embedder_remote.h"
class TextEmbedder {
public:
TextEmbedder(const std::string& model_path);
TextEmbedder(const std::string& openai_model_path, const std::string& api_key);
TextEmbedder(const std::string& model_name, const std::string& api_key);
~TextEmbedder();
Option<std::vector<float>> Embed(const std::string& text);
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs);
const std::string& get_vocab_file_name() const;
bool is_openai() {
return !api_key.empty();
bool is_remote() {
return remote_embedder_ != nullptr;
}
static bool is_model_valid(const std::string& model_path, unsigned int& num_dims);
static Option<bool> is_model_valid(const std::string openai_model_path, const std::string api_key, unsigned int& num_dims);
static Option<bool> is_model_valid(const std::string model_name, const std::string api_key, unsigned int& num_dims);
private:
std::unique_ptr<Ort::Session> session_;
Ort::Env env_;
encoded_input_t Encode(const std::string& text);
std::unique_ptr<TextEmbeddingTokenizer> tokenizer_;
std::unique_ptr<RemoteEmbedder> remote_embedder_;
std::string vocab_file_name;
static std::vector<float> mean_pooling(const std::vector<std::vector<float>>& input);
std::string output_tensor_name;
std::string api_key;
std::string openai_model_path;
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings";
};

View File

@ -58,6 +58,7 @@ public:
static const std::string get_vocab_url(const text_embedding_model& model);
static Option<nlohmann::json> get_public_model_config(const std::string& model_name);
static const std::string get_model_name_without_namespace(const std::string& model_name);
static const std::string get_model_namespace(const std::string& model_name);
static const std::string get_model_subdir(const std::string& model_name);
static const bool check_md5(const std::string& file_path, const std::string& target_md5);
Option<bool> download_public_model(const std::string& model_name);

View File

@ -0,0 +1,46 @@
#pragma once
#include <vector>
#include <string>
#include "http_client.h"
#include "option.h"
class RemoteEmbedder {
public:
virtual Option<std::vector<float>> Embed(const std::string& text) = 0;
virtual Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) = 0;
};
class OpenAIEmbedder : public RemoteEmbedder {
private:
std::string api_key;
std::string openai_model_path;
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings";
public:
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key);
static Option<bool> is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims);
Option<std::vector<float>> Embed(const std::string& text) override;
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
};
class GoogleEmbedder : public RemoteEmbedder {
private:
// only support this model for now
inline static const char* SUPPORTED_MODEL = "embedding-gecko-001";
inline static constexpr short GOOGLE_EMBEDDING_DIM = 768;
inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=";
std::string google_api_key;
public:
GoogleEmbedder(const std::string& google_api_key);
static Option<bool> is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims);
Option<std::vector<float>> Embed(const std::string& text) override;
Option<std::vector<std::vector<float>>> batch_embed(const std::vector<std::string>& inputs) override;
};

View File

@ -1195,10 +1195,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
auto embedder = embedder_op.get();
if(embedder->is_openai()) {
if(embedder->is_remote()) {
// return error if prefix search is used with openai embedder
if((prefixes.size() == 1 && prefixes[0] == true) || (prefixes.size() > 1 && prefixes[i] == true)) {
std::string error = "Prefix search is not supported for OpenAI embedder.";
std::string error = "Prefix search is not supported for remote embedders.";
return Option<nlohmann::json>(400, error);
}
}

View File

@ -36,8 +36,15 @@ TextEmbedder::TextEmbedder(const std::string& model_name) {
}
}
TextEmbedder::TextEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
TextEmbedder::TextEmbedder(const std::string& model_name, const std::string& api_key) {
LOG(INFO) << "Loading model from remote: " << model_name;
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
if (model_namespace == "openai") {
remote_embedder_ = std::make_unique<OpenAIEmbedder>(model_name, api_key);
} else if (model_namespace == "google") {
remote_embedder_ = std::make_unique<GoogleEmbedder>(api_key);
}
}
@ -55,23 +62,8 @@ std::vector<float> TextEmbedder::mean_pooling(const std::vector<std::vector<floa
}
Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
if(is_openai()) {
HttpClient& client = HttpClient::get_instance();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
std::string res;
nlohmann::json req_body;
req_body["input"] = text;
// remove "openai/" prefix
req_body["model"] = openai_model_path.substr(7);
auto res_code = client.post_response(TextEmbedder::OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
if (res_code != 200) {
LOG(ERROR) << "OpenAI API error: " << res;
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
}
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
if(is_remote()) {
return remote_embedder_->Embed(text);
} else {
auto encoded_input = tokenizer_->Encode(text);
// create input tensor object from data values
@ -119,34 +111,13 @@ Option<std::vector<float>> TextEmbedder::Embed(const std::string& text) {
Option<std::vector<std::vector<float>>> TextEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<std::vector<float>> outputs;
if(!is_openai()) {
if(!is_remote()) {
// for now only openai is supported for batch embedding
for(const auto& input : inputs) {
outputs.push_back(Embed(input).get());
}
} else {
nlohmann::json req_body;
req_body["input"] = inputs;
// remove "openai/" prefix
req_body["model"] = openai_model_path.substr(7);
std::unordered_map<std::string, std::string> headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
std::map<std::string, std::string> res_headers;
std::string res;
HttpClient& client = HttpClient::get_instance();
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
LOG(ERROR) << "OpenAI API error: " << res;
return Option<std::vector<std::vector<float>>>(400, res);
}
nlohmann::json res_json = nlohmann::json::parse(res);
for(auto& data : res_json["data"]) {
outputs.push_back(data["embedding"].get<std::vector<float>>());
}
outputs = std::move(remote_embedder_->batch_embed(inputs).get());
}
return Option<std::vector<std::vector<float>>>(outputs);
}
@ -252,55 +223,14 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n
}
Option<bool> TextEmbedder::is_model_valid(const std::string openai_model_path, const std::string api_key, unsigned int& num_dims) {
if (openai_model_path.empty() || api_key.empty() || openai_model_path.length() < 7) {
return Option<bool>(400, "Invalid OpenAI model path or API key");
}
Option<bool> TextEmbedder::is_model_valid(const std::string model_name, const std::string api_key, unsigned int& num_dims) {
auto model_namespace = TextEmbedderManager::get_model_namespace(model_name);
HttpClient& client = HttpClient::get_instance();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
std::string res;
auto res_code = client.get_response(TextEmbedder::OPENAI_LIST_MODELS, res, res_headers, headers);
if (res_code != 200) {
LOG(ERROR) << "OpenAI API error: " << res;
return Option<bool>(400, "OpenAI API error: " + res);
}
auto models_json = nlohmann::json::parse(res);
bool found = false;
// extract model name by removing "openai/" prefix
auto model_name = openai_model_path.substr(7);
for (auto& model : models_json["data"]) {
if (model["id"] == model_name) {
found = true;
break;
}
}
if (!found) {
return Option<bool>(400, "OpenAI model not found");
}
// This part is hard coded for now. Because OpenAI API does not provide a way to get the output dimensions of the model.
if(model_name.find("-ada-") != std::string::npos) {
if(model_name.substr(model_name.length() - 3) == "002") {
num_dims = 1536;
} else {
num_dims = 1024;
}
}
else if(model_name.find("-davinci-") != std::string::npos) {
num_dims = 12288;
} else if(model_name.find("-curie-") != std::string::npos) {
num_dims = 4096;
} else if(model_name.find("-babbage-") != std::string::npos) {
num_dims = 2048;
if(model_namespace == "openai") {
return OpenAIEmbedder::is_model_valid(model_name, api_key, num_dims);
} else if(model_namespace == "google") {
return GoogleEmbedder::is_model_valid(model_name, api_key, num_dims);
} else {
num_dims = 768;
return Option<bool>(400, "Invalid model namespace");
}
return Option<bool>(true);
}

View File

@ -254,4 +254,12 @@ const std::string TextEmbedderManager::get_model_url(const text_embedding_model&
const std::string TextEmbedderManager::get_vocab_url(const text_embedding_model& model) {
return MODELS_REPO_URL + model.model_name + "/" + model.vocab_file_name;
}
const std::string TextEmbedderManager::get_model_namespace(const std::string& model_name) {
if(model_name.find("/") != std::string::npos) {
return model_name.substr(0, model_name.find("/"));
} else {
return "ts";
}
}

View File

@ -0,0 +1,198 @@
#include "text_embedder_remote.h"
#include "text_embedder_manager.h"
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) {
}
Option<bool> OpenAIEmbedder::is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims) {
if (model_name.empty() || api_key.empty()) {
return Option<bool>(400, "Invalid OpenAI model name or API key");
}
if(TextEmbedderManager::get_model_namespace(model_name) != "openai") {
return Option<bool>(400, "Invalid OpenAI model name");
}
HttpClient& client = HttpClient::get_instance();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
std::string res;
auto res_code = client.get_response(OPENAI_LIST_MODELS, res, res_headers, headers);
if (res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "OpenAI API error: " + res);
}
return Option<bool>(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
auto models_json = nlohmann::json::parse(res);
bool found = false;
// extract model name by removing "openai/" prefix
auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name);
for (auto& model : models_json["data"]) {
if (model["id"] == model_name_without_namespace) {
found = true;
break;
}
}
if (!found) {
return Option<bool>(400, "OpenAI model not found");
}
// This part is hard coded for now. Because OpenAI API does not provide a way to get the output dimensions of the model.
if(model_name.find("-ada-") != std::string::npos) {
if(model_name.substr(model_name.length() - 3) == "002") {
num_dims = 1536;
} else {
num_dims = 1024;
}
}
else if(model_name.find("-davinci-") != std::string::npos) {
num_dims = 12288;
} else if(model_name.find("-curie-") != std::string::npos) {
num_dims = 4096;
} else if(model_name.find("-babbage-") != std::string::npos) {
num_dims = 2048;
} else {
num_dims = 768;
}
return Option<bool>(true);
}
Option<std::vector<float>> OpenAIEmbedder::Embed(const std::string& text) {
HttpClient& client = HttpClient::get_instance();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
std::string res;
nlohmann::json req_body;
req_body["input"] = text;
// remove "openai/" prefix
req_body["model"] = openai_model_path.substr(7);
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
if (res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
}
return Option<std::vector<float>>(400, "OpenAI API error: " + res);
}
return Option<std::vector<float>>(nlohmann::json::parse(res)["data"][0]["embedding"].get<std::vector<float>>());
}
Option<std::vector<std::vector<float>>> OpenAIEmbedder::batch_embed(const std::vector<std::string>& inputs) {
nlohmann::json req_body;
req_body["input"] = inputs;
// remove "openai/" prefix
req_body["model"] = openai_model_path.substr(7);
std::unordered_map<std::string, std::string> headers;
headers["Authorization"] = "Bearer " + api_key;
headers["Content-Type"] = "application/json";
std::map<std::string, std::string> res_headers;
std::string res;
HttpClient& client = HttpClient::get_instance();
auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<std::vector<std::vector<float>>>(400, "OpenAI API error: " + res);
}
return Option<std::vector<std::vector<float>>>(400, res);
}
nlohmann::json res_json = nlohmann::json::parse(res);
std::vector<std::vector<float>> outputs;
for(auto& data : res_json["data"]) {
outputs.push_back(data["embedding"].get<std::vector<float>>());
}
return Option<std::vector<std::vector<float>>>(outputs);
}
GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) {
}
Option<bool> GoogleEmbedder::is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims) {
if(model_name.empty() || api_key.empty()) {
return Option<bool>(400, "Invalid Google model name or API key");
}
if(TextEmbedderManager::get_model_namespace(model_name) != "google") {
return Option<bool>(400, "Invalid Google model name");
}
if(TextEmbedderManager::get_model_name_without_namespace(model_name) != std::string(SUPPORTED_MODEL)) {
return Option<bool>(400, "Invalid Google model name");
}
HttpClient& client = HttpClient::get_instance();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Content-Type"] = "application/json";
std::string res;
nlohmann::json req_body;
req_body["text"] = "test";
auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "Google API error: " + res);
}
return Option<bool>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
num_dims = GOOGLE_EMBEDDING_DIM;
return Option<bool>(true);
}
Option<std::vector<float>> GoogleEmbedder::Embed(const std::string& text) {
HttpClient& client = HttpClient::get_instance();
std::unordered_map<std::string, std::string> headers;
std::map<std::string, std::string> res_headers;
headers["Content-Type"] = "application/json";
std::string res;
nlohmann::json req_body;
req_body["text"] = text;
auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
nlohmann::json json_res = nlohmann::json::parse(res);
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<std::vector<float>>(400, "Google API error: " + res);
}
return Option<std::vector<float>>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get<std::string>());
}
return Option<std::vector<float>>(nlohmann::json::parse(res)["embedding"]["value"].get<std::vector<float>>());
}
Option<std::vector<std::vector<float>>> GoogleEmbedder::batch_embed(const std::vector<std::string>& inputs) {
std::vector<std::vector<float>> outputs;
for(auto& input : inputs) {
auto res = Embed(input);
if(!res.ok()) {
return Option<std::vector<std::vector<float>>>(res.code(), res.error());
}
outputs.push_back(res.get());
}
return Option<std::vector<std::vector<float>>>(outputs);
}

View File

@ -5001,7 +5001,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) {
}
TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
TEST_F(CollectionTest, CreateOpenAIEmbeddingField) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
@ -5032,7 +5032,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) {
ASSERT_EQ(1536, add_op.get()["embedding"].size());
}
TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) {
TEST_F(CollectionTest, HideOpenAIApiKey) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
@ -5056,7 +5056,7 @@ TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) {
ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get<std::string>(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'));
}
TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) {
TEST_F(CollectionTest, PrefixSearchDisabledForOpenAI) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
@ -5086,10 +5086,10 @@ TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) {
auto search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_FALSE(search_res_op.ok());
ASSERT_EQ("Prefix search is not supported for OpenAI embedder.", search_res_op.error());
ASSERT_EQ("Prefix search is not supported for remote embedders.", search_res_op.error());
search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, "");
ASSERT_FALSE(search_res_op.ok());
ASSERT_TRUE(search_res_op.ok());
}