mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 20:52:50 +08:00
Add support for custom OpenAI API URL (#1583)
* Add support for custom OpenAI API URL * Fix test
This commit is contained in:
parent
f514c42e2d
commit
c30fc2791f
@ -47,12 +47,15 @@ class OpenAIEmbedder : public RemoteEmbedder {
|
||||
private:
|
||||
std::string api_key;
|
||||
std::string openai_model_path;
|
||||
static constexpr char* OPENAI_LIST_MODELS = "https://api.openai.com/v1/models";
|
||||
static constexpr char* OPENAI_CREATE_EMBEDDING = "https://api.openai.com/v1/embeddings";
|
||||
static constexpr char* OPENAI_CREATE_EMBEDDING = "v1/embeddings";
|
||||
bool has_custom_dims;
|
||||
size_t num_dims;
|
||||
std::string openai_url = "https://api.openai.com";
|
||||
static std::string get_openai_create_embedding_url(const std::string& openai_url) {
|
||||
return openai_url.back() == '/' ? openai_url + OPENAI_CREATE_EMBEDDING : openai_url + "/" + OPENAI_CREATE_EMBEDDING;
|
||||
}
|
||||
public:
|
||||
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims);
|
||||
OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims, const std::string& openai_url);
|
||||
static Option<bool> is_model_valid(const nlohmann::json& model_config, size_t& num_dims);
|
||||
embedding_res_t Embed(const std::string& text, const size_t remote_embedder_timeout_ms = 30000, const size_t remote_embedding_num_tries = 2) override;
|
||||
std::vector<embedding_res_t> batch_embed(const std::vector<std::string>& inputs, const size_t remote_embedding_batch_size = 200,
|
||||
|
@ -70,8 +70,9 @@ TextEmbedder::TextEmbedder(const nlohmann::json& model_config, size_t num_dims,
|
||||
|
||||
if(model_namespace == "openai") {
|
||||
auto api_key = model_config["api_key"].get<std::string>();
|
||||
const std::string& url = model_config.contains("url") ? model_config["url"].get<std::string>() : "";
|
||||
|
||||
remote_embedder_ = std::make_unique<OpenAIEmbedder>(model_name, api_key, num_dims, has_custom_dims);
|
||||
remote_embedder_ = std::make_unique<OpenAIEmbedder>(model_name, api_key, num_dims, has_custom_dims, url);
|
||||
} else if(model_namespace == "google") {
|
||||
auto api_key = model_config["api_key"].get<std::string>();
|
||||
|
||||
|
@ -68,8 +68,13 @@ const std::string RemoteEmbedder::get_model_key(const nlohmann::json& model_conf
|
||||
}
|
||||
}
|
||||
|
||||
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims) : api_key(api_key), openai_model_path(openai_model_path), num_dims(num_dims), has_custom_dims(has_custom_dims) {
|
||||
|
||||
OpenAIEmbedder::OpenAIEmbedder(const std::string& openai_model_path, const std::string& api_key, const size_t num_dims, const bool has_custom_dims, const std::string& openai_url) : api_key(api_key), openai_model_path(openai_model_path),
|
||||
num_dims(num_dims), has_custom_dims(has_custom_dims){
|
||||
if(openai_url.empty()) {
|
||||
this->openai_url = "https://api.openai.com";
|
||||
} else {
|
||||
this->openai_url = openai_url;
|
||||
}
|
||||
}
|
||||
|
||||
Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config, size_t& num_dims) {
|
||||
@ -79,6 +84,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
return validate_properties;
|
||||
}
|
||||
|
||||
const std::string openai_url = model_config.count("url") > 0 ? model_config["url"].get<std::string>() : "https://api.openai.com";
|
||||
auto model_name = model_config["model_name"].get<std::string>();
|
||||
auto api_key = model_config["api_key"].get<std::string>();
|
||||
|
||||
@ -90,47 +96,11 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
std::map<std::string, std::string> res_headers;
|
||||
headers["Authorization"] = "Bearer " + api_key;
|
||||
std::string res;
|
||||
auto res_code = call_remote_api("GET", OPENAI_LIST_MODELS, "", res, res_headers, headers);
|
||||
|
||||
if(res_code == 408) {
|
||||
return Option<bool>(408, "OpenAI API timeout.");
|
||||
}
|
||||
|
||||
if (res_code != 200) {
|
||||
nlohmann::json json_res;
|
||||
try {
|
||||
json_res = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
return Option<bool>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
if(json_res.count("error") == 0 || json_res["error"].count("message") == 0) {
|
||||
return Option<bool>(400, "OpenAI API error: " + res);
|
||||
}
|
||||
return Option<bool>(400, "OpenAI API error: " + json_res["error"]["message"].get<std::string>());
|
||||
}
|
||||
|
||||
nlohmann::json models_json;
|
||||
try {
|
||||
models_json = nlohmann::json::parse(res);
|
||||
} catch (const std::exception& e) {
|
||||
return Option<bool>(400, "Got malformed response from OpenAI API.");
|
||||
}
|
||||
bool found = false;
|
||||
// extract model name by removing "openai/" prefix
|
||||
auto model_name_without_namespace = EmbedderManager::get_model_name_without_namespace(model_name);
|
||||
for (auto& model : models_json["data"]) {
|
||||
if (model["id"] == model_name_without_namespace) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
return Option<bool>(400, "Property `embed.model_config.model_name` is not a valid OpenAI model.");
|
||||
}
|
||||
nlohmann::json req_body;
|
||||
req_body["input"] = "typesense";
|
||||
// remove "openai/" prefix
|
||||
auto model_name_without_namespace = EmbedderManager::get_model_name_without_namespace(model_name);
|
||||
req_body["model"] = model_name_without_namespace;
|
||||
|
||||
if(num_dims > 0) {
|
||||
@ -139,7 +109,7 @@ Option<bool> OpenAIEmbedder::is_model_valid(const nlohmann::json& model_config,
|
||||
|
||||
std::string embedding_res;
|
||||
headers["Content-Type"] = "application/json";
|
||||
res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), embedding_res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), embedding_res, res_headers, headers);
|
||||
|
||||
|
||||
if(res_code == 408) {
|
||||
@ -183,7 +153,7 @@ embedding_res_t OpenAIEmbedder::Embed(const std::string& text, const size_t remo
|
||||
}
|
||||
// remove "openai/" prefix
|
||||
req_body["model"] = EmbedderManager::get_model_name_without_namespace(openai_model_path);
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), res, res_headers, headers);
|
||||
if (res_code != 200) {
|
||||
return embedding_res_t(res_code, get_error_json(req_body, res_code, res));
|
||||
}
|
||||
@ -221,7 +191,7 @@ std::vector<embedding_res_t> OpenAIEmbedder::batch_embed(const std::vector<std::
|
||||
headers["num_try"] = std::to_string(remote_embedding_num_tries);
|
||||
std::map<std::string, std::string> res_headers;
|
||||
std::string res;
|
||||
auto res_code = call_remote_api("POST", OPENAI_CREATE_EMBEDDING, req_body.dump(), res, res_headers, headers);
|
||||
auto res_code = call_remote_api("POST", get_openai_create_embedding_url(openai_url), req_body.dump(), res, res_headers, headers);
|
||||
|
||||
if(res_code != 200) {
|
||||
std::vector<embedding_res_t> outputs;
|
||||
@ -279,7 +249,7 @@ nlohmann::json OpenAIEmbedder::get_error_json(const nlohmann::json& req_body, lo
|
||||
nlohmann::json embedding_res = nlohmann::json::object();
|
||||
embedding_res["response"] = json_res;
|
||||
embedding_res["request"] = nlohmann::json::object();
|
||||
embedding_res["request"]["url"] = OPENAI_CREATE_EMBEDDING;
|
||||
embedding_res["request"]["url"] = get_openai_create_embedding_url(openai_url);
|
||||
embedding_res["request"]["method"] = "POST";
|
||||
embedding_res["request"]["body"] = req_body;
|
||||
if(embedding_res["request"]["body"].count("input") > 0 && embedding_res["request"]["body"]["input"].get<std::vector<std::string>>().size() > 1) {
|
||||
|
@ -5243,7 +5243,7 @@ TEST_F(CollectionTest, CatchPartialResponseFromRemoteEmbedding) {
|
||||
]
|
||||
})"_json;
|
||||
|
||||
OpenAIEmbedder embedder("", "", 0, false);
|
||||
OpenAIEmbedder embedder("", "", 0, false, "");
|
||||
|
||||
auto res = embedder.get_error_json(req_body, 200, partial_json);
|
||||
|
||||
|
@ -4421,3 +4421,27 @@ TEST_F(CollectionVectorTest, TestCFModelResponseParsing) {
|
||||
ASSERT_EQ("00,\n\"publishDateYear\": 2011,\n\"title\": \"SOPA\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"transcript\": \" \"\n},\n{\n\"altTitle\": \"I'm currently getting totally blacked out.\",\n\"id\": \"1006\",\n\"imageUrl\": \"https://imgs.xkcd.com/comics/blackout.png\",\n\"publishDateDay\": 18,\n\"publishDateMonth\": 1,\n\"publishDateTimestamp\": 1326866400,\n\"publishDateYear\": 2011,\n\"title\": \"Blackout\",\n\"topics\": [\n\"Links to xkcd.com\",\n\"April fools' comics\",\n\"Interactive comics\",\n\"Comics with animation\",\n\"Dynamic comics\",\n\"Comics with audio\"\n ],\n\"", parsed_string.get());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, TestInvalidOpenAIURL) {
|
||||
nlohmann::json schema_json = R"({
|
||||
"name": "test",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{
|
||||
"name": "vector",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": ["name"],
|
||||
"model_config": {
|
||||
"model_name": "openai/text-embedding-3-small",
|
||||
"api_key": "123",
|
||||
"url": "invalid url"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema_json);
|
||||
ASSERT_FALSE(collection_create_op.ok());
|
||||
ASSERT_EQ("OpenAI API error: ", collection_create_op.error());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user