diff --git a/include/text_embedder.h b/include/text_embedder.h index 855154a7..4e1084da 100644 --- a/include/text_embedder.h +++ b/include/text_embedder.h @@ -6,33 +6,29 @@ #include #include "option.h" #include "text_embedder_tokenizer.h" +#include "text_embedder_remote.h" class TextEmbedder { public: TextEmbedder(const std::string& model_path); - TextEmbedder(const std::string& openai_model_path, const std::string& api_key); + TextEmbedder(const std::string& model_name, const std::string& api_key); ~TextEmbedder(); Option> Embed(const std::string& text); Option>> batch_embed(const std::vector& inputs); const std::string& get_vocab_file_name() const; - - bool is_openai() { - return !api_key.empty(); + bool is_remote() { + return remote_embedder_ != nullptr; } - static bool is_model_valid(const std::string& model_path, unsigned int& num_dims); - static Option is_model_valid(const std::string openai_model_path, const std::string api_key, unsigned int& num_dims); + static Option is_model_valid(const std::string model_name, const std::string api_key, unsigned int& num_dims); private: std::unique_ptr session_; Ort::Env env_; encoded_input_t Encode(const std::string& text); std::unique_ptr tokenizer_; + std::unique_ptr remote_embedder_; std::string vocab_file_name; static std::vector mean_pooling(const std::vector>& input); std::string output_tensor_name; - std::string api_key; - std::string openai_model_path; - static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; - static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; }; diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index a008ee76..d67a2c55 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -58,6 +58,7 @@ public: static const std::string get_vocab_url(const text_embedding_model& model); static Option get_public_model_config(const std::string& model_name); static const std::string get_model_name_without_namespace(const std::string& model_name); + static const std::string get_model_namespace(const std::string& model_name); static const std::string get_model_subdir(const std::string& model_name); static const bool check_md5(const std::string& file_path, const std::string& target_md5); Option download_public_model(const std::string& model_name); diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h new file mode 100644 index 00000000..660c441d --- /dev/null +++ b/include/text_embedder_remote.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include "http_client.h" +#include "option.h" + + + + +class RemoteEmbedder { + public: + virtual Option> Embed(const std::string& text) = 0; + virtual Option>> batch_embed(const std::vector& inputs) = 0; +}; + + +class OpenAIEmbedder : public RemoteEmbedder { + private: + std::string api_key; + std::string openai_model_path; + static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; + static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; + public: + OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key); + static Option is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims); + Option> Embed(const std::string& text) override; + Option>> batch_embed(const std::vector& inputs) override; +}; + + +class GoogleEmbedder : public RemoteEmbedder { + private: + // only support this model for now + inline static const char* SUPPORTED_MODEL = "embedding-gecko-001"; + inline static constexpr short GOOGLE_EMBEDDING_DIM = 768; + inline static constexpr char* GOOGLE_CREATE_EMBEDDING = "https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key="; + std::string google_api_key; + public: + GoogleEmbedder(const std::string& google_api_key); + static Option is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims); + Option> Embed(const std::string& text) override; + Option>> batch_embed(const std::vector& inputs) override; +}; + + diff --git a/src/collection.cpp b/src/collection.cpp index d9e4d5ca..46ab45b5 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1195,10 +1195,10 @@ Option Collection::search(std::string raw_query, } auto embedder = embedder_op.get(); - if(embedder->is_openai()) { + if(embedder->is_remote()) { // return error if prefix search is used with openai embedder if((prefixes.size() == 1 && prefixes[0] == true) || (prefixes.size() > 1 && prefixes[i] == true)) { - std::string error = "Prefix search is not supported for OpenAI embedder."; + std::string error = "Prefix search is not supported for remote embedders."; return Option(400, error); } } diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 0ee7cfef..c285f06e 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -36,8 +36,15 @@ TextEmbedder::TextEmbedder(const std::string& model_name) { } } -TextEmbedder::TextEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) { +TextEmbedder::TextEmbedder(const std::string& model_name, const std::string& api_key) { + LOG(INFO) << "Loading model from remote: " << model_name; + auto model_namespace = TextEmbedderManager::get_model_namespace(model_name); + if (model_namespace == "openai") { + remote_embedder_ = std::make_unique(model_name, api_key); + } else if (model_namespace == "google") { + remote_embedder_ = std::make_unique(api_key); + } } @@ -55,23 +62,8 @@ std::vector TextEmbedder::mean_pooling(const std::vector> TextEmbedder::Embed(const std::string& text) { - if(is_openai()) { - HttpClient& client = HttpClient::get_instance(); - std::unordered_map headers; - std::map res_headers; - headers["Authorization"] = "Bearer " + api_key; - headers["Content-Type"] = "application/json"; - std::string res; - nlohmann::json req_body; - req_body["input"] = text; - // remove "openai/" prefix - req_body["model"] = openai_model_path.substr(7); - auto res_code = client.post_response(TextEmbedder::OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); - if (res_code != 200) { - LOG(ERROR) << "OpenAI API error: " << res; - return Option>(400, "OpenAI API error: " + res); - } - return Option>(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); + if(is_remote()) { + return remote_embedder_->Embed(text); } else { auto encoded_input = tokenizer_->Encode(text); // create input tensor object from data values @@ -119,34 +111,13 @@ Option> TextEmbedder::Embed(const std::string& text) { Option>> TextEmbedder::batch_embed(const std::vector& inputs) { std::vector> outputs; - if(!is_openai()) { + if(!is_remote()) { // for now only openai is supported for batch embedding for(const auto& input : inputs) { outputs.push_back(Embed(input).get()); } } else { - nlohmann::json req_body; - req_body["input"] = inputs; - // remove "openai/" prefix - req_body["model"] = openai_model_path.substr(7); - std::unordered_map headers; - headers["Authorization"] = "Bearer " + api_key; - headers["Content-Type"] = "application/json"; - std::map res_headers; - std::string res; - HttpClient& client = HttpClient::get_instance(); - - auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); - - if(res_code != 200) { - LOG(ERROR) << "OpenAI API error: " << res; - return Option>>(400, res); - } - - nlohmann::json res_json = nlohmann::json::parse(res); - for(auto& data : res_json["data"]) { - outputs.push_back(data["embedding"].get>()); - } + outputs = std::move(remote_embedder_->batch_embed(inputs).get()); } return Option>>(outputs); } @@ -252,55 +223,14 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n } -Option TextEmbedder::is_model_valid(const std::string openai_model_path, const std::string api_key, unsigned int& num_dims) { - if (openai_model_path.empty() || api_key.empty() || openai_model_path.length() < 7) { - return Option(400, "Invalid OpenAI model path or API key"); - } +Option TextEmbedder::is_model_valid(const std::string model_name, const std::string api_key, unsigned int& num_dims) { + auto model_namespace = TextEmbedderManager::get_model_namespace(model_name); - HttpClient& client = HttpClient::get_instance(); - std::unordered_map headers; - std::map res_headers; - headers["Authorization"] = "Bearer " + api_key; - std::string res; - auto res_code = client.get_response(TextEmbedder::OPENAI_LIST_MODELS, res, res_headers, headers); - if (res_code != 200) { - LOG(ERROR) << "OpenAI API error: " << res; - return Option(400, "OpenAI API error: " + res); - } - - auto models_json = nlohmann::json::parse(res); - bool found = false; - // extract model name by removing "openai/" prefix - auto model_name = openai_model_path.substr(7); - for (auto& model : models_json["data"]) { - if (model["id"] == model_name) { - found = true; - break; - } - } - - if (!found) { - return Option(400, "OpenAI model not found"); - } - - // This part is hard coded for now. Because OpenAI API does not provide a way to get the output dimensions of the model. - if(model_name.find("-ada-") != std::string::npos) { - if(model_name.substr(model_name.length() - 3) == "002") { - num_dims = 1536; - } else { - num_dims = 1024; - } - } - else if(model_name.find("-davinci-") != std::string::npos) { - num_dims = 12288; - } else if(model_name.find("-curie-") != std::string::npos) { - num_dims = 4096; - } else if(model_name.find("-babbage-") != std::string::npos) { - num_dims = 2048; + if(model_namespace == "openai") { + return OpenAIEmbedder::is_model_valid(model_name, api_key, num_dims); + } else if(model_namespace == "google") { + return GoogleEmbedder::is_model_valid(model_name, api_key, num_dims); } else { - num_dims = 768; + return Option(400, "Invalid model namespace"); } - - - return Option(true); } \ No newline at end of file diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index 579adbd3..75d472a6 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -254,4 +254,12 @@ const std::string TextEmbedderManager::get_model_url(const text_embedding_model& const std::string TextEmbedderManager::get_vocab_url(const text_embedding_model& model) { return MODELS_REPO_URL + model.model_name + "/" + model.vocab_file_name; +} + +const std::string TextEmbedderManager::get_model_namespace(const std::string& model_name) { + if(model_name.find("/") != std::string::npos) { + return model_name.substr(0, model_name.find("/")); + } else { + return "ts"; + } } \ No newline at end of file diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp new file mode 100644 index 00000000..8c257df8 --- /dev/null +++ b/src/text_embedder_remote.cpp @@ -0,0 +1,198 @@ +#include "text_embedder_remote.h" +#include "text_embedder_manager.h" + + + +OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key) : api_key(api_key), openai_model_path(openai_model_path) { + +} + + +Option OpenAIEmbedder::is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims) { + if (model_name.empty() || api_key.empty()) { + return Option(400, "Invalid OpenAI model name or API key"); + } + + if(TextEmbedderManager::get_model_namespace(model_name) != "openai") { + return Option(400, "Invalid OpenAI model name"); + } + + HttpClient& client = HttpClient::get_instance(); + std::unordered_map headers; + std::map res_headers; + headers["Authorization"] = "Bearer " + api_key; + std::string res; + auto res_code = client.get_response(OPENAI_LIST_MODELS, res, res_headers, headers); + if (res_code != 200) { + nlohmann::json json_res = nlohmann::json::parse(res); + if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { + return Option(400, "OpenAI API error: " + res); + } + return Option(400, "OpenAI API error: " + nlohmann::json::parse(res)["error"]["message"].get()); + } + + auto models_json = nlohmann::json::parse(res); + bool found = false; + // extract model name by removing "openai/" prefix + auto model_name_without_namespace = TextEmbedderManager::get_model_name_without_namespace(model_name); + for (auto& model : models_json["data"]) { + if (model["id"] == model_name_without_namespace) { + found = true; + break; + } + } + + if (!found) { + return Option(400, "OpenAI model not found"); + } + + // This part is hard coded for now. Because OpenAI API does not provide a way to get the output dimensions of the model. + if(model_name.find("-ada-") != std::string::npos) { + if(model_name.substr(model_name.length() - 3) == "002") { + num_dims = 1536; + } else { + num_dims = 1024; + } + } + else if(model_name.find("-davinci-") != std::string::npos) { + num_dims = 12288; + } else if(model_name.find("-curie-") != std::string::npos) { + num_dims = 4096; + } else if(model_name.find("-babbage-") != std::string::npos) { + num_dims = 2048; + } else { + num_dims = 768; + } + + return Option(true); +} + +Option> OpenAIEmbedder::Embed(const std::string& text) { + HttpClient& client = HttpClient::get_instance(); + std::unordered_map headers; + std::map res_headers; + headers["Authorization"] = "Bearer " + api_key; + headers["Content-Type"] = "application/json"; + std::string res; + nlohmann::json req_body; + req_body["input"] = text; + // remove "openai/" prefix + req_body["model"] = openai_model_path.substr(7); + auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); + if (res_code != 200) { + nlohmann::json json_res = nlohmann::json::parse(res); + if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { + return Option>(400, "OpenAI API error: " + res); + } + return Option>(400, "OpenAI API error: " + res); + } + return Option>(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); +} + +Option>> OpenAIEmbedder::batch_embed(const std::vector& inputs) { + nlohmann::json req_body; + req_body["input"] = inputs; + // remove "openai/" prefix + req_body["model"] = openai_model_path.substr(7); + std::unordered_map headers; + headers["Authorization"] = "Bearer " + api_key; + headers["Content-Type"] = "application/json"; + std::map res_headers; + std::string res; + HttpClient& client = HttpClient::get_instance(); + + auto res_code = client.post_response(OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); + + if(res_code != 200) { + nlohmann::json json_res = nlohmann::json::parse(res); + if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { + return Option>>(400, "OpenAI API error: " + res); + } + return Option>>(400, res); + } + + nlohmann::json res_json = nlohmann::json::parse(res); + std::vector> outputs; + for(auto& data : res_json["data"]) { + outputs.push_back(data["embedding"].get>()); + } + + return Option>>(outputs); +} + + +GoogleEmbedder::GoogleEmbedder(const std::string& google_api_key) : google_api_key(google_api_key) { + +} + +Option GoogleEmbedder::is_model_valid(const std::string& model_name, const std::string& api_key, unsigned int& num_dims) { + if(model_name.empty() || api_key.empty()) { + return Option(400, "Invalid Google model name or API key"); + } + + if(TextEmbedderManager::get_model_namespace(model_name) != "google") { + return Option(400, "Invalid Google model name"); + } + + if(TextEmbedderManager::get_model_name_without_namespace(model_name) != std::string(SUPPORTED_MODEL)) { + return Option(400, "Invalid Google model name"); + } + + HttpClient& client = HttpClient::get_instance(); + std::unordered_map headers; + std::map res_headers; + headers["Content-Type"] = "application/json"; + std::string res; + nlohmann::json req_body; + req_body["text"] = "test"; + + auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + api_key, req_body.dump(), res, res_headers, headers); + + if(res_code != 200) { + nlohmann::json json_res = nlohmann::json::parse(res); + if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { + return Option(400, "Google API error: " + res); + } + return Option(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); + } + + num_dims = GOOGLE_EMBEDDING_DIM; + + return Option(true); +} + +Option> GoogleEmbedder::Embed(const std::string& text) { + HttpClient& client = HttpClient::get_instance(); + std::unordered_map headers; + std::map res_headers; + headers["Content-Type"] = "application/json"; + std::string res; + nlohmann::json req_body; + req_body["text"] = text; + + auto res_code = client.post_response(std::string(GOOGLE_CREATE_EMBEDDING) + google_api_key, req_body.dump(), res, res_headers, headers); + + if(res_code != 200) { + nlohmann::json json_res = nlohmann::json::parse(res); + if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { + return Option>(400, "Google API error: " + res); + } + return Option>(400, "Google API error: " + nlohmann::json::parse(res)["error"]["message"].get()); + } + + return Option>(nlohmann::json::parse(res)["embedding"]["value"].get>()); +} + + +Option>> GoogleEmbedder::batch_embed(const std::vector& inputs) { + std::vector> outputs; + for(auto& input : inputs) { + auto res = Embed(input); + if(!res.ok()) { + return Option>>(res.code(), res.error()); + } + outputs.push_back(res.get()); + } + + return Option>>(outputs); +} diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 2c909a96..69a17e0e 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5001,7 +5001,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) { } -TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) { +TEST_F(CollectionTest, CreateOpenAIEmbeddingField) { nlohmann::json schema = R"({ "name": "objects", "fields": [ @@ -5032,7 +5032,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) { ASSERT_EQ(1536, add_op.get()["embedding"].size()); } -TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) { +TEST_F(CollectionTest, HideOpenAIApiKey) { nlohmann::json schema = R"({ "name": "objects", "fields": [ @@ -5056,7 +5056,7 @@ TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) { ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*')); } -TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) { +TEST_F(CollectionTest, PrefixSearchDisabledForOpenAI) { nlohmann::json schema = R"({ "name": "objects", "fields": [ @@ -5086,10 +5086,10 @@ TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) { auto search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); ASSERT_FALSE(search_res_op.ok()); - ASSERT_EQ("Prefix search is not supported for OpenAI embedder.", search_res_op.error()); + ASSERT_EQ("Prefix search is not supported for remote embedders.", search_res_op.error()); search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); - ASSERT_FALSE(search_res_op.ok()); + ASSERT_TRUE(search_res_op.ok()); }