Sort by vector query

This commit is contained in:
ozanarmagan 2023-10-14 02:07:51 +03:00
parent 998b071956
commit 4167fe69c8
8 changed files with 227 additions and 12 deletions

View File

@ -207,8 +207,10 @@ private:
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
bool is_wildcard_query,const bool is_vector_query,
bool is_group_by_query = false) const;
bool is_wildcard_query, const bool is_vector_query,
const std::string& query, bool is_group_by_query = false,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_tries = 2) const;
Option<bool> persist_collection_meta();

View File

@ -11,6 +11,7 @@
#include <tsl/htrie_map.h>
#include "json.hpp"
#include "text_embedder_manager.h"
#include "vector_query_ops.h"
namespace field_types {
// first field value indexed will determine the type
@ -661,6 +662,7 @@ namespace sort_field_const {
static const std::string missing_values = "missing_values";
static const std::string vector_distance = "_vector_distance";
static const std::string vector_query = "_vector_query";
}
struct sort_by {
@ -690,10 +692,11 @@ struct sort_by {
missing_values_t missing_values;
eval_t eval;
vector_query_t vector_query;
sort_by(const std::string & name, const std::string & order):
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),
missing_values(normal) {
}
sort_by(const std::string &name, const std::string &order, uint32_t text_match_buckets, int64_t geopoint,
@ -701,7 +704,6 @@ struct sort_by {
name(name), order(order), text_match_buckets(text_match_buckets),
geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision),
missing_values(normal) {
}
sort_by& operator=(const sort_by& other) {

View File

@ -351,6 +351,7 @@ private:
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> vector_distance_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> vector_query_sentinel_value;
// Internal utility functions

View File

@ -32,5 +32,6 @@ class VectorQueryOps {
public:
static Option<bool> parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query,
const bool is_wildcard_query,
const Collection* coll);
const Collection* coll,
const bool allow_empty_query = false);
};

View File

@ -744,7 +744,9 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query,
const bool is_vector_query,
const bool is_group_by_query) const {
const std::string& query, const bool is_group_by_query,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_tries) const {
size_t num_sort_expressions = 0;
@ -793,6 +795,62 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_field_std.name = actual_field_name;
num_sort_expressions++;
} else if(actual_field_name == sort_field_const::vector_query) {
const std::string& vector_query_str = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() - paran_start -
2);
if(vector_query_str.empty()) {
return Option<bool>(400, "The vector query in sort_by is empty.");
}
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, sort_field_std.vector_query,
is_wildcard_query, this, true);
if(!parse_vector_op.ok()) {
return Option<bool>(400, parse_vector_op.error());
}
auto vector_field_it = search_schema.find(sort_field_std.vector_query.field_name);
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
return Option<bool>(400, "Field `" + sort_field_std.vector_query.field_name + "` does not have a vector query index.");
}
if(sort_field_std.vector_query.values.empty() && embedding_fields.find(sort_field_std.vector_query.field_name) != embedding_fields.end()) {
// generate embeddings for the query
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
if(!embedder_op.ok()) {
return Option<bool>(embedder_op.code(), embedder_op.error());
}
auto embedder = embedder_op.get();
if(embedder->is_remote() && remote_embedding_num_tries == 0) {
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
return Option<bool>(400, error);
}
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + query;
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!embedding_op.success) {
if(!embedding_op.error["error"].get<std::string>().empty()) {
return Option<bool>(400, embedding_op.error["error"].get<std::string>());
} else {
return Option<bool>(400, embedding_op.error.dump());
}
}
sort_field_std.vector_query.values = embedding_op.embedding;
}
if(vector_field_it.value().num_dim != sort_field_std.vector_query.values.size()) {
return Option<bool>(400, "Query field `" + sort_field_std.vector_query.field_name + "` must have " +
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
}
sort_field_std.name = actual_field_name;
} else {
if(field_it == search_schema.end()) {
@ -918,7 +976,8 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
}
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance &&
sort_field_std.name != sort_field_const::vector_query) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
@ -1537,7 +1596,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
if(curated_sort_by.empty()) {
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
@ -1549,7 +1608,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
@ -5048,4 +5107,4 @@ void Collection::remove_embedding_field(const std::string& field_name) {
tsl::htrie_map<char, field> Collection::get_embedding_fields_unsafe() {
return embedding_fields;
}
}

View File

@ -43,6 +43,7 @@ spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_distance_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_query_sentinel_value;
struct token_posting_t {
uint32_t token_id;
@ -4395,6 +4396,28 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[0] = int64_t(found);
} else if(field_values[0] == &vector_distance_sentinel_value) {
scores[0] = float_to_int64_t(vector_distance);
} else if(field_values[0] == &vector_query_sentinel_value) {
scores[0] = float_to_int64_t(2.0f);
try {
auto& field_vector_index = vector_index.at(sort_fields[0].vector_query.field_name);
const auto& values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = field_vector_index->space->get_dist_func();
float dist = 2.0f;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_values(sort_fields[0].vector_query.values.size());
hnsw_index_t::normalize_vector(sort_fields[0].vector_query.values, normalized_values);
dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim);
} else {
dist = dist_func(sort_fields[0].vector_query.values.data(), values.data(), &field_vector_index->num_dim);
}
scores[0] = float_to_int64_t(dist);
} catch(...) {
// probably not found
// do nothing
}
} else {
auto it = field_values[0]->find(seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
@ -4453,6 +4476,28 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[1] = int64_t(found);
} else if(field_values[1] == &vector_distance_sentinel_value) {
scores[1] = float_to_int64_t(vector_distance);
} else if(field_values[1] == &vector_query_sentinel_value) {
scores[1] = float_to_int64_t(2.0f);
try {
auto& field_vector_index = vector_index.at(sort_fields[1].vector_query.field_name);
const auto& values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = field_vector_index->space->get_dist_func();
float dist = 2.0f;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_values(sort_fields[1].vector_query.values.size());
hnsw_index_t::normalize_vector(sort_fields[1].vector_query.values, normalized_values);
dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim);
} else {
dist = dist_func(sort_fields[1].vector_query.values.data(), values.data(), &field_vector_index->num_dim);
}
scores[1] = float_to_int64_t(dist);
} catch(...) {
// probably not found
// do nothing
}
} else {
auto it = field_values[1]->find(seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
@ -4507,6 +4552,28 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[2] = int64_t(found);
} else if(field_values[2] == &vector_distance_sentinel_value) {
scores[2] = float_to_int64_t(vector_distance);
} else if(field_values[2] == &vector_query_sentinel_value) {
scores[2] = float_to_int64_t(2.0f);
try {
auto& field_vector_index = vector_index.at(sort_fields[2].vector_query.field_name);
const auto& values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = field_vector_index->space->get_dist_func();
float dist = 2.0f;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_values(sort_fields[2].vector_query.values.size());
hnsw_index_t::normalize_vector(sort_fields[2].vector_query.values, normalized_values);
dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim);
} else {
dist = dist_func(sort_fields[2].vector_query.values.data(), values.data(), &field_vector_index->num_dim);
}
scores[2] = float_to_int64_t(dist);
} catch(...) {
// probably not found
// do nothing
}
} else {
auto it = field_values[2]->find(seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
@ -5238,6 +5305,8 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
result.docs = nullptr;
} else if(sort_fields_std[i].name == sort_field_const::vector_distance) {
field_values[i] = &vector_distance_sentinel_value;
} else if(sort_fields_std[i].name == sort_field_const::vector_query) {
field_values[i] = &vector_query_sentinel_value;
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
geopoint_indices.push_back(i);

View File

@ -5,7 +5,8 @@
Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str,
vector_query_t& vector_query,
const bool is_wildcard_query,
const Collection* coll) {
const Collection* coll,
const bool allow_empty_query) {
// FORMAT:
// field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)
size_t i = 0;
@ -72,7 +73,7 @@ Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_qu
if(i == vector_query_str.size()-1) {
// missing params
if(vector_query.values.empty()) {
if(vector_query.values.empty() && !allow_empty_query) {
// when query values are missing, atleast the `id` parameter must be present
return Option<bool>(400, "When a vector query value is empty, an `id` parameter must be present.");
}

View File

@ -2387,4 +2387,84 @@ TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) {
ASSERT_FALSE(results.ok());
ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error());
}
TEST_F(CollectionSortingTest, TestSortByVectorQuery) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "name", "type": "string" },
{"name": "points", "type": "float[]", "num_dim": 2}
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
auto create_coll = collectionManager.create_collection(schema);
ASSERT_TRUE(create_coll.ok());
auto coll = create_coll.get();
std::vector<std::vector<float>> points = {
{7.0, 8.0},
{8.0, 15.0},
{5.0, 12.0},
};
for(size_t i = 0; i < points.size(); i++) {
nlohmann::json doc;
doc["name"] = "Title " + std::to_string(i);
doc["points"] = points[i];
ASSERT_TRUE(coll->add(doc.dump()).ok());
}
std::vector<sort_by> sort_fields = {};
auto results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "").get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("2", results["hits"][0]["document"]["id"]);
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
sort_fields = {
sort_by("_vector_query(points:([5.0, 5.0]))", "asc"),
};
results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "").get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"]);
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
ASSERT_EQ("2", results["hits"][2]["document"]["id"]);
sort_fields = {
sort_by("_vector_query(points:([5.0, 5.0]))", "desc"),
};
results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "").get();
ASSERT_EQ(3, results["hits"].size());
ASSERT_EQ("2", results["hits"][0]["document"]["id"]);
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
}