From 38eab6251bdcdc4df67f8500315f2a585a3dfbd3 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Tue, 11 Apr 2023 17:15:03 +0300 Subject: [PATCH 1/2] Temp fix for text embedding prefixes --- src/collection.cpp | 4 +++- src/index.cpp | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index a4e1a2d7..59b81ea3 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1186,7 +1186,9 @@ Option Collection::search(std::string raw_query, TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder = embedder_manager.get_text_embedder(search_field.model_name.size() > 0 ? search_field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(raw_query); + std::string embed_query = "query: " + raw_query; + + std::vector embedding = embedder->Embed(embed_query); vector_query._reset(); vector_query.values = embedding; vector_query.field_name = field_name; diff --git a/src/index.cpp b/src/index.cpp index 97accb7a..6e99d0d0 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6337,7 +6337,7 @@ Option Index::embed_fields(nlohmann::json& document, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema) { for(const auto& field : embedding_fields) { - std::string text_to_embed; + std::string text_to_embed = "passage: "; for(const auto& field_name : field.embed_from) { auto field_it = search_schema.find(field_name); if(field_it.value().type == field_types::STRING) { From dede71bad99301d8327b519ce33c86f6499e32e4 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Tue, 11 Apr 2023 21:03:19 +0300 Subject: [PATCH 2/2] Added support to add text embedding fields with schema update --- src/collection.cpp | 4 ++-- test/collection_schema_change_test.cpp | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index 59b81ea3..d6aa3403 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -4131,7 +4131,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, } if(!f.embed_from.empty()) { - return Option(400, "Embedding fields can only be added at the time of collection creation."); + embedding_fields.emplace(f.name, f); } if(f.nested && enable_nested_fields) { @@ -4146,7 +4146,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); if(!prefix_kv.value().embed_from.empty()) { - return Option(400, "Embedding fields can only be added at the time of collection creation."); + embedding_fields.emplace(prefix_kv.key(), prefix_kv.value()); } if(is_reindex) { diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index 1a301446..88d8e940 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1455,10 +1455,13 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { ] })"_json; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); ASSERT_TRUE(op.ok()); Collection* coll = op.get(); + nlohmann::json update_schema = R"({ "fields": [ @@ -1468,8 +1471,17 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { auto res = coll->alter(update_schema); - ASSERT_FALSE(res.ok()); - ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); + ASSERT_TRUE(res.ok()); + ASSERT_EQ(1, coll->get_embedding_fields().size()); + + nlohmann::json doc; + doc["names"] = {"hello", "world"}; + auto add_op = coll->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); + auto added_doc = add_op.get(); + + ASSERT_EQ(384, added_doc["embedding"].get>().size()); } TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) {