From 4167fe69c8a93d8f1f5e64974f8be0903b6d754e Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sat, 14 Oct 2023 02:07:51 +0300 Subject: [PATCH 1/4] Sort by vector query --- include/collection.h | 6 ++- include/field.h | 6 ++- include/index.h | 1 + include/vector_query_ops.h | 3 +- src/collection.cpp | 69 +++++++++++++++++++++++++-- src/index.cpp | 69 +++++++++++++++++++++++++++ src/vector_query_ops.cpp | 5 +- test/collection_sorting_test.cpp | 80 ++++++++++++++++++++++++++++++++ 8 files changed, 227 insertions(+), 12 deletions(-) diff --git a/include/collection.h b/include/collection.h index 90586fc1..a1f5fa87 100644 --- a/include/collection.h +++ b/include/collection.h @@ -207,8 +207,10 @@ private: Option validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, - bool is_wildcard_query,const bool is_vector_query, - bool is_group_by_query = false) const; + bool is_wildcard_query, const bool is_vector_query, + const std::string& query, bool is_group_by_query = false, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_tries = 2) const; Option persist_collection_meta(); diff --git a/include/field.h b/include/field.h index 6d262eba..0b5d98c8 100644 --- a/include/field.h +++ b/include/field.h @@ -11,6 +11,7 @@ #include #include "json.hpp" #include "text_embedder_manager.h" +#include "vector_query_ops.h" namespace field_types { // first field value indexed will determine the type @@ -661,6 +662,7 @@ namespace sort_field_const { static const std::string missing_values = "missing_values"; static const std::string vector_distance = "_vector_distance"; + static const std::string vector_query = "_vector_query"; } struct sort_by { @@ -690,10 +692,11 @@ struct sort_by { missing_values_t missing_values; eval_t eval; + vector_query_t vector_query; + sort_by(const std::string & name, const std::string & order): name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0), missing_values(normal) { - } sort_by(const std::string &name, const std::string &order, uint32_t text_match_buckets, int64_t geopoint, @@ -701,7 +704,6 @@ struct sort_by { name(name), order(order), text_match_buckets(text_match_buckets), geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision), missing_values(normal) { - } sort_by& operator=(const sort_by& other) { diff --git a/include/index.h b/include/index.h index a987f3f2..faa0d745 100644 --- a/include/index.h +++ b/include/index.h @@ -351,6 +351,7 @@ private: static spp::sparse_hash_map geo_sentinel_value; static spp::sparse_hash_map str_sentinel_value; static spp::sparse_hash_map vector_distance_sentinel_value; + static spp::sparse_hash_map vector_query_sentinel_value; // Internal utility functions diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index b161bd3e..3fc033d1 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -32,5 +32,6 @@ class VectorQueryOps { public: static Option parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query, const bool is_wildcard_query, - const Collection* coll); + const Collection* coll, + const bool allow_empty_query = false); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index adda2fc5..a7d993f9 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -744,7 +744,9 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< std::vector& sort_fields_std, const bool is_wildcard_query, const bool is_vector_query, - const bool is_group_by_query) const { + const std::string& query, const bool is_group_by_query, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_tries) const { size_t num_sort_expressions = 0; @@ -793,6 +795,62 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_field_std.name = actual_field_name; num_sort_expressions++; + } else if(actual_field_name == sort_field_const::vector_query) { + const std::string& vector_query_str = sort_field_std.name.substr(paran_start + 1, + sort_field_std.name.size() - paran_start - + 2); + if(vector_query_str.empty()) { + return Option(400, "The vector query in sort_by is empty."); + } + + + auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, sort_field_std.vector_query, + is_wildcard_query, this, true); + if(!parse_vector_op.ok()) { + return Option(400, parse_vector_op.error()); + } + + auto vector_field_it = search_schema.find(sort_field_std.vector_query.field_name); + if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) { + return Option(400, "Field `" + sort_field_std.vector_query.field_name + "` does not have a vector query index."); + } + + + if(sort_field_std.vector_query.values.empty() && embedding_fields.find(sort_field_std.vector_query.field_name) != embedding_fields.end()) { + // generate embeddings for the query + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]); + if(!embedder_op.ok()) { + return Option(embedder_op.code(), embedder_op.error()); + } + + auto embedder = embedder_op.get(); + + if(embedder->is_remote() && remote_embedding_num_tries == 0) { + std::string error = "`remote_embedding_num_tries` must be greater than 0."; + return Option(400, error); + } + + std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + query; + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries); + + if(!embedding_op.success) { + if(!embedding_op.error["error"].get().empty()) { + return Option(400, embedding_op.error["error"].get()); + } else { + return Option(400, embedding_op.error.dump()); + } + } + + sort_field_std.vector_query.values = embedding_op.embedding; + } + + if(vector_field_it.value().num_dim != sort_field_std.vector_query.values.size()) { + return Option(400, "Query field `" + sort_field_std.vector_query.field_name + "` must have " + + std::to_string(vector_field_it.value().num_dim) + " dimensions."); + } + + sort_field_std.name = actual_field_name; } else { if(field_it == search_schema.end()) { @@ -918,7 +976,8 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && - sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) { + sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance && + sort_field_std.name != sort_field_const::vector_query) { const auto field_it = search_schema.find(sort_field_std.name); if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { @@ -1537,7 +1596,7 @@ Option Collection::search(std::string raw_query, if(curated_sort_by.empty()) { auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, - sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1549,7 +1608,7 @@ Option Collection::search(std::string raw_query, } auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, - sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -5048,4 +5107,4 @@ void Collection::remove_embedding_field(const std::string& field_name) { tsl::htrie_map Collection::get_embedding_fields_unsafe() { return embedding_fields; -} \ No newline at end of file +} diff --git a/src/index.cpp b/src/index.cpp index 8dc5ebb8..c7a9c8a7 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -43,6 +43,7 @@ spp::sparse_hash_map Index::eval_sentinel_value; spp::sparse_hash_map Index::geo_sentinel_value; spp::sparse_hash_map Index::str_sentinel_value; spp::sparse_hash_map Index::vector_distance_sentinel_value; +spp::sparse_hash_map Index::vector_query_sentinel_value; struct token_posting_t { uint32_t token_id; @@ -4395,6 +4396,28 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[0] = int64_t(found); } else if(field_values[0] == &vector_distance_sentinel_value) { scores[0] = float_to_int64_t(vector_distance); + } else if(field_values[0] == &vector_query_sentinel_value) { + scores[0] = float_to_int64_t(2.0f); + try { + auto& field_vector_index = vector_index.at(sort_fields[0].vector_query.field_name); + const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = field_vector_index->space->get_dist_func(); + float dist = 2.0f; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_values(sort_fields[0].vector_query.values.size()); + hnsw_index_t::normalize_vector(sort_fields[0].vector_query.values, normalized_values); + dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); + + } else { + dist = dist_func(sort_fields[0].vector_query.values.data(), values.data(), &field_vector_index->num_dim); + } + + scores[0] = float_to_int64_t(dist); + } catch(...) { + // probably not found + // do nothing + } + } else { auto it = field_values[0]->find(seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; @@ -4453,6 +4476,28 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[1] = int64_t(found); } else if(field_values[1] == &vector_distance_sentinel_value) { scores[1] = float_to_int64_t(vector_distance); + } else if(field_values[1] == &vector_query_sentinel_value) { + scores[1] = float_to_int64_t(2.0f); + try { + auto& field_vector_index = vector_index.at(sort_fields[1].vector_query.field_name); + const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = field_vector_index->space->get_dist_func(); + float dist = 2.0f; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_values(sort_fields[1].vector_query.values.size()); + hnsw_index_t::normalize_vector(sort_fields[1].vector_query.values, normalized_values); + dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); + + } else { + dist = dist_func(sort_fields[1].vector_query.values.data(), values.data(), &field_vector_index->num_dim); + } + + scores[1] = float_to_int64_t(dist); + } catch(...) { + // probably not found + // do nothing + } + } else { auto it = field_values[1]->find(seq_id); scores[1] = (it == field_values[1]->end()) ? default_score : it->second; @@ -4507,6 +4552,28 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[2] = int64_t(found); } else if(field_values[2] == &vector_distance_sentinel_value) { scores[2] = float_to_int64_t(vector_distance); + } else if(field_values[2] == &vector_query_sentinel_value) { + scores[2] = float_to_int64_t(2.0f); + try { + auto& field_vector_index = vector_index.at(sort_fields[2].vector_query.field_name); + const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = field_vector_index->space->get_dist_func(); + float dist = 2.0f; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_values(sort_fields[2].vector_query.values.size()); + hnsw_index_t::normalize_vector(sort_fields[2].vector_query.values, normalized_values); + dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); + + } else { + dist = dist_func(sort_fields[2].vector_query.values.data(), values.data(), &field_vector_index->num_dim); + } + + scores[2] = float_to_int64_t(dist); + } catch(...) { + // probably not found + // do nothing + } + } else { auto it = field_values[2]->find(seq_id); scores[2] = (it == field_values[2]->end()) ? default_score : it->second; @@ -5238,6 +5305,8 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint result.docs = nullptr; } else if(sort_fields_std[i].name == sort_field_const::vector_distance) { field_values[i] = &vector_distance_sentinel_value; + } else if(sort_fields_std[i].name == sort_field_const::vector_query) { + field_values[i] = &vector_query_sentinel_value; } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index dba9d27d..74189948 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -5,7 +5,8 @@ Option VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query, const bool is_wildcard_query, - const Collection* coll) { + const Collection* coll, + const bool allow_empty_query) { // FORMAT: // field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10) size_t i = 0; @@ -72,7 +73,7 @@ Option VectorQueryOps::parse_vector_query_str(const std::string& vector_qu if(i == vector_query_str.size()-1) { // missing params - if(vector_query.values.empty()) { + if(vector_query.values.empty() && !allow_empty_query) { // when query values are missing, atleast the `id` parameter must be present return Option(400, "When a vector query value is empty, an `id` parameter must be present."); } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 95c255e9..8d51a079 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -2387,4 +2387,84 @@ TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) { ASSERT_FALSE(results.ok()); ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error()); +} + + +TEST_F(CollectionSortingTest, TestSortByVectorQuery) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "name", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + auto create_coll = collectionManager.create_collection(schema); + ASSERT_TRUE(create_coll.ok()); + + auto coll = create_coll.get(); + + std::vector> points = { + {7.0, 8.0}, + {8.0, 15.0}, + {5.0, 12.0}, + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["name"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll->add(doc.dump()).ok()); + } + + std::vector sort_fields = {}; + + auto results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "").get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"]); + ASSERT_EQ("1", results["hits"][1]["document"]["id"]); + ASSERT_EQ("0", results["hits"][2]["document"]["id"]); + + sort_fields = { + sort_by("_vector_query(points:([5.0, 5.0]))", "asc"), + }; + + results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "").get(); + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"]); + ASSERT_EQ("1", results["hits"][1]["document"]["id"]); + ASSERT_EQ("2", results["hits"][2]["document"]["id"]); + + sort_fields = { + sort_by("_vector_query(points:([5.0, 5.0]))", "desc"), + }; + + results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "").get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"]); + ASSERT_EQ("1", results["hits"][1]["document"]["id"]); + ASSERT_EQ("0", results["hits"][2]["document"]["id"]); } \ No newline at end of file From dfbb1ebfb43c06f0e84addcc2b0d11b3f4929843 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sun, 15 Oct 2023 14:08:14 +0300 Subject: [PATCH 2/4] Refactor vector index lookup & normalization --- include/field.h | 9 ++++++++- src/collection.cpp | 33 +++++++++++++++++++++++--------- src/index.cpp | 47 +++++++++------------------------------------- 3 files changed, 41 insertions(+), 48 deletions(-) diff --git a/include/field.h b/include/field.h index 0b5d98c8..b6f245eb 100644 --- a/include/field.h +++ b/include/field.h @@ -665,6 +665,13 @@ namespace sort_field_const { static const std::string vector_query = "_vector_query"; } +struct hnsw_index_t; + +struct sort_vector_query_t { + vector_query_t query; + hnsw_index_t* vector_index; +}; + struct sort_by { enum missing_values_t { first, @@ -692,7 +699,7 @@ struct sort_by { missing_values_t missing_values; eval_t eval; - vector_query_t vector_query; + sort_vector_query_t vector_query; sort_by(const std::string & name, const std::string & order): name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0), diff --git a/src/collection.cpp b/src/collection.cpp index a7d993f9..997a47a6 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -17,6 +17,7 @@ #include "thread_local_vars.h" #include "vector_query_ops.h" #include "text_embedder_manager.h" +#include "field.h" const std::string override_t::MATCH_EXACT = "exact"; const std::string override_t::MATCH_CONTAINS = "contains"; @@ -804,20 +805,20 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } - auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, sort_field_std.vector_query, + auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, sort_field_std.vector_query.query, is_wildcard_query, this, true); if(!parse_vector_op.ok()) { return Option(400, parse_vector_op.error()); } - auto vector_field_it = search_schema.find(sort_field_std.vector_query.field_name); + auto vector_field_it = search_schema.find(sort_field_std.vector_query.query.field_name); if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) { - return Option(400, "Field `" + sort_field_std.vector_query.field_name + "` does not have a vector query index."); + return Option(400, "Field `" + sort_field_std.vector_query.query.field_name + "` does not have a vector query index."); } - - - if(sort_field_std.vector_query.values.empty() && embedding_fields.find(sort_field_std.vector_query.field_name) != embedding_fields.end()) { + + if(sort_field_std.vector_query.query.values.empty() && embedding_fields.find(sort_field_std.vector_query.query.field_name) != embedding_fields.end()) { // generate embeddings for the query + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]); if(!embedder_op.ok()) { @@ -842,14 +843,28 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } } - sort_field_std.vector_query.values = embedding_op.embedding; + sort_field_std.vector_query.query.values = embedding_op.embedding; } - if(vector_field_it.value().num_dim != sort_field_std.vector_query.values.size()) { - return Option(400, "Query field `" + sort_field_std.vector_query.field_name + "` must have " + + const auto& vector_index_map = index->_get_vector_index(); + if(vector_index_map.find(sort_field_std.vector_query.query.field_name) == vector_index_map.end()) { + return Option(400, "Field `" + sort_field_std.vector_query.query.field_name + "` does not have a vector index."); + } + + + if(vector_field_it.value().num_dim != sort_field_std.vector_query.query.values.size()) { + return Option(400, "Query field `" + sort_field_std.vector_query.query.field_name + "` must have " + std::to_string(vector_field_it.value().num_dim) + " dimensions."); } + sort_field_std.vector_query.vector_index = vector_index_map.at(sort_field_std.vector_query.query.field_name); + + if(sort_field_std.vector_query.vector_index->distance_type == cosine) { + std::vector normalized_values(sort_field_std.vector_query.query.values.size()); + hnsw_index_t::normalize_vector(sort_field_std.vector_query.query.values, normalized_values); + sort_field_std.vector_query.query.values = normalized_values; + } + sort_field_std.name = actual_field_name; } else { diff --git a/src/index.cpp b/src/index.cpp index c7a9c8a7..65ae04ee 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4399,25 +4399,15 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } else if(field_values[0] == &vector_query_sentinel_value) { scores[0] = float_to_int64_t(2.0f); try { - auto& field_vector_index = vector_index.at(sort_fields[0].vector_query.field_name); - const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); - const auto& dist_func = field_vector_index->space->get_dist_func(); - float dist = 2.0f; - if(field_vector_index->distance_type == cosine) { - std::vector normalized_values(sort_fields[0].vector_query.values.size()); - hnsw_index_t::normalize_vector(sort_fields[0].vector_query.values, normalized_values); - dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); - - } else { - dist = dist_func(sort_fields[0].vector_query.values.data(), values.data(), &field_vector_index->num_dim); - } + const auto& values = sort_fields[0].vector_query.vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = sort_fields[0].vector_query.vector_index->space->get_dist_func(); + float dist = dist_func(sort_fields[0].vector_query.query.values.data(), values.data(), &sort_fields[0].vector_query.vector_index->num_dim); scores[0] = float_to_int64_t(dist); } catch(...) { // probably not found // do nothing } - } else { auto it = field_values[0]->find(seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; @@ -4479,18 +4469,9 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } else if(field_values[1] == &vector_query_sentinel_value) { scores[1] = float_to_int64_t(2.0f); try { - auto& field_vector_index = vector_index.at(sort_fields[1].vector_query.field_name); - const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); - const auto& dist_func = field_vector_index->space->get_dist_func(); - float dist = 2.0f; - if(field_vector_index->distance_type == cosine) { - std::vector normalized_values(sort_fields[1].vector_query.values.size()); - hnsw_index_t::normalize_vector(sort_fields[1].vector_query.values, normalized_values); - dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); - - } else { - dist = dist_func(sort_fields[1].vector_query.values.data(), values.data(), &field_vector_index->num_dim); - } + const auto& values = sort_fields[1].vector_query.vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = sort_fields[1].vector_query.vector_index->space->get_dist_func(); + float dist = dist_func(sort_fields[1].vector_query.query.values.data(), values.data(), &sort_fields[1].vector_query.vector_index->num_dim); scores[1] = float_to_int64_t(dist); } catch(...) { @@ -4555,25 +4536,15 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } else if(field_values[2] == &vector_query_sentinel_value) { scores[2] = float_to_int64_t(2.0f); try { - auto& field_vector_index = vector_index.at(sort_fields[2].vector_query.field_name); - const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); - const auto& dist_func = field_vector_index->space->get_dist_func(); - float dist = 2.0f; - if(field_vector_index->distance_type == cosine) { - std::vector normalized_values(sort_fields[2].vector_query.values.size()); - hnsw_index_t::normalize_vector(sort_fields[2].vector_query.values, normalized_values); - dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); - - } else { - dist = dist_func(sort_fields[2].vector_query.values.data(), values.data(), &field_vector_index->num_dim); - } + const auto& values = sort_fields[2].vector_query.vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = sort_fields[2].vector_query.vector_index->space->get_dist_func(); + float dist = dist_func(sort_fields[2].vector_query.query.values.data(), values.data(), &sort_fields[2].vector_query.vector_index->num_dim); scores[2] = float_to_int64_t(dist); } catch(...) { // probably not found // do nothing } - } else { auto it = field_values[2]->find(seq_id); scores[2] = (it == field_values[2]->end()) ? default_score : it->second; From 8b01c5cd859ca953523a93c0362c630dd6ab4521 Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sun, 15 Oct 2023 14:27:59 +0300 Subject: [PATCH 3/4] Update error message --- src/collection.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/collection.cpp b/src/collection.cpp index 997a47a6..84ec2184 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -813,7 +813,7 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< auto vector_field_it = search_schema.find(sort_field_std.vector_query.query.field_name); if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) { - return Option(400, "Field `" + sort_field_std.vector_query.query.field_name + "` does not have a vector query index."); + return Option(400, "Could not find a field named `" + sort_field_std.vector_query.query.field_name + "` in vector index."); } if(sort_field_std.vector_query.query.values.empty() && embedding_fields.find(sort_field_std.vector_query.query.field_name) != embedding_fields.end()) { From a44f996a1b8e311832e59806592365579e8f562a Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Sun, 15 Oct 2023 14:34:07 +0300 Subject: [PATCH 4/4] Update `parse_vector_query_str` --- include/vector_query_ops.h | 2 +- src/collection.cpp | 2 +- test/vector_query_ops_test.cpp | 22 +++++++++++----------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index 3fc033d1..ac124ca8 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -33,5 +33,5 @@ public: static Option parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query, const bool is_wildcard_query, const Collection* coll, - const bool allow_empty_query = false); + const bool allow_empty_query); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index 84ec2184..1bc16495 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1250,7 +1250,7 @@ Option Collection::search(std::string raw_query, bool is_wildcard_query = (raw_query == "*" || raw_query.empty()); auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, - is_wildcard_query, this); + is_wildcard_query, this, false); if(!parse_vector_op.ok()) { return Option(400, parse_vector_op.error()); } diff --git a/test/vector_query_ops_test.cpp b/test/vector_query_ops_test.cpp index e2c81e71..d1846ea6 100644 --- a/test/vector_query_ops_test.cpp +++ b/test/vector_query_ops_test.cpp @@ -17,7 +17,7 @@ protected: TEST_F(VectorQueryOpsTest, ParseVectorQueryString) { vector_query_t vector_query; - auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr); + auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr, false); ASSERT_TRUE(parsed.ok()); ASSERT_EQ("vec", vector_query.field_name); ASSERT_EQ(10, vector_query.k); @@ -28,49 +28,49 @@ TEST_F(VectorQueryOpsTest, ParseVectorQueryString) { } vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr, false); ASSERT_TRUE(parsed.ok()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, false, nullptr, false); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error()); // cannot pass both vector and id vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, false, nullptr, false); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string: cannot pass both vector query and `id` parameter.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, false, nullptr, false); ASSERT_TRUE(parsed.ok()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, true, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, true, nullptr, false); ASSERT_TRUE(parsed.ok()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, false, nullptr, false); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, false, nullptr, false); ASSERT_TRUE(parsed.ok()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, false, nullptr, false); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, false, nullptr, false); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); vector_query._reset(); - parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, false, nullptr); + parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, false, nullptr, false); ASSERT_FALSE(parsed.ok()); ASSERT_EQ("Malformed vector query string.", parsed.error()); }