mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 13:12:22 +08:00
Allow params for hybrid vector search to be sent via vector_query.
This commit is contained in:
parent
2b204112ef
commit
c2db7436a2
@ -16,7 +16,7 @@ struct KV {
|
||||
int64_t scores[3]{}; // match score + 2 custom attributes
|
||||
|
||||
// only to be used in hybrid search
|
||||
float vector_distance = 0.0f;
|
||||
float vector_distance = 2.0f;
|
||||
int64_t text_match_score = 0;
|
||||
|
||||
reference_filter_result_t* reference_filter_result = nullptr;
|
||||
|
@ -29,6 +29,7 @@ struct vector_query_t {
|
||||
|
||||
class VectorQueryOps {
|
||||
public:
|
||||
static Option<bool> parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query,
|
||||
static Option<bool> parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query,
|
||||
const bool is_wildcard_query,
|
||||
const Collection* coll);
|
||||
};
|
@ -1168,7 +1168,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
vector_query_t vector_query;
|
||||
if(!vector_query_str.empty()) {
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this);
|
||||
bool is_wildcard_query = (raw_query == "*" || raw_query.empty());
|
||||
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query,
|
||||
is_wildcard_query, this);
|
||||
if(!parse_vector_op.ok()) {
|
||||
return Option<nlohmann::json>(400, parse_vector_op.error());
|
||||
}
|
||||
@ -1178,18 +1181,17 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
return Option<nlohmann::json>(400, "Field `" + vector_query.field_name + "` does not have a vector query index.");
|
||||
}
|
||||
|
||||
if(vector_field_it.value().num_dim != vector_query.values.size()) {
|
||||
if(is_wildcard_query && vector_field_it.value().num_dim != vector_query.values.size()) {
|
||||
return Option<nlohmann::json>(400, "Query field `" + vector_query.field_name + "` must have " +
|
||||
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// validate search fields
|
||||
std::vector<std::string> processed_search_fields;
|
||||
std::vector<uint32_t> query_by_weights;
|
||||
bool has_embedding_query = false;
|
||||
size_t num_embed_fields = 0;
|
||||
|
||||
for(size_t i = 0; i < raw_search_fields.size(); i++) {
|
||||
const std::string& field_name = raw_search_fields[i];
|
||||
if(field_name == "id") {
|
||||
@ -1208,7 +1210,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
auto search_field = search_schema.at(expanded_search_field);
|
||||
|
||||
if(search_field.num_dim > 0) {
|
||||
if(!vector_query.field_name.empty()) {
|
||||
num_embed_fields++;
|
||||
|
||||
if(num_embed_fields > 1 ||
|
||||
(!vector_query.field_name.empty() && search_field.name != vector_query.field_name)) {
|
||||
std::string error = "Only one embedding field is allowed in the query.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
@ -1253,10 +1258,13 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
}
|
||||
std::vector<float> embedding = embedding_op.embedding;
|
||||
// distance could have been set for an embed field, so we take a backup and restore
|
||||
auto dist = vector_query.distance_threshold;
|
||||
vector_query._reset();
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
vector_query.k = vector_query_hits;
|
||||
vector_query.distance_threshold = dist;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1267,6 +1275,11 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty() && vector_query.values.empty() && num_embed_fields == 0) {
|
||||
std::string error = "Vector query could not find any embedded fields.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
std::string real_raw_query = raw_query;
|
||||
if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) {
|
||||
raw_query = "*";
|
||||
@ -1962,7 +1975,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
wrapper_doc["geo_distance_meters"] = geo_distances;
|
||||
}
|
||||
|
||||
if(!vector_query.field_name.empty() && query == "*") {
|
||||
if(!vector_query.field_name.empty()) {
|
||||
wrapper_doc["vector_distance"] = field_order_kv->vector_distance;
|
||||
}
|
||||
|
||||
|
@ -2,8 +2,10 @@
|
||||
#include "string_utils.h"
|
||||
#include "collection.h"
|
||||
|
||||
Option<bool> VectorQueryOps::parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query,
|
||||
const Collection* coll) {
|
||||
Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str,
|
||||
vector_query_t& vector_query,
|
||||
const bool is_wildcard_query,
|
||||
const Collection* coll) {
|
||||
// FORMAT:
|
||||
// field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)
|
||||
size_t i = 0;
|
||||
@ -156,7 +158,7 @@ Option<bool> VectorQueryOps::parse_vector_query_str(std::string vector_query_str
|
||||
}
|
||||
}
|
||||
|
||||
if(!vector_query.query_doc_given && vector_query.values.empty()) {
|
||||
if(is_wildcard_query && !vector_query.query_doc_given && vector_query.values.empty()) {
|
||||
return Option<bool>(400, "When a vector query value is empty, an `id` parameter must be present.");
|
||||
}
|
||||
|
||||
|
@ -743,8 +743,64 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
|
||||
ASSERT_FLOAT_EQ((1.0/1.0 * 0.7) + (1.0/1.0 * 0.3), search_res["hits"][0]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/2.0 * 0.7) + (1.0/3.0 * 0.3), search_res["hits"][1]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
|
||||
}
|
||||
|
||||
// hybrid search with empty vector (to pass distance threshold param)
|
||||
std::string vec_query = "embedding:([], distance_threshold: 0.20)";
|
||||
|
||||
search_res_op = coll->search("butter", {"embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, vec_query);
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
search_res = search_res_op.get();
|
||||
|
||||
ASSERT_EQ(2, search_res["found"].get<size_t>());
|
||||
ASSERT_EQ(2, search_res["hits"].size());
|
||||
|
||||
ASSERT_FLOAT_EQ(0.0462081432, search_res["hits"][0]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.1213316321, search_res["hits"][1]["vector_distance"].get<float>());
|
||||
|
||||
// when no embedding field is passed, it should not be allowed
|
||||
search_res_op = coll->search("butter", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, vec_query);
|
||||
ASSERT_FALSE(search_res_op.ok());
|
||||
ASSERT_EQ("Vector query could not find any embedded fields.", search_res_op.error());
|
||||
|
||||
// when no vector matches distance threshold, only text matches are entertained and distance score should be
|
||||
// 2 in those cases
|
||||
vec_query = "embedding:([], distance_threshold: 0.01)";
|
||||
search_res_op = coll->search("butter", {"name", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, vec_query);
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
search_res = search_res_op.get();
|
||||
|
||||
ASSERT_EQ(3, search_res["found"].get<size_t>());
|
||||
ASSERT_EQ(3, search_res["hits"].size());
|
||||
|
||||
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][0]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][1]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][2]["vector_distance"].get<float>());
|
||||
|
||||
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][0]["hybrid_search_info"]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][1]["hybrid_search_info"]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][2]["hybrid_search_info"]["vector_distance"].get<float>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) {
|
||||
nlohmann::json schema = R"({
|
||||
@ -837,7 +893,7 @@ TEST_F(CollectionVectorTest, DistanceThresholdTest) {
|
||||
|
||||
}
|
||||
|
||||
TEST_F(CollectionVectorTest, EmbeddingFieldVectorIndexTest) {
|
||||
TEST_F(CollectionVectorTest, EmbeddingFieldAlterDropTest) {
|
||||
nlohmann::json schema = R"({
|
||||
"name": "objects",
|
||||
"fields": [
|
||||
|
@ -17,7 +17,7 @@ protected:
|
||||
|
||||
TEST_F(VectorQueryOpsTest, ParseVectorQueryString) {
|
||||
vector_query_t vector_query;
|
||||
auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr);
|
||||
auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
ASSERT_EQ("vec", vector_query.field_name);
|
||||
ASSERT_EQ(10, vector_query.k);
|
||||
@ -28,46 +28,50 @@ TEST_F(VectorQueryOpsTest, ParseVectorQueryString) {
|
||||
}
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, false, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error());
|
||||
|
||||
// cannot pass both vector and id
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, false, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string: cannot pass both vector query and `id` parameter.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, false, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, true, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, false, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, false, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, false, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, nullptr);
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, false, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user