From 7ae3cc9781ee6efc3809f9eb4435c19b661ad1d9 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Apr 2023 23:56:25 +0300 Subject: [PATCH] Review Changes II --- include/collection.h | 5 - include/field.h | 42 +++---- include/index.h | 12 ++ include/validator.h | 6 + src/collection.cpp | 147 ++++--------------------- src/collection_manager.cpp | 6 +- src/field.cpp | 10 +- src/index.cpp | 64 ++++++++++- src/validator.cpp | 51 +++++++++ test/collection_all_fields_test.cpp | 14 +-- test/collection_schema_change_test.cpp | 16 +-- test/collection_test.cpp | 61 +++++----- test/collection_vector_search_test.cpp | 2 +- 13 files changed, 228 insertions(+), 208 deletions(-) diff --git a/include/collection.h b/include/collection.h index e3440006..dce59d0b 100644 --- a/include/collection.h +++ b/include/collection.h @@ -210,8 +210,6 @@ private: bool is_wildcard_query, bool is_group_by_query = false) const; - Option validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const; - Option persist_collection_meta(); Option batch_alter_data(const std::vector& alter_fields, @@ -358,9 +356,6 @@ public: const DIRTY_VALUES dirty_values, const std::string& id=""); - Option embed_fields(nlohmann::json& document); - - Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc); static uint32_t get_seq_id_from_key(const std::string & key); diff --git a/include/field.h b/include/field.h index 5b44b52a..febb2a1e 100644 --- a/include/field.h +++ b/include/field.h @@ -50,7 +50,7 @@ namespace fields { static const std::string num_dim = "num_dim"; static const std::string vec_dist = "vec_dist"; static const std::string reference = "reference"; - static const std::string create_from = "create_from"; + static const std::string embed_from = "embed_from"; static const std::string model_name = "model_name"; } @@ -77,7 +77,7 @@ struct field { int nested_array; size_t num_dim; - std::vector create_from; + std::vector embed_from; std::string model_name; vector_distance_type_t vec_dist; @@ -89,9 +89,9 @@ struct field { field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false, - int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector &create_from = {}, const std::string& model_name = "") : + int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector &embed_from = {}, const std::string& model_name = "") : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale), - nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), create_from(create_from), model_name(model_name) { + nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_name(model_name) { set_computed_defaults(sort, infix); } @@ -319,8 +319,8 @@ struct field { if (!field.reference.empty()) { field_val[fields::reference] = field.reference; } - if(!field.create_from.empty()) { - field_val[fields::create_from] = field.create_from; + if(!field.embed_from.empty()) { + field_val[fields::embed_from] = field.embed_from; if(!field.model_name.empty()) { field_val[fields::model_name] = field.model_name; } @@ -421,36 +421,36 @@ struct field { for(nlohmann::json & field_json: fields_json) { - if(field_json.count(fields::create_from) != 0) { + if(field_json.count(fields::embed_from) != 0) { if(TextEmbedderManager::model_dir.empty()) { return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); } - if(!field_json[fields::create_from].is_array()) { - return Option(400, "Property `" + fields::create_from + "` must be an array."); + if(!field_json[fields::embed_from].is_array()) { + return Option(400, "Property `" + fields::embed_from + "` must be an array."); } - if(field_json[fields::create_from].empty()) { - return Option(400, "Property `" + fields::create_from + "` must have at least one element."); + if(field_json[fields::embed_from].empty()) { + return Option(400, "Property `" + fields::embed_from + "` must have at least one element."); } - for(auto& create_from_field : field_json[fields::create_from]) { - if(!create_from_field.is_string()) { - return Option(400, "Property `" + fields::create_from + "` must contain only field names as strings."); + for(auto& embed_from_field : field_json[fields::embed_from]) { + if(!embed_from_field.is_string()) { + return Option(400, "Property `" + fields::embed_from + "` must contain only field names as strings."); } } if(field_json[fields::type] != field_types::FLOAT_ARRAY) { - return Option(400, "Property `" + fields::create_from + "` is only allowed on a float array field."); + return Option(400, "Property `" + fields::embed_from + "` is only allowed on a float array field."); } - for(auto& create_from_field : field_json[fields::create_from]) { + for(auto& embed_from_field : field_json[fields::embed_from]) { bool flag = false; for(const auto& field : fields_json) { - if(field[fields::name] == create_from_field) { + if(field[fields::name] == embed_from_field) { if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) { - return Option(400, "Property `" + fields::create_from + "` can only have string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only have string or string array fields."); } flag = true; break; @@ -458,9 +458,9 @@ struct field { } if(!flag) { for(const auto& field : the_fields) { - if(field.name == create_from_field) { + if(field.name == embed_from_field) { if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) { - return Option(400, "Property `" + fields::create_from + "` can only have used with string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only have used with string or string array fields."); } flag = true; break; @@ -468,7 +468,7 @@ struct field { } } if(!flag) { - return Option(400, "Property `" + fields::create_from + "` can only be used with string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only be used with string or string array fields."); } } } diff --git a/include/index.h b/include/index.h index 6c4d66fe..76ca3696 100644 --- a/include/index.h +++ b/include/index.h @@ -532,6 +532,16 @@ private: const std::string& token, uint32_t seq_id); void initialize_facet_indexes(const field& facet_field); + + + + static Option embed_fields(nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema); + + static Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema); public: // for limiting number of results on multiple candidates / query rewrites @@ -663,6 +673,7 @@ public: const size_t batch_start_index, const size_t batch_size, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, @@ -672,6 +683,7 @@ public: std::vector& iter_batch, const std::string& default_sorting_field, const tsl::htrie_map& search_schema, + const tsl::htrie_map & embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, diff --git a/include/validator.h b/include/validator.h index 638b5eb3..3998f5dc 100644 --- a/include/validator.h +++ b/include/validator.h @@ -27,6 +27,7 @@ public: static Option validate_index_in_memory(nlohmann::json &document, uint32_t seq_id, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const index_operation_t op, const bool is_update, const std::string& fallback_field_type, @@ -67,4 +68,9 @@ public: nlohmann::json::iterator& array_iter, bool is_array, bool& array_ele_erased); + static Option validate_embed_fields(const nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema, + const bool& error_if_field_not_found); + }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 231a3f9a..676628d2 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -52,7 +52,7 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co index(init_index()) { for (auto const& field: fields) { - if (!field.create_from.empty()) { + if (!field.embed_from.empty()) { embedding_fields.emplace(field.name, field); } } @@ -108,12 +108,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: uint32_t seq_id = get_next_seq_id(); document["id"] = std::to_string(seq_id); - // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc - auto embed_res = embed_fields(document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } - // Add reference helper fields in the document. for (auto const& pair: reference_fields) { @@ -191,14 +185,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: // UPSERT, EMPLACE or UPDATE uint32_t seq_id = (uint32_t) std::stoul(seq_id_str); - //Handle embedding here for UPDATE - nlohmann::json old_doc; - get_document_from_store(get_seq_id_key(seq_id), old_doc); - auto embed_res = embed_fields_update(old_doc, document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } - return Option(doc_seq_id_t{seq_id, false}); } else { @@ -209,11 +195,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: // for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc uint32_t seq_id = get_next_seq_id(); - // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc - auto embed_res = embed_fields(document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } return Option(doc_seq_id_t{seq_id, true}); } } @@ -259,8 +240,8 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::infix] = coll_field.infix; field_json[fields::locale] = coll_field.locale; - if(!coll_field.create_from.empty()) { - field_json[fields::create_from] = coll_field.create_from; + if(!coll_field.embed_from.empty()) { + field_json[fields::embed_from] = coll_field.embed_from; } if(coll_field.model_name.size() > 0) { @@ -397,6 +378,7 @@ nlohmann::json Collection::add_many(std::vector& json_lines, nlohma do_batched_index: + if((i+1) % index_batch_size == 0 || i == json_lines.size()-1 || repeated_doc) { batch_index(index_records, json_lines, num_indexed, return_doc, return_id); @@ -593,7 +575,7 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t std::unique_lock lock(mutex); Option validation_op = validator_t::validate_index_in_memory(document, seq_id, default_sorting_field, - search_schema, op, false, + search_schema, embedding_fields, op, false, fallback_field_type, dirty_values); if(!validation_op.ok()) { @@ -604,7 +586,7 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t std::vector index_batch; index_batch.emplace_back(std::move(rec)); - Index::batch_memory_index(index, index_batch, default_sorting_field, search_schema, + Index::batch_memory_index(index, index_batch, default_sorting_field, search_schema, embedding_fields, fallback_field_type, token_separators, symbols_to_index, true); num_documents += 1; @@ -614,7 +596,7 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t size_t Collection::batch_index_in_memory(std::vector& index_records) { std::unique_lock lock(mutex); size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field, - search_schema, fallback_field_type, + search_schema, embedding_fields, fallback_field_type, token_separators, symbols_to_index, true); num_documents += num_indexed; return num_indexed; @@ -1014,7 +996,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().create_from.empty(); + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().embed_from.empty(); if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -3735,7 +3717,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields } } - Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, + Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields, fallback_field_type, token_separators, symbols_to_index, true); iter_batch.clear(); @@ -3771,7 +3753,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_fields.erase(del_field.name); } - if(!del_field.create_from.empty()) { + if(!del_field.embed_from.empty()) { embedding_fields.erase(del_field.name); } @@ -4069,7 +4051,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return Option(400, "Field `" + field_name + "` is not part of collection schema."); } - if(found_field && !field_it.value().create_from.empty()) { + if(found_field && !field_it.value().embed_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4078,7 +4060,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(field_it.key()); updated_nested_fields.erase(field_it.key()); - if(!field_it.value().create_from.empty()) { + if(!field_it.value().embed_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4092,7 +4074,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(prefix_kv.key()); updated_nested_fields.erase(prefix_kv.key()); - if(!prefix_kv.value().create_from.empty()) { + if(!prefix_kv.value().embed_from.empty()) { updated_embedding_fields.erase(prefix_kv.key()); } } @@ -4145,7 +4127,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, addition_fields.push_back(f); } - if(!f.create_from.empty()) { + if(!f.embed_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4160,7 +4142,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value()); updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); - if(!prefix_kv.value().create_from.empty()) { + if(!prefix_kv.value().embed_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4234,6 +4216,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, // validate existing data on disk for compatibility via updated_search_schema auto validate_op = validator_t::validate_index_in_memory(document, seq_id, default_sorting_field, updated_search_schema, + updated_embedding_fields, index_operation_t::CREATE, false, fallback_field_type, @@ -4484,7 +4467,7 @@ Index* Collection::init_index() { nested_fields.emplace(field.name, field); } - if(!field.create_from.empty()) { + if(!field.embed_from.empty()) { embedding_fields.emplace(field.name, field); } @@ -4759,108 +4742,18 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha } -Option Collection::embed_fields(nlohmann::json& document) { - auto validate_res = validate_embed_fields(document, true); - if(!validate_res.ok()) { - return validate_res; - } - for(const auto& field : embedding_fields) { - std::string text_to_embed; - for(const auto& field_name : field.create_from) { - auto field_it = search_schema.find(field_name); - if(field_it.value().type == field_types::STRING) { - text_to_embed += document[field_name].get() + " "; - } else if(field_it.value().type == field_types::STRING_ARRAY) { - for(const auto& val : document[field_name]) { - text_to_embed += val.get() + " "; - } - } - } - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - document[field.name] = embedding; - } - - return Option(true); -} - -Option Collection::validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const { - if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } - for(const auto& field : embedding_fields) { - for(const auto& field_name : field.create_from) { - auto schema_field_it = search_schema.find(field_name); - auto doc_field_it = document.find(field_name); - if(schema_field_it == search_schema.end()) { - return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); - } - if(doc_field_it == document.end()) { - if(error_if_field_not_found) { - return Option(400, "Field `" + field_name + "` is needed to create embedding."); - } else { - continue; - } - } - if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || - (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - if(doc_field_it.value().is_array()) { - for(const auto& val : doc_field_it.value()) { - if(!val.is_string()) { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } - } - } - - return Option(true); -} - -Option Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) { - auto validate_res = validate_embed_fields(new_doc, false); - if(!validate_res.ok()) { - return validate_res; - } - nlohmann::json new_doc_copy = new_doc; - for(const auto& field : embedding_fields) { - std::string text_to_embed; - for(const auto& field_name : field.create_from) { - auto field_it = search_schema.find(field_name); - nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; - if(field_it.value().type == field_types::STRING) { - text_to_embed += value.get() + " "; - } else if(field_it.value().type == field_types::STRING_ARRAY) { - for(const auto& val : value) { - text_to_embed += val.get() + " "; - } - } - } - - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - new_doc_copy[field.name] = embedding; - } - new_doc = new_doc_copy; - return Option(true); -} - void Collection::process_remove_field_for_embedding_fields(const field& the_field) { std::vector::iterator> empty_fields; for(auto& embedding_field : embedding_fields) { const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) { return other_field.name == embedding_field.name; }); - actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) { + actual_field->embed_from.erase(std::remove_if(actual_field->embed_from.begin(), actual_field->embed_from.end(), [&the_field](std::string field_name) { return the_field.name == field_name; })); embedding_field = *actual_field; - // store to remove embedding field if it has no field names in 'create_from' anymore. - if(embedding_field.create_from.empty()) { + // store to remove embedding field if it has no field names in 'embed_from' anymore. + if(embedding_field.embed_from.empty()) { empty_fields.push_back(actual_field); } } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f42d356a..aadcbf31 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -58,8 +58,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection field_obj[fields::reference] = ""; } - if(field_obj.count(fields::create_from) == 0) { - field_obj[fields::create_from] = std::vector(); + if(field_obj.count(fields::embed_from) == 0) { + field_obj[fields::embed_from] = std::vector(); } if(field_obj.count(fields::model_name) == 0) { @@ -78,7 +78,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet], field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale], -1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array], - field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::create_from], + field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed_from], field_obj[fields::model_name]); // value of `sort` depends on field type diff --git a/src/field.cpp b/src/field.cpp index 3675d43c..ee808158 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -672,11 +672,11 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } } - if(field_json.count(fields::model_name) > 0 && field_json.count(fields::create_from) == 0) { - return Option(400, "Property `" + fields::model_name + "` can only be used with `" + fields::create_from + "`."); + if(field_json.count(fields::model_name) > 0 && field_json.count(fields::embed_from) == 0) { + return Option(400, "Property `" + fields::model_name + "` can only be used with `" + fields::embed_from + "`."); } - if(field_json.count(fields::create_from) != 0) { + if(field_json.count(fields::embed_from) != 0) { // If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model) field_json[fields::num_dim] = static_cast(384); if(field_json.count(fields::model_name) != 0) { @@ -695,7 +695,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } } } else { - field_json[fields::create_from] = std::vector(); + field_json[fields::embed_from] = std::vector(); } @@ -784,7 +784,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::optional], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix], field_json[fields::nested], field_json[fields::nested_array], field_json[fields::num_dim], vec_dist, - field_json[fields::reference], field_json[fields::create_from].get>(), + field_json[fields::reference], field_json[fields::embed_from].get>(), field_json[fields::model_name]) ); diff --git a/src/index.cpp b/src/index.cpp index 7a110436..0b59c823 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -411,6 +411,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite const size_t batch_start_index, const size_t batch_size, const std::string& default_sorting_field, const tsl::htrie_map& search_schema, + const tsl::htrie_map& embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, @@ -435,6 +436,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite Option validation_op = validator_t::validate_index_in_memory(index_rec.doc, index_rec.seq_id, default_sorting_field, search_schema, + embedding_fields, index_rec.operation, index_rec.is_update, fallback_field_type, @@ -451,6 +453,9 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite get_doc_changes(index_rec.operation, index_rec.doc, index_rec.old_doc, index_rec.new_doc, index_rec.del_doc); scrub_reindex_doc(search_schema, index_rec.doc, index_rec.del_doc, index_rec.old_doc); + embed_fields(index_rec.new_doc, embedding_fields, search_schema); + } else { + embed_fields(index_rec.doc, embedding_fields, search_schema); } compute_token_offsets_facets(index_rec, search_schema, token_separators, symbols_to_index); @@ -485,6 +490,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite size_t Index::batch_memory_index(Index *index, std::vector& iter_batch, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, @@ -518,7 +524,7 @@ size_t Index::batch_memory_index(Index *index, std::vector& iter_b index->thread_pool->enqueue([&, batch_index, batch_len]() { write_log_index = local_write_log_index; validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema, - fallback_field_type, token_separators, symbols_to_index, do_validation); + embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation); std::unique_lock lock(m_process); num_processed++; @@ -6257,6 +6263,61 @@ bool Index::common_results_exist(std::vector& leaves, bool must_match return phrase_exists; } +Option Index::embed_fields(nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema) { + for(const auto& field : embedding_fields) { + std::string text_to_embed; + for(const auto& field_name : field.embed_from) { + auto field_it = search_schema.find(field_name); + if(field_it.value().type == field_types::STRING) { + text_to_embed += document[field_name].get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : document[field_name]) { + text_to_embed += val.get() + " "; + } + } + } + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + document[field.name] = embedding; + } + + return Option(true); +} + + +Option Index::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema) { + nlohmann::json new_doc_copy = new_doc; + for(const auto& field : embedding_fields) { + std::string text_to_embed; + for(const auto& field_name : field.embed_from) { + auto field_it = search_schema.find(field_name); + nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; + if(field_it.value().type == field_types::STRING) { + text_to_embed += value.get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : value) { + text_to_embed += val.get() + " "; + } + } + } + + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + new_doc_copy[field.name] = embedding; + } + new_doc = new_doc_copy; + return Option(true); +} + + + + /* // https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon // NOTE: polygon and point should have been transformed with `transform_for_180th_meridian` @@ -6302,3 +6363,4 @@ void Index::transform_for_180th_meridian(GeoCoord &point, double offset) { point.lon = point.lon < 0.0 ? point.lon + offset : point.lon; } */ + diff --git a/src/validator.cpp b/src/validator.cpp index 9616b84b..f66eaf9d 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -529,6 +529,7 @@ Option validator_t::coerce_float(const DIRTY_VALUES& dirty_values, con Option validator_t::validate_index_in_memory(nlohmann::json& document, uint32_t seq_id, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const index_operation_t op, const bool is_update, const std::string& fallback_field_type, @@ -544,6 +545,11 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, for(const auto& a_field: search_schema) { const std::string& field_name = a_field.name; + // ignore embedding fields, they will be validated later + if(embedding_fields.count(field_name) > 0) { + continue; + } + if(field_name == "id" || a_field.is_object()) { continue; } @@ -574,5 +580,50 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, } } + // validate embedding fields + auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update); + if(!validate_embed_op.ok()) { + return Option<>(validate_embed_op.code(), validate_embed_op.error()); + } + return Option<>(200); +} + + +Option validator_t::validate_embed_fields(const nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema, + const bool& error_if_field_not_found) { + if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); + } + for(const auto& field : embedding_fields) { + for(const auto& field_name : field.embed_from) { + auto schema_field_it = search_schema.find(field_name); + auto doc_field_it = document.find(field_name); + if(schema_field_it == search_schema.end()) { + return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); + } + if(doc_field_it == document.end()) { + if(error_if_field_not_found) { + return Option(400, "Field `" + field_name + "` is needed to create embedding."); + } else { + continue; + } + } + if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || + (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + if(doc_field_it.value().is_array()) { + for(const auto& val : doc_field_it.value()) { + if(!val.is_string()) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } + } + } + + return Option(true); } \ No newline at end of file diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 1a8af133..9269c858 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1597,7 +1597,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; - field_json["create_from"] = {"name"}; + field_json["embed_from"] = {"name"}; std::vector fields; std::string fallback_field_type; @@ -1607,7 +1607,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `create_from` can only be used with string or string array fields.", field_op.error()); + ASSERT_EQ("Property `embed_from` can only be used with string or string array fields.", field_op.error()); } TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { @@ -1615,7 +1615,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; - field_json["create_from"] = {"name"}; + field_json["embed_from"] = {"name"}; std::vector fields; std::string fallback_field_type; @@ -1633,7 +1633,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) { nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; - field_json["create_from"] = "name"; + field_json["embed_from"] = "name"; std::vector fields; std::string fallback_field_type; @@ -1643,7 +1643,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `create_from` must be an array.", field_op.error()); + ASSERT_EQ("Property `embed_from` must be an array.", field_op.error()); } TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) { @@ -1660,7 +1660,7 @@ TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `model_name` can only be used with `create_from`.", field_op.error()); + ASSERT_EQ("Property `model_name` can only be used with `embed_from`.", field_op.error()); } @@ -1670,7 +1670,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) { TextEmbedderManager::download_default_model(); field embedding = field("embedding", field_types::FLOAT_ARRAY, false); - embedding.create_from.push_back("name"); + embedding.embed_from.push_back("name"); std::vector fields = {field("name", field_types::STRING, false), embedding}; auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO); diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index ec6a88e7..82ccd65a 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1462,7 +1462,7 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { nlohmann::json update_schema = R"({ "fields": [ - {"name": "embedding", "type":"float[]", "create_from": ["names"]} + {"name": "embedding", "type":"float[]", "embed_from": ["names"]} ] })"_json; @@ -1478,7 +1478,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { "fields": [ {"name": "names", "type": "string[]"}, {"name": "category", "type":"string"}, - {"name": "embedding", "type":"float[]", "create_from": ["names","category"]} + {"name": "embedding", "type":"float[]", "embed_from": ["names","category"]} ] })"_json; @@ -1500,7 +1500,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { LOG(INFO) << "Dropping field"; auto embedding_fields = coll->get_embedding_fields(); - ASSERT_EQ(2, embedding_fields["embedding"].create_from.size()); + ASSERT_EQ(2, embedding_fields["embedding"].embed_from.size()); LOG(INFO) << "Before alter"; @@ -1510,8 +1510,8 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { LOG(INFO) << "After alter"; embedding_fields = coll->get_embedding_fields(); - ASSERT_EQ(1, embedding_fields["embedding"].create_from.size()); - ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]); + ASSERT_EQ(1, embedding_fields["embedding"].embed_from.size()); + ASSERT_EQ("category", embedding_fields["embedding"].embed_from[0]); } TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { @@ -1519,7 +1519,7 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -1535,8 +1535,8 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { auto embedding_field_it = embedding_fields_map.find("embedding"); ASSERT_TRUE(embedding_field_it != embedding_fields_map.end()); ASSERT_EQ("embedding", embedding_field_it.value().name); - ASSERT_EQ(1, embedding_field_it.value().create_from.size()); - ASSERT_EQ("name", embedding_field_it.value().create_from[0]); + ASSERT_EQ(1, embedding_field_it.value().embed_from.size()); + ASSERT_EQ("name", embedding_field_it.value().embed_from[0]); // drop the embedding field nlohmann::json schema_without_embedding = R"({ diff --git a/test/collection_test.cpp b/test/collection_test.cpp index b558c1d6..2a15e531 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4616,7 +4616,7 @@ TEST_F(CollectionTest, SemanticSearchTest) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4651,7 +4651,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4682,7 +4682,7 @@ TEST_F(CollectionTest, HybridSearch) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4695,6 +4695,7 @@ TEST_F(CollectionTest, HybridSearch) { nlohmann::json object; object["name"] = "apple"; auto add_op = coll->add(object.dump()); + LOG(INFO) << "add_op.error(): " << add_op.error(); ASSERT_TRUE(add_op.ok()); ASSERT_EQ("apple", add_op.get()["name"]); @@ -4710,40 +4711,40 @@ TEST_F(CollectionTest, HybridSearch) { ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size()); } -TEST_F(CollectionTest, EmbedFielsTest) { - nlohmann::json schema = R"({ - "name": "objects", - "fields": [ - {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} - ] - })"_json; +// TEST_F(CollectionTest, EmbedFielsTest) { +// nlohmann::json schema = R"({ +// "name": "objects", +// "fields": [ +// {"name": "name", "type": "string"}, +// {"name": "embedding", "type":"float[]", "embed_from": ["name"]} +// ] +// })"_json; - TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); - TextEmbedderManager::download_default_model(); +// TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); +// TextEmbedderManager::download_default_model(); - auto op = collectionManager.create_collection(schema); - ASSERT_TRUE(op.ok()); - Collection* coll = op.get(); +// auto op = collectionManager.create_collection(schema); +// ASSERT_TRUE(op.ok()); +// Collection* coll = op.get(); - nlohmann::json object = R"({ - "name": "apple" - })"_json; +// nlohmann::json object = R"({ +// "name": "apple" +// })"_json; - auto embed_op = coll->embed_fields(object); +// auto embed_op = coll->embed_fields(object); - ASSERT_TRUE(embed_op.ok()); +// ASSERT_TRUE(embed_op.ok()); - ASSERT_EQ("apple", object["name"]); - ASSERT_EQ(384, object["embedding"].get>().size()); -} +// ASSERT_EQ("apple", object["name"]); +// ASSERT_EQ(384, object["embedding"].get>().size()); +// } TEST_F(CollectionTest, HybridSearchRankFusionTest) { nlohmann::json schema = R"({ "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4817,7 +4818,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4851,7 +4852,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) { "name": "objects", "fields": [ {"name": "names", "type": "string[]"}, - {"name": "embedding", "type":"float[]", "create_from": ["names"]} + {"name": "embedding", "type":"float[]", "embed_from": ["names"]} ] })"_json; @@ -4876,8 +4877,8 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) { "name": "objects", "fields": [ {"name": "names", "type": "string[]"}, - {"name": "category", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["names", "category"]} + {"name": "category", "type": "string", "optional": true}, + {"name": "embedding", "type":"float[]", "embed_from": ["names", "category"]} ] })"_json; @@ -4903,7 +4904,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 7f8b212b..e1dc5d10 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -683,7 +683,7 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { "name": "coll1", "fields": [ {"name": "name", "type": "string"}, - {"name": "vec", "type": "float[]", "create_from": ["name"]} + {"name": "vec", "type": "float[]", "embed_from": ["name"]} ] })"_json;