Use embedding in auto embedding field if already present.

This commit is contained in:
Kishore Nallan 2023-08-04 14:43:26 +05:30
parent bfb81173b9
commit 3ba9b48eff
4 changed files with 100 additions and 5 deletions

View File

@ -71,6 +71,6 @@ public:
static Option<bool> validate_embed_fields(const nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema,
const bool& error_if_field_not_found);
const bool& is_update);
};

View File

@ -504,6 +504,7 @@ void Index::validate_and_preprocess(Index *index, std::vector<index_record>& ite
index_rec.index_failure(400, e.what());
}
}
if(generate_embeddings) {
batch_embed_fields(records_to_embed, embedding_fields, search_schema, remote_embedding_batch_size);
}
@ -6499,6 +6500,12 @@ void Index::batch_embed_fields(std::vector<index_record*>& records,
if(document == nullptr) {
continue;
}
if(document->contains(field.name) && !record->is_update) {
// embedding already exists (could be a restore from export)
continue;
}
std::string text = indexing_prefix;
const auto& embed_from = field.embed[fields::from].get<std::vector<std::string>>();
for(const auto& field_name : embed_from) {

View File

@ -654,7 +654,7 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
if(validate_embedding_fields) {
// validate embedding fields
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update);
auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, is_update);
if(!validate_embed_op.ok()) {
return Option<>(validate_embed_op.code(), validate_embed_op.error());
}
@ -667,8 +667,26 @@ Option<uint32_t> validator_t::validate_index_in_memory(nlohmann::json& document,
Option<bool> validator_t::validate_embed_fields(const nlohmann::json& document,
const tsl::htrie_map<char, field>& embedding_fields,
const tsl::htrie_map<char, field> & search_schema,
const bool& error_if_field_not_found) {
const bool& is_update) {
for(const auto& field : embedding_fields) {
if(document.contains(field.name) && !is_update) {
const auto& field_vec = document[field.name];
if(!field_vec.is_array() || field_vec.empty() || !field_vec[0].is_number() ||
field_vec.size() != field.num_dim) {
return Option<bool>(400, "Field `" + field.name + "` contains an invalid embedding.");
}
auto it = field_vec.begin();
while(it != field_vec.end()) {
if(!it.value().is_number()) {
return Option<bool>(400, "Field `" + field.name + "` contains invalid float values.");
}
it++;
}
continue;
}
const auto& embed_from = field.embed[fields::from].get<std::vector<std::string>>();
// flag to check if all fields to embed from are optional and null
bool all_optional_and_null = true;
@ -679,7 +697,7 @@ Option<bool> validator_t::validate_embed_fields(const nlohmann::json& document,
return Option<bool>(400, "Field `" + field.name + "` has invalid fields to create embeddings from.");
}
if(doc_field_it == document.end()) {
if(error_if_field_not_found && !schema_field_it->optional) {
if(!is_update && !schema_field_it->optional) {
return Option<bool>(400, "Field `" + field_name + "` is needed to create embedding.");
} else {
continue;

View File

@ -1031,4 +1031,74 @@ TEST_F(CollectionVectorTest, EmbedFromOptionalNullField) {
add_op = coll->add(doc.dump());
ASSERT_TRUE(add_op.ok());
}
}
TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) {
nlohmann::json schema = R"({
"name": "objects",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
nlohmann::json model_config = R"({
"model_name": "ts/e5-small"
})"_json;
// will be roughly 0.1110895648598671,-0.11710234731435776,-0.5319093465805054, ...
auto op = collectionManager.create_collection(schema);
ASSERT_TRUE(op.ok());
Collection* coll = op.get();
// document with explicit embedding vector
nlohmann::json doc;
doc["name"] = "FOO";
std::vector<float> vec;
for(size_t i = 0; i < 384; i++) {
vec.push_back(0.345);
}
doc["embedding"] = vec;
auto add_op = coll->add(doc.dump());
ASSERT_TRUE(add_op.ok());
// get the vector back
auto res = coll->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true},
Index::DROP_TOKENS_THRESHOLD).get();
// let's check the first few vectors
auto stored_vec = res["hits"][0]["document"]["embedding"];
ASSERT_NEAR(0.345, stored_vec[0], 0.01);
ASSERT_NEAR(0.345, stored_vec[1], 0.01);
ASSERT_NEAR(0.345, stored_vec[2], 0.01);
ASSERT_NEAR(0.345, stored_vec[3], 0.01);
ASSERT_NEAR(0.345, stored_vec[4], 0.01);
// what happens when vector contains invalid value, like string
doc["embedding"] = "foo"; //{0.11, 0.11};
add_op = coll->add(doc.dump());
ASSERT_FALSE(add_op.ok());
ASSERT_EQ("Field `embedding` contains an invalid embedding.", add_op.error());
// when dims don't match
doc["embedding"] = {0.11, 0.11};
add_op = coll->add(doc.dump());
ASSERT_FALSE(add_op.ok());
ASSERT_EQ("Field `embedding` contains an invalid embedding.", add_op.error());
// invalid array value
doc["embedding"].clear();
for(size_t i = 0; i < 384; i++) {
doc["embedding"].push_back(0.01);
}
doc["embedding"][5] = "foo";
add_op = coll->add(doc.dump());
ASSERT_FALSE(add_op.ok());
ASSERT_EQ("Field `embedding` contains invalid float values.", add_op.error());
}