diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index 04f5e3e6..a008ee76 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -8,6 +8,7 @@ #include #include "logger.h" #include "http_client.h" +#include "option.h" #include "text_embedder.h" struct text_embedding_model { @@ -34,7 +35,7 @@ public: TextEmbedderManager(const TextEmbedderManager&) = delete; TextEmbedderManager& operator=(const TextEmbedderManager&) = delete; - TextEmbedder* get_text_embedder(const nlohmann::json& model_config); + Option get_text_embedder(const nlohmann::json& model_config); void delete_text_embedder(const std::string& model_path); void delete_all_text_embedders(); @@ -59,7 +60,7 @@ public: static const std::string get_model_name_without_namespace(const std::string& model_name); static const std::string get_model_subdir(const std::string& model_name); static const bool check_md5(const std::string& file_path, const std::string& target_md5); - void download_public_model(const std::string& model_name); + Option download_public_model(const std::string& model_name); const bool is_public_model(const std::string& model_name); diff --git a/src/collection.cpp b/src/collection.cpp index fe1cbb63..d9e4d5ca 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -242,13 +242,13 @@ nlohmann::json Collection::get_summary_json() const { if(coll_field.embed.count(fields::from) != 0) { field_json[fields::embed] = coll_field.embed; - if(field_json[fields::embed].count(fields::api_key) != 0) { + if(field_json[fields::embed].count(fields::model_config) != 0 && field_json[fields::embed][fields::model_config].count(fields::api_key) != 0) { // hide api key with * except first 3 chars - std::string api_key = field_json[fields::embed][fields::api_key]; + std::string api_key = field_json[fields::embed][fields::model_config][fields::api_key]; if(api_key.size() > 3) { - field_json[fields::embed][fields::api_key] = api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'); + field_json[fields::embed][fields::model_config][fields::api_key] = api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*'); } else { - field_json[fields::embed][fields::api_key] = api_key.replace(0, api_key.size(), api_key.size(), '*'); + field_json[fields::embed][fields::model_config][fields::api_key] = api_key.replace(0, api_key.size(), api_key.size(), '*'); } } } @@ -1189,7 +1189,19 @@ Option Collection::search(std::string raw_query, } TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(search_field.embed[fields::model_config]); + auto embedder_op = embedder_manager.get_text_embedder(search_field.embed[fields::model_config]); + if(!embedder_op.ok()) { + return Option(400, embedder_op.error()); + } + auto embedder = embedder_op.get(); + + if(embedder->is_openai()) { + // return error if prefix search is used with openai embedder + if((prefixes.size() == 1 && prefixes[0] == true) || (prefixes.size() > 1 && prefixes[i] == true)) { + std::string error = "Prefix search is not supported for OpenAI embedder."; + return Option(400, error); + } + } std::string embed_query = embedder_manager.get_query_prefix(search_field.embed[fields::model_config]) + raw_query; auto embedding_op = embedder->Embed(embed_query); diff --git a/src/index.cpp b/src/index.cpp index 013ab0b9..c8971785 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6365,8 +6365,13 @@ Option Index::batch_embed_fields(std::vector& documents, text_to_embed.push_back(text); } TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.embed[fields::model_config]); - auto embedding_op = embedder->batch_embed(text_to_embed); + auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]); + + if(!embedder_op.ok()) { + return Option(400, embedder_op.error()); + } + + auto embedding_op = embedder_op.get()->batch_embed(text_to_embed); if(!embedding_op.ok()) { return Option(400, embedding_op.error()); diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index 10647d09..775edf2a 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -73,7 +73,6 @@ Option> TextEmbedder::Embed(const std::string& text) { } return Option>(nlohmann::json::parse(res)["data"][0]["embedding"].get>()); } else { - LOG(INFO) << "Embedding text: " << text; auto encoded_input = tokenizer_->Encode(text); // create input tensor object from data values Ort::AllocatorWithDefaultOptions allocator; @@ -160,7 +159,11 @@ bool TextEmbedder::is_model_valid(const std::string& model_name, unsigned int& n LOG(INFO) << "Loading model: " << model_name; if(TextEmbedderManager::get_instance().is_public_model(model_name)) { - TextEmbedderManager::get_instance().download_public_model(model_name); + auto res = TextEmbedderManager::get_instance().download_public_model(model_name); + if(!res.ok()) { + LOG(ERROR) << res.error(); + return false; + } } diff --git a/src/text_embedder_manager.cpp b/src/text_embedder_manager.cpp index 0db49149..579adbd3 100644 --- a/src/text_embedder_manager.cpp +++ b/src/text_embedder_manager.cpp @@ -6,21 +6,24 @@ TextEmbedderManager& TextEmbedderManager::get_instance() { return instance; } -TextEmbedder* TextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) { +OptionTextEmbedderManager::get_text_embedder(const nlohmann::json& model_config) { std::unique_lock lock(text_embedders_mutex); const std::string& model_name = model_config.at("model_name"); if(text_embedders[model_name] == nullptr) { if(model_config.count("api_key") == 0) { if(is_public_model(model_name)) { // download the model if it doesn't exist - download_public_model(model_name); + auto res = download_public_model(model_name); + if(!res.ok()) { + return Option(res.code(), res.error()); + } } text_embedders[model_name] = std::make_shared(get_model_name_without_namespace(model_name)); } else { text_embedders[model_name] = std::make_shared(model_name, model_config.at("api_key").get()); } } - return text_embedders[model_name].get(); + return Option(text_embedders[model_name].get()); } void TextEmbedderManager::delete_text_embedder(const std::string& model_path) { @@ -119,7 +122,7 @@ const bool TextEmbedderManager::check_md5(const std::string& file_path, const st } return res.str() == target_md5; } -void TextEmbedderManager::download_public_model(const std::string& model_name) { +Option TextEmbedderManager::download_public_model(const std::string& model_name) { HttpClient& httpClient = HttpClient::get_instance(); auto model = public_models[model_name]; auto actual_model_name = get_model_name_without_namespace(model_name); @@ -127,6 +130,7 @@ void TextEmbedderManager::download_public_model(const std::string& model_name) { long res = httpClient.download_file(get_model_url(model), get_absolute_model_path(actual_model_name)); if(res != 200) { LOG(INFO) << "Failed to download public model " << model_name << ": " << res; + return Option(400, "Failed to download model file"); } } @@ -134,9 +138,11 @@ void TextEmbedderManager::download_public_model(const std::string& model_name) { long res = httpClient.download_file(get_vocab_url(model), get_absolute_vocab_path(actual_model_name, model.vocab_file_name)); if(res != 200) { LOG(INFO) << "Failed to download default vocab " << model_name << ": " << res; + return Option(400, "Failed to download vocab file"); } } + return Option(true); } const bool TextEmbedderManager::is_public_model(const std::string& model_name) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index af81fb61..2c909a96 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4892,6 +4892,21 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) { ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error()); } +TEST_F(CollectionTest, WrongTypeInEmbedFrom) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": [1122], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + ASSERT_FALSE(op.ok()); + ASSERT_EQ("Property `embed.from` must contain only field names as strings.", op.error()); +} TEST_F(CollectionTest, WrongTypeForEmbedding) { nlohmann::json schema = R"({ @@ -4991,7 +5006,7 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}} + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}} ] })"_json; @@ -5001,15 +5016,13 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) { } auto api_key = std::string(std::getenv("api_key")); - schema["fields"][1]["model_config"]["api_key"] = api_key; + schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key; TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); auto op = collectionManager.create_collection(schema); ASSERT_TRUE(op.ok()); auto summary = op.get()->get_summary_json(); - ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["model_config"]["model_name"]); + ASSERT_EQ("openai/text-embedding-ada-002", summary["fields"][1]["embed"]["model_config"]["model_name"]); ASSERT_EQ(1536, summary["fields"][1]["num_dim"]); - // make sure api_key is - ASSERT_EQ("", summary["fields"][1]["model_config"]["api_key"]); nlohmann::json doc; doc["name"] = "butter"; @@ -5019,7 +5032,68 @@ TEST_F(CollectionTest, DISABLED_CreateOpenAIEmbeddingField) { ASSERT_EQ(1536, add_op.get()["embedding"].size()); } -TEST_F(CollectionTest, MoreThganOneEmbeddingField) { +TEST_F(CollectionTest, DISABLED_HideOpenAIApiKey) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}} + ] + })"_json; + + if (std::getenv("api_key") == nullptr) { + LOG(INFO) << "Skipping test as api_key is not set."; + return; + } + + auto api_key = std::string(std::getenv("api_key")); + schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + auto summary = op.get()->get_summary_json(); + // hide api key with * after first 3 characters + ASSERT_EQ(summary["fields"][1]["embed"]["model_config"]["api_key"].get(), api_key.replace(3, api_key.size() - 3, api_key.size() - 3, '*')); +} + +TEST_F(CollectionTest, DISABLED_PrefixSearchDisabledForOpenAI) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "openai/text-embedding-ada-002"}}} + ] + })"_json; + + if (std::getenv("api_key") == nullptr) { + LOG(INFO) << "Skipping test as api_key is not set."; + return; + } + + auto api_key = std::string(std::getenv("api_key")); + schema["fields"][1]["embed"]["model_config"]["api_key"] = api_key; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + + nlohmann::json doc; + doc["name"] = "butter"; + + auto add_op = op.get()->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + + spp::sparse_hash_set dummy_include_exclude; + auto search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); + + ASSERT_FALSE(search_res_op.ok()); + ASSERT_EQ("Prefix search is not supported for OpenAI embedder.", search_res_op.error()); + + search_res_op = op.get()->search("dummy", {"embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {false}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); + ASSERT_FALSE(search_res_op.ok()); +} + + +TEST_F(CollectionTest, MoreThanOneEmbeddingField) { nlohmann::json schema = R"({ "name": "objects", "fields": [