diff --git a/include/topster.h b/include/topster.h index d5889b5b..b0b8f125 100644 --- a/include/topster.h +++ b/include/topster.h @@ -16,7 +16,7 @@ struct KV { int64_t scores[3]{}; // match score + 2 custom attributes // only to be used in hybrid search - float vector_distance = 0.0f; + float vector_distance = 2.0f; int64_t text_match_score = 0; reference_filter_result_t* reference_filter_result = nullptr; diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index 5ffd5c0b..29bb36f6 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -29,6 +29,7 @@ struct vector_query_t { class VectorQueryOps { public: - static Option parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query, + static Option parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query, + const bool is_wildcard_query, const Collection* coll); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 7d6f8506..7778bf0c 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1168,7 +1168,10 @@ Option Collection::search(std::string raw_query, vector_query_t vector_query; if(!vector_query_str.empty()) { - auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this); + bool is_wildcard_query = (raw_query == "*" || raw_query.empty()); + + auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, + is_wildcard_query, this); if(!parse_vector_op.ok()) { return Option(400, parse_vector_op.error()); } @@ -1178,18 +1181,17 @@ Option Collection::search(std::string raw_query, return Option(400, "Field `" + vector_query.field_name + "` does not have a vector query index."); } - if(vector_field_it.value().num_dim != vector_query.values.size()) { + if(is_wildcard_query && vector_field_it.value().num_dim != vector_query.values.size()) { return Option(400, "Query field `" + vector_query.field_name + "` must have " + std::to_string(vector_field_it.value().num_dim) + " dimensions."); } } - - // validate search fields std::vector processed_search_fields; std::vector query_by_weights; - bool has_embedding_query = false; + size_t num_embed_fields = 0; + for(size_t i = 0; i < raw_search_fields.size(); i++) { const std::string& field_name = raw_search_fields[i]; if(field_name == "id") { @@ -1208,7 +1210,10 @@ Option Collection::search(std::string raw_query, auto search_field = search_schema.at(expanded_search_field); if(search_field.num_dim > 0) { - if(!vector_query.field_name.empty()) { + num_embed_fields++; + + if(num_embed_fields > 1 || + (!vector_query.field_name.empty() && search_field.name != vector_query.field_name)) { std::string error = "Only one embedding field is allowed in the query."; return Option(400, error); } @@ -1253,10 +1258,13 @@ Option Collection::search(std::string raw_query, } } std::vector embedding = embedding_op.embedding; + // distance could have been set for an embed field, so we take a backup and restore + auto dist = vector_query.distance_threshold; vector_query._reset(); vector_query.values = embedding; vector_query.field_name = field_name; vector_query.k = vector_query_hits; + vector_query.distance_threshold = dist; continue; } @@ -1267,6 +1275,11 @@ Option Collection::search(std::string raw_query, } } + if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) { + std::string error = "Vector query could not find any embedded fields."; + return Option(400, error); + } + std::string real_raw_query = raw_query; if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) { raw_query = "*"; @@ -1962,7 +1975,7 @@ Option Collection::search(std::string raw_query, wrapper_doc["geo_distance_meters"] = geo_distances; } - if(!vector_query.field_name.empty() && query == "*") { + if(!vector_query.field_name.empty()) { wrapper_doc["vector_distance"] = field_order_kv->vector_distance; } diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index 54f65d5c..66ec3a13 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -2,8 +2,10 @@ #include "string_utils.h" #include "collection.h" -Option VectorQueryOps::parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query, - const Collection* coll) { +Option VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str, + vector_query_t& vector_query, + const bool is_wildcard_query, + const Collection* coll) { // FORMAT: // field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10) size_t i = 0; @@ -156,7 +158,7 @@ Option VectorQueryOps::parse_vector_query_str(std::string vector_query_str } } - if(!vector_query.query_doc_given && vector_query.values.empty()) { + if(is_wildcard_query && !vector_query.query_doc_given && vector_query.values.empty()) { return Option(400, "When a vector query value is empty, an `id` parameter must be present."); } diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 797b610b..e12b58ad 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -743,8 +743,64 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get()); ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get()); ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get()); -} + // hybrid search with empty vector (to pass distance threshold param) + std::string vec_query = "embedding:([], distance_threshold: 0.20)"; + + search_res_op = coll->search("butter", {"embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, vec_query); + ASSERT_TRUE(search_res_op.ok()); + search_res = search_res_op.get(); + + ASSERT_EQ(2, search_res["found"].get()); + ASSERT_EQ(2, search_res["hits"].size()); + + ASSERT_FLOAT_EQ(0.0462081432, search_res["hits"][0]["vector_distance"].get()); + ASSERT_FLOAT_EQ(0.1213316321, search_res["hits"][1]["vector_distance"].get()); + + // when no embedding field is passed, it should not be allowed + search_res_op = coll->search("butter", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, vec_query); + ASSERT_FALSE(search_res_op.ok()); + ASSERT_EQ("Vector query could not find any embedded fields.", search_res_op.error()); + + // when no vector matches distance threshold, only text matches are entertained and distance score should be + // 2 in those cases + vec_query = "embedding:([], distance_threshold: 0.01)"; + search_res_op = coll->search("butter", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, vec_query); + ASSERT_TRUE(search_res_op.ok()); + search_res = search_res_op.get(); + + ASSERT_EQ(3, search_res["found"].get()); + ASSERT_EQ(3, search_res["hits"].size()); + + ASSERT_FLOAT_EQ(2.0f, search_res["hits"][0]["vector_distance"].get()); + ASSERT_FLOAT_EQ(2.0f, search_res["hits"][1]["vector_distance"].get()); + ASSERT_FLOAT_EQ(2.0f, search_res["hits"][2]["vector_distance"].get()); + + ASSERT_FLOAT_EQ(2.0f, search_res["hits"][0]["hybrid_search_info"]["vector_distance"].get()); + ASSERT_FLOAT_EQ(2.0f, search_res["hits"][1]["hybrid_search_info"]["vector_distance"].get()); + ASSERT_FLOAT_EQ(2.0f, search_res["hits"][2]["hybrid_search_info"]["vector_distance"].get()); +} TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) { nlohmann::json schema = R"({ @@ -837,7 +893,7 @@ TEST_F(CollectionVectorTest, DistanceThresholdTest) { } -TEST_F(CollectionVectorTest, EmbeddingFieldVectorIndexTest) { +TEST_F(CollectionVectorTest, EmbeddingFieldAlterDropTest) { nlohmann::json schema = R"({ "name": "objects", "fields": [ diff --git a/test/vector_query_ops_test.cpp b/test/vector_query_ops_test.cpp index 96661fce..0a8b6b49 100644 --- a/test/vector_query_ops_test.cpp +++ b/test/vector_query_ops_test.cpp @@ -17,7 +17,7 @@ protected: TEST_F(VectorQueryOpsTest, ParseVectorQueryString) { vector_query_t vector_query; - auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr); + auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr); ASSERT_TRUE(parsed.ok()); ASSERT_EQ("vec", vector_query.field_name); ASSERT_EQ(10, vector_query.k); @@ -28,46 +28,50 @@ TEST_F(VectorQueryOpsTest, ParseVectorQueryString) { } vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr); ASSERT_TRUE(parsed.ok()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, false, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error()); // cannot pass both vector and id vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, false, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string: cannot pass both vector query and `id` parameter.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, false, nullptr); + ASSERT_TRUE(parsed.ok()); + + vector_query._reset(); + parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, true, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, false, nullptr); ASSERT_TRUE(parsed.ok()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, false, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, false, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, false, nullptr); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); }