From e85ae5d7d2706caee47b2efa7b185d5289a7d20f Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sun, 2 Apr 2023 01:32:49 +0300 Subject: [PATCH 1/8] Hybrid & sematic search improvements --- include/collection.h | 4 + include/field.h | 12 +-- include/text_embedder_manager.h | 6 +- include/vector_query_ops.h | 2 + src/collection.cpp | 104 +++++++++++++++++------ src/index.cpp | 11 ++- src/typesense_server_utils.cpp | 2 +- src/vector_query_ops.cpp | 9 ++ test/collection_all_fields_test.cpp | 2 +- test/collection_test.cpp | 110 +++++++++++++++++++++++-- test/collection_vector_search_test.cpp | 59 +++++++++++-- 11 files changed, 269 insertions(+), 52 deletions(-) diff --git a/include/collection.h b/include/collection.h index 3b50c435..c81c4123 100644 --- a/include/collection.h +++ b/include/collection.h @@ -120,6 +120,8 @@ private: tsl::htrie_map nested_fields; + tsl::htrie_map embedding_fields; + bool enable_nested_fields; std::vector symbols_to_index; @@ -342,6 +344,8 @@ public: tsl::htrie_map get_nested_fields(); + tsl::htrie_map get_embedding_fields(); + std::string get_default_sorting_field(); Option to_doc(const std::string& json_str, nlohmann::json& document, diff --git a/include/field.h b/include/field.h index 792aaec6..5b44b52a 100644 --- a/include/field.h +++ b/include/field.h @@ -436,7 +436,7 @@ struct field { for(auto& create_from_field : field_json[fields::create_from]) { if(!create_from_field.is_string()) { - return Option(400, "Property `" + fields::create_from + "` must be an array of strings."); + return Option(400, "Property `" + fields::create_from + "` must contain only field names as strings."); } } @@ -449,8 +449,8 @@ struct field { bool flag = false; for(const auto& field : fields_json) { if(field[fields::name] == create_from_field) { - if(field[fields::type] != field_types::STRING) { - return Option(400, "Property `" + fields::create_from + "` can only be used with array of string fields."); + if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) { + return Option(400, "Property `" + fields::create_from + "` can only have string or string array fields."); } flag = true; break; @@ -459,8 +459,8 @@ struct field { if(!flag) { for(const auto& field : the_fields) { if(field.name == create_from_field) { - if(field.type != field_types::STRING) { - return Option(400, "Property `" + fields::create_from + "` can only be used with array of string fields."); + if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) { + return Option(400, "Property `" + fields::create_from + "` can only have used with string or string array fields."); } flag = true; break; @@ -468,7 +468,7 @@ struct field { } } if(!flag) { - return Option(400, "Property `" + fields::create_from + "` can only be used with array of string fields."); + return Option(400, "Property `" + fields::create_from + "` can only be used with string or string array fields."); } } } diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index 3b626a97..192044c8 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -1,6 +1,6 @@ #pragma once - +#include #include #include #include @@ -43,6 +43,10 @@ public: } static void set_model_dir(const std::string& dir) { + // create the directory if it doesn't exist + if(!std::filesystem::exists(dir)) { + std::filesystem::create_directories(dir); + } model_dir = dir; } diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index d3424e30..32ee7448 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -10,6 +10,7 @@ struct vector_query_t { std::string field_name; size_t k = 0; size_t flat_search_cutoff = 0; + float similarity_cutoff = 0.0; std::vector values; uint32_t seq_id = 0; @@ -19,6 +20,7 @@ struct vector_query_t { // used for testing only field_name.clear(); k = 0; + similarity_cutoff = 0.0; values.clear(); seq_id = 0; query_doc_given = false; diff --git a/src/collection.cpp b/src/collection.cpp index b5310d46..bea45361 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1120,10 +1120,6 @@ Option Collection::search(std::string raw_query, vector_query_t vector_query; if(!vector_query_str.empty()) { - if(raw_query != "*") { - return Option(400, "Vector query is supported only on wildcard (q=*) searches."); - } - auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this); if(!parse_vector_op.ok()) { return Option(400, parse_vector_op.error()); @@ -3459,6 +3455,11 @@ tsl::htrie_map Collection::get_nested_fields() { return nested_fields; }; +tsl::htrie_map Collection::get_embedding_fields() { + std::shared_lock lock(mutex); + return embedding_fields; +}; + std::string Collection::get_meta_key(const std::string & collection_name) { return std::string(COLLECTION_META_PREFIX) + "_" + collection_name; } @@ -3747,6 +3748,10 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_fields.erase(del_field.name); } + if(del_field.create_from.size() > 0) { + embedding_fields.erase(del_field.name); + } + if(del_field.name == ".*") { fallback_field_type = ""; } @@ -3982,6 +3987,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, std::vector diff_fields; tsl::htrie_map updated_search_schema = search_schema; tsl::htrie_map updated_nested_fields = nested_fields; + tsl::htrie_map updated_embedding_fields = embedding_fields; size_t num_auto_detect_fields = 0; // since fields can be deleted and added in the same change set, @@ -4038,10 +4044,18 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return Option(400, "Field `" + field_name + "` is not part of collection schema."); } + if(found_field && field_it.value().create_from.size() > 0) { + updated_embedding_fields.erase(field_it.key()); + } + if(found_field) { del_fields.push_back(field_it.value()); updated_search_schema.erase(field_it.key()); updated_nested_fields.erase(field_it.key()); + + if(field_it.value().create_from.size() > 0) { + updated_embedding_fields.erase(field_it.key()); + } // should also remove children if the field being dropped is an object if(field_it.value().nested && enable_nested_fields) { @@ -4052,6 +4066,10 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, del_fields.push_back(prefix_kv.value()); updated_search_schema.erase(prefix_kv.key()); updated_nested_fields.erase(prefix_kv.key()); + + if(prefix_kv.value().create_from.size() > 0) { + updated_embedding_fields.erase(prefix_kv.key()); + } } } } @@ -4102,6 +4120,10 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, addition_fields.push_back(f); } + if(f.create_from.size() > 0) { + return Option(400, "Embedding fields can only be added at the time of collection creation."); + } + if(f.nested && enable_nested_fields) { updated_nested_fields.emplace(f.name, f); @@ -4113,6 +4135,10 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value()); updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); + if(prefix_kv.value().create_from.size() > 0) { + return Option(400, "Embedding fields can only be added at the time of collection creation."); + } + if(is_reindex) { reindex_fields.push_back(prefix_kv.value()); } else { @@ -4432,6 +4458,10 @@ Index* Collection::init_index() { nested_fields.emplace(field.name, field); } + if(field.create_from.size() > 0) { + embedding_fields.emplace(field.name, field); + } + if(!field.reference.empty()) { auto dot_index = field.reference.find('.'); auto collection_name = field.reference.substr(0, dot_index); @@ -4704,30 +4734,50 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha Option Collection::embed_fields(nlohmann::json& document) { - for(const auto& field : fields) { - if(field.create_from.size() > 0) { - if(TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } - std::string text_to_embed; - for(const auto& field_name : field.create_from) { - auto field_it = document.find(field_name); - if(field_it != document.end()) { - if(field_it->is_string()) { - text_to_embed += field_it->get() + " "; - } else { - return Option(400, "Field `" + field_name + "` is not a string."); - } - } else { - return Option(400, "Field `" + field_name + "` not found in document."); - } - } - - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - document[field.name] = embedding; + for(const auto& field : embedding_fields) { + if(TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); } + std::string text_to_embed; + for(const auto& field_name : field.create_from) { + auto field_it = search_schema.find(field_name); + if(field_it != search_schema.end()) { + if(field_it.value().type == field_types::STRING) { + if(document.find(field_name) != document.end()) { + if(document[field_name].is_string()) { + text_to_embed += document[field_name].get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } else if(field_it.value().type == field_types::STRING_ARRAY) { + if(document.find(field_name) != document.end()) { + if(document[field_name].is_array()) { + for(const auto& val : document[field_name]) { + if(val.is_string()) { + text_to_embed += val.get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } + else { + return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); + } + } else { + return Option(400, "Field `" + field_name + "` is not a valid field."); + } + } + + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + document[field.name] = embedding; } + return Option(true); } diff --git a/src/index.cpp b/src/index.cpp index 5ea1f9c3..4859a451 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2867,6 +2867,10 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; + + if(vector_query.similarity_cutoff > 0 && vec_dist_score > vector_query.similarity_cutoff) { + continue; + } int64_t scores[3] = {0}; scores[0] = -float_to_int64_t(vec_dist_score); @@ -3087,9 +3091,14 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - + if(vector_query.similarity_cutoff > 0) { + if(vec_dist_score > vector_query.similarity_cutoff) { + continue; + } + } vec_results.emplace_back(seq_id, vec_dist_score); } + std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) { return a.second < b.second; }); diff --git a/src/typesense_server_utils.cpp b/src/typesense_server_utils.cpp index cd9b4952..ed381e8b 100644 --- a/src/typesense_server_utils.cpp +++ b/src/typesense_server_utils.cpp @@ -456,7 +456,7 @@ int run_server(const Config & config, const std::string & version, void (*master if(config.get_model_dir().size() > 0) { LOG(INFO) << "Loading text embedding models from " << config.get_model_dir(); - TextEmbedderManager::model_dir = config.get_model_dir(); + TextEmbedderManager::set_model_dir(config.get_model_dir()); TextEmbedderManager::download_default_model(); } // first we start the peering service diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index b4cd3ffa..729157b0 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -145,6 +145,15 @@ Option VectorQueryOps::parse_vector_query_str(std::string vector_query_str vector_query.flat_search_cutoff = std::stoi(param_kv[1]); } + + if(param_kv[0] == "similarity_cutoff") { + if(!StringUtils::is_float(param_kv[1])) { + return Option(400, "Malformed vector query string: " + "`similarity_cutoff` parameter must be a float."); + } + + vector_query.similarity_cutoff = std::stof(param_kv[1]); + } } if(!vector_query.query_doc_given && vector_query.values.empty()) { diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 9234d6bd..1a8af133 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1607,7 +1607,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `create_from` can only be used with array of string fields.", field_op.error()); + ASSERT_EQ("Property `create_from` can only be used with string or string array fields.", field_op.error()); } TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index f9e842ab..7bc1a768 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4815,7 +4815,39 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) { } TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { - nlohmann::json schema = R"({ + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "create_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + spp::sparse_hash_set dummy_include_exclude; + auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); + + ASSERT_FALSE(search_res_op.ok()); + ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error()); +} + +TEST_F(CollectionTest, CreateModelDirIfNotExists) { + system("mkdir -p /tmp/typesense_test/models"); + system("rm -rf /tmp/typesense_test/models"); + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + // check if model dir is created + ASSERT_TRUE(std::filesystem::exists("/tmp/typesense_test/models")); +} + +TEST_F(CollectionTest, EmbeddingFieldsMapTest) { + nlohmann::json schema = R"({ "name": "objects", "fields": [ {"name": "name", "type": "string"}, @@ -4830,9 +4862,75 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { ASSERT_TRUE(op.ok()); Collection* coll = op.get(); - spp::sparse_hash_set dummy_include_exclude; - auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); + auto embedding_fields_map = coll->get_embedding_fields(); + ASSERT_EQ(1, embedding_fields_map.size()); + auto embedding_field_it = embedding_fields_map.find("embedding"); + ASSERT_TRUE(embedding_field_it != embedding_fields_map.end()); + ASSERT_EQ("embedding", embedding_field_it.value().name); + ASSERT_EQ(1, embedding_field_it.value().create_from.size()); + ASSERT_EQ("name", embedding_field_it.value().create_from[0]); + + // drop the embedding field + nlohmann::json schema_without_embedding = R"({ + "fields": [ + {"name": "embedding", "drop": true} + ] + })"_json; + auto update_op = coll->alter(schema_without_embedding); + + ASSERT_TRUE(update_op.ok()); + + embedding_fields_map = coll->get_embedding_fields(); + ASSERT_EQ(0, embedding_fields_map.size()); +} + +TEST_F(CollectionTest, EmbedStringArrayField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "embedding", "type":"float[]", "create_from": ["names"]} + ] + })"_json; + + TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json doc; + doc["names"].push_back("butter"); + doc["names"].push_back("butterfly"); + doc["names"].push_back("butterball"); + + auto add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); +} + +TEST_F(CollectionTest, UpdateSchemaWithNewEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"} + ] + })"_json; + + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json update_schema = R"({ + "fields": [ + {"name": "embedding", "type":"float[]", "create_from": ["names"]} + ] + })"_json; + + auto res = coll->alter(update_schema); + + ASSERT_FALSE(res.ok()); + ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); +} - ASSERT_FALSE(search_res_op.ok()); - ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error()); -} \ No newline at end of file diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 36381780..ec233436 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -184,17 +184,18 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) { ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Document id referenced in vector query is not found.", res_op.error()); + // DEPRECATED: vector query is also supported on non-wildcard queries with hybrid search // only supported with wildcard queries - res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, - spp::sparse_hash_set(), - spp::sparse_hash_set(), 10, "", 30, 5, - "", 10, {}, {}, {}, 0, - "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, - 4, {off}, 32767, 32767, 2, - false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])"); + // res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + // spp::sparse_hash_set(), + // spp::sparse_hash_set(), 10, "", 30, 5, + // "", 10, {}, {}, {}, 0, + // "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + // 4, {off}, 32767, 32767, 2, + // false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])"); - ASSERT_FALSE(res_op.ok()); - ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error()); + // ASSERT_FALSE(res_op.ok()); + // ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error()); // support num_dim on only float array fields schema = R"({ @@ -676,3 +677,43 @@ TEST_F(CollectionVectorTest, VectorWithNullValue) { ASSERT_EQ("Field `vec` must be an array.", nlohmann::json::parse(json_lines[1])["error"].get()); } + +TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "vec", "type": "float[]", "create_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + + doc["name"] = "john doe"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + std::string dummy_vec_string = "[0.9"; + for (int i = 0; i < 382; i++) { + dummy_vec_string += ", 0.9"; + } + dummy_vec_string += ", 0.9]"; + + auto results_op = coll1->search("john", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:(" + dummy_vec_string +")"); + ASSERT_EQ(true, results_op.ok()); + + + ASSERT_EQ(1, results_op.get()["found"].get()); + ASSERT_EQ(1, results_op.get()["hits"].size()); +} \ No newline at end of file From 401ebbe481f89d389d74751d602a28f217d37347 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sun, 2 Apr 2023 12:48:33 +0300 Subject: [PATCH 2/8] Fix for text embedding when schema or document updated --- include/collection.h | 4 ++ src/collection.cpp | 101 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/include/collection.h b/include/collection.h index c81c4123..5fcb54c8 100644 --- a/include/collection.h +++ b/include/collection.h @@ -162,6 +162,8 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); + void process_remove_field_for_embedding_fields(const field& the_field); + void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, const std::vector& hidden_hits, @@ -355,6 +357,8 @@ public: Option embed_fields(nlohmann::json& document); + Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc); + static uint32_t get_seq_id_from_key(const std::string & key); Option get_document_from_store(const std::string & seq_id_key, nlohmann::json & document, bool raw_doc = false) const; diff --git a/src/collection.cpp b/src/collection.cpp index bea45361..3937bbfc 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -72,10 +72,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: const std::string& id) { try { document = nlohmann::json::parse(json_str); - auto embed_res = embed_fields(document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } } catch(const std::exception& e) { LOG(ERROR) << "JSON error: " << e.what(); return Option(400, std::string("Bad JSON: ") + e.what()); @@ -106,6 +102,13 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: uint32_t seq_id = get_next_seq_id(); document["id"] = std::to_string(seq_id); + // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc + auto embed_res = embed_fields(document); + if (!embed_res.ok()) { + return Option(400, embed_res.error()); + } + + // Add reference helper fields in the document. for (auto const& pair: reference_fields) { auto field_name = pair.first; @@ -176,9 +179,20 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: if(operation == CREATE) { return Option(409, std::string("A document with id ") + doc_id + " already exists."); } + + // UPSERT, EMPLACE or UPDATE uint32_t seq_id = (uint32_t) std::stoul(seq_id_str); + + //Handle embedding here for UPDATE + nlohmann::json old_doc; + get_document_from_store(get_seq_id_key(seq_id), old_doc); + auto embed_res = embed_fields_update(old_doc, document); + if (!embed_res.ok()) { + return Option(400, embed_res.error()); + } + return Option(doc_seq_id_t{seq_id, false}); } else { @@ -188,6 +202,12 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: } else { // for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc uint32_t seq_id = get_next_seq_id(); + + // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc + auto embed_res = embed_fields(document); + if (!embed_res.ok()) { + return Option(400, embed_res.error()); + } return Option(doc_seq_id_t{seq_id, true}); } } @@ -298,7 +318,6 @@ nlohmann::json Collection::add_many(std::vector& json_lines, nlohma const std::string & json_line = json_lines[i]; Option doc_seq_id_op = to_doc(json_line, document, operation, dirty_values, id); - const uint32_t seq_id = doc_seq_id_op.ok() ? doc_seq_id_op.get().seq_id : 0; index_record record(i, seq_id, document, operation, dirty_values); @@ -3759,6 +3778,8 @@ Option Collection::batch_alter_data(const std::vector& alter_fields if(del_field.name == default_sorting_field) { default_sorting_field = ""; } + + process_remove_field_for_embedding_fields(del_field); } index->refresh_schemas({}, del_fields); @@ -4781,3 +4802,73 @@ Option Collection::embed_fields(nlohmann::json& document) { return Option(true); } + +Option Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) { + nlohmann::json new_doc_copy = new_doc; + for(const auto& field : embedding_fields) { + if(TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); + } + std::string text_to_embed; + for(const auto& field_name : field.create_from) { + auto field_it = search_schema.find(field_name); + if(field_it != search_schema.end()) { + nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; + if(field_it.value().type == field_types::STRING) { + if(value.is_string()) { + text_to_embed += value.get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } else if(field_it.value().type == field_types::STRING_ARRAY) { + if(value.is_array()) { + for(const auto& val : value) { + if(val.is_string()) { + text_to_embed += val.get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + else { + return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); + } + } else { + return Option(400, "Field `" + field_name + "` is not a valid field."); + } + } + + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + new_doc_copy[field.name] = embedding; + } + new_doc = new_doc_copy; + return Option(true); +} + +void Collection::process_remove_field_for_embedding_fields(const field& the_field) { + std::vector::iterator> empty_fields; + for(auto& embedding_field : embedding_fields) { + const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) { + return other_field.name == embedding_field.name; + }); + actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) { + return the_field.name == field_name; + })); + embedding_field = *actual_field; + // store to remove embedding field if it has no field names in 'create_from' anymore. + if(embedding_field.create_from.size() == 0) { + empty_fields.push_back(actual_field); + } + } + + for(const auto& empty_field : empty_fields) { + search_schema.erase(empty_field->name); + embedding_fields.erase(empty_field->name); + fields.erase(empty_field); + } +} \ No newline at end of file From 1ad7bcdce3a249da0c17ae1f234a734bd2ca0c5b Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Wed, 5 Apr 2023 18:05:45 +0300 Subject: [PATCH 3/8] Review Changes --- include/collection.h | 3 + include/vector_query_ops.h | 4 +- src/collection.cpp | 140 ++++++++++++------------- src/index.cpp | 8 +- src/vector_query_ops.cpp | 8 +- test/collection_schema_change_test.cpp | 67 ++++++++++++ test/collection_test.cpp | 53 +++++----- test/collection_vector_search_test.cpp | 60 +++++++++++ 8 files changed, 236 insertions(+), 107 deletions(-) diff --git a/include/collection.h b/include/collection.h index 230ea65f..e3440006 100644 --- a/include/collection.h +++ b/include/collection.h @@ -209,6 +209,9 @@ private: std::vector& sort_fields_std, bool is_wildcard_query, bool is_group_by_query = false) const; + + Option validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const; + Option persist_collection_meta(); Option batch_alter_data(const std::vector& alter_fields, diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index 32ee7448..5ffd5c0b 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -10,7 +10,7 @@ struct vector_query_t { std::string field_name; size_t k = 0; size_t flat_search_cutoff = 0; - float similarity_cutoff = 0.0; + float distance_threshold = 2.01; std::vector values; uint32_t seq_id = 0; @@ -20,7 +20,7 @@ struct vector_query_t { // used for testing only field_name.clear(); k = 0; - similarity_cutoff = 0.0; + distance_threshold = 2.01; values.clear(); seq_id = 0; query_doc_given = false; diff --git a/src/collection.cpp b/src/collection.cpp index ae3b3d80..231a3f9a 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -51,6 +51,12 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co symbols_to_index(to_char_array(symbols_to_index)), token_separators(to_char_array(token_separators)), index(init_index()) { + for (auto const& field: fields) { + if (!field.create_from.empty()) { + embedding_fields.emplace(field.name, field); + } + } + this->num_documents = 0; } @@ -253,7 +259,7 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::infix] = coll_field.infix; field_json[fields::locale] = coll_field.locale; - if(coll_field.create_from.size() > 0) { + if(!coll_field.create_from.empty()) { field_json[fields::create_from] = coll_field.create_from; } @@ -1008,7 +1014,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().create_from.size() > 0; + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().create_from.empty(); if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -3765,7 +3771,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_fields.erase(del_field.name); } - if(del_field.create_from.size() > 0) { + if(!del_field.create_from.empty()) { embedding_fields.erase(del_field.name); } @@ -4063,7 +4069,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return Option(400, "Field `" + field_name + "` is not part of collection schema."); } - if(found_field && field_it.value().create_from.size() > 0) { + if(found_field && !field_it.value().create_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4072,7 +4078,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(field_it.key()); updated_nested_fields.erase(field_it.key()); - if(field_it.value().create_from.size() > 0) { + if(!field_it.value().create_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4086,7 +4092,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(prefix_kv.key()); updated_nested_fields.erase(prefix_kv.key()); - if(prefix_kv.value().create_from.size() > 0) { + if(!prefix_kv.value().create_from.empty()) { updated_embedding_fields.erase(prefix_kv.key()); } } @@ -4139,7 +4145,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, addition_fields.push_back(f); } - if(f.create_from.size() > 0) { + if(!f.create_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4154,7 +4160,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value()); updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); - if(prefix_kv.value().create_from.size() > 0) { + if(!prefix_kv.value().create_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4478,7 +4484,7 @@ Index* Collection::init_index() { nested_fields.emplace(field.name, field); } - if(field.create_from.size() > 0) { + if(!field.create_from.empty()) { embedding_fields.emplace(field.name, field); } @@ -4754,45 +4760,22 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha Option Collection::embed_fields(nlohmann::json& document) { + auto validate_res = validate_embed_fields(document, true); + if(!validate_res.ok()) { + return validate_res; + } for(const auto& field : embedding_fields) { - if(TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } std::string text_to_embed; for(const auto& field_name : field.create_from) { auto field_it = search_schema.find(field_name); - if(field_it != search_schema.end()) { - if(field_it.value().type == field_types::STRING) { - if(document.find(field_name) != document.end()) { - if(document[field_name].is_string()) { - text_to_embed += document[field_name].get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } else if(field_it.value().type == field_types::STRING_ARRAY) { - if(document.find(field_name) != document.end()) { - if(document[field_name].is_array()) { - for(const auto& val : document[field_name]) { - if(val.is_string()) { - text_to_embed += val.get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } + if(field_it.value().type == field_types::STRING) { + text_to_embed += document[field_name].get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : document[field_name]) { + text_to_embed += val.get() + " "; } - else { - return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); - } - } else { - return Option(400, "Field `" + field_name + "` is not a valid field."); } } - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); std::vector embedding = embedder->Embed(text_to_embed); @@ -4802,41 +4785,58 @@ Option Collection::embed_fields(nlohmann::json& document) { return Option(true); } +Option Collection::validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const { + if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); + } + for(const auto& field : embedding_fields) { + for(const auto& field_name : field.create_from) { + auto schema_field_it = search_schema.find(field_name); + auto doc_field_it = document.find(field_name); + if(schema_field_it == search_schema.end()) { + return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); + } + if(doc_field_it == document.end()) { + if(error_if_field_not_found) { + return Option(400, "Field `" + field_name + "` is needed to create embedding."); + } else { + continue; + } + } + if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || + (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + if(doc_field_it.value().is_array()) { + for(const auto& val : doc_field_it.value()) { + if(!val.is_string()) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } + } + } + + return Option(true); +} + Option Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) { + auto validate_res = validate_embed_fields(new_doc, false); + if(!validate_res.ok()) { + return validate_res; + } nlohmann::json new_doc_copy = new_doc; for(const auto& field : embedding_fields) { - if(TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } std::string text_to_embed; for(const auto& field_name : field.create_from) { auto field_it = search_schema.find(field_name); - if(field_it != search_schema.end()) { - nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; - if(field_it.value().type == field_types::STRING) { - if(value.is_string()) { - text_to_embed += value.get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } else if(field_it.value().type == field_types::STRING_ARRAY) { - if(value.is_array()) { - for(const auto& val : value) { - if(val.is_string()) { - text_to_embed += val.get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } + nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; + if(field_it.value().type == field_types::STRING) { + text_to_embed += value.get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : value) { + text_to_embed += val.get() + " "; } - else { - return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); - } - } else { - return Option(400, "Field `" + field_name + "` is not a valid field."); } } @@ -4860,7 +4860,7 @@ void Collection::process_remove_field_for_embedding_fields(const field& the_fiel })); embedding_field = *actual_field; // store to remove embedding field if it has no field names in 'create_from' anymore. - if(embedding_field.create_from.size() == 0) { + if(embedding_field.create_from.empty()) { empty_fields.push_back(actual_field); } } diff --git a/src/index.cpp b/src/index.cpp index 4e088576..7a110436 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2882,7 +2882,7 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - if(vector_query.similarity_cutoff > 0 && vec_dist_score > vector_query.similarity_cutoff) { + if(vec_dist_score > vector_query.distance_threshold) { continue; } @@ -3105,10 +3105,8 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - if(vector_query.similarity_cutoff > 0) { - if(vec_dist_score > vector_query.similarity_cutoff) { - continue; - } + if(vec_dist_score > vector_query.distance_threshold) { + continue; } vec_results.emplace_back(seq_id, vec_dist_score); } diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index 729157b0..54f65d5c 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -146,13 +146,13 @@ Option VectorQueryOps::parse_vector_query_str(std::string vector_query_str vector_query.flat_search_cutoff = std::stoi(param_kv[1]); } - if(param_kv[0] == "similarity_cutoff") { - if(!StringUtils::is_float(param_kv[1])) { + if(param_kv[0] == "distance_threshold") { + if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 2.0) { return Option(400, "Malformed vector query string: " - "`similarity_cutoff` parameter must be a float."); + "`distance_threshold` parameter must be a float between 0.0-2.0."); } - vector_query.similarity_cutoff = std::stof(param_kv[1]); + vector_query.distance_threshold = std::stof(param_kv[1]); } } diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index c32ba142..dd1301d9 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1446,3 +1446,70 @@ TEST_F(CollectionSchemaChangeTest, GeoFieldSchemaAddition) { ASSERT_TRUE(res_op.ok()); ASSERT_EQ(2, res_op.get()["found"].get()); } + +TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"} + ] + })"_json; + + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json update_schema = R"({ + "fields": [ + {"name": "embedding", "type":"float[]", "create_from": ["names"]} + ] + })"_json; + + auto res = coll->alter(update_schema); + + ASSERT_FALSE(res.ok()); + ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); +} + +TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "category", "type":"string"}, + {"name": "embedding", "type":"float[]", "create_from": ["names","category"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + LOG(INFO) << "Created collection"; + + auto schema_changes = R"({ + "fields": [ + {"name": "names", "drop": true} + ] + })"_json; + + LOG(INFO) << "Dropping field"; + + auto embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(2, embedding_fields["embedding"].create_from.size()); + + LOG(INFO) << "Before alter"; + + auto alter_op = coll->alter(schema_changes); + ASSERT_TRUE(alter_op.ok()); + + LOG(INFO) << "After alter"; + + embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(1, embedding_fields["embedding"].create_from.size()); + ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]); +} \ No newline at end of file diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 4e0b43a3..b19201f7 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4620,7 +4620,7 @@ TEST_F(CollectionTest, SemanticSearchTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4655,7 +4655,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4686,7 +4686,7 @@ TEST_F(CollectionTest, HybridSearch) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4719,7 +4719,7 @@ TEST_F(CollectionTest, EmbedFielsTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4747,7 +4747,7 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4821,7 +4821,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4853,7 +4853,7 @@ TEST_F(CollectionTest, EmbeddingFieldsMapTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4891,7 +4891,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4907,28 +4907,29 @@ TEST_F(CollectionTest, EmbedStringArrayField) { ASSERT_TRUE(add_op.ok()); } -TEST_F(CollectionTest, UpdateSchemaWithNewEmbeddingField) { +TEST_F(CollectionTest, MissingFieldForEmbedding) { nlohmann::json schema = R"({ - "name": "objects", - "fields": [ - {"name": "names", "type": "string[]"} - ] - })"_json; - + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "create_from": ["names", "category"]} + ] + })"_json; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + auto op = collectionManager.create_collection(schema); ASSERT_TRUE(op.ok()); Collection* coll = op.get(); - nlohmann::json update_schema = R"({ - "fields": [ - {"name": "embedding", "type":"float[]", "create_from": ["names"]} - ] - })"_json; - - auto res = coll->alter(update_schema); - - ASSERT_FALSE(res.ok()); - ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); -} + nlohmann::json doc; + doc["names"].push_back("butter"); + doc["names"].push_back("butterfly"); + doc["names"].push_back("butterball"); + auto add_op = coll->add(doc.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error()); +} \ No newline at end of file diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index ec233436..7f8b212b 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -716,4 +716,64 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { ASSERT_EQ(1, results_op.get()["found"].get()); ASSERT_EQ(1, results_op.get()["hits"].size()); +} + +TEST_F(CollectionVectorTest, DistanceThresholdTest) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + {"name": "vec", "type": "float[]", "num_dim": 3} + ] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + doc["vec"] = {0.1, 0.2, 0.3}; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + // write a vector which is 0.5 away from the first vector + doc["vec"] = {0.6, 0.7, 0.8}; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + + auto results_op = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:([0.3,0.4,0.5])"); + + ASSERT_EQ(true, results_op.ok()); + ASSERT_EQ(2, results_op.get()["found"].get()); + ASSERT_EQ(2, results_op.get()["hits"].size()); + + ASSERT_FLOAT_EQ(0.6, results_op.get()["hits"][0]["document"]["vec"].get>()[0]); + ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get>()[1]); + ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get>()[2]); + + ASSERT_FLOAT_EQ(0.1, results_op.get()["hits"][1]["document"]["vec"].get>()[0]); + ASSERT_FLOAT_EQ(0.2, results_op.get()["hits"][1]["document"]["vec"].get>()[1]); + ASSERT_FLOAT_EQ(0.3, results_op.get()["hits"][1]["document"]["vec"].get>()[2]); + + results_op = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:([0.3,0.4,0.5], distance_threshold:0.01)"); + + ASSERT_EQ(true, results_op.ok()); + ASSERT_EQ(1, results_op.get()["found"].get()); + ASSERT_EQ(1, results_op.get()["hits"].size()); + + ASSERT_FLOAT_EQ(0.6, results_op.get()["hits"][0]["document"]["vec"].get>()[0]); + ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get>()[1]); + ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get>()[2]); + + } \ No newline at end of file From b7c988ab45398e1130a6b5792c709b7e1165350f Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Wed, 5 Apr 2023 18:23:46 +0300 Subject: [PATCH 4/8] Added more test --- test/collection_schema_change_test.cpp | 38 ++++++++++++ test/collection_test.cpp | 83 +++++++++++++++----------- 2 files changed, 85 insertions(+), 36 deletions(-) diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index dd1301d9..ec6a88e7 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1512,4 +1512,42 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { embedding_fields = coll->get_embedding_fields(); ASSERT_EQ(1, embedding_fields["embedding"].create_from.size()); ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]); +} + +TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "create_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + auto embedding_fields_map = coll->get_embedding_fields(); + ASSERT_EQ(1, embedding_fields_map.size()); + auto embedding_field_it = embedding_fields_map.find("embedding"); + ASSERT_TRUE(embedding_field_it != embedding_fields_map.end()); + ASSERT_EQ("embedding", embedding_field_it.value().name); + ASSERT_EQ(1, embedding_field_it.value().create_from.size()); + ASSERT_EQ("name", embedding_field_it.value().create_from[0]); + + // drop the embedding field + nlohmann::json schema_without_embedding = R"({ + "fields": [ + {"name": "embedding", "drop": true} + ] + })"_json; + auto update_op = coll->alter(schema_without_embedding); + + ASSERT_TRUE(update_op.ok()); + + embedding_fields_map = coll->get_embedding_fields(); + ASSERT_EQ(0, embedding_fields_map.size()); } \ No newline at end of file diff --git a/test/collection_test.cpp b/test/collection_test.cpp index b19201f7..b558c1d6 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4844,43 +4844,7 @@ TEST_F(CollectionTest, CreateModelDirIfNotExists) { ASSERT_TRUE(std::filesystem::exists("/tmp/typesense_test/models")); } -TEST_F(CollectionTest, EmbeddingFieldsMapTest) { - nlohmann::json schema = R"({ - "name": "objects", - "fields": [ - {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} - ] - })"_json; - - TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); - TextEmbedderManager::download_default_model(); - auto op = collectionManager.create_collection(schema); - ASSERT_TRUE(op.ok()); - Collection* coll = op.get(); - - auto embedding_fields_map = coll->get_embedding_fields(); - ASSERT_EQ(1, embedding_fields_map.size()); - auto embedding_field_it = embedding_fields_map.find("embedding"); - ASSERT_TRUE(embedding_field_it != embedding_fields_map.end()); - ASSERT_EQ("embedding", embedding_field_it.value().name); - ASSERT_EQ(1, embedding_field_it.value().create_from.size()); - ASSERT_EQ("name", embedding_field_it.value().create_from[0]); - - // drop the embedding field - nlohmann::json schema_without_embedding = R"({ - "fields": [ - {"name": "embedding", "drop": true} - ] - })"_json; - auto update_op = coll->alter(schema_without_embedding); - - ASSERT_TRUE(update_op.ok()); - - embedding_fields_map = coll->get_embedding_fields(); - ASSERT_EQ(0, embedding_fields_map.size()); -} TEST_F(CollectionTest, EmbedStringArrayField) { nlohmann::json schema = R"({ @@ -4932,4 +4896,51 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) { auto add_op = coll->add(doc.dump()); ASSERT_FALSE(add_op.ok()); ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error()); +} + +TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "create_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json doc; + doc["name"] = "butter"; + + auto add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + // get embedding field + + // get id of the document + auto id = add_op.get()["id"]; + // get embedding field from the document + auto embedding_field = add_op.get()["embedding"].get>(); + ASSERT_EQ(384, embedding_field.size()); + + // update the document + nlohmann::json update_doc; + update_doc["name"] = "butterball"; + std::string dirty_values; + + auto update_op = coll->update_matching_filter("id:=" + id.get(), update_doc.dump(), dirty_values); + ASSERT_TRUE(update_op.ok()); + ASSERT_EQ(1, update_op.get()["num_updated"]); + + // get the document again + auto get_op = coll->get(id); + ASSERT_TRUE(get_op.ok()); + auto updated_embedding_field = get_op.get()["embedding"].get>(); + + // check if the embedding field is updated + ASSERT_NE(embedding_field, updated_embedding_field); } \ No newline at end of file From 7ae3cc9781ee6efc3809f9eb4435c19b661ad1d9 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Fri, 7 Apr 2023 23:56:25 +0300 Subject: [PATCH 5/8] Review Changes II --- include/collection.h | 5 - include/field.h | 42 +++---- include/index.h | 12 ++ include/validator.h | 6 + src/collection.cpp | 147 ++++--------------------- src/collection_manager.cpp | 6 +- src/field.cpp | 10 +- src/index.cpp | 64 ++++++++++- src/validator.cpp | 51 +++++++++ test/collection_all_fields_test.cpp | 14 +-- test/collection_schema_change_test.cpp | 16 +-- test/collection_test.cpp | 61 +++++----- test/collection_vector_search_test.cpp | 2 +- 13 files changed, 228 insertions(+), 208 deletions(-) diff --git a/include/collection.h b/include/collection.h index e3440006..dce59d0b 100644 --- a/include/collection.h +++ b/include/collection.h @@ -210,8 +210,6 @@ private: bool is_wildcard_query, bool is_group_by_query = false) const; - Option validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const; - Option persist_collection_meta(); Option batch_alter_data(const std::vector& alter_fields, @@ -358,9 +356,6 @@ public: const DIRTY_VALUES dirty_values, const std::string& id=""); - Option embed_fields(nlohmann::json& document); - - Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc); static uint32_t get_seq_id_from_key(const std::string & key); diff --git a/include/field.h b/include/field.h index 5b44b52a..febb2a1e 100644 --- a/include/field.h +++ b/include/field.h @@ -50,7 +50,7 @@ namespace fields { static const std::string num_dim = "num_dim"; static const std::string vec_dist = "vec_dist"; static const std::string reference = "reference"; - static const std::string create_from = "create_from"; + static const std::string embed_from = "embed_from"; static const std::string model_name = "model_name"; } @@ -77,7 +77,7 @@ struct field { int nested_array; size_t num_dim; - std::vector create_from; + std::vector embed_from; std::string model_name; vector_distance_type_t vec_dist; @@ -89,9 +89,9 @@ struct field { field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false, - int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector &create_from = {}, const std::string& model_name = "") : + int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const std::vector &embed_from = {}, const std::string& model_name = "") : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale), - nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), create_from(create_from), model_name(model_name) { + nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed_from(embed_from), model_name(model_name) { set_computed_defaults(sort, infix); } @@ -319,8 +319,8 @@ struct field { if (!field.reference.empty()) { field_val[fields::reference] = field.reference; } - if(!field.create_from.empty()) { - field_val[fields::create_from] = field.create_from; + if(!field.embed_from.empty()) { + field_val[fields::embed_from] = field.embed_from; if(!field.model_name.empty()) { field_val[fields::model_name] = field.model_name; } @@ -421,36 +421,36 @@ struct field { for(nlohmann::json & field_json: fields_json) { - if(field_json.count(fields::create_from) != 0) { + if(field_json.count(fields::embed_from) != 0) { if(TextEmbedderManager::model_dir.empty()) { return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); } - if(!field_json[fields::create_from].is_array()) { - return Option(400, "Property `" + fields::create_from + "` must be an array."); + if(!field_json[fields::embed_from].is_array()) { + return Option(400, "Property `" + fields::embed_from + "` must be an array."); } - if(field_json[fields::create_from].empty()) { - return Option(400, "Property `" + fields::create_from + "` must have at least one element."); + if(field_json[fields::embed_from].empty()) { + return Option(400, "Property `" + fields::embed_from + "` must have at least one element."); } - for(auto& create_from_field : field_json[fields::create_from]) { - if(!create_from_field.is_string()) { - return Option(400, "Property `" + fields::create_from + "` must contain only field names as strings."); + for(auto& embed_from_field : field_json[fields::embed_from]) { + if(!embed_from_field.is_string()) { + return Option(400, "Property `" + fields::embed_from + "` must contain only field names as strings."); } } if(field_json[fields::type] != field_types::FLOAT_ARRAY) { - return Option(400, "Property `" + fields::create_from + "` is only allowed on a float array field."); + return Option(400, "Property `" + fields::embed_from + "` is only allowed on a float array field."); } - for(auto& create_from_field : field_json[fields::create_from]) { + for(auto& embed_from_field : field_json[fields::embed_from]) { bool flag = false; for(const auto& field : fields_json) { - if(field[fields::name] == create_from_field) { + if(field[fields::name] == embed_from_field) { if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) { - return Option(400, "Property `" + fields::create_from + "` can only have string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only have string or string array fields."); } flag = true; break; @@ -458,9 +458,9 @@ struct field { } if(!flag) { for(const auto& field : the_fields) { - if(field.name == create_from_field) { + if(field.name == embed_from_field) { if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) { - return Option(400, "Property `" + fields::create_from + "` can only have used with string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only have used with string or string array fields."); } flag = true; break; @@ -468,7 +468,7 @@ struct field { } } if(!flag) { - return Option(400, "Property `" + fields::create_from + "` can only be used with string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only be used with string or string array fields."); } } } diff --git a/include/index.h b/include/index.h index 6c4d66fe..76ca3696 100644 --- a/include/index.h +++ b/include/index.h @@ -532,6 +532,16 @@ private: const std::string& token, uint32_t seq_id); void initialize_facet_indexes(const field& facet_field); + + + + static Option embed_fields(nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema); + + static Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema); public: // for limiting number of results on multiple candidates / query rewrites @@ -663,6 +673,7 @@ public: const size_t batch_start_index, const size_t batch_size, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, @@ -672,6 +683,7 @@ public: std::vector& iter_batch, const std::string& default_sorting_field, const tsl::htrie_map& search_schema, + const tsl::htrie_map & embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, diff --git a/include/validator.h b/include/validator.h index 638b5eb3..3998f5dc 100644 --- a/include/validator.h +++ b/include/validator.h @@ -27,6 +27,7 @@ public: static Option validate_index_in_memory(nlohmann::json &document, uint32_t seq_id, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const index_operation_t op, const bool is_update, const std::string& fallback_field_type, @@ -67,4 +68,9 @@ public: nlohmann::json::iterator& array_iter, bool is_array, bool& array_ele_erased); + static Option validate_embed_fields(const nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema, + const bool& error_if_field_not_found); + }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 231a3f9a..676628d2 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -52,7 +52,7 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co index(init_index()) { for (auto const& field: fields) { - if (!field.create_from.empty()) { + if (!field.embed_from.empty()) { embedding_fields.emplace(field.name, field); } } @@ -108,12 +108,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: uint32_t seq_id = get_next_seq_id(); document["id"] = std::to_string(seq_id); - // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc - auto embed_res = embed_fields(document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } - // Add reference helper fields in the document. for (auto const& pair: reference_fields) { @@ -191,14 +185,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: // UPSERT, EMPLACE or UPDATE uint32_t seq_id = (uint32_t) std::stoul(seq_id_str); - //Handle embedding here for UPDATE - nlohmann::json old_doc; - get_document_from_store(get_seq_id_key(seq_id), old_doc); - auto embed_res = embed_fields_update(old_doc, document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } - return Option(doc_seq_id_t{seq_id, false}); } else { @@ -209,11 +195,6 @@ Option Collection::to_doc(const std::string & json_str, nlohmann:: // for UPSERT, EMPLACE or CREATE, if a document with given ID is not found, we will treat it as a new doc uint32_t seq_id = get_next_seq_id(); - // Handle embedding here for UPSERT, EMPLACE or CREATE when we treat is as a new doc - auto embed_res = embed_fields(document); - if (!embed_res.ok()) { - return Option(400, embed_res.error()); - } return Option(doc_seq_id_t{seq_id, true}); } } @@ -259,8 +240,8 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::infix] = coll_field.infix; field_json[fields::locale] = coll_field.locale; - if(!coll_field.create_from.empty()) { - field_json[fields::create_from] = coll_field.create_from; + if(!coll_field.embed_from.empty()) { + field_json[fields::embed_from] = coll_field.embed_from; } if(coll_field.model_name.size() > 0) { @@ -397,6 +378,7 @@ nlohmann::json Collection::add_many(std::vector& json_lines, nlohma do_batched_index: + if((i+1) % index_batch_size == 0 || i == json_lines.size()-1 || repeated_doc) { batch_index(index_records, json_lines, num_indexed, return_doc, return_id); @@ -593,7 +575,7 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t std::unique_lock lock(mutex); Option validation_op = validator_t::validate_index_in_memory(document, seq_id, default_sorting_field, - search_schema, op, false, + search_schema, embedding_fields, op, false, fallback_field_type, dirty_values); if(!validation_op.ok()) { @@ -604,7 +586,7 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t std::vector index_batch; index_batch.emplace_back(std::move(rec)); - Index::batch_memory_index(index, index_batch, default_sorting_field, search_schema, + Index::batch_memory_index(index, index_batch, default_sorting_field, search_schema, embedding_fields, fallback_field_type, token_separators, symbols_to_index, true); num_documents += 1; @@ -614,7 +596,7 @@ Option Collection::index_in_memory(nlohmann::json &document, uint32_t size_t Collection::batch_index_in_memory(std::vector& index_records) { std::unique_lock lock(mutex); size_t num_indexed = Index::batch_memory_index(index, index_records, default_sorting_field, - search_schema, fallback_field_type, + search_schema, embedding_fields, fallback_field_type, token_separators, symbols_to_index, true); num_documents += num_indexed; return num_indexed; @@ -1014,7 +996,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().create_from.empty(); + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().embed_from.empty(); if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -3735,7 +3717,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields } } - Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, + Index::batch_memory_index(index, iter_batch, default_sorting_field, schema_additions, embedding_fields, fallback_field_type, token_separators, symbols_to_index, true); iter_batch.clear(); @@ -3771,7 +3753,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_fields.erase(del_field.name); } - if(!del_field.create_from.empty()) { + if(!del_field.embed_from.empty()) { embedding_fields.erase(del_field.name); } @@ -4069,7 +4051,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return Option(400, "Field `" + field_name + "` is not part of collection schema."); } - if(found_field && !field_it.value().create_from.empty()) { + if(found_field && !field_it.value().embed_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4078,7 +4060,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(field_it.key()); updated_nested_fields.erase(field_it.key()); - if(!field_it.value().create_from.empty()) { + if(!field_it.value().embed_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4092,7 +4074,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(prefix_kv.key()); updated_nested_fields.erase(prefix_kv.key()); - if(!prefix_kv.value().create_from.empty()) { + if(!prefix_kv.value().embed_from.empty()) { updated_embedding_fields.erase(prefix_kv.key()); } } @@ -4145,7 +4127,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, addition_fields.push_back(f); } - if(!f.create_from.empty()) { + if(!f.embed_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4160,7 +4142,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value()); updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); - if(!prefix_kv.value().create_from.empty()) { + if(!prefix_kv.value().embed_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4234,6 +4216,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, // validate existing data on disk for compatibility via updated_search_schema auto validate_op = validator_t::validate_index_in_memory(document, seq_id, default_sorting_field, updated_search_schema, + updated_embedding_fields, index_operation_t::CREATE, false, fallback_field_type, @@ -4484,7 +4467,7 @@ Index* Collection::init_index() { nested_fields.emplace(field.name, field); } - if(!field.create_from.empty()) { + if(!field.embed_from.empty()) { embedding_fields.emplace(field.name, field); } @@ -4759,108 +4742,18 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha } -Option Collection::embed_fields(nlohmann::json& document) { - auto validate_res = validate_embed_fields(document, true); - if(!validate_res.ok()) { - return validate_res; - } - for(const auto& field : embedding_fields) { - std::string text_to_embed; - for(const auto& field_name : field.create_from) { - auto field_it = search_schema.find(field_name); - if(field_it.value().type == field_types::STRING) { - text_to_embed += document[field_name].get() + " "; - } else if(field_it.value().type == field_types::STRING_ARRAY) { - for(const auto& val : document[field_name]) { - text_to_embed += val.get() + " "; - } - } - } - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - document[field.name] = embedding; - } - - return Option(true); -} - -Option Collection::validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const { - if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } - for(const auto& field : embedding_fields) { - for(const auto& field_name : field.create_from) { - auto schema_field_it = search_schema.find(field_name); - auto doc_field_it = document.find(field_name); - if(schema_field_it == search_schema.end()) { - return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); - } - if(doc_field_it == document.end()) { - if(error_if_field_not_found) { - return Option(400, "Field `" + field_name + "` is needed to create embedding."); - } else { - continue; - } - } - if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || - (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - if(doc_field_it.value().is_array()) { - for(const auto& val : doc_field_it.value()) { - if(!val.is_string()) { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } - } - } - - return Option(true); -} - -Option Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) { - auto validate_res = validate_embed_fields(new_doc, false); - if(!validate_res.ok()) { - return validate_res; - } - nlohmann::json new_doc_copy = new_doc; - for(const auto& field : embedding_fields) { - std::string text_to_embed; - for(const auto& field_name : field.create_from) { - auto field_it = search_schema.find(field_name); - nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; - if(field_it.value().type == field_types::STRING) { - text_to_embed += value.get() + " "; - } else if(field_it.value().type == field_types::STRING_ARRAY) { - for(const auto& val : value) { - text_to_embed += val.get() + " "; - } - } - } - - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - new_doc_copy[field.name] = embedding; - } - new_doc = new_doc_copy; - return Option(true); -} - void Collection::process_remove_field_for_embedding_fields(const field& the_field) { std::vector::iterator> empty_fields; for(auto& embedding_field : embedding_fields) { const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) { return other_field.name == embedding_field.name; }); - actual_field->create_from.erase(std::remove_if(actual_field->create_from.begin(), actual_field->create_from.end(), [&the_field](std::string field_name) { + actual_field->embed_from.erase(std::remove_if(actual_field->embed_from.begin(), actual_field->embed_from.end(), [&the_field](std::string field_name) { return the_field.name == field_name; })); embedding_field = *actual_field; - // store to remove embedding field if it has no field names in 'create_from' anymore. - if(embedding_field.create_from.empty()) { + // store to remove embedding field if it has no field names in 'embed_from' anymore. + if(embedding_field.embed_from.empty()) { empty_fields.push_back(actual_field); } } diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index f42d356a..aadcbf31 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -58,8 +58,8 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection field_obj[fields::reference] = ""; } - if(field_obj.count(fields::create_from) == 0) { - field_obj[fields::create_from] = std::vector(); + if(field_obj.count(fields::embed_from) == 0) { + field_obj[fields::embed_from] = std::vector(); } if(field_obj.count(fields::model_name) == 0) { @@ -78,7 +78,7 @@ Collection* CollectionManager::init_collection(const nlohmann::json & collection field f(field_obj[fields::name], field_obj[fields::type], field_obj[fields::facet], field_obj[fields::optional], field_obj[fields::index], field_obj[fields::locale], -1, field_obj[fields::infix], field_obj[fields::nested], field_obj[fields::nested_array], - field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::create_from], + field_obj[fields::num_dim], vec_dist_type, field_obj[fields::reference], field_obj[fields::embed_from], field_obj[fields::model_name]); // value of `sort` depends on field type diff --git a/src/field.cpp b/src/field.cpp index 3675d43c..ee808158 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -672,11 +672,11 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } } - if(field_json.count(fields::model_name) > 0 && field_json.count(fields::create_from) == 0) { - return Option(400, "Property `" + fields::model_name + "` can only be used with `" + fields::create_from + "`."); + if(field_json.count(fields::model_name) > 0 && field_json.count(fields::embed_from) == 0) { + return Option(400, "Property `" + fields::model_name + "` can only be used with `" + fields::embed_from + "`."); } - if(field_json.count(fields::create_from) != 0) { + if(field_json.count(fields::embed_from) != 0) { // If the model path is not specified, use the default model and set the number of dimensions to 384 (number of dimensions of the default model) field_json[fields::num_dim] = static_cast(384); if(field_json.count(fields::model_name) != 0) { @@ -695,7 +695,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } } } else { - field_json[fields::create_from] = std::vector(); + field_json[fields::embed_from] = std::vector(); } @@ -784,7 +784,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::optional], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix], field_json[fields::nested], field_json[fields::nested_array], field_json[fields::num_dim], vec_dist, - field_json[fields::reference], field_json[fields::create_from].get>(), + field_json[fields::reference], field_json[fields::embed_from].get>(), field_json[fields::model_name]) ); diff --git a/src/index.cpp b/src/index.cpp index 7a110436..0b59c823 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -411,6 +411,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite const size_t batch_start_index, const size_t batch_size, const std::string& default_sorting_field, const tsl::htrie_map& search_schema, + const tsl::htrie_map& embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, @@ -435,6 +436,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite Option validation_op = validator_t::validate_index_in_memory(index_rec.doc, index_rec.seq_id, default_sorting_field, search_schema, + embedding_fields, index_rec.operation, index_rec.is_update, fallback_field_type, @@ -451,6 +453,9 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite get_doc_changes(index_rec.operation, index_rec.doc, index_rec.old_doc, index_rec.new_doc, index_rec.del_doc); scrub_reindex_doc(search_schema, index_rec.doc, index_rec.del_doc, index_rec.old_doc); + embed_fields(index_rec.new_doc, embedding_fields, search_schema); + } else { + embed_fields(index_rec.doc, embedding_fields, search_schema); } compute_token_offsets_facets(index_rec, search_schema, token_separators, symbols_to_index); @@ -485,6 +490,7 @@ void Index::validate_and_preprocess(Index *index, std::vector& ite size_t Index::batch_memory_index(Index *index, std::vector& iter_batch, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const std::string& fallback_field_type, const std::vector& token_separators, const std::vector& symbols_to_index, @@ -518,7 +524,7 @@ size_t Index::batch_memory_index(Index *index, std::vector& iter_b index->thread_pool->enqueue([&, batch_index, batch_len]() { write_log_index = local_write_log_index; validate_and_preprocess(index, iter_batch, batch_index, batch_len, default_sorting_field, search_schema, - fallback_field_type, token_separators, symbols_to_index, do_validation); + embedding_fields, fallback_field_type, token_separators, symbols_to_index, do_validation); std::unique_lock lock(m_process); num_processed++; @@ -6257,6 +6263,61 @@ bool Index::common_results_exist(std::vector& leaves, bool must_match return phrase_exists; } +Option Index::embed_fields(nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema) { + for(const auto& field : embedding_fields) { + std::string text_to_embed; + for(const auto& field_name : field.embed_from) { + auto field_it = search_schema.find(field_name); + if(field_it.value().type == field_types::STRING) { + text_to_embed += document[field_name].get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : document[field_name]) { + text_to_embed += val.get() + " "; + } + } + } + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + document[field.name] = embedding; + } + + return Option(true); +} + + +Option Index::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema) { + nlohmann::json new_doc_copy = new_doc; + for(const auto& field : embedding_fields) { + std::string text_to_embed; + for(const auto& field_name : field.embed_from) { + auto field_it = search_schema.find(field_name); + nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; + if(field_it.value().type == field_types::STRING) { + text_to_embed += value.get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : value) { + text_to_embed += val.get() + " "; + } + } + } + + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + new_doc_copy[field.name] = embedding; + } + new_doc = new_doc_copy; + return Option(true); +} + + + + /* // https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon // NOTE: polygon and point should have been transformed with `transform_for_180th_meridian` @@ -6302,3 +6363,4 @@ void Index::transform_for_180th_meridian(GeoCoord &point, double offset) { point.lon = point.lon < 0.0 ? point.lon + offset : point.lon; } */ + diff --git a/src/validator.cpp b/src/validator.cpp index 9616b84b..f66eaf9d 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -529,6 +529,7 @@ Option validator_t::coerce_float(const DIRTY_VALUES& dirty_values, con Option validator_t::validate_index_in_memory(nlohmann::json& document, uint32_t seq_id, const std::string & default_sorting_field, const tsl::htrie_map & search_schema, + const tsl::htrie_map & embedding_fields, const index_operation_t op, const bool is_update, const std::string& fallback_field_type, @@ -544,6 +545,11 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, for(const auto& a_field: search_schema) { const std::string& field_name = a_field.name; + // ignore embedding fields, they will be validated later + if(embedding_fields.count(field_name) > 0) { + continue; + } + if(field_name == "id" || a_field.is_object()) { continue; } @@ -574,5 +580,50 @@ Option validator_t::validate_index_in_memory(nlohmann::json& document, } } + // validate embedding fields + auto validate_embed_op = validate_embed_fields(document, embedding_fields, search_schema, !is_update); + if(!validate_embed_op.ok()) { + return Option<>(validate_embed_op.code(), validate_embed_op.error()); + } + return Option<>(200); +} + + +Option validator_t::validate_embed_fields(const nlohmann::json& document, + const tsl::htrie_map& embedding_fields, + const tsl::htrie_map & search_schema, + const bool& error_if_field_not_found) { + if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); + } + for(const auto& field : embedding_fields) { + for(const auto& field_name : field.embed_from) { + auto schema_field_it = search_schema.find(field_name); + auto doc_field_it = document.find(field_name); + if(schema_field_it == search_schema.end()) { + return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); + } + if(doc_field_it == document.end()) { + if(error_if_field_not_found) { + return Option(400, "Field `" + field_name + "` is needed to create embedding."); + } else { + continue; + } + } + if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || + (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + if(doc_field_it.value().is_array()) { + for(const auto& val : doc_field_it.value()) { + if(!val.is_string()) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } + } + } + + return Option(true); } \ No newline at end of file diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 1a8af133..9269c858 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1597,7 +1597,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; - field_json["create_from"] = {"name"}; + field_json["embed_from"] = {"name"}; std::vector fields; std::string fallback_field_type; @@ -1607,7 +1607,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `create_from` can only be used with string or string array fields.", field_op.error()); + ASSERT_EQ("Property `embed_from` can only be used with string or string array fields.", field_op.error()); } TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { @@ -1615,7 +1615,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; - field_json["create_from"] = {"name"}; + field_json["embed_from"] = {"name"}; std::vector fields; std::string fallback_field_type; @@ -1633,7 +1633,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) { nlohmann::json field_json; field_json["name"] = "embedding"; field_json["type"] = "float[]"; - field_json["create_from"] = "name"; + field_json["embed_from"] = "name"; std::vector fields; std::string fallback_field_type; @@ -1643,7 +1643,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `create_from` must be an array.", field_op.error()); + ASSERT_EQ("Property `embed_from` must be an array.", field_op.error()); } TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) { @@ -1660,7 +1660,7 @@ TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `model_name` can only be used with `create_from`.", field_op.error()); + ASSERT_EQ("Property `model_name` can only be used with `embed_from`.", field_op.error()); } @@ -1670,7 +1670,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) { TextEmbedderManager::download_default_model(); field embedding = field("embedding", field_types::FLOAT_ARRAY, false); - embedding.create_from.push_back("name"); + embedding.embed_from.push_back("name"); std::vector fields = {field("name", field_types::STRING, false), embedding}; auto obj_coll_op = collectionManager.create_collection("obj_coll", 1, fields, "", 0, field_types::AUTO); diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index ec6a88e7..82ccd65a 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1462,7 +1462,7 @@ TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { nlohmann::json update_schema = R"({ "fields": [ - {"name": "embedding", "type":"float[]", "create_from": ["names"]} + {"name": "embedding", "type":"float[]", "embed_from": ["names"]} ] })"_json; @@ -1478,7 +1478,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { "fields": [ {"name": "names", "type": "string[]"}, {"name": "category", "type":"string"}, - {"name": "embedding", "type":"float[]", "create_from": ["names","category"]} + {"name": "embedding", "type":"float[]", "embed_from": ["names","category"]} ] })"_json; @@ -1500,7 +1500,7 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { LOG(INFO) << "Dropping field"; auto embedding_fields = coll->get_embedding_fields(); - ASSERT_EQ(2, embedding_fields["embedding"].create_from.size()); + ASSERT_EQ(2, embedding_fields["embedding"].embed_from.size()); LOG(INFO) << "Before alter"; @@ -1510,8 +1510,8 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { LOG(INFO) << "After alter"; embedding_fields = coll->get_embedding_fields(); - ASSERT_EQ(1, embedding_fields["embedding"].create_from.size()); - ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]); + ASSERT_EQ(1, embedding_fields["embedding"].embed_from.size()); + ASSERT_EQ("category", embedding_fields["embedding"].embed_from[0]); } TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { @@ -1519,7 +1519,7 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -1535,8 +1535,8 @@ TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { auto embedding_field_it = embedding_fields_map.find("embedding"); ASSERT_TRUE(embedding_field_it != embedding_fields_map.end()); ASSERT_EQ("embedding", embedding_field_it.value().name); - ASSERT_EQ(1, embedding_field_it.value().create_from.size()); - ASSERT_EQ("name", embedding_field_it.value().create_from[0]); + ASSERT_EQ(1, embedding_field_it.value().embed_from.size()); + ASSERT_EQ("name", embedding_field_it.value().embed_from[0]); // drop the embedding field nlohmann::json schema_without_embedding = R"({ diff --git a/test/collection_test.cpp b/test/collection_test.cpp index b558c1d6..2a15e531 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4616,7 +4616,7 @@ TEST_F(CollectionTest, SemanticSearchTest) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4651,7 +4651,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4682,7 +4682,7 @@ TEST_F(CollectionTest, HybridSearch) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4695,6 +4695,7 @@ TEST_F(CollectionTest, HybridSearch) { nlohmann::json object; object["name"] = "apple"; auto add_op = coll->add(object.dump()); + LOG(INFO) << "add_op.error(): " << add_op.error(); ASSERT_TRUE(add_op.ok()); ASSERT_EQ("apple", add_op.get()["name"]); @@ -4710,40 +4711,40 @@ TEST_F(CollectionTest, HybridSearch) { ASSERT_EQ(384, search_res["hits"][0]["document"]["embedding"].size()); } -TEST_F(CollectionTest, EmbedFielsTest) { - nlohmann::json schema = R"({ - "name": "objects", - "fields": [ - {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} - ] - })"_json; +// TEST_F(CollectionTest, EmbedFielsTest) { +// nlohmann::json schema = R"({ +// "name": "objects", +// "fields": [ +// {"name": "name", "type": "string"}, +// {"name": "embedding", "type":"float[]", "embed_from": ["name"]} +// ] +// })"_json; - TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); - TextEmbedderManager::download_default_model(); +// TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); +// TextEmbedderManager::download_default_model(); - auto op = collectionManager.create_collection(schema); - ASSERT_TRUE(op.ok()); - Collection* coll = op.get(); +// auto op = collectionManager.create_collection(schema); +// ASSERT_TRUE(op.ok()); +// Collection* coll = op.get(); - nlohmann::json object = R"({ - "name": "apple" - })"_json; +// nlohmann::json object = R"({ +// "name": "apple" +// })"_json; - auto embed_op = coll->embed_fields(object); +// auto embed_op = coll->embed_fields(object); - ASSERT_TRUE(embed_op.ok()); +// ASSERT_TRUE(embed_op.ok()); - ASSERT_EQ("apple", object["name"]); - ASSERT_EQ(384, object["embedding"].get>().size()); -} +// ASSERT_EQ("apple", object["name"]); +// ASSERT_EQ(384, object["embedding"].get>().size()); +// } TEST_F(CollectionTest, HybridSearchRankFusionTest) { nlohmann::json schema = R"({ "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4817,7 +4818,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; @@ -4851,7 +4852,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) { "name": "objects", "fields": [ {"name": "names", "type": "string[]"}, - {"name": "embedding", "type":"float[]", "create_from": ["names"]} + {"name": "embedding", "type":"float[]", "embed_from": ["names"]} ] })"_json; @@ -4876,8 +4877,8 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) { "name": "objects", "fields": [ {"name": "names", "type": "string[]"}, - {"name": "category", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["names", "category"]} + {"name": "category", "type": "string", "optional": true}, + {"name": "embedding", "type":"float[]", "embed_from": ["names", "category"]} ] })"_json; @@ -4903,7 +4904,7 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) { "name": "objects", "fields": [ {"name": "name", "type": "string"}, - {"name": "embedding", "type":"float[]", "create_from": ["name"]} + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} ] })"_json; diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 7f8b212b..e1dc5d10 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -683,7 +683,7 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { "name": "coll1", "fields": [ {"name": "name", "type": "string"}, - {"name": "vec", "type": "float[]", "create_from": ["name"]} + {"name": "vec", "type": "float[]", "embed_from": ["name"]} ] })"_json; From 7d2d3f39109e7a8adc08452377b9f64eb770dfd3 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sat, 8 Apr 2023 00:31:44 +0300 Subject: [PATCH 6/8] Removed unused function --- include/index.h | 4 ---- src/index.cpp | 31 ------------------------------- 2 files changed, 35 deletions(-) diff --git a/include/index.h b/include/index.h index 76ca3696..19bfc902 100644 --- a/include/index.h +++ b/include/index.h @@ -538,10 +538,6 @@ private: static Option embed_fields(nlohmann::json& document, const tsl::htrie_map& embedding_fields, const tsl::htrie_map & search_schema); - - static Option embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc, - const tsl::htrie_map& embedding_fields, - const tsl::htrie_map & search_schema); public: // for limiting number of results on multiple candidates / query rewrites diff --git a/src/index.cpp b/src/index.cpp index 0b59c823..cf5a49d3 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -6287,37 +6287,6 @@ Option Index::embed_fields(nlohmann::json& document, return Option(true); } - -Option Index::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc, - const tsl::htrie_map& embedding_fields, - const tsl::htrie_map & search_schema) { - nlohmann::json new_doc_copy = new_doc; - for(const auto& field : embedding_fields) { - std::string text_to_embed; - for(const auto& field_name : field.embed_from) { - auto field_it = search_schema.find(field_name); - nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; - if(field_it.value().type == field_types::STRING) { - text_to_embed += value.get() + " "; - } else if(field_it.value().type == field_types::STRING_ARRAY) { - for(const auto& val : value) { - text_to_embed += val.get() + " "; - } - } - } - - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - new_doc_copy[field.name] = embedding; - } - new_doc = new_doc_copy; - return Option(true); -} - - - - /* // https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon // NOTE: polygon and point should have been transformed with `transform_for_180th_meridian` From a50fe99a7c47ec9104e6e5ae6688e6a96cc99d56 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Mon, 10 Apr 2023 06:23:16 +0300 Subject: [PATCH 7/8] Review Changes III --- include/collection.h | 2 +- include/field.h | 6 +-- src/collection.cpp | 40 ++++++++++---------- test/collection_all_fields_test.cpp | 30 ++++++++++++--- test/collection_schema_change_test.cpp | 18 ++++++--- test/collection_test.cpp | 52 +++++++++++++++++++++++++- test/collection_vector_search_test.cpp | 33 ++++++++++++++++ 7 files changed, 145 insertions(+), 36 deletions(-) diff --git a/include/collection.h b/include/collection.h index dce59d0b..8cc8b15a 100644 --- a/include/collection.h +++ b/include/collection.h @@ -162,7 +162,7 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); - void process_remove_field_for_embedding_fields(const field& the_field); + std::vector process_remove_field_for_embedding_fields(const field& the_field); void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, diff --git a/include/field.h b/include/field.h index febb2a1e..d4361db9 100644 --- a/include/field.h +++ b/include/field.h @@ -450,7 +450,7 @@ struct field { for(const auto& field : fields_json) { if(field[fields::name] == embed_from_field) { if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) { - return Option(400, "Property `" + fields::embed_from + "` can only have string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields."); } flag = true; break; @@ -460,7 +460,7 @@ struct field { for(const auto& field : the_fields) { if(field.name == embed_from_field) { if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) { - return Option(400, "Property `" + fields::embed_from + "` can only have used with string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields."); } flag = true; break; @@ -468,7 +468,7 @@ struct field { } } if(!flag) { - return Option(400, "Property `" + fields::embed_from + "` can only be used with string or string array fields."); + return Option(400, "Property `" + fields::embed_from + "` can only refer to string or string array fields."); } } } diff --git a/src/collection.cpp b/src/collection.cpp index 676628d2..457e155d 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -3736,7 +3736,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields } LOG(INFO) << "Finished altering " << num_found_docs << " document(s)."; - + std::vector garbage_embedding_fields_vec; for(auto& del_field: del_fields) { search_schema.erase(del_field.name); auto new_end = std::remove_if(fields.begin(), fields.end(), [&del_field](const field& f) { @@ -3765,10 +3765,13 @@ Option Collection::batch_alter_data(const std::vector& alter_fields default_sorting_field = ""; } - process_remove_field_for_embedding_fields(del_field); + auto garbage_embedding_fields = process_remove_field_for_embedding_fields(del_field); + garbage_embedding_fields_vec.insert(garbage_embedding_fields_vec.end(), garbage_embedding_fields.begin(), + garbage_embedding_fields.end()); } index->refresh_schemas({}, del_fields); + index->refresh_schemas({}, garbage_embedding_fields_vec); auto persist_op = persist_collection_meta(); if(!persist_op.ok()) { @@ -4741,26 +4744,25 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha return populate_include_exclude_fields(include_fields, exclude_fields, include_fields_full, exclude_fields_full); } - -void Collection::process_remove_field_for_embedding_fields(const field& the_field) { - std::vector::iterator> empty_fields; - for(auto& embedding_field : embedding_fields) { - const auto& actual_field = std::find_if(fields.begin(), fields.end(), [&embedding_field] (field other_field) { - return other_field.name == embedding_field.name; - }); - actual_field->embed_from.erase(std::remove_if(actual_field->embed_from.begin(), actual_field->embed_from.end(), [&the_field](std::string field_name) { +// Removes the dropped field from embed_from of all embedding fields. +std::vector Collection::process_remove_field_for_embedding_fields(const field& the_field) { + std::vector garbage_fields; + for(auto& field : fields) { + if(field.embed_from.empty()) { + continue; + } + field.embed_from.erase(std::remove_if(field.embed_from.begin(), field.embed_from.end(), [&the_field](std::string field_name) { return the_field.name == field_name; })); - embedding_field = *actual_field; - // store to remove embedding field if it has no field names in 'embed_from' anymore. - if(embedding_field.embed_from.empty()) { - empty_fields.push_back(actual_field); + embedding_fields[field.name] = field; + + // mark this embedding field as "garbage" if it has no more embed_from fields + if(field.embed_from.empty()) { + embedding_fields.erase(field.name); + garbage_fields.push_back(field); } } - for(const auto& empty_field : empty_fields) { - search_schema.erase(empty_field->name); - embedding_fields.erase(empty_field->name); - fields.erase(empty_field); - } + // return garbage embedding fields + return garbage_fields; } \ No newline at end of file diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 9269c858..3acc794c 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1592,7 +1592,7 @@ TEST_F(CollectionAllFieldsTest, FieldNameMatchingRegexpShouldNotBeIndexedInNonAu ASSERT_EQ(1, results["hits"].size()); } -TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { +TEST_F(CollectionAllFieldsTest, EmbedFromFieldJSONInvalidField) { TextEmbedderManager::model_dir = "/tmp/models"; nlohmann::json field_json; field_json["name"] = "embedding"; @@ -1607,10 +1607,10 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `embed_from` can only be used with string or string array fields.", field_op.error()); + ASSERT_EQ("Property `embed_from` can only refer to string or string array fields.", field_op.error()); } -TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { +TEST_F(CollectionAllFieldsTest, EmbedFromFieldNoModelDir) { TextEmbedderManager::model_dir = std::string(); nlohmann::json field_json; field_json["name"] = "embedding"; @@ -1628,7 +1628,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { ASSERT_EQ("Text embedding is not enabled. Please set `model-dir` at startup.", field_op.error()); } -TEST_F(CollectionAllFieldsTest, CreateFromNotArray) { +TEST_F(CollectionAllFieldsTest, EmbedFromNotArray) { TextEmbedderManager::model_dir = "/tmp/models"; nlohmann::json field_json; field_json["name"] = "embedding"; @@ -1646,7 +1646,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromNotArray) { ASSERT_EQ("Property `embed_from` must be an array.", field_op.error()); } -TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) { +TEST_F(CollectionAllFieldsTest, ModelPathWithoutEmbedFrom) { TextEmbedderManager::model_dir = "/tmp/models"; nlohmann::json field_json; field_json["name"] = "embedding"; @@ -1664,7 +1664,7 @@ TEST_F(CollectionAllFieldsTest, ModelPathWithoutCreateFrom) { } -TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) { +TEST_F(CollectionAllFieldsTest, EmbedFromBasicValid) { TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; TextEmbedderManager::download_default_model(); @@ -1690,3 +1690,21 @@ TEST_F(CollectionAllFieldsTest, CreateFromBasicValid) { } +TEST_F(CollectionAllFieldsTest, WrongDataTypeForEmbedFrom) { + TextEmbedderManager::model_dir = "/tmp/models"; + nlohmann::json field_json; + field_json["name"] = "embedding"; + field_json["type"] = "float[]"; + field_json["embed_from"] = {"age"}; + + std::vector fields; + std::string fallback_field_type; + auto arr = nlohmann::json::array(); + arr.push_back(field_json); + field_json["name"] = "age"; + field_json["type"] = "int32"; + arr.push_back(field_json); + auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); + ASSERT_FALSE(field_op.ok()); + ASSERT_EQ("Property `embed_from` can only refer to string or string array fields.", field_op.error()); +} \ No newline at end of file diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index 82ccd65a..1a301446 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1497,21 +1497,29 @@ TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { ] })"_json; - LOG(INFO) << "Dropping field"; auto embedding_fields = coll->get_embedding_fields(); ASSERT_EQ(2, embedding_fields["embedding"].embed_from.size()); - LOG(INFO) << "Before alter"; - auto alter_op = coll->alter(schema_changes); ASSERT_TRUE(alter_op.ok()); - LOG(INFO) << "After alter"; - embedding_fields = coll->get_embedding_fields(); ASSERT_EQ(1, embedding_fields["embedding"].embed_from.size()); ASSERT_EQ("category", embedding_fields["embedding"].embed_from[0]); + + schema_changes = R"({ + "fields": [ + {"name": "category", "drop": true} + ] + })"_json; + + alter_op = coll->alter(schema_changes); + ASSERT_TRUE(alter_op.ok()); + + embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(0, embedding_fields.size()); + ASSERT_EQ(0, coll->_get_index()->_get_vector_index().size()); } TEST_F(CollectionSchemaChangeTest, EmbeddingFieldsMapTest) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 2a15e531..2bd77929 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4666,7 +4666,6 @@ TEST_F(CollectionTest, InvalidSemanticSearch) { object["name"] = "apple"; auto add_op = coll->add(object.dump()); ASSERT_TRUE(add_op.ok()); - LOG(INFO) << "add_op.get(): " << add_op.get().dump(); ASSERT_EQ("apple", add_op.get()["name"]); ASSERT_EQ(384, add_op.get()["embedding"].size()); @@ -4899,6 +4898,55 @@ TEST_F(CollectionTest, MissingFieldForEmbedding) { ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error()); } + +TEST_F(CollectionTest, WrongTypeForEmbedding) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed_from": ["category"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json doc; + doc["category"] = 1; + + auto add_op = validator_t::validate_embed_fields(doc, coll->get_embedding_fields(), coll->get_schema(), true); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `category` has malformed data.", add_op.error()); +} + +TEST_F(CollectionTest, WrongTypeOfElementForEmbeddingInStringArray) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "category", "type": "string[]"}, + {"name": "embedding", "type":"float[]", "embed_from": ["category"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json doc; + doc["category"].push_back(33); + + auto add_op = validator_t::validate_embed_fields(doc, coll->get_embedding_fields(), coll->get_schema(), true); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `category` has malformed data.", add_op.error()); +} + TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) { nlohmann::json schema = R"({ "name": "objects", @@ -4944,4 +4992,4 @@ TEST_F(CollectionTest, UpdateEmbeddingsForUpdatedDocument) { // check if the embedding field is updated ASSERT_NE(embedding_field, updated_embedding_field); -} \ No newline at end of file +} diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index e1dc5d10..c78e004e 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -775,5 +775,38 @@ TEST_F(CollectionVectorTest, DistanceThresholdTest) { ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get>()[1]); ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get>()[2]); +} +TEST_F(CollectionVectorTest, EmbeddingFieldVectorIndexTest) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + auto& vec_index = coll->_get_index()->_get_vector_index(); + ASSERT_EQ(1, vec_index.size()); + ASSERT_EQ(1, vec_index.count("embedding")); + + + nlohmann::json schema_change = R"({ + "fields": [ + {"name": "embedding", "drop": true} + ] + })"_json; + + auto schema_change_op = coll->alter(schema_change); + + ASSERT_TRUE(schema_change_op.ok()); + ASSERT_EQ(0, vec_index.size()); + ASSERT_EQ(0, vec_index.count("embedding")); } \ No newline at end of file From cdbe63747b0d4ed9681b72042665b51acf6288ab Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Mon, 10 Apr 2023 07:32:48 +0300 Subject: [PATCH 8/8] Update for process_remove_field_for_embedding_fields --- include/collection.h | 2 +- src/collection.cpp | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/include/collection.h b/include/collection.h index 8cc8b15a..caaf897f 100644 --- a/include/collection.h +++ b/include/collection.h @@ -162,7 +162,7 @@ private: void remove_document(const nlohmann::json & document, const uint32_t seq_id, bool remove_from_store); - std::vector process_remove_field_for_embedding_fields(const field& the_field); + void process_remove_field_for_embedding_fields(const field& the_field, std::vector& garbage_fields); void curate_results(string& actual_query, const string& filter_query, bool enable_overrides, bool already_segmented, const std::map>& pinned_hits, diff --git a/src/collection.cpp b/src/collection.cpp index 457e155d..a4e1a2d7 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -3765,9 +3765,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields default_sorting_field = ""; } - auto garbage_embedding_fields = process_remove_field_for_embedding_fields(del_field); - garbage_embedding_fields_vec.insert(garbage_embedding_fields_vec.end(), garbage_embedding_fields.begin(), - garbage_embedding_fields.end()); + process_remove_field_for_embedding_fields(del_field, garbage_embedding_fields_vec); } index->refresh_schemas({}, del_fields); @@ -4745,8 +4743,7 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha } // Removes the dropped field from embed_from of all embedding fields. -std::vector Collection::process_remove_field_for_embedding_fields(const field& the_field) { - std::vector garbage_fields; +void Collection::process_remove_field_for_embedding_fields(const field& the_field, std::vector& garbage_fields) { for(auto& field : fields) { if(field.embed_from.empty()) { continue; @@ -4763,6 +4760,4 @@ std::vector Collection::process_remove_field_for_embedding_fields(const f } } - // return garbage embedding fields - return garbage_fields; } \ No newline at end of file