From e85ae5d7d2706caee47b2efa7b185d5289a7d20f Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sun, 2 Apr 2023 01:32:49 +0300 Subject: [PATCH] Hybrid & sematic search improvements --- include/collection.h | 4 + include/field.h | 12 +-- include/text_embedder_manager.h | 6 +- include/vector_query_ops.h | 2 + src/collection.cpp | 104 +++++++++++++++++------ src/index.cpp | 11 ++- src/typesense_server_utils.cpp | 2 +- src/vector_query_ops.cpp | 9 ++ test/collection_all_fields_test.cpp | 2 +- test/collection_test.cpp | 110 +++++++++++++++++++++++-- test/collection_vector_search_test.cpp | 59 +++++++++++-- 11 files changed, 269 insertions(+), 52 deletions(-) diff --git a/include/collection.h b/include/collection.h index 3b50c435..c81c4123 100644 --- a/include/collection.h +++ b/include/collection.h @@ -120,6 +120,8 @@ private: tsl::htrie_map nested_fields; + tsl::htrie_map embedding_fields; + bool enable_nested_fields; std::vector symbols_to_index; @@ -342,6 +344,8 @@ public: tsl::htrie_map get_nested_fields(); + tsl::htrie_map get_embedding_fields(); + std::string get_default_sorting_field(); Option to_doc(const std::string& json_str, nlohmann::json& document, diff --git a/include/field.h b/include/field.h index 792aaec6..5b44b52a 100644 --- a/include/field.h +++ b/include/field.h @@ -436,7 +436,7 @@ struct field { for(auto& create_from_field : field_json[fields::create_from]) { if(!create_from_field.is_string()) { - return Option(400, "Property `" + fields::create_from + "` must be an array of strings."); + return Option(400, "Property `" + fields::create_from + "` must contain only field names as strings."); } } @@ -449,8 +449,8 @@ struct field { bool flag = false; for(const auto& field : fields_json) { if(field[fields::name] == create_from_field) { - if(field[fields::type] != field_types::STRING) { - return Option(400, "Property `" + fields::create_from + "` can only be used with array of string fields."); + if(field[fields::type] != field_types::STRING && field[fields::type] != field_types::STRING_ARRAY) { + return Option(400, "Property `" + fields::create_from + "` can only have string or string array fields."); } flag = true; break; @@ -459,8 +459,8 @@ struct field { if(!flag) { for(const auto& field : the_fields) { if(field.name == create_from_field) { - if(field.type != field_types::STRING) { - return Option(400, "Property `" + fields::create_from + "` can only be used with array of string fields."); + if(field.type != field_types::STRING && field.type != field_types::STRING_ARRAY) { + return Option(400, "Property `" + fields::create_from + "` can only have used with string or string array fields."); } flag = true; break; @@ -468,7 +468,7 @@ struct field { } } if(!flag) { - return Option(400, "Property `" + fields::create_from + "` can only be used with array of string fields."); + return Option(400, "Property `" + fields::create_from + "` can only be used with string or string array fields."); } } } diff --git a/include/text_embedder_manager.h b/include/text_embedder_manager.h index 3b626a97..192044c8 100644 --- a/include/text_embedder_manager.h +++ b/include/text_embedder_manager.h @@ -1,6 +1,6 @@ #pragma once - +#include #include #include #include @@ -43,6 +43,10 @@ public: } static void set_model_dir(const std::string& dir) { + // create the directory if it doesn't exist + if(!std::filesystem::exists(dir)) { + std::filesystem::create_directories(dir); + } model_dir = dir; } diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index d3424e30..32ee7448 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -10,6 +10,7 @@ struct vector_query_t { std::string field_name; size_t k = 0; size_t flat_search_cutoff = 0; + float similarity_cutoff = 0.0; std::vector values; uint32_t seq_id = 0; @@ -19,6 +20,7 @@ struct vector_query_t { // used for testing only field_name.clear(); k = 0; + similarity_cutoff = 0.0; values.clear(); seq_id = 0; query_doc_given = false; diff --git a/src/collection.cpp b/src/collection.cpp index b5310d46..bea45361 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1120,10 +1120,6 @@ Option Collection::search(std::string raw_query, vector_query_t vector_query; if(!vector_query_str.empty()) { - if(raw_query != "*") { - return Option(400, "Vector query is supported only on wildcard (q=*) searches."); - } - auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this); if(!parse_vector_op.ok()) { return Option(400, parse_vector_op.error()); @@ -3459,6 +3455,11 @@ tsl::htrie_map Collection::get_nested_fields() { return nested_fields; }; +tsl::htrie_map Collection::get_embedding_fields() { + std::shared_lock lock(mutex); + return embedding_fields; +}; + std::string Collection::get_meta_key(const std::string & collection_name) { return std::string(COLLECTION_META_PREFIX) + "_" + collection_name; } @@ -3747,6 +3748,10 @@ Option Collection::batch_alter_data(const std::vector& alter_fields nested_fields.erase(del_field.name); } + if(del_field.create_from.size() > 0) { + embedding_fields.erase(del_field.name); + } + if(del_field.name == ".*") { fallback_field_type = ""; } @@ -3982,6 +3987,7 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, std::vector diff_fields; tsl::htrie_map updated_search_schema = search_schema; tsl::htrie_map updated_nested_fields = nested_fields; + tsl::htrie_map updated_embedding_fields = embedding_fields; size_t num_auto_detect_fields = 0; // since fields can be deleted and added in the same change set, @@ -4038,10 +4044,18 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, return Option(400, "Field `" + field_name + "` is not part of collection schema."); } + if(found_field && field_it.value().create_from.size() > 0) { + updated_embedding_fields.erase(field_it.key()); + } + if(found_field) { del_fields.push_back(field_it.value()); updated_search_schema.erase(field_it.key()); updated_nested_fields.erase(field_it.key()); + + if(field_it.value().create_from.size() > 0) { + updated_embedding_fields.erase(field_it.key()); + } // should also remove children if the field being dropped is an object if(field_it.value().nested && enable_nested_fields) { @@ -4052,6 +4066,10 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, del_fields.push_back(prefix_kv.value()); updated_search_schema.erase(prefix_kv.key()); updated_nested_fields.erase(prefix_kv.key()); + + if(prefix_kv.value().create_from.size() > 0) { + updated_embedding_fields.erase(prefix_kv.key()); + } } } } @@ -4102,6 +4120,10 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, addition_fields.push_back(f); } + if(f.create_from.size() > 0) { + return Option(400, "Embedding fields can only be added at the time of collection creation."); + } + if(f.nested && enable_nested_fields) { updated_nested_fields.emplace(f.name, f); @@ -4113,6 +4135,10 @@ Option Collection::validate_alter_payload(nlohmann::json& schema_changes, updated_search_schema.emplace(prefix_kv.key(), prefix_kv.value()); updated_nested_fields.emplace(prefix_kv.key(), prefix_kv.value()); + if(prefix_kv.value().create_from.size() > 0) { + return Option(400, "Embedding fields can only be added at the time of collection creation."); + } + if(is_reindex) { reindex_fields.push_back(prefix_kv.value()); } else { @@ -4432,6 +4458,10 @@ Index* Collection::init_index() { nested_fields.emplace(field.name, field); } + if(field.create_from.size() > 0) { + embedding_fields.emplace(field.name, field); + } + if(!field.reference.empty()) { auto dot_index = field.reference.find('.'); auto collection_name = field.reference.substr(0, dot_index); @@ -4704,30 +4734,50 @@ Option Collection::populate_include_exclude_fields_lk(const spp::sparse_ha Option Collection::embed_fields(nlohmann::json& document) { - for(const auto& field : fields) { - if(field.create_from.size() > 0) { - if(TextEmbedderManager::model_dir.empty()) { - return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); - } - std::string text_to_embed; - for(const auto& field_name : field.create_from) { - auto field_it = document.find(field_name); - if(field_it != document.end()) { - if(field_it->is_string()) { - text_to_embed += field_it->get() + " "; - } else { - return Option(400, "Field `" + field_name + "` is not a string."); - } - } else { - return Option(400, "Field `" + field_name + "` not found in document."); - } - } - - TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); - auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); - std::vector embedding = embedder->Embed(text_to_embed); - document[field.name] = embedding; + for(const auto& field : embedding_fields) { + if(TextEmbedderManager::model_dir.empty()) { + return Option(400, "Text embedding is not enabled. Please set `model-dir` at startup."); } + std::string text_to_embed; + for(const auto& field_name : field.create_from) { + auto field_it = search_schema.find(field_name); + if(field_it != search_schema.end()) { + if(field_it.value().type == field_types::STRING) { + if(document.find(field_name) != document.end()) { + if(document[field_name].is_string()) { + text_to_embed += document[field_name].get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } else if(field_it.value().type == field_types::STRING_ARRAY) { + if(document.find(field_name) != document.end()) { + if(document[field_name].is_array()) { + for(const auto& val : document[field_name]) { + if(val.is_string()) { + text_to_embed += val.get() + " "; + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } else { + return Option(400, "Field `" + field_name + "` has malformed data."); + } + } + } + else { + return Option(400, "Field `" + field_name + "` is not a string nor string array. Can not create vector from it."); + } + } else { + return Option(400, "Field `" + field_name + "` is not a valid field."); + } + } + + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder = embedder_manager.get_text_embedder(field.model_name.size() > 0 ? field.model_name : TextEmbedderManager::DEFAULT_MODEL_NAME); + std::vector embedding = embedder->Embed(text_to_embed); + document[field.name] = embedding; } + return Option(true); } diff --git a/src/index.cpp b/src/index.cpp index 5ea1f9c3..4859a451 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2867,6 +2867,10 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; + + if(vector_query.similarity_cutoff > 0 && vec_dist_score > vector_query.similarity_cutoff) { + continue; + } int64_t scores[3] = {0}; scores[0] = -float_to_int64_t(vec_dist_score); @@ -3087,9 +3091,14 @@ Option Index::search(std::vector& field_query_tokens, cons auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) : dist_label.first; - + if(vector_query.similarity_cutoff > 0) { + if(vec_dist_score > vector_query.similarity_cutoff) { + continue; + } + } vec_results.emplace_back(seq_id, vec_dist_score); } + std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) { return a.second < b.second; }); diff --git a/src/typesense_server_utils.cpp b/src/typesense_server_utils.cpp index cd9b4952..ed381e8b 100644 --- a/src/typesense_server_utils.cpp +++ b/src/typesense_server_utils.cpp @@ -456,7 +456,7 @@ int run_server(const Config & config, const std::string & version, void (*master if(config.get_model_dir().size() > 0) { LOG(INFO) << "Loading text embedding models from " << config.get_model_dir(); - TextEmbedderManager::model_dir = config.get_model_dir(); + TextEmbedderManager::set_model_dir(config.get_model_dir()); TextEmbedderManager::download_default_model(); } // first we start the peering service diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index b4cd3ffa..729157b0 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -145,6 +145,15 @@ Option VectorQueryOps::parse_vector_query_str(std::string vector_query_str vector_query.flat_search_cutoff = std::stoi(param_kv[1]); } + + if(param_kv[0] == "similarity_cutoff") { + if(!StringUtils::is_float(param_kv[1])) { + return Option(400, "Malformed vector query string: " + "`similarity_cutoff` parameter must be a float."); + } + + vector_query.similarity_cutoff = std::stof(param_kv[1]); + } } if(!vector_query.query_doc_given && vector_query.values.empty()) { diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 9234d6bd..1a8af133 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -1607,7 +1607,7 @@ TEST_F(CollectionAllFieldsTest, CreateFromFieldJSONInvalidField) { auto field_op = field::json_fields_to_fields(false, arr, fallback_field_type, fields); ASSERT_FALSE(field_op.ok()); - ASSERT_EQ("Property `create_from` can only be used with array of string fields.", field_op.error()); + ASSERT_EQ("Property `create_from` can only be used with string or string array fields.", field_op.error()); } TEST_F(CollectionAllFieldsTest, CreateFromFieldNoModelDir) { diff --git a/test/collection_test.cpp b/test/collection_test.cpp index f9e842ab..7bc1a768 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -4815,7 +4815,39 @@ TEST_F(CollectionTest, HybridSearchRankFusionTest) { } TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { - nlohmann::json schema = R"({ + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "embedding", "type":"float[]", "create_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + spp::sparse_hash_set dummy_include_exclude; + auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); + + ASSERT_FALSE(search_res_op.ok()); + ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error()); +} + +TEST_F(CollectionTest, CreateModelDirIfNotExists) { + system("mkdir -p /tmp/typesense_test/models"); + system("rm -rf /tmp/typesense_test/models"); + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + // check if model dir is created + ASSERT_TRUE(std::filesystem::exists("/tmp/typesense_test/models")); +} + +TEST_F(CollectionTest, EmbeddingFieldsMapTest) { + nlohmann::json schema = R"({ "name": "objects", "fields": [ {"name": "name", "type": "string"}, @@ -4830,9 +4862,75 @@ TEST_F(CollectionTest, WildcardSearchWithEmbeddingField) { ASSERT_TRUE(op.ok()); Collection* coll = op.get(); - spp::sparse_hash_set dummy_include_exclude; - auto search_res_op = coll->search("*", {"name","embedding"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, dummy_include_exclude, dummy_include_exclude, 10, "", 30, 4, ""); + auto embedding_fields_map = coll->get_embedding_fields(); + ASSERT_EQ(1, embedding_fields_map.size()); + auto embedding_field_it = embedding_fields_map.find("embedding"); + ASSERT_TRUE(embedding_field_it != embedding_fields_map.end()); + ASSERT_EQ("embedding", embedding_field_it.value().name); + ASSERT_EQ(1, embedding_field_it.value().create_from.size()); + ASSERT_EQ("name", embedding_field_it.value().create_from[0]); + + // drop the embedding field + nlohmann::json schema_without_embedding = R"({ + "fields": [ + {"name": "embedding", "drop": true} + ] + })"_json; + auto update_op = coll->alter(schema_without_embedding); + + ASSERT_TRUE(update_op.ok()); + + embedding_fields_map = coll->get_embedding_fields(); + ASSERT_EQ(0, embedding_fields_map.size()); +} + +TEST_F(CollectionTest, EmbedStringArrayField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"}, + {"name": "embedding", "type":"float[]", "create_from": ["names"]} + ] + })"_json; + + TextEmbedderManager::model_dir = "/tmp/typesense_test/models"; + TextEmbedderManager::download_default_model(); + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json doc; + doc["names"].push_back("butter"); + doc["names"].push_back("butterfly"); + doc["names"].push_back("butterball"); + + auto add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); +} + +TEST_F(CollectionTest, UpdateSchemaWithNewEmbeddingField) { + nlohmann::json schema = R"({ + "name": "objects", + "fields": [ + {"name": "names", "type": "string[]"} + ] + })"_json; + + + auto op = collectionManager.create_collection(schema); + ASSERT_TRUE(op.ok()); + Collection* coll = op.get(); + + nlohmann::json update_schema = R"({ + "fields": [ + {"name": "embedding", "type":"float[]", "create_from": ["names"]} + ] + })"_json; + + auto res = coll->alter(update_schema); + + ASSERT_FALSE(res.ok()); + ASSERT_EQ("Embedding fields can only be added at the time of collection creation.", res.error()); +} - ASSERT_FALSE(search_res_op.ok()); - ASSERT_EQ("Wildcard query is not supported for embedding fields.", search_res_op.error()); -} \ No newline at end of file diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 36381780..ec233436 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -184,17 +184,18 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) { ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Document id referenced in vector query is not found.", res_op.error()); + // DEPRECATED: vector query is also supported on non-wildcard queries with hybrid search // only supported with wildcard queries - res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, - spp::sparse_hash_set(), - spp::sparse_hash_set(), 10, "", 30, 5, - "", 10, {}, {}, {}, 0, - "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, - 4, {off}, 32767, 32767, 2, - false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])"); + // res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + // spp::sparse_hash_set(), + // spp::sparse_hash_set(), 10, "", 30, 5, + // "", 10, {}, {}, {}, 0, + // "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + // 4, {off}, 32767, 32767, 2, + // false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])"); - ASSERT_FALSE(res_op.ok()); - ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error()); + // ASSERT_FALSE(res_op.ok()); + // ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error()); // support num_dim on only float array fields schema = R"({ @@ -676,3 +677,43 @@ TEST_F(CollectionVectorTest, VectorWithNullValue) { ASSERT_EQ("Field `vec` must be an array.", nlohmann::json::parse(json_lines[1])["error"].get()); } + +TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "vec", "type": "float[]", "create_from": ["name"]} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + TextEmbedderManager::download_default_model(); + + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + + doc["name"] = "john doe"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + std::string dummy_vec_string = "[0.9"; + for (int i = 0; i < 382; i++) { + dummy_vec_string += ", 0.9"; + } + dummy_vec_string += ", 0.9]"; + + auto results_op = coll1->search("john", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:(" + dummy_vec_string +")"); + ASSERT_EQ(true, results_op.ok()); + + + ASSERT_EQ(1, results_op.get()["found"].get()); + ASSERT_EQ(1, results_op.get()["hits"].size()); +} \ No newline at end of file