From 1ad7bcdce3a249da0c17ae1f234a734bd2ca0c5b Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Wed, 5 Apr 2023 18:05:45 +0300 Subject: [PATCH] Review Changes --- include/collection.h | 3 + include/vector_query_ops.h | 4 +- src/collection.cpp | 140 ++++++++++++------------- src/index.cpp | 8 +- src/vector_query_ops.cpp | 8 +- test/collection_schema_change_test.cpp | 67 ++++++++++++ test/collection_test.cpp | 53 +++++----- test/collection_vector_search_test.cpp | 60 +++++++++++ 8 files changed, 236 insertions(+), 107 deletions(-) diff --git a/include/collection.h b/include/collection.h index 230ea65f..e3440006 100644 --- a/include/collection.h +++ b/include/collection.h @@ -209,6 +209,9 @@ private: std::vector& sort_fields_std, bool is_wildcard_query, bool is_group_by_query = false) const; + + Option validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const; + Option persist_collection_meta(); Option batch_alter_data(const std::vector& alter_fields, diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index 32ee7448..5ffd5c0b 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -10,7 +10,7 @@ struct vector_query_t { std::string field_name; size_t k = 0; size_t flat_search_cutoff = 0; - float similarity_cutoff = 0.0; + float distance_threshold = 2.01; std::vector values; uint32_t seq_id = 0; @@ -20,7 +20,7 @@ struct vector_query_t { // used for testing only field_name.clear(); k = 0; - similarity_cutoff = 0.0; + distance_threshold = 2.01; values.clear(); seq_id = 0; query_doc_given = false; diff --git a/src/collection.cpp b/src/collection.cpp index ae3b3d80..231a3f9a 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -51,6 +51,12 @@ Collection::Collection(const std::string& name, const uint32_t collection_id, co symbols_to_index(to_char_array(symbols_to_index)), token_separators(to_char_array(token_separators)), index(init_index()) { + for (auto const& field: fields) { + if (!field.create_from.empty()) { + embedding_fields.emplace(field.name, field); + } + } + this->num_documents = 0; } @@ -253,7 +259,7 @@ nlohmann::json Collection::get_summary_json() const { field_json[fields::infix] = coll_field.infix; field_json[fields::locale] = coll_field.locale; - if(coll_field.create_from.size() > 0) { + if(!coll_field.create_from.empty()) { field_json[fields::create_from] = coll_field.create_from; } @@ -1008,7 +1014,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().create_from.size() > 0; + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && !kv.value().create_from.empty(); if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -3765,7 +3771,7 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_fields.erase(del_field.name); } - if(del_field.create_from.size() > 0) { + if(!del_field.create_from.empty()) { embedding_fields.erase(del_field.name); } @@ -4063,7 +4069,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return Option(400, "Field `" + field_name + "` is not part of collection schema."); } - if(found_field && field_it.value().create_from.size() > 0) { + if(found_field && !field_it.value().create_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4072,7 +4078,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(field_it.key()); updated_nested_fields.erase(field_it.key()); - if(field_it.value().create_from.size() > 0) { + if(!field_it.value().create_from.empty()) { updated_embedding_fields.erase(field_it.key()); } @@ -4086,7 +4092,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.erase(prefix_kv.key()); updated_nested_fields.erase(prefix_kv.key()); - if(prefix_kv.value().create_from.size() > 0) { + if(!prefix_kv.value().create_from.empty()) { updated_embedding_fields.erase(prefix_kv.key()); } } @@ -4139,7 +4145,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, addition_fields.push_back(f); } - if(f.create_from.size() > 0) { + if(!f.create_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4154,7 +4160,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value()); updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); - if(prefix_kv.value().create_from.size() > 0) { + if(!prefix_kv.value().create_from.empty()) { return Option(400, "Embedding fields can only be added at the time of collection creation."); } @@ -4478,7 +4484,7 @@ Index* Collection::init_index() { nested_fields.emplace(field.name, field); } - if(field.create_from.size() > 0) { + if(!field.create_from.empty()) { embedding_fields.emplace(field.name, field); } @@ -4754,45 +4760,22 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha Option Collection::embed_fields(nlohmann::json& document) { + auto validate_res = validate_embed_fields(document, true); + if(!validate_res.ok()) { + return validate_res; + } for(const auto& field : embedding_fields) { - if(TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } std::string text_to_embed; for(const auto& field_name : field.create_from) { auto field_it = search_schema.find(field_name); - if(field_it != search_schema.end()) { - if(field_it.value().type == field_types::STRING) { - if(document.find(field_name) != document.end()) { - if(document[field_name].is_string()) { - text_to_embed += document[field_name].get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } else if(field_it.value().type == field_types::STRING_ARRAY) { - if(document.find(field_name) != document.end()) { - if(document[field_name].is_array()) { - for(const auto& val : document[field_name]) { - if(val.is_string()) { - text_to_embed += val.get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } + if(field_it.value().type == field_types::STRING) { + text_to_embed += document[field_name].get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : document[field_name]) { + text_to_embed += val.get() + " "; } - else { - return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); - } - } else { - return Option(400, "Field `" + field_name + "` is not a valid field."); } } - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); std::vector embedding = embedder->Embed(text_to_embed); @@ -4802,41 +4785,58 @@ Option Collection::embed_fields(nlohmann::json& document) { return Option(true); } +Option Collection::validate_embed_fields(const nlohmann::json& document, const bool& error_if_field_not_found) const { + if(!embedding_fields.empty() && TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); + } + for(const auto& field : embedding_fields) { + for(const auto& field_name : field.create_from) { + auto schema_field_it = search_schema.find(field_name); + auto doc_field_it = document.find(field_name); + if(schema_field_it == search_schema.end()) { + return Option(400, "Field `" + field.name + "` has invalid fields to create embeddings from."); + } + if(doc_field_it == document.end()) { + if(error_if_field_not_found) { + return Option(400, "Field `" + field_name + "` is needed to create embedding."); + } else { + continue; + } + } + if((schema_field_it.value().type == field_types::STRING && !doc_field_it.value().is_string()) || + (schema_field_it.value().type == field_types::STRING_ARRAY && !doc_field_it.value().is_array())) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + if(doc_field_it.value().is_array()) { + for(const auto& val : doc_field_it.value()) { + if(!val.is_string()) { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } + } + } + + return Option(true); +} + Option Collection::embed_fields_update(const nlohmann::json& old_doc, nlohmann::json& new_doc) { + auto validate_res = validate_embed_fields(new_doc, false); + if(!validate_res.ok()) { + return validate_res; + } nlohmann::json new_doc_copy = new_doc; for(const auto& field : embedding_fields) { - if(TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } std::string text_to_embed; for(const auto& field_name : field.create_from) { auto field_it = search_schema.find(field_name); - if(field_it != search_schema.end()) { - nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; - if(field_it.value().type == field_types::STRING) { - if(value.is_string()) { - text_to_embed += value.get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } else if(field_it.value().type == field_types::STRING_ARRAY) { - if(value.is_array()) { - for(const auto& val : value) { - if(val.is_string()) { - text_to_embed += val.get() + " "; - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } - } - } else { - return Option(400, "Field `" + field_name + "` has malformed data."); - } + nlohmann::json value = (new_doc.find(field_name) != new_doc.end()) ? new_doc[field_name] : old_doc[field_name]; + if(field_it.value().type == field_types::STRING) { + text_to_embed += value.get() + " "; + } else if(field_it.value().type == field_types::STRING_ARRAY) { + for(const auto& val : value) { + text_to_embed += val.get() + " "; } - else { - return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); - } - } else { - return Option(400, "Field `" + field_name + "` is not a valid field."); } } @@ -4860,7 +4860,7 @@ void Collection::process_remove_field_for_embedding_fields(const field& the_fiel })); embedding_field = *actual_field; // store to remove embedding field if it has no field names in 'create_from' anymore. - if(embedding_field.create_from.size() == 0) { + if(embedding_field.create_from.empty()) { empty_fields.push_back(actual_field); } } diff --git a/src/index.cpp b/src/index.cpp index 4e088576..7a110436 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2882,7 +2882,7 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - if(vector_query.similarity_cutoff > 0 && vec_dist_score > vector_query.similarity_cutoff) { + if(vec_dist_score > vector_query.distance_threshold) { continue; } @@ -3105,10 +3105,8 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - if(vector_query.similarity_cutoff > 0) { - if(vec_dist_score > vector_query.similarity_cutoff) { - continue; - } + if(vec_dist_score > vector_query.distance_threshold) { + continue; } vec_results.emplace_back(seq_id, vec_dist_score); } diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index 729157b0..54f65d5c 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -146,13 +146,13 @@ Option VectorQueryOps::parse_vector_query_str(std::string vector_query_str vector_query.flat_search_cutoff = std::stoi(param_kv[1]); } - if(param_kv[0] == "similarity_cutoff") { - if(!StringUtils::is_float(param_kv[1])) { + if(param_kv[0] == "distance_threshold") { + if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 2.0) { return Option(400, "Malformed vector query string: " - "`similarity_cutoff` parameter must be a float."); + "`distance_threshold` parameter must be a float between 0.0-2.0."); } - vector_query.similarity_cutoff = std::stof(param_kv[1]); + vector_query.distance_threshold = std::stof(param_kv[1]); } } diff --git a/test/collection_schema_change_test.cpp b/test/collection_schema_change_test.cpp index c32ba142..dd1301d9 100644 --- a/test/collection_schema_change_test.cpp +++ b/test/collection_schema_change_test.cpp @@ -1446,3 +1446,70 @@ TEST_F(CollectionSchemaChangeTest, GeoFieldSchemaAddition) { ASSERT_TRUE(res_op.ok()); ASSERT_EQ(2, res_op.get()["found"].get()); } + +TEST_F(CollectionSchemaChangeTest, UpdateSchemaWithNewEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"} + ] + })"_json; + + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json update_schema = R"({ + "fields": [ + {"name": "embedding", "type":"float[]", "create_from": ["names"]} + ] + })"_json; + + auto res = coll->alter(update_schema); + + ASSERT_FALSE(res.ok()); + ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); +} + +TEST_F(CollectionSchemaChangeTest, DropFieldUsedForEmbedding) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "category", "type":"string"}, + {"name": "embedding", "type":"float[]", "create_from": ["names","category"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + LOG(INFO) << "Created collection"; + + auto schema_changes = R"({ + "fields": [ + {"name": "names", "drop": true} + ] + })"_json; + + LOG(INFO) << "Dropping field"; + + auto embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(2, embedding_fields["embedding"].create_from.size()); + + LOG(INFO) << "Before alter"; + + auto alter_op = coll->alter(schema_changes); + ASSERT_TRUE(alter_op.ok()); + + LOG(INFO) << "After alter"; + + embedding_fields = coll->get_embedding_fields(); + ASSERT_EQ(1, embedding_fields["embedding"].create_from.size()); + ASSERT_EQ("category", embedding_fields["embedding"].create_from[0]); +} \ No newline at end of file diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 4e0b43a3..b19201f7 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4620,7 +4620,7 @@ TEST_F(CollectionTest, SemanticSearchTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4655,7 +4655,7 @@ TEST_F(CollectionTest, InvalidSemanticSearch) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4686,7 +4686,7 @@ TEST_F(CollectionTest, HybridSearch) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4719,7 +4719,7 @@ TEST_F(CollectionTest, EmbedFielsTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4747,7 +4747,7 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4821,7 +4821,7 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4853,7 +4853,7 @@ TEST_F(CollectionTest, EmbeddingFieldsMapTest) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4891,7 +4891,7 @@ TEST_F(CollectionTest, EmbedStringArrayField) { ] })"_json; - TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); TextEmbedderManager::download_default_model(); auto op = collectionManager.create_collection(schema); @@ -4907,28 +4907,29 @@ TEST_F(CollectionTest, EmbedStringArrayField) { ASSERT_TRUE(add_op.ok()); } -TEST_F(CollectionTest, UpdateSchemaWithNewEmbeddingField) { +TEST_F(CollectionTest, MissingFieldForEmbedding) { nlohmann::json schema = R"({ - "name": "objects", - "fields": [ - {"name": "names", "type": "string[]"} - ] - })"_json; - + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "create_from": ["names", "category"]} + ] + })"_json; + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + auto op = collectionManager.create_collection(schema); ASSERT_TRUE(op.ok()); Collection* coll = op.get(); - nlohmann::json update_schema = R"({ - "fields": [ - {"name": "embedding", "type":"float[]", "create_from": ["names"]} - ] - })"_json; - - auto res = coll->alter(update_schema); - - ASSERT_FALSE(res.ok()); - ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); -} + nlohmann::json doc; + doc["names"].push_back("butter"); + doc["names"].push_back("butterfly"); + doc["names"].push_back("butterball"); + auto add_op = coll->add(doc.dump()); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `category` is needed to create embedding.", add_op.error()); +} \ No newline at end of file diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index ec233436..7f8b212b 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -716,4 +716,64 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { ASSERT_EQ(1, results_op.get()["found"].get()); ASSERT_EQ(1, results_op.get()["hits"].size()); +} + +TEST_F(CollectionVectorTest, DistanceThresholdTest) { + nlohmann::json schema = R"({ + "name": "test", + "fields": [ + {"name": "vec", "type": "float[]", "num_dim": 3} + ] + })"_json; + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + doc["vec"] = {0.1, 0.2, 0.3}; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + // write a vector which is 0.5 away from the first vector + doc["vec"] = {0.6, 0.7, 0.8}; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + + auto results_op = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:([0.3,0.4,0.5])"); + + ASSERT_EQ(true, results_op.ok()); + ASSERT_EQ(2, results_op.get()["found"].get()); + ASSERT_EQ(2, results_op.get()["hits"].size()); + + ASSERT_FLOAT_EQ(0.6, results_op.get()["hits"][0]["document"]["vec"].get>()[0]); + ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get>()[1]); + ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get>()[2]); + + ASSERT_FLOAT_EQ(0.1, results_op.get()["hits"][1]["document"]["vec"].get>()[0]); + ASSERT_FLOAT_EQ(0.2, results_op.get()["hits"][1]["document"]["vec"].get>()[1]); + ASSERT_FLOAT_EQ(0.3, results_op.get()["hits"][1]["document"]["vec"].get>()[2]); + + results_op = coll1->search("*", {}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:([0.3,0.4,0.5], distance_threshold:0.01)"); + + ASSERT_EQ(true, results_op.ok()); + ASSERT_EQ(1, results_op.get()["found"].get()); + ASSERT_EQ(1, results_op.get()["hits"].size()); + + ASSERT_FLOAT_EQ(0.6, results_op.get()["hits"][0]["document"]["vec"].get>()[0]); + ASSERT_FLOAT_EQ(0.7, results_op.get()["hits"][0]["document"]["vec"].get>()[1]); + ASSERT_FLOAT_EQ(0.8, results_op.get()["hits"][0]["document"]["vec"].get>()[2]); + + } \ No newline at end of file