Refactor validating vector query & add more tests for pq

This commit is contained in:
ozanarmagan 2023-12-28 15:30:12 +03:00
parent 1f4f197fd5
commit d142172c0b
4 changed files with 264 additions and 84 deletions

View File

@ -333,6 +333,13 @@ private:
void remove_embedding_field(const std::string& field_name);
Option<bool> parse_and_validate_vector_query(const std::string& vector_query_str,
vector_query_t& vector_query,
const bool is_wildcard_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries,
size_t& per_page) const;
public:
enum {MAX_ARRAY_MATCHES = 5};

View File

@ -1219,6 +1219,68 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
return Option<bool>(400, "Could not find a field named `" + sort_field_std.vector_query.query.field_name + "` in vector index.");
}
if(!sort_field_std.vector_query.query.qs.empty()) {
if(embedding_fields.find(sort_field_std.vector_query.query.field_name) == embedding_fields.end()) {
return Option<bool>(400, "`qs` parameter is only supported for auto-embedding fields.");
}
std::vector<std::vector<float>> embeddings;
for(const auto& q: sort_field_std.vector_query.query.qs) {
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
if(!embedder_op.ok()) {
return Option<bool>(400, embedder_op.error());
}
auto remote_embedding_timeout_us = remote_embedding_timeout_ms * 1000;
if((std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > remote_embedding_timeout_us) {
std::string error = "Request timed out.";
return Option<bool>(500, error);
}
auto embedder = embedder_op.get();
if(embedder->is_remote()) {
if(remote_embedding_num_tries == 0) {
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
return Option<bool>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + q;
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<bool>(400, embedding_op.error["error"].get<std::string>());
} else {
return Option<bool>(400, embedding_op.error.dump());
}
}
embeddings.emplace_back(embedding_op.embedding);
}
// get average of all embeddings
std::vector<float> avg_embedding(vector_field_it.value().num_dim, 0);
for(const auto& embedding: embeddings) {
for(size_t i = 0; i < embedding.size(); i++) {
avg_embedding[i] += embedding[i];
}
}
for(size_t i = 0; i < avg_embedding.size(); i++) {
avg_embedding[i] /= embeddings.size();
}
sort_field_std.vector_query.query.values = avg_embedding;
if(vector_field_it.value().vec_dist == cosine) {
std::vector<float> normalized_values(sort_field_std.vector_query.query.values.size());
hnsw_index_t::normalize_vector(sort_field_std.vector_query.query.values, normalized_values);
sort_field_std.vector_query.query.values = normalized_values;
}
}
if(sort_field_std.vector_query.query.values.empty() && embedding_fields.find(sort_field_std.vector_query.query.field_name) != embedding_fields.end()) {
// generate embeddings for the query
@ -1676,91 +1738,10 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
if(!vector_query_str.empty()) {
bool is_wildcard_query = (raw_query == "*" || raw_query.empty());
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query,
is_wildcard_query, this, false);
auto parse_vector_op = parse_and_validate_vector_query(vector_query_str, vector_query, is_wildcard_query, remote_embedding_timeout_ms, remote_embedding_num_tries, per_page);
if(!parse_vector_op.ok()) {
return Option<nlohmann::json>(400, parse_vector_op.error());
}
auto vector_field_it = search_schema.find(vector_query.field_name);
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
return Option<nlohmann::json>(400, "Field `" + vector_query.field_name + "` does not have a vector query index.");
}
if(!vector_field_it.value().index) {
return Option<nlohmann::json>(400, "Field `" + vector_query.field_name + "` is marked as a non-indexed field in the schema.");
}
if(!vector_query.qs.empty()) {
if(embedding_fields.find(vector_query.field_name) == embedding_fields.end()) {
return Option<nlohmann::json>(400, "`qs` parameter is only supported for auto-embedding fields.");
}
std::vector<std::vector<float>> embeddings;
for(const auto& q: vector_query.qs) {
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
if(!embedder_op.ok()) {
return Option<nlohmann::json>(400, embedder_op.error());
}
auto remote_embedding_timeout_us = remote_embedding_timeout_ms * 1000;
if((std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > remote_embedding_timeout_us) {
std::string error = "Request timed out.";
return Option<nlohmann::json>(500, error);
}
auto embedder = embedder_op.get();
if(embedder->is_remote()) {
if(remote_embedding_num_tries == 0) {
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
return Option<nlohmann::json>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + q;
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<nlohmann::json>(400, embedding_op.error["error"].get<std::string>());
} else {
return Option<nlohmann::json>(400, embedding_op.error.dump());
}
}
embeddings.emplace_back(embedding_op.embedding);
}
// get average of all embeddings
std::vector<float> avg_embedding(vector_field_it.value().num_dim, 0);
for(const auto& embedding: embeddings) {
for(size_t i = 0; i < embedding.size(); i++) {
avg_embedding[i] += embedding[i];
}
}
for(size_t i = 0; i < avg_embedding.size(); i++) {
avg_embedding[i] /= embeddings.size();
}
vector_query.values = avg_embedding;
}
if(is_wildcard_query) {
if(vector_query.values.empty() && !vector_query.query_doc_given) {
// for usability we will treat this as non-vector query
vector_query.field_name.clear();
if(vector_query.k != 0) {
per_page = std::min(per_page, vector_query.k);
}
}
else if(vector_field_it.value().num_dim != vector_query.values.size()) {
return Option<nlohmann::json>(400, "Query field `" + vector_query.field_name + "` must have " +
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
}
return Option<nlohmann::json>(parse_vector_op.code(), parse_vector_op.error());
}
}
@ -6375,3 +6356,100 @@ tsl::htrie_map<char, field> Collection::get_embedding_fields_unsafe() {
void Collection::do_housekeeping() {
index->repair_hnsw_index();
}
Option<bool> Collection::parse_and_validate_vector_query(const std::string& vector_query_str,
vector_query_t& vector_query,
const bool is_wildcard_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries,
size_t& per_page) const {
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query,
is_wildcard_query, this, false);
if(!parse_vector_op.ok()) {
return Option<bool>(400, parse_vector_op.error());
}
auto vector_field_it = search_schema.find(vector_query.field_name);
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
return Option<bool>(400, "Field `" + vector_query.field_name + "` does not have a vector query index.");
}
if(!vector_field_it.value().index) {
return Option<bool>(400, "Field `" + vector_query.field_name + "` is marked as a non-indexed field in the schema.");
}
if(!vector_query.qs.empty()) {
if(embedding_fields.find(vector_query.field_name) == embedding_fields.end()) {
return Option<bool>(400, "`qs` parameter is only supported for auto-embedding fields.");
}
std::vector<std::vector<float>> embeddings;
for(const auto& q: vector_query.qs) {
EmbedderManager& embedder_manager = EmbedderManager::get_instance();
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
if(!embedder_op.ok()) {
return Option<bool>(400, embedder_op.error());
}
auto remote_embedding_timeout_us = remote_embedding_timeout_ms * 1000;
if((std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > remote_embedding_timeout_us) {
std::string error = "Request timed out.";
return Option<bool>(500, error);
}
auto embedder = embedder_op.get();
if(embedder->is_remote()) {
if(remote_embedding_num_tries == 0) {
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
return Option<bool>(400, error);
}
}
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + q;
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<bool>(400, embedding_op.error["error"].get<std::string>());
} else {
return Option<bool>(400, embedding_op.error.dump());
}
}
embeddings.emplace_back(embedding_op.embedding);
}
// get average of all embeddings
std::vector<float> avg_embedding(vector_field_it.value().num_dim, 0);
for(const auto& embedding: embeddings) {
for(size_t i = 0; i < embedding.size(); i++) {
avg_embedding[i] += embedding[i];
}
}
for(size_t i = 0; i < avg_embedding.size(); i++) {
avg_embedding[i] /= embeddings.size();
}
vector_query.values = avg_embedding;
}
if(is_wildcard_query) {
if(vector_query.values.empty() && !vector_query.query_doc_given) {
// for usability we will treat this as non-vector query
vector_query.field_name.clear();
if(vector_query.k != 0) {
per_page = std::min(per_page, vector_query.k);
}
}
else if(vector_field_it.value().num_dim != vector_query.values.size()) {
return Option<bool>(400, "Query field `" + vector_query.field_name + "` must have " +
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
}
}
return Option<bool>(true);
}

View File

@ -2582,3 +2582,59 @@ TEST_F(CollectionSortingTest, TestSortByVectorQuery) {
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
}
TEST_F(CollectionSortingTest, TestVectorQueryQsSorting) {
auto schema_json =
R"({
"name": "test",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}}
]
})"_json;
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
auto coll = collection_create_op.get();
auto add_op = coll->add(R"({
"name": "buttercup"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
add_op = coll->add(R"({
"name": "butter"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
auto results = coll->search("butter", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "").get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"]);
ASSERT_EQ("0", results["hits"][1]["document"]["id"]);
sort_fields = {
sort_by("_vector_query(embedding:([], qs: [powerpuff girls, cartoon]))", "asc"),
};
results = coll->search("butter", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "").get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"]);
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
}

View File

@ -3483,4 +3483,43 @@ TEST_F(CollectionVectorTest, TestVectorQueryInvalidQs) {
ASSERT_EQ(results.error(), "Malformed vector query string: "
"`qs` parameter must be a list of strings.");
}
TEST_F(CollectionVectorTest, TestVectorQueryQsWithHybridSearch) {
auto schema_json =
R"({
"name": "test",
"fields": [
{"name": "name", "type": "string"},
{"name": "embedding", "type":"float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/all-MiniLM-L12-v2"}}}
]
})"_json;
EmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
auto coll = collection_create_op.get();
auto add_op = coll->add(R"({
"name": "Stark Industries"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
auto results = coll->search("stark", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, "embedding:([], qs:[superhero, company])");
ASSERT_TRUE(results.ok());
ASSERT_EQ(results.get()["hits"].size(), 1);
}