diff --git a/include/collection.h b/include/collection.h index 90586fc1..a1f5fa87 100644 --- a/include/collection.h +++ b/include/collection.h @@ -207,8 +207,10 @@ private: Option validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, - bool is_wildcard_query,const bool is_vector_query, - bool is_group_by_query = false) const; + bool is_wildcard_query, const bool is_vector_query, + const std::string& query, bool is_group_by_query = false, + const size_t remote_embedding_timeout_ms = 30000, + const size_t remote_embedding_num_tries = 2) const; Option persist_collection_meta(); diff --git a/include/field.h b/include/field.h index 6d262eba..0b5d98c8 100644 --- a/include/field.h +++ b/include/field.h @@ -11,6 +11,7 @@ #include #include "json.hpp" #include "text_embedder_manager.h" +#include "vector_query_ops.h" namespace field_types { // first field value indexed will determine the type @@ -661,6 +662,7 @@ namespace sort_field_const { static const std::string missing_values = "missing_values"; static const std::string vector_distance = "_vector_distance"; + static const std::string vector_query = "_vector_query"; } struct sort_by { @@ -690,10 +692,11 @@ struct sort_by { missing_values_t missing_values; eval_t eval; + vector_query_t vector_query; + sort_by(const std::string & name, const std::string & order): name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0), missing_values(normal) { - } sort_by(const std::string &name, const std::string &order, uint32_t text_match_buckets, int64_t geopoint, @@ -701,7 +704,6 @@ struct sort_by { name(name), order(order), text_match_buckets(text_match_buckets), geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision), missing_values(normal) { - } sort_by& operator=(const sort_by& other) { diff --git a/include/index.h b/include/index.h index a987f3f2..faa0d745 100644 --- a/include/index.h +++ b/include/index.h @@ -351,6 +351,7 @@ private: static spp::sparse_hash_map geo_sentinel_value; static spp::sparse_hash_map str_sentinel_value; static spp::sparse_hash_map vector_distance_sentinel_value; + static spp::sparse_hash_map vector_query_sentinel_value; // Internal utility functions diff --git a/include/vector_query_ops.h b/include/vector_query_ops.h index b161bd3e..3fc033d1 100644 --- a/include/vector_query_ops.h +++ b/include/vector_query_ops.h @@ -32,5 +32,6 @@ class VectorQueryOps { public: static Option parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query, const bool is_wildcard_query, - const Collection* coll); + const Collection* coll, + const bool allow_empty_query = false); }; \ No newline at end of file diff --git a/src/collection.cpp b/src/collection.cpp index adda2fc5..a7d993f9 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -744,7 +744,9 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< std::vector& sort_fields_std, const bool is_wildcard_query, const bool is_vector_query, - const bool is_group_by_query) const { + const std::string& query, const bool is_group_by_query, + const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_num_tries) const { size_t num_sort_expressions = 0; @@ -793,6 +795,62 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_field_std.name = actual_field_name; num_sort_expressions++; + } else if(actual_field_name == sort_field_const::vector_query) { + const std::string& vector_query_str = sort_field_std.name.substr(paran_start + 1, + sort_field_std.name.size() - paran_start - + 2); + if(vector_query_str.empty()) { + return Option(400, "The vector query in sort_by is empty."); + } + + + auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, sort_field_std.vector_query, + is_wildcard_query, this, true); + if(!parse_vector_op.ok()) { + return Option(400, parse_vector_op.error()); + } + + auto vector_field_it = search_schema.find(sort_field_std.vector_query.field_name); + if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) { + return Option(400, "Field `" + sort_field_std.vector_query.field_name + "` does not have a vector query index."); + } + + + if(sort_field_std.vector_query.values.empty() && embedding_fields.find(sort_field_std.vector_query.field_name) != embedding_fields.end()) { + // generate embeddings for the query + TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance(); + auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]); + if(!embedder_op.ok()) { + return Option(embedder_op.code(), embedder_op.error()); + } + + auto embedder = embedder_op.get(); + + if(embedder->is_remote() && remote_embedding_num_tries == 0) { + std::string error = "`remote_embedding_num_tries` must be greater than 0."; + return Option(400, error); + } + + std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + query; + auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries); + + if(!embedding_op.success) { + if(!embedding_op.error["error"].get().empty()) { + return Option(400, embedding_op.error["error"].get()); + } else { + return Option(400, embedding_op.error.dump()); + } + } + + sort_field_std.vector_query.values = embedding_op.embedding; + } + + if(vector_field_it.value().num_dim != sort_field_std.vector_query.values.size()) { + return Option(400, "Query field `" + sort_field_std.vector_query.field_name + "` must have " + + std::to_string(vector_field_it.value().num_dim) + " dimensions."); + } + + sort_field_std.name = actual_field_name; } else { if(field_it == search_schema.end()) { @@ -918,7 +976,8 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && - sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) { + sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance && + sort_field_std.name != sort_field_const::vector_query) { const auto field_it = search_schema.find(sort_field_std.name); if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { @@ -1537,7 +1596,7 @@ Option Collection::search(std::string raw_query, if(curated_sort_by.empty()) { auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, - sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1549,7 +1608,7 @@ Option Collection::search(std::string raw_query, } auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, - sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -5048,4 +5107,4 @@ void Collection::remove_embedding_field(const std::string& field_name) { tsl::htrie_map Collection::get_embedding_fields_unsafe() { return embedding_fields; -} \ No newline at end of file +} diff --git a/src/index.cpp b/src/index.cpp index 8dc5ebb8..c7a9c8a7 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -43,6 +43,7 @@ spp::sparse_hash_map Index::eval_sentinel_value; spp::sparse_hash_map Index::geo_sentinel_value; spp::sparse_hash_map Index::str_sentinel_value; spp::sparse_hash_map Index::vector_distance_sentinel_value; +spp::sparse_hash_map Index::vector_query_sentinel_value; struct token_posting_t { uint32_t token_id; @@ -4395,6 +4396,28 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[0] = int64_t(found); } else if(field_values[0] == &vector_distance_sentinel_value) { scores[0] = float_to_int64_t(vector_distance); + } else if(field_values[0] == &vector_query_sentinel_value) { + scores[0] = float_to_int64_t(2.0f); + try { + auto& field_vector_index = vector_index.at(sort_fields[0].vector_query.field_name); + const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = field_vector_index->space->get_dist_func(); + float dist = 2.0f; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_values(sort_fields[0].vector_query.values.size()); + hnsw_index_t::normalize_vector(sort_fields[0].vector_query.values, normalized_values); + dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); + + } else { + dist = dist_func(sort_fields[0].vector_query.values.data(), values.data(), &field_vector_index->num_dim); + } + + scores[0] = float_to_int64_t(dist); + } catch(...) { + // probably not found + // do nothing + } + } else { auto it = field_values[0]->find(seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; @@ -4453,6 +4476,28 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[1] = int64_t(found); } else if(field_values[1] == &vector_distance_sentinel_value) { scores[1] = float_to_int64_t(vector_distance); + } else if(field_values[1] == &vector_query_sentinel_value) { + scores[1] = float_to_int64_t(2.0f); + try { + auto& field_vector_index = vector_index.at(sort_fields[1].vector_query.field_name); + const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = field_vector_index->space->get_dist_func(); + float dist = 2.0f; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_values(sort_fields[1].vector_query.values.size()); + hnsw_index_t::normalize_vector(sort_fields[1].vector_query.values, normalized_values); + dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); + + } else { + dist = dist_func(sort_fields[1].vector_query.values.data(), values.data(), &field_vector_index->num_dim); + } + + scores[1] = float_to_int64_t(dist); + } catch(...) { + // probably not found + // do nothing + } + } else { auto it = field_values[1]->find(seq_id); scores[1] = (it == field_values[1]->end()) ? default_score : it->second; @@ -4507,6 +4552,28 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[2] = int64_t(found); } else if(field_values[2] == &vector_distance_sentinel_value) { scores[2] = float_to_int64_t(vector_distance); + } else if(field_values[2] == &vector_query_sentinel_value) { + scores[2] = float_to_int64_t(2.0f); + try { + auto& field_vector_index = vector_index.at(sort_fields[2].vector_query.field_name); + const auto& values = field_vector_index->vecdex->getDataByLabel(seq_id); + const auto& dist_func = field_vector_index->space->get_dist_func(); + float dist = 2.0f; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_values(sort_fields[2].vector_query.values.size()); + hnsw_index_t::normalize_vector(sort_fields[2].vector_query.values, normalized_values); + dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim); + + } else { + dist = dist_func(sort_fields[2].vector_query.values.data(), values.data(), &field_vector_index->num_dim); + } + + scores[2] = float_to_int64_t(dist); + } catch(...) { + // probably not found + // do nothing + } + } else { auto it = field_values[2]->find(seq_id); scores[2] = (it == field_values[2]->end()) ? default_score : it->second; @@ -5238,6 +5305,8 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint result.docs = nullptr; } else if(sort_fields_std[i].name == sort_field_const::vector_distance) { field_values[i] = &vector_distance_sentinel_value; + } else if(sort_fields_std[i].name == sort_field_const::vector_query) { + field_values[i] = &vector_query_sentinel_value; } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/src/vector_query_ops.cpp b/src/vector_query_ops.cpp index dba9d27d..74189948 100644 --- a/src/vector_query_ops.cpp +++ b/src/vector_query_ops.cpp @@ -5,7 +5,8 @@ Option VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query, const bool is_wildcard_query, - const Collection* coll) { + const Collection* coll, + const bool allow_empty_query) { // FORMAT: // field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10) size_t i = 0; @@ -72,7 +73,7 @@ Option VectorQueryOps::parse_vector_query_str(const std::string& vector_qu if(i == vector_query_str.size()-1) { // missing params - if(vector_query.values.empty()) { + if(vector_query.values.empty() && !allow_empty_query) { // when query values are missing, atleast the `id` parameter must be present return Option(400, "When a vector query value is empty, an `id` parameter must be present."); } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 95c255e9..8d51a079 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -2387,4 +2387,84 @@ TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) { ASSERT_FALSE(results.ok()); ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error()); +} + + +TEST_F(CollectionSortingTest, TestSortByVectorQuery) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "name", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + auto create_coll = collectionManager.create_collection(schema); + ASSERT_TRUE(create_coll.ok()); + + auto coll = create_coll.get(); + + std::vector> points = { + {7.0, 8.0}, + {8.0, 15.0}, + {5.0, 12.0}, + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["name"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll->add(doc.dump()).ok()); + } + + std::vector sort_fields = {}; + + auto results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "").get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"]); + ASSERT_EQ("1", results["hits"][1]["document"]["id"]); + ASSERT_EQ("0", results["hits"][2]["document"]["id"]); + + sort_fields = { + sort_by("_vector_query(points:([5.0, 5.0]))", "asc"), + }; + + results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "").get(); + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"]); + ASSERT_EQ("1", results["hits"][1]["document"]["id"]); + ASSERT_EQ("2", results["hits"][2]["document"]["id"]); + + sort_fields = { + sort_by("_vector_query(points:([5.0, 5.0]))", "desc"), + }; + + results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "").get(); + + ASSERT_EQ(3, results["hits"].size()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"]); + ASSERT_EQ("1", results["hits"][1]["document"]["id"]); + ASSERT_EQ("0", results["hits"][2]["document"]["id"]); } \ No newline at end of file