diff --git a/include/text_embedder_remote.h b/include/text_embedder_remote.h index 5ba12690..f2b0b41e 100644 --- a/include/text_embedder_remote.h +++ b/include/text_embedder_remote.h @@ -47,12 +47,15 @@ class OpenAIEmbedder : public RemoteEmbedder { private: std::string api_key; std::string openai_model_path; - static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models"; - static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings"; + static constexpr char* OPENAI_CREATE_EMBEDDING = "v1/embeddings"; bool has_custom_dims; size_t num_dims; + std::string openai_url = "https://api.openai.com"; + static std::string get_openai_create_embedding_url(const std::string& openai_url) { + return openai_url.back() == '/' ? openai_url + OPENAI_CREATE_EMBEDDING : openai_url + "/" + OPENAI_CREATE_EMBEDDING; + } public: - OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims); + OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims, const std::string& openai_url); static Option is_model_valid(const nlohmann::json& model_config, size_t& num_dims); embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override; std::vector batch_embed(const std::vector& inputs, const size_t remote_embedding_batch_size = 200, diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 56dfc335..804970b1 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -70,8 +70,9 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims, if(model_namespace == "openai") { auto api_key = model_config["api_key"].get(); + const std::string& url = model_config.contains("url") ? model_config["url"].get() : ""; - remote_embedder_ = std::make_unique(model_name, api_key, num_dims, has_custom_dims); + remote_embedder_ = std::make_unique(model_name, api_key, num_dims, has_custom_dims, url); } else if(model_namespace == "google") { auto api_key = model_config["api_key"].get(); diff --git a/src/text_embedder_remote.cpp b/src/text_embedder_remote.cpp index 46f7193a..48df5b06 100644 --- a/src/text_embedder_remote.cpp +++ b/src/text_embedder_remote.cpp @@ -68,8 +68,13 @@ const std::string RemoteEmbedder::get_model_key(const nlohmann::json& model_conf } } -OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims) : api_key(api_key), openai_model_path(openai_model_path), num_dims(num_dims), has_custom_dims(has_custom_dims) { - +OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims, const std::string& openai_url) : api_key(api_key), openai_model_path(openai_model_path), + num_dims(num_dims), has_custom_dims(has_custom_dims){ + if(openai_url.empty()) { + this->openai_url = "https://api.openai.com"; + } else { + this->openai_url = openai_url; + } } Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, size_t& num_dims) { @@ -79,6 +84,7 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, return validate_properties; } + const std::string openai_url = model_config.count("url") > 0 ? model_config["url"].get() : "https://api.openai.com"; auto model_name = model_config["model_name"].get(); auto api_key = model_config["api_key"].get(); @@ -90,47 +96,11 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::map res_headers; headers["Authorization"] = "Bearer " + api_key; std::string res; - auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers); - if(res_code == 408) { - return Option(408, "OpenAI API timeout."); - } - - if (res_code != 200) { - nlohmann::json json_res; - try { - json_res = nlohmann::json::parse(res); - } catch (const std::exception& e) { - return Option(400, "OpenAI API error: " + res); - } - if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) { - return Option(400, "OpenAI API error: " + res); - } - return Option(400, "OpenAI API error: " + json_res["error"]["message"].get()); - } - - nlohmann::json models_json; - try { - models_json = nlohmann::json::parse(res); - } catch (const std::exception& e) { - return Option(400, "Got malformed response from OpenAI API."); - } - bool found = false; - // extract model name by removing "openai/" prefix - auto model_name_without_namespace = EmbedderManager::get_model_name_without_namespace(model_name); - for (auto& model : models_json["data"]) { - if (model["id"] == model_name_without_namespace) { - found = true; - break; - } - } - - if (!found) { - return Option(400, "Property `embed.model_config.model_name` is not a valid OpenAI model."); - } nlohmann::json req_body; req_body["input"] = "typesense"; // remove "openai/" prefix + auto model_name_without_namespace = EmbedderManager::get_model_name_without_namespace(model_name); req_body["model"] = model_name_without_namespace; if(num_dims > 0) { @@ -139,7 +109,7 @@ Option OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, std::string embedding_res; headers["Content-Type"] = "application/json"; - res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers); + auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), embedding_res, res_headers, headers); if(res_code == 408) { @@ -183,7 +153,7 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remo } // remove "openai/" prefix req_body["model"] = EmbedderManager::get_model_name_without_namespace(openai_model_path); - auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), res, res_headers, headers); if (res_code != 200) { return embedding_res_t(res_code, get_error_json(req_body, res_code, res)); } @@ -221,7 +191,7 @@ std::vector OpenAIEmbedder::batch_embed(const std::vector res_headers; std::string res; - auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers); + auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), res, res_headers, headers); if(res_code != 200) { std::vector outputs; @@ -279,7 +249,7 @@ nlohmann::json OpenAIEmbedder::get_error_json(const nlohmann::json& req_body, lo nlohmann::json embedding_res = nlohmann::json::object(); embedding_res["response"] = json_res; embedding_res["request"] = nlohmann::json::object(); - embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING; + embedding_res["request"]["url"] = get_openai_create_embedding_url(openai_url); embedding_res["request"]["method"] = "POST"; embedding_res["request"]["body"] = req_body; if(embedding_res["request"]["body"].count("input") > 0 && embedding_res["request"]["body"]["input"].get>().size() > 1) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index d6edc02b..7df81a1a 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5243,7 +5243,7 @@ TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) { ] })"_json; - OpenAIEmbedder embedder("", "", 0, false); + OpenAIEmbedder embedder("", "", 0, false, ""); auto res = embedder.get_error_json(req_body, 200, partial_json); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index a1e3e0f0..452e8a4a 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -4421,3 +4421,27 @@ TEST_F(CollectionVectorTest, TestCFModelResponseParsing) { ASSERT_EQ("00,\n\"publishDateYear\": 2011,\n\"title\": \"SOPA\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"transcript\": \" \"\n},\n{\n\"altTitle\": \"I'm currently getting totally blacked out.\",\n\"id\": \"1006\",\n\"imageUrl\": \"https://imgs.xkcd.com/comics/blackout.png\",\n\"publishDateDay\": 18,\n\"publishDateMonth\": 1,\n\"publishDateTimestamp\": 1326866400,\n\"publishDateYear\": 2011,\n\"title\": \"Blackout\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"", parsed_string.get()); } +TEST_F(CollectionVectorTest, TestInvalidOpenAIURL) { + nlohmann::json schema_json = R"({ + "name": "test", + "fields": [ + {"name": "name", "type": "string"}, + { + "name": "vector", + "type": "float[]", + "embed": { + "from": ["name"], + "model_config": { + "model_name": "openai/text-embedding-3-small", + "api_key": "123", + "url": "invalid url" + } + } + } + ] + })"_json; + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_FALSE(collection_create_op.ok()); + ASSERT_EQ("OpenAI API error: ", collection_create_op.error()); +} \ No newline at end of file