From abff0e0cb67f453eaf4308cd2d7c1170dbc21b74 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Thu, 27 Jul 2023 13:37:04 +0300 Subject: [PATCH 1/4] Ignore null optional fields while generating embedding --- src/index.cpp | 5 ++++- src/validator.cpp | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index a8008a2d..dd1f5335 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6495,6 +6495,9 @@ void Index::batch_embed_fields(std::vector& records, auto embed_from = field.embed[fields::from].get>(); for(const auto& field_name : embed_from) { auto field_it = search_schema.find(field_name); + if(document->count(field_name) == 0) { + continue; + } if(field_it.value().type == field_types::STRING) { text += (*document)[field_name].get() + " "; } else if(field_it.value().type == field_types::STRING_ARRAY) { @@ -6511,7 +6514,7 @@ void Index::batch_embed_fields(std::vector& records, if(texts_to_embed.empty()) { continue; } - + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder_op = embedder_manager.get_text_embedder(field.embed[fields::model_config]); diff --git a/src/validator.cpp b/src/validator.cpp index b8516130..58a443fe 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -677,7 +677,7 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); } if(doc_field_it == document.end()) { - if(error_if_field_not_found) { + if(error_if_field_not_found && schema_field_it->optional == false) { return Option(400, "Field `" + field_name + "` is needed to create embedding."); } else { continue; From 94d7b54b8b339938e02dc218081fce28a28ec288 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 28 Jul 2023 14:38:27 +0300 Subject: [PATCH 2/4] Add support for optional embedding field --- src/validator.cpp | 10 ++++-- test/collection_test.cpp | 3 +- test/collection_vector_search_test.cpp | 48 ++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/validator.cpp b/src/validator.cpp index 58a443fe..5fe9bab3 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -669,7 +669,9 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, const tsl::htrie_map & search_schema, const bool& error_if_field_not_found) { for(const auto& field : embedding_fields) { - auto embed_from = field.embed[fields::from].get>(); + const auto& embed_from = field.embed[fields::from].get>(); + // flag to check if all fields to embed from are optional and null + bool all_optional_and_null = true; for(const auto& field_name : embed_from) { auto schema_field_it = search_schema.find(field_name); auto doc_field_it = document.find(field_name); @@ -677,12 +679,13 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); } if(doc_field_it == document.end()) { - if(error_if_field_not_found && schema_field_it->optional == false) { + if(error_if_field_not_found && !schema_field_it->optional) { return Option(400, "Field `" + field_name + "` is needed to create embedding."); } else { continue; } } + all_optional_and_null = false; if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { return Option(400, "Field `" + field_name + "` has malformed data."); @@ -695,6 +698,9 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, } } } + if(all_optional_and_null && !field.optional) { + return Option(400, "No valid fields found to create embedding for `" + field.name + "`, please provide at least one valid field or make the embedding field optional."); + } } return Option(true); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index bb1faa88..8211f962 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4870,8 +4870,7 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) { doc["names"].push_back("butterball"); auto add_op = coll->add(doc.dump()); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error()); + ASSERT_TRUE(add_op.ok()); } TEST_F(CollectionTest, WrongTypeInEmbedFrom) { diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index b726a06a..7f1e7eef 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -987,4 +987,52 @@ TEST_F(CollectionVectorTest, HybridSearchSortByGeopoint) { ASSERT_EQ("butter", search_res["hits"][0]["document"]["name"].get()); ASSERT_EQ("butterball", search_res["hits"][1]["document"]["name"].get()); ASSERT_EQ("butterfly", search_res["hits"][2]["document"]["name"].get()); +} + + +TEST_F(CollectionVectorTest, EmbedFromOptionalNullField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "text", "type": "string", "optional": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["text"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto op = collectionManager.create_collection(schema); + + ASSERT_TRUE(op.ok()); + auto coll = op.get(); + + nlohmann::json doc = R"({ + })"_json; + + auto add_op = coll->add(doc.dump()); + + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("No valid fields found to create embedding for `embedding`, please provide at least one valid field or make the embedding field optional.", add_op.error()); + + doc["text"] = "butter"; + add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + // drop the embedding field and reindex + + nlohmann::json alter_schema = R"({ + "fields": [ + {"name": "embedding", "drop": true}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["text"], "model_config": {"model_name": "ts/e5-small"}}, "optional": true} + ] + })"_json; + + auto update_op = coll->alter(alter_schema); + ASSERT_TRUE(update_op.ok()); + + + doc = R"({ + })"_json; + add_op = coll->add(doc.dump()); + + ASSERT_TRUE(add_op.ok()); } \ No newline at end of file From 5aaaac29a6eecb8de004bedbe225a9084b5d6f85 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 28 Jul 2023 18:28:29 +0300 Subject: [PATCH 3/4] Fix document iterator comparison --- src/index.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index dd1f5335..22a34b0d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6492,10 +6492,11 @@ void Index::batch_embed_fields(std::vector& records, continue; } std::string text = indexing_prefix; - auto embed_from = field.embed[fields::from].get>(); + const auto& embed_from = field.embed[fields::from].get>(); for(const auto& field_name : embed_from) { auto field_it = search_schema.find(field_name); - if(document->count(field_name) == 0) { + auto doc_field_it = document->find(field_name); + if(doc_field_it == document->end()) { continue; } if(field_it.value().type == field_types::STRING) { From dce27b918ef418eb85ddde7f72d27e3085979219 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 28 Jul 2023 19:00:16 +0300 Subject: [PATCH 4/4] Use doc_field_it --- src/index.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 22a34b0d..7192faba 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6500,9 +6500,9 @@ void Index::batch_embed_fields(std::vector& records, continue; } if(field_it.value().type == field_types::STRING) { - text += (*document)[field_name].get() + " "; + text += doc_field_it->get() + " "; } else if(field_it.value().type == field_types::STRING_ARRAY) { - for(const auto& val : (*document)[field_name]) { + for(const auto& val : *(doc_field_it)) { text += val.get() + " "; } }