mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 23:06:30 +08:00
Refactor validating vector query & add more tests for pq
This commit is contained in:
parent
1f4f197fd5
commit
d142172c0b
@ -333,6 +333,13 @@ private:
|
||||
|
||||
void remove_embedding_field(const std::string& field_name);
|
||||
|
||||
Option<bool> parse_and_validate_vector_query(const std::string& vector_query_str,
|
||||
vector_query_t& vector_query,
|
||||
const bool is_wildcard_query,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_tries,
|
||||
size_t& per_page) const;
|
||||
|
||||
public:
|
||||
|
||||
enum {MAX_ARRAY_MATCHES = 5};
|
||||
|
@ -1219,6 +1219,68 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
|
||||
return Option<bool>(400, "Could not find a field named `" + sort_field_std.vector_query.query.field_name + "` in vector index.");
|
||||
}
|
||||
|
||||
if(!sort_field_std.vector_query.query.qs.empty()) {
|
||||
if(embedding_fields.find(sort_field_std.vector_query.query.field_name) == embedding_fields.end()) {
|
||||
return Option<bool>(400, "`qs` parameter is only supported for auto-embedding fields.");
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> embeddings;
|
||||
for(const auto& q: sort_field_std.vector_query.query.qs) {
|
||||
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
|
||||
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<bool>(400, embedder_op.error());
|
||||
}
|
||||
|
||||
auto remote_embedding_timeout_us = remote_embedding_timeout_ms * 1000;
|
||||
if((std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > remote_embedding_timeout_us) {
|
||||
std::string error = "Request timed out.";
|
||||
return Option<bool>(500, error);
|
||||
}
|
||||
|
||||
auto embedder = embedder_op.get();
|
||||
|
||||
if(embedder->is_remote()) {
|
||||
if(remote_embedding_num_tries == 0) {
|
||||
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
}
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + q;
|
||||
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
|
||||
|
||||
if(!embedding_op.success) {
|
||||
if(!embedding_op.error["error"].get<std::string>().empty()) {
|
||||
return Option<bool>(400, embedding_op.error["error"].get<std::string>());
|
||||
} else {
|
||||
return Option<bool>(400, embedding_op.error.dump());
|
||||
}
|
||||
}
|
||||
|
||||
embeddings.emplace_back(embedding_op.embedding);
|
||||
}
|
||||
|
||||
// get average of all embeddings
|
||||
std::vector<float> avg_embedding(vector_field_it.value().num_dim, 0);
|
||||
for(const auto& embedding: embeddings) {
|
||||
for(size_t i = 0; i < embedding.size(); i++) {
|
||||
avg_embedding[i] += embedding[i];
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < avg_embedding.size(); i++) {
|
||||
avg_embedding[i] /= embeddings.size();
|
||||
}
|
||||
|
||||
sort_field_std.vector_query.query.values = avg_embedding;
|
||||
if(vector_field_it.value().vec_dist == cosine) {
|
||||
std::vector<float> normalized_values(sort_field_std.vector_query.query.values.size());
|
||||
hnsw_index_t::normalize_vector(sort_field_std.vector_query.query.values, normalized_values);
|
||||
sort_field_std.vector_query.query.values = normalized_values;
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_field_std.vector_query.query.values.empty() && embedding_fields.find(sort_field_std.vector_query.query.field_name) != embedding_fields.end()) {
|
||||
// generate embeddings for the query
|
||||
@ -1676,91 +1738,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
if(!vector_query_str.empty()) {
|
||||
bool is_wildcard_query = (raw_query == "*" || raw_query.empty());
|
||||
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query,
|
||||
is_wildcard_query, this, false);
|
||||
auto parse_vector_op = parse_and_validate_vector_query(vector_query_str, vector_query, is_wildcard_query, remote_embedding_timeout_ms, remote_embedding_num_tries, per_page);
|
||||
|
||||
if(!parse_vector_op.ok()) {
|
||||
return Option<nlohmann::json>(400, parse_vector_op.error());
|
||||
}
|
||||
|
||||
auto vector_field_it = search_schema.find(vector_query.field_name);
|
||||
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
|
||||
return Option<nlohmann::json>(400, "Field `" + vector_query.field_name + "` does not have a vector query index.");
|
||||
}
|
||||
|
||||
if(!vector_field_it.value().index) {
|
||||
return Option<nlohmann::json>(400, "Field `" + vector_query.field_name + "` is marked as a non-indexed field in the schema.");
|
||||
}
|
||||
|
||||
if(!vector_query.qs.empty()) {
|
||||
if(embedding_fields.find(vector_query.field_name) == embedding_fields.end()) {
|
||||
return Option<nlohmann::json>(400, "`qs` parameter is only supported for auto-embedding fields.");
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> embeddings;
|
||||
for(const auto& q: vector_query.qs) {
|
||||
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
|
||||
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<nlohmann::json>(400, embedder_op.error());
|
||||
}
|
||||
|
||||
auto remote_embedding_timeout_us = remote_embedding_timeout_ms * 1000;
|
||||
if((std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > remote_embedding_timeout_us) {
|
||||
std::string error = "Request timed out.";
|
||||
return Option<nlohmann::json>(500, error);
|
||||
}
|
||||
|
||||
auto embedder = embedder_op.get();
|
||||
|
||||
if(embedder->is_remote()) {
|
||||
if(remote_embedding_num_tries == 0) {
|
||||
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
}
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + q;
|
||||
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
|
||||
|
||||
if(!embedding_op.success) {
|
||||
if(!embedding_op.error["error"].get<std::string>().empty()) {
|
||||
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());
|
||||
} else {
|
||||
return Option<nlohmann::json>(400, embedding_op.error.dump());
|
||||
}
|
||||
}
|
||||
|
||||
embeddings.emplace_back(embedding_op.embedding);
|
||||
}
|
||||
|
||||
// get average of all embeddings
|
||||
std::vector<float> avg_embedding(vector_field_it.value().num_dim, 0);
|
||||
for(const auto& embedding: embeddings) {
|
||||
for(size_t i = 0; i < embedding.size(); i++) {
|
||||
avg_embedding[i] += embedding[i];
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < avg_embedding.size(); i++) {
|
||||
avg_embedding[i] /= embeddings.size();
|
||||
}
|
||||
|
||||
vector_query.values = avg_embedding;
|
||||
}
|
||||
|
||||
if(is_wildcard_query) {
|
||||
if(vector_query.values.empty() && !vector_query.query_doc_given) {
|
||||
// for usability we will treat this as non-vector query
|
||||
vector_query.field_name.clear();
|
||||
if(vector_query.k != 0) {
|
||||
per_page = std::min(per_page, vector_query.k);
|
||||
}
|
||||
}
|
||||
|
||||
else if(vector_field_it.value().num_dim != vector_query.values.size()) {
|
||||
return Option<nlohmann::json>(400, "Query field `" + vector_query.field_name + "` must have " +
|
||||
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
|
||||
}
|
||||
return Option<nlohmann::json>(parse_vector_op.code(), parse_vector_op.error());
|
||||
}
|
||||
}
|
||||
|
||||
@ -6375,3 +6356,100 @@ tsl::htrie_map<char, field> Collection::get_embedding_fields_unsafe() {
|
||||
void Collection::do_housekeeping() {
|
||||
index->repair_hnsw_index();
|
||||
}
|
||||
|
||||
Option<bool> Collection::parse_and_validate_vector_query(const std::string& vector_query_str,
|
||||
vector_query_t& vector_query,
|
||||
const bool is_wildcard_query,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_tries,
|
||||
size_t& per_page) const {
|
||||
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query,
|
||||
is_wildcard_query, this, false);
|
||||
if(!parse_vector_op.ok()) {
|
||||
return Option<bool>(400, parse_vector_op.error());
|
||||
}
|
||||
|
||||
auto vector_field_it = search_schema.find(vector_query.field_name);
|
||||
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
|
||||
return Option<bool>(400, "Field `" + vector_query.field_name + "` does not have a vector query index.");
|
||||
}
|
||||
|
||||
if(!vector_field_it.value().index) {
|
||||
return Option<bool>(400, "Field `" + vector_query.field_name + "` is marked as a non-indexed field in the schema.");
|
||||
}
|
||||
|
||||
if(!vector_query.qs.empty()) {
|
||||
if(embedding_fields.find(vector_query.field_name) == embedding_fields.end()) {
|
||||
return Option<bool>(400, "`qs` parameter is only supported for auto-embedding fields.");
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> embeddings;
|
||||
for(const auto& q: vector_query.qs) {
|
||||
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
|
||||
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<bool>(400, embedder_op.error());
|
||||
}
|
||||
|
||||
auto remote_embedding_timeout_us = remote_embedding_timeout_ms * 1000;
|
||||
if((std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > remote_embedding_timeout_us) {
|
||||
std::string error = "Request timed out.";
|
||||
return Option<bool>(500, error);
|
||||
}
|
||||
|
||||
auto embedder = embedder_op.get();
|
||||
|
||||
if(embedder->is_remote()) {
|
||||
if(remote_embedding_num_tries == 0) {
|
||||
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
}
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + q;
|
||||
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
|
||||
|
||||
if(!embedding_op.success) {
|
||||
if(!embedding_op.error["error"].get<std::string>().empty()) {
|
||||
return Option<bool>(400, embedding_op.error["error"].get<std::string>());
|
||||
} else {
|
||||
return Option<bool>(400, embedding_op.error.dump());
|
||||
}
|
||||
}
|
||||
|
||||
embeddings.emplace_back(embedding_op.embedding);
|
||||
}
|
||||
|
||||
// get average of all embeddings
|
||||
std::vector<float> avg_embedding(vector_field_it.value().num_dim, 0);
|
||||
for(const auto& embedding: embeddings) {
|
||||
for(size_t i = 0; i < embedding.size(); i++) {
|
||||
avg_embedding[i] += embedding[i];
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < avg_embedding.size(); i++) {
|
||||
avg_embedding[i] /= embeddings.size();
|
||||
}
|
||||
|
||||
vector_query.values = avg_embedding;
|
||||
}
|
||||
|
||||
if(is_wildcard_query) {
|
||||
if(vector_query.values.empty() && !vector_query.query_doc_given) {
|
||||
// for usability we will treat this as non-vector query
|
||||
vector_query.field_name.clear();
|
||||
if(vector_query.k != 0) {
|
||||
per_page = std::min(per_page, vector_query.k);
|
||||
}
|
||||
}
|
||||
|
||||
else if(vector_field_it.value().num_dim != vector_query.values.size()) {
|
||||
return Option<bool>(400, "Query field `" + vector_query.field_name + "` must have " +
|
||||
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
@ -2582,3 +2582,59 @@ TEST_F(CollectionSortingTest, TestSortByVectorQuery) {
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
|
||||
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, TestVectorQueryQsSorting) {
|
||||
auto schema_json =
|
||||
R"({
|
||||
"name": "test",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema_json);
|
||||
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
auto coll = collection_create_op.get();
|
||||
|
||||
auto add_op = coll->add(R"({
|
||||
"name": "buttercup"
|
||||
})"_json.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
add_op = coll->add(R"({
|
||||
"name": "butter"
|
||||
})"_json.dump());
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
auto results = coll->search("butter", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "").get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("0", results["hits"][1]["document"]["id"]);
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_vector_query(embedding:([], qs: [powerpuff girls, cartoon]))", "asc"),
|
||||
};
|
||||
|
||||
results = coll->search("butter", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "").get();
|
||||
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
|
||||
}
|
@ -3483,4 +3483,43 @@ TEST_F(CollectionVectorTest, TestVectorQueryInvalidQs) {
|
||||
|
||||
ASSERT_EQ(results.error(), "Malformed vector query string: "
|
||||
"`qs` parameter must be a list of strings.");
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(CollectionVectorTest, TestVectorQueryQsWithHybridSearch) {
|
||||
auto schema_json =
|
||||
R"({
|
||||
"name": "test",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/all-MiniLM-L12-v2"}}}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
|
||||
|
||||
auto collection_create_op = collectionManager.create_collection(schema_json);
|
||||
|
||||
ASSERT_TRUE(collection_create_op.ok());
|
||||
|
||||
auto coll = collection_create_op.get();
|
||||
|
||||
auto add_op = coll->add(R"({
|
||||
"name": "Stark Industries"
|
||||
})"_json.dump());
|
||||
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
|
||||
auto results = coll->search("stark", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "embedding:([], qs:[superhero, company])");
|
||||
|
||||
ASSERT_TRUE(results.ok());
|
||||
ASSERT_EQ(results.get()["hits"].size(), 1);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user