mirror of
https://github.com/typesense/typesense.git
synced 2025-05-22 23:06:30 +08:00
Sort by vector query
This commit is contained in:
parent
998b071956
commit
4167fe69c8
@ -207,8 +207,10 @@ private:
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
bool is_wildcard_query,const bool is_vector_query,
|
||||
bool is_group_by_query = false) const;
|
||||
bool is_wildcard_query, const bool is_vector_query,
|
||||
const std::string& query, bool is_group_by_query = false,
|
||||
const size_t remote_embedding_timeout_ms = 30000,
|
||||
const size_t remote_embedding_num_tries = 2) const;
|
||||
|
||||
|
||||
Option<bool> persist_collection_meta();
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <tsl/htrie_map.h>
|
||||
#include "json.hpp"
|
||||
#include "text_embedder_manager.h"
|
||||
#include "vector_query_ops.h"
|
||||
|
||||
namespace field_types {
|
||||
// first field value indexed will determine the type
|
||||
@ -661,6 +662,7 @@ namespace sort_field_const {
|
||||
static const std::string missing_values = "missing_values";
|
||||
|
||||
static const std::string vector_distance = "_vector_distance";
|
||||
static const std::string vector_query = "_vector_query";
|
||||
}
|
||||
|
||||
struct sort_by {
|
||||
@ -690,10 +692,11 @@ struct sort_by {
|
||||
missing_values_t missing_values;
|
||||
eval_t eval;
|
||||
|
||||
vector_query_t vector_query;
|
||||
|
||||
sort_by(const std::string & name, const std::string & order):
|
||||
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),
|
||||
missing_values(normal) {
|
||||
|
||||
}
|
||||
|
||||
sort_by(const std::string &name, const std::string &order, uint32_t text_match_buckets, int64_t geopoint,
|
||||
@ -701,7 +704,6 @@ struct sort_by {
|
||||
name(name), order(order), text_match_buckets(text_match_buckets),
|
||||
geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision),
|
||||
missing_values(normal) {
|
||||
|
||||
}
|
||||
|
||||
sort_by& operator=(const sort_by& other) {
|
||||
|
@ -351,6 +351,7 @@ private:
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> vector_distance_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> vector_query_sentinel_value;
|
||||
|
||||
// Internal utility functions
|
||||
|
||||
|
@ -32,5 +32,6 @@ class VectorQueryOps {
|
||||
public:
|
||||
static Option<bool> parse_vector_query_str(const std::string& vector_query_str, vector_query_t& vector_query,
|
||||
const bool is_wildcard_query,
|
||||
const Collection* coll);
|
||||
const Collection* coll,
|
||||
const bool allow_empty_query = false);
|
||||
};
|
@ -744,7 +744,9 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query,
|
||||
const bool is_vector_query,
|
||||
const bool is_group_by_query) const {
|
||||
const std::string& query, const bool is_group_by_query,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_tries) const {
|
||||
|
||||
size_t num_sort_expressions = 0;
|
||||
|
||||
@ -793,6 +795,62 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
|
||||
sort_field_std.name = actual_field_name;
|
||||
num_sort_expressions++;
|
||||
} else if(actual_field_name == sort_field_const::vector_query) {
|
||||
const std::string& vector_query_str = sort_field_std.name.substr(paran_start + 1,
|
||||
sort_field_std.name.size() - paran_start -
|
||||
2);
|
||||
if(vector_query_str.empty()) {
|
||||
return Option<bool>(400, "The vector query in sort_by is empty.");
|
||||
}
|
||||
|
||||
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, sort_field_std.vector_query,
|
||||
is_wildcard_query, this, true);
|
||||
if(!parse_vector_op.ok()) {
|
||||
return Option<bool>(400, parse_vector_op.error());
|
||||
}
|
||||
|
||||
auto vector_field_it = search_schema.find(sort_field_std.vector_query.field_name);
|
||||
if(vector_field_it == search_schema.end() || vector_field_it.value().num_dim == 0) {
|
||||
return Option<bool>(400, "Field `" + sort_field_std.vector_query.field_name + "` does not have a vector query index.");
|
||||
}
|
||||
|
||||
|
||||
if(sort_field_std.vector_query.values.empty() && embedding_fields.find(sort_field_std.vector_query.field_name) != embedding_fields.end()) {
|
||||
// generate embeddings for the query
|
||||
TextEmbedderManager& embedder_manager = TextEmbedderManager::get_instance();
|
||||
auto embedder_op = embedder_manager.get_text_embedder(vector_field_it.value().embed[fields::model_config]);
|
||||
if(!embedder_op.ok()) {
|
||||
return Option<bool>(embedder_op.code(), embedder_op.error());
|
||||
}
|
||||
|
||||
auto embedder = embedder_op.get();
|
||||
|
||||
if(embedder->is_remote() && remote_embedding_num_tries == 0) {
|
||||
std::string error = "`remote_embedding_num_tries` must be greater than 0.";
|
||||
return Option<bool>(400, error);
|
||||
}
|
||||
|
||||
std::string embed_query = embedder_manager.get_query_prefix(vector_field_it.value().embed[fields::model_config]) + query;
|
||||
auto embedding_op = embedder->Embed(embed_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
|
||||
|
||||
if(!embedding_op.success) {
|
||||
if(!embedding_op.error["error"].get<std::string>().empty()) {
|
||||
return Option<bool>(400, embedding_op.error["error"].get<std::string>());
|
||||
} else {
|
||||
return Option<bool>(400, embedding_op.error.dump());
|
||||
}
|
||||
}
|
||||
|
||||
sort_field_std.vector_query.values = embedding_op.embedding;
|
||||
}
|
||||
|
||||
if(vector_field_it.value().num_dim != sort_field_std.vector_query.values.size()) {
|
||||
return Option<bool>(400, "Query field `" + sort_field_std.vector_query.field_name + "` must have " +
|
||||
std::to_string(vector_field_it.value().num_dim) + " dimensions.");
|
||||
}
|
||||
|
||||
sort_field_std.name = actual_field_name;
|
||||
|
||||
} else {
|
||||
if(field_it == search_schema.end()) {
|
||||
@ -918,7 +976,8 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) {
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance &&
|
||||
sort_field_std.name != sort_field_const::vector_query) {
|
||||
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
@ -1537,7 +1596,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
|
||||
if(curated_sort_by.empty()) {
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1549,7 +1608,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query);
|
||||
sort_fields_std, is_wildcard_query, is_vector_query, raw_query, is_group_by_query, remote_embedding_timeout_ms, remote_embedding_num_tries);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -5048,4 +5107,4 @@ void Collection::remove_embedding_field(const std::string& field_name) {
|
||||
|
||||
tsl::htrie_map<char, field> Collection::get_embedding_fields_unsafe() {
|
||||
return embedding_fields;
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_distance_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::vector_query_sentinel_value;
|
||||
|
||||
struct token_posting_t {
|
||||
uint32_t token_id;
|
||||
@ -4395,6 +4396,28 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
scores[0] = int64_t(found);
|
||||
} else if(field_values[0] == &vector_distance_sentinel_value) {
|
||||
scores[0] = float_to_int64_t(vector_distance);
|
||||
} else if(field_values[0] == &vector_query_sentinel_value) {
|
||||
scores[0] = float_to_int64_t(2.0f);
|
||||
try {
|
||||
auto& field_vector_index = vector_index.at(sort_fields[0].vector_query.field_name);
|
||||
const auto& values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
|
||||
const auto& dist_func = field_vector_index->space->get_dist_func();
|
||||
float dist = 2.0f;
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_values(sort_fields[0].vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(sort_fields[0].vector_query.values, normalized_values);
|
||||
dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim);
|
||||
|
||||
} else {
|
||||
dist = dist_func(sort_fields[0].vector_query.values.data(), values.data(), &field_vector_index->num_dim);
|
||||
}
|
||||
|
||||
scores[0] = float_to_int64_t(dist);
|
||||
} catch(...) {
|
||||
// probably not found
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} else {
|
||||
auto it = field_values[0]->find(seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
@ -4453,6 +4476,28 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
scores[1] = int64_t(found);
|
||||
} else if(field_values[1] == &vector_distance_sentinel_value) {
|
||||
scores[1] = float_to_int64_t(vector_distance);
|
||||
} else if(field_values[1] == &vector_query_sentinel_value) {
|
||||
scores[1] = float_to_int64_t(2.0f);
|
||||
try {
|
||||
auto& field_vector_index = vector_index.at(sort_fields[1].vector_query.field_name);
|
||||
const auto& values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
|
||||
const auto& dist_func = field_vector_index->space->get_dist_func();
|
||||
float dist = 2.0f;
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_values(sort_fields[1].vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(sort_fields[1].vector_query.values, normalized_values);
|
||||
dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim);
|
||||
|
||||
} else {
|
||||
dist = dist_func(sort_fields[1].vector_query.values.data(), values.data(), &field_vector_index->num_dim);
|
||||
}
|
||||
|
||||
scores[1] = float_to_int64_t(dist);
|
||||
} catch(...) {
|
||||
// probably not found
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} else {
|
||||
auto it = field_values[1]->find(seq_id);
|
||||
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
|
||||
@ -4507,6 +4552,28 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
scores[2] = int64_t(found);
|
||||
} else if(field_values[2] == &vector_distance_sentinel_value) {
|
||||
scores[2] = float_to_int64_t(vector_distance);
|
||||
} else if(field_values[2] == &vector_query_sentinel_value) {
|
||||
scores[2] = float_to_int64_t(2.0f);
|
||||
try {
|
||||
auto& field_vector_index = vector_index.at(sort_fields[2].vector_query.field_name);
|
||||
const auto& values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
|
||||
const auto& dist_func = field_vector_index->space->get_dist_func();
|
||||
float dist = 2.0f;
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_values(sort_fields[2].vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(sort_fields[2].vector_query.values, normalized_values);
|
||||
dist = dist_func(normalized_values.data(), values.data(), &field_vector_index->num_dim);
|
||||
|
||||
} else {
|
||||
dist = dist_func(sort_fields[2].vector_query.values.data(), values.data(), &field_vector_index->num_dim);
|
||||
}
|
||||
|
||||
scores[2] = float_to_int64_t(dist);
|
||||
} catch(...) {
|
||||
// probably not found
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} else {
|
||||
auto it = field_values[2]->find(seq_id);
|
||||
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
|
||||
@ -5238,6 +5305,8 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
result.docs = nullptr;
|
||||
} else if(sort_fields_std[i].name == sort_field_const::vector_distance) {
|
||||
field_values[i] = &vector_distance_sentinel_value;
|
||||
} else if(sort_fields_std[i].name == sort_field_const::vector_query) {
|
||||
field_values[i] = &vector_query_sentinel_value;
|
||||
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
|
||||
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
|
||||
geopoint_indices.push_back(i);
|
||||
|
@ -5,7 +5,8 @@
|
||||
Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str,
|
||||
vector_query_t& vector_query,
|
||||
const bool is_wildcard_query,
|
||||
const Collection* coll) {
|
||||
const Collection* coll,
|
||||
const bool allow_empty_query) {
|
||||
// FORMAT:
|
||||
// field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)
|
||||
size_t i = 0;
|
||||
@ -72,7 +73,7 @@ Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_qu
|
||||
|
||||
if(i == vector_query_str.size()-1) {
|
||||
// missing params
|
||||
if(vector_query.values.empty()) {
|
||||
if(vector_query.values.empty() && !allow_empty_query) {
|
||||
// when query values are missing, atleast the `id` parameter must be present
|
||||
return Option<bool>(400, "When a vector query value is empty, an `id` parameter must be present.");
|
||||
}
|
||||
|
@ -2387,4 +2387,84 @@ TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) {
|
||||
ASSERT_FALSE(results.ok());
|
||||
|
||||
ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(CollectionSortingTest, TestSortByVectorQuery) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string" },
|
||||
{"name": "points", "type": "float[]", "num_dim": 2}
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
auto create_coll = collectionManager.create_collection(schema);
|
||||
ASSERT_TRUE(create_coll.ok());
|
||||
|
||||
auto coll = create_coll.get();
|
||||
|
||||
std::vector<std::vector<float>> points = {
|
||||
{7.0, 8.0},
|
||||
{8.0, 15.0},
|
||||
{5.0, 12.0},
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < points.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
doc["name"] = "Title " + std::to_string(i);
|
||||
doc["points"] = points[i];
|
||||
ASSERT_TRUE(coll->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields = {};
|
||||
|
||||
auto results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "").get();
|
||||
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_EQ("2", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
|
||||
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_vector_query(points:([5.0, 5.0]))", "asc"),
|
||||
};
|
||||
|
||||
results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "").get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
|
||||
ASSERT_EQ("2", results["hits"][2]["document"]["id"]);
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_vector_query(points:([5.0, 5.0]))", "desc"),
|
||||
};
|
||||
|
||||
results = coll->search("title", {"name"}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "").get();
|
||||
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
ASSERT_EQ("2", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"]);
|
||||
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user