diff --git a/src/collection.cpp b/src/collection.cpp index 82811659..04625c8c 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1043,7 +1043,7 @@ Option Collection::extract_field_name(const std::string& field_name, for(auto kv = prefix_it.first; kv != prefix_it.second; ++kv) { bool exact_key_match = (kv.key().size() == field_name.size()); bool exact_primitive_match = exact_key_match && !kv.value().is_object(); - bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().embed.count(fields::from) != 0; + bool text_embedding = kv.value().type == field_types::FLOAT_ARRAY && kv.value().num_dim > 0; if(extract_only_string_fields && !kv.value().is_string() && !text_embedding) { if(exact_primitive_match && !is_wildcard) { @@ -1073,7 +1073,7 @@ Option Collection::extract_field_name(const std::string& field_name, return Option(true); } -Option Collection::search(std::string raw_query, +Option Collection::search(std::string raw_query, const std::vector& raw_search_fields, const std::string & filter_query, const std::vector& facet_fields, const std::vector & sort_fields, const std::vector& num_typos, @@ -1201,6 +1201,7 @@ Option Collection::search(std::string raw_query, std::vector processed_search_fields; std::vector query_by_weights; size_t num_embed_fields = 0; + std::string query = raw_query; for(size_t i = 0; i < raw_search_fields.size(); i++) { const std::string& field_name = raw_search_fields[i]; @@ -1289,6 +1290,11 @@ Option Collection::search(std::string raw_query, } } + // Set query to * if it is semantic search + if(!vector_query.field_name.empty() && processed_search_fields.empty()) { + query = "*"; + } + if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) { std::string error = "Vector query could not find any embedded fields."; return Option(400, error); @@ -1444,7 +1450,7 @@ Option Collection::search(std::string raw_query, size_t max_hits = DEFAULT_TOPSTER_SIZE; // ensure that `max_hits` never exceeds number of documents in collection - if(search_fields.size() <= 1 || raw_query == "*") { + if(search_fields.size() <= 1 || query == "*") { max_hits = std::min(std::max(fetch_size, max_hits), get_num_documents()); } else { max_hits = std::min(std::max(fetch_size, max_hits), get_num_documents()); @@ -1477,7 +1483,6 @@ Option Collection::search(std::string raw_query, StringUtils::split(hidden_hits_str, hidden_hits, ","); std::vector filter_overrides; - std::string query = raw_query; bool filter_curated_hits = false; std::string curated_sort_by; curate_results(query, filter_query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 3e5b43a7..fbc78e22 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -2582,4 +2582,6 @@ TEST_F(CollectionSpecificMoreTest, HybridSearchTextMatchInfo) { ASSERT_EQ(0, results["hits"][0]["text_match_info"]["tokens_matched"].get()); ASSERT_EQ(0, results["hits"][1]["text_match_info"]["tokens_matched"].get()); -} \ No newline at end of file +} + + diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 3ffe288e..2f55cd34 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -1102,3 +1102,115 @@ TEST_F(CollectionVectorTest, SkipEmbeddingOpWhenValueExists) { ASSERT_FALSE(add_op.ok()); ASSERT_EQ("Field `embedding` contains invalid float values.", add_op.error()); } + +TEST_F(CollectionVectorTest, SemanticSearchReturnOnlyVectorDistance) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + + ASSERT_TRUE(add_op.ok()); + + auto results = coll1->search("moisturizer", {"embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // Return only vector distance + ASSERT_EQ(0, results["hits"][0].count("text_match_info")); + ASSERT_EQ(0, results["hits"][0].count("hybrid_search_info")); + ASSERT_EQ(1, results["hits"][0].count("vector_distance")); +} + +TEST_F(CollectionVectorTest, KeywordSearchReturnOnlyTextMatchInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + auto results = coll1->search("moisturizer", {"product_name"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + + ASSERT_EQ(1, results["hits"].size()); + + // Return only text match info + ASSERT_EQ(0, results["hits"][0].count("vector_distance")); + ASSERT_EQ(0, results["hits"][0].count("hybrid_search_info")); + ASSERT_EQ(1, results["hits"][0].count("text_match_info")); +} + +TEST_F(CollectionVectorTest, HybridSearchReturnAllInfo) { + auto schema_json = + R"({ + "name": "Products", + "fields": [ + {"name": "product_name", "type": "string", "infix": true}, + {"name": "category", "type": "string"}, + {"name": "embedding", "type":"float[]", "embed":{"from": ["product_name", "category"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + + auto collection_create_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(collection_create_op.ok()); + auto coll1 = collection_create_op.get(); + + auto add_op = coll1->add(R"({ + "product_name": "moisturizer", + "category": "beauty" + })"_json.dump()); + ASSERT_TRUE(add_op.ok()); + + + auto results = coll1->search("moisturizer", {"product_name", "embedding"}, + "", {}, {}, {2}, 10, + 1, FREQUENCY, {true}, + 0, spp::sparse_hash_set()).get(); + + ASSERT_EQ(1, results["hits"].size()); + + // Return all info + ASSERT_EQ(1, results["hits"][0].count("vector_distance")); + ASSERT_EQ(1, results["hits"][0].count("text_match_info")); + ASSERT_EQ(1, results["hits"][0].count("hybrid_search_info")); +}