diff --git a/include/validator.h b/include/validator.h index b4952cf6..d3898dba 100644 --- a/include/validator.h +++ b/include/validator.h @@ -71,6 +71,6 @@ public: static Option validate_embed_fields(const nlohmann::json& document, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema, - const bool& error_if_field_not_found); + const bool& is_update); }; \ No newline at end of file diff --git a/src/index.cpp b/src/index.cpp index 69df9301..a87f245c 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -504,6 +504,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite index_rec.index_failure(400, e.what()); } } + if(generate_embeddings) { batch_embed_fields(records_to_embed, embedding_fields, search_schema, remote_embedding_batch_size); } @@ -6499,6 +6500,12 @@ void Index::batch_embed_fields(std::vector& records, if(document == nullptr) { continue; } + + if(document->contains(field.name) && !record->is_update) { + // embedding already exists (could be a restore from export) + continue; + } + std::string text = indexing_prefix; const auto& embed_from = field.embed[fields::from].get>(); for(const auto& field_name : embed_from) { diff --git a/src/validator.cpp b/src/validator.cpp index 5fe9bab3..51a7d19c 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -654,7 +654,7 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, if(validate_embedding_fields) { // validate embedding fields - auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update); + auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, is_update); if(!validate_embed_op.ok()) { return Option<>(validate_embed_op.code(), validate_embed_op.error()); } @@ -667,8 +667,26 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, Option validator_t::validate_embed_fields(const nlohmann::json& document, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema, - const bool& error_if_field_not_found) { + const bool& is_update) { for(const auto& field : embedding_fields) { + if(document.contains(field.name) && !is_update) { + const auto& field_vec = document[field.name]; + if(!field_vec.is_array() || field_vec.empty() || !field_vec[0].is_number() || + field_vec.size() != field.num_dim) { + return Option(400, "Field `" + field.name + "` contains an invalid embedding."); + } + + auto it = field_vec.begin(); + while(it != field_vec.end()) { + if(!it.value().is_number()) { + return Option(400, "Field `" + field.name + "` contains invalid float values."); + } + it++; + } + + continue; + } + const auto& embed_from = field.embed[fields::from].get>(); // flag to check if all fields to embed from are optional and null bool all_optional_and_null = true; @@ -679,7 +697,7 @@ Option validator_t::validate_embed_fields(const nlohmann::json& document, return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); } if(doc_field_it == document.end()) { - if(error_if_field_not_found && !schema_field_it->optional) { + if(!is_update && !schema_field_it->optional) { return Option(400, "Field `" + field_name + "` is needed to create embedding."); } else { continue; diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index d28d9c03..3ffe288e 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -1031,4 +1031,74 @@ TEST_F(CollectionVectorTest, EmbedFromOptionalNullField) { add_op = coll->add(doc.dump()); ASSERT_TRUE(add_op.ok()); -} \ No newline at end of file +} + +TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + nlohmann::json model_config = R"({ + "model_name": "ts/e5-small" + })"_json; + + // will be roughly 0.1110895648598671,-0.11710234731435776,-0.5319093465805054, ... + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + // document with explicit embedding vector + nlohmann::json doc; + doc["name"] = "FOO"; + + std::vector vec; + for(size_t i = 0; i < 384; i++) { + vec.push_back(0.345); + } + + doc["embedding"] = vec; + + auto add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + + // get the vector back + auto res = coll->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, + Index::DROP_TOKENS_THRESHOLD).get(); + + // let's check the first few vectors + auto stored_vec = res["hits"][0]["document"]["embedding"]; + ASSERT_NEAR(0.345, stored_vec[0], 0.01); + ASSERT_NEAR(0.345, stored_vec[1], 0.01); + ASSERT_NEAR(0.345, stored_vec[2], 0.01); + ASSERT_NEAR(0.345, stored_vec[3], 0.01); + ASSERT_NEAR(0.345, stored_vec[4], 0.01); + + // what happens when vector contains invalid value, like string + doc["embedding"] = "foo"; //{0.11, 0.11}; + add_op = coll->add(doc.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `embedding` contains an invalid embedding.", add_op.error()); + + // when dims don't match + doc["embedding"] = {0.11, 0.11}; + add_op = coll->add(doc.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `embedding` contains an invalid embedding.", add_op.error()); + + // invalid array value + doc["embedding"].clear(); + for(size_t i = 0; i < 384; i++) { + doc["embedding"].push_back(0.01); + } + doc["embedding"][5] = "foo"; + add_op = coll->add(doc.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `embedding` contains invalid float values.", add_op.error()); +}