Add support for custom OpenAI API URL (#1583)

* Add support for custom OpenAI API URL

* Fix test
This commit is contained in:
Ozan Armağan 2024-02-28 12:38:37 +03:00 committed by GitHub
parent f514c42e2d
commit c30fc2791f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 48 deletions

View File

@ -47,12 +47,15 @@ class OpenAIEmbedder : public RemoteEmbedder {
private:
std::string api_key;
std::string openai_model_path;
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings";
static constexpr char* OPENAI_CREATE_EMBEDDING = "v1/embeddings";
bool has_custom_dims;
size_t num_dims;
std::string openai_url = "https://api.openai.com";
static std::string get_openai_create_embedding_url(const std::string& openai_url) {
return openai_url.back() == '/' ? openai_url + OPENAI_CREATE_EMBEDDING : openai_url + "/" + OPENAI_CREATE_EMBEDDING;
}
public:
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims);
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims, const std::string& openai_url);
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200,

View File

@ -70,8 +70,9 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims,
if(model_namespace == "openai") {
auto api_key = model_config["api_key"].get<std::string>();
const std::string& url = model_config.contains("url") ? model_config["url"].get<std::string>() : "";
remote_embedder_ = std::make_unique<OpenAIEmbedder>(model_name, api_key, num_dims, has_custom_dims);
remote_embedder_ = std::make_unique<OpenAIEmbedder>(model_name, api_key, num_dims, has_custom_dims, url);
} else if(model_namespace == "google") {
auto api_key = model_config["api_key"].get<std::string>();

View File

@ -68,8 +68,13 @@ const std::string RemoteEmbedder::get_model_key(const nlohmann::json& model_conf
}
}
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims) : api_key(api_key), openai_model_path(openai_model_path), num_dims(num_dims), has_custom_dims(has_custom_dims) {
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims, const std::string& openai_url) : api_key(api_key), openai_model_path(openai_model_path),
num_dims(num_dims), has_custom_dims(has_custom_dims){
if(openai_url.empty()) {
this->openai_url = "https://api.openai.com";
} else {
this->openai_url = openai_url;
}
}
Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, size_t& num_dims) {
@ -79,6 +84,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
return validate_properties;
}
const std::string openai_url = model_config.count("url") > 0 ? model_config["url"].get<std::string>() : "https://api.openai.com";
auto model_name = model_config["model_name"].get<std::string>();
auto api_key = model_config["api_key"].get<std::string>();
@ -90,47 +96,11 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
std::map<std::string, std::string> res_headers;
headers["Authorization"] = "Bearer " + api_key;
std::string res;
auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
if(res_code == 408) {
return Option<bool>(408, "OpenAI API timeout.");
}
if (res_code != 200) {
nlohmann::json json_res;
try {
json_res = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "OpenAI API error: " + res);
}
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
return Option<bool>(400, "OpenAI API error: " + res);
}
return Option<bool>(400, "OpenAI API error: " + json_res["error"]["message"].get<std::string>());
}
nlohmann::json models_json;
try {
models_json = nlohmann::json::parse(res);
} catch (const std::exception& e) {
return Option<bool>(400, "Got malformed response from OpenAI API.");
}
bool found = false;
// extract model name by removing "openai/" prefix
auto model_name_without_namespace = EmbedderManager::get_model_name_without_namespace(model_name);
for (auto& model : models_json["data"]) {
if (model["id"] == model_name_without_namespace) {
found = true;
break;
}
}
if (!found) {
return Option<bool>(400, "Property `embed.model_config.model_name` is not a valid OpenAI model.");
}
nlohmann::json req_body;
req_body["input"] = "typesense";
// remove "openai/" prefix
auto model_name_without_namespace = EmbedderManager::get_model_name_without_namespace(model_name);
req_body["model"] = model_name_without_namespace;
if(num_dims > 0) {
@ -139,7 +109,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
std::string embedding_res;
headers["Content-Type"] = "application/json";
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), embedding_res, res_headers, headers);
if(res_code == 408) {
@ -183,7 +153,7 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remo
}
// remove "openai/" prefix
req_body["model"] = EmbedderManager::get_model_name_without_namespace(openai_model_path);
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), res, res_headers, headers);
if (res_code != 200) {
return embedding_res_t(res_code, get_error_json(req_body, res_code, res));
}
@ -221,7 +191,7 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
headers["num_try"] = std::to_string(remote_embedding_num_tries);
std::map<std::string, std::string> res_headers;
std::string res;
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), res, res_headers, headers);
if(res_code != 200) {
std::vector<embedding_res_t> outputs;
@ -279,7 +249,7 @@ nlohmann::json OpenAIEmbedder::get_error_json(const nlohmann::json& req_body, lo
nlohmann::json embedding_res = nlohmann::json::object();
embedding_res["response"] = json_res;
embedding_res["request"] = nlohmann::json::object();
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
embedding_res["request"]["url"] = get_openai_create_embedding_url(openai_url);
embedding_res["request"]["method"] = "POST";
embedding_res["request"]["body"] = req_body;
if(embedding_res["request"]["body"].count("input") > 0 && embedding_res["request"]["body"]["input"].get<std::vector<std::string>>().size() > 1) {

View File

@ -5243,7 +5243,7 @@ TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) {
]
})"_json;
OpenAIEmbedder embedder("", "", 0, false);
OpenAIEmbedder embedder("", "", 0, false, "");
auto res = embedder.get_error_json(req_body, 200, partial_json);

View File

@ -4421,3 +4421,27 @@ TEST_F(CollectionVectorTest, TestCFModelResponseParsing) {
ASSERT_EQ("00,\n\"publishDateYear\": 2011,\n\"title\": \"SOPA\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"transcript\": \" \"\n},\n{\n\"altTitle\": \"I'm currently getting totally blacked out.\",\n\"id\": \"1006\",\n\"imageUrl\": \"https://imgs.xkcd.com/comics/blackout.png\",\n\"publishDateDay\": 18,\n\"publishDateMonth\": 1,\n\"publishDateTimestamp\": 1326866400,\n\"publishDateYear\": 2011,\n\"title\": \"Blackout\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"", parsed_string.get());
}
TEST_F(CollectionVectorTest, TestInvalidOpenAIURL) {
nlohmann::json schema_json = R"({
"name": "test",
"fields": [
{"name": "name", "type": "string"},
{
"name": "vector",
"type": "float[]",
"embed": {
"from": ["name"],
"model_config": {
"model_name": "openai/text-embedding-3-small",
"api_key": "123",
"url": "invalid url"
}
}
}
]
})"_json;
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_FALSE(collection_create_op.ok());
ASSERT_EQ("OpenAI API error: ", collection_create_op.error());
}