mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 12:42:50 +08:00
Allow vector query to pass a document ID.
This commit is contained in:
parent
a9b926e24b
commit
bf0f7430a0
@ -181,8 +181,6 @@ public:
|
||||
|
||||
static bool parse_sort_by_str(std::string sort_by_str, std::vector<sort_by>& sort_fields);
|
||||
|
||||
static bool parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query);
|
||||
|
||||
// symlinks
|
||||
Option<std::string> resolve_symlink(const std::string & symlink_name) const;
|
||||
|
||||
|
@ -609,20 +609,6 @@ struct sort_by {
|
||||
}
|
||||
};
|
||||
|
||||
struct vector_query_t {
|
||||
std::string field_name;
|
||||
size_t k = 0;
|
||||
size_t flat_search_cutoff = 0;
|
||||
std::vector<float> values;
|
||||
|
||||
void _reset() {
|
||||
// used for testing only
|
||||
field_name.clear();
|
||||
k = 0;
|
||||
values.clear();
|
||||
}
|
||||
};
|
||||
|
||||
class GeoPoint {
|
||||
constexpr static const double EARTH_RADIUS = 3958.75;
|
||||
constexpr static const double METER_CONVERT = 1609.00;
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "id_list.h"
|
||||
#include "synonym_index.h"
|
||||
#include "override.h"
|
||||
#include "vector_query_ops.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
|
||||
static constexpr size_t ARRAY_FACET_DIM = 4;
|
||||
|
32
include/vector_query_ops.h
Normal file
32
include/vector_query_ops.h
Normal file
@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "option.h"
|
||||
|
||||
class Collection;
|
||||
|
||||
struct vector_query_t {
|
||||
std::string field_name;
|
||||
size_t k = 0;
|
||||
size_t flat_search_cutoff = 0;
|
||||
std::vector<float> values;
|
||||
|
||||
uint32_t seq_id = 0;
|
||||
bool query_doc_given = false;
|
||||
|
||||
void _reset() {
|
||||
// used for testing only
|
||||
field_name.clear();
|
||||
k = 0;
|
||||
values.clear();
|
||||
seq_id = 0;
|
||||
query_doc_given = false;
|
||||
}
|
||||
};
|
||||
|
||||
class VectorQueryOps {
|
||||
public:
|
||||
static Option<bool> parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query,
|
||||
const Collection* coll);
|
||||
};
|
@ -15,6 +15,7 @@
|
||||
#include "topster.h"
|
||||
#include "logger.h"
|
||||
#include "thread_local_vars.h"
|
||||
#include "vector_query_ops.h"
|
||||
|
||||
const std::string override_t::MATCH_EXACT = "exact";
|
||||
const std::string override_t::MATCH_CONTAINS = "contains";
|
||||
@ -921,8 +922,9 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
return Option<nlohmann::json>(400, "Vector query is supported only on wildcard (q=*) searches.");
|
||||
}
|
||||
|
||||
if(!CollectionManager::parse_vector_query_str(vector_query_str, vector_query)) {
|
||||
return Option<nlohmann::json>(400, "The `vector_query` parameter is malformed.");
|
||||
auto parse_vector_op = VectorQueryOps::parse_vector_query_str(vector_query_str, vector_query, this);
|
||||
if(!parse_vector_op.ok()) {
|
||||
return Option<nlohmann::json>(400, parse_vector_op.error());
|
||||
}
|
||||
|
||||
auto vector_field_it = search_schema.find(vector_query.field_name);
|
||||
|
@ -1411,112 +1411,3 @@ Option<Collection*> CollectionManager::clone_collection(const string& existing_n
|
||||
|
||||
return Option<Collection*>(new_coll);
|
||||
}
|
||||
|
||||
bool CollectionManager::parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query) {
|
||||
// FORMAT:
|
||||
// field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)
|
||||
size_t i = 0;
|
||||
while(i < vector_query_str.size()) {
|
||||
if(vector_query_str[i] != ':') {
|
||||
vector_query.field_name += vector_query_str[i];
|
||||
i++;
|
||||
} else {
|
||||
if(vector_query_str[i] != ':') {
|
||||
// missing ":"
|
||||
return false;
|
||||
}
|
||||
|
||||
// field name is done
|
||||
i++;
|
||||
|
||||
StringUtils::trim(vector_query.field_name);
|
||||
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != '(') {
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != '(') {
|
||||
// missing "("
|
||||
return false;
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != '[') {
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != '[') {
|
||||
// missing opening "["
|
||||
return false;
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
std::string values_str;
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != ']') {
|
||||
values_str += vector_query_str[i];
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != ']') {
|
||||
// missing closing "]"
|
||||
return false;
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
std::vector<std::string> svalues;
|
||||
StringUtils::split(values_str, svalues, ",");
|
||||
|
||||
for(auto& svalue: svalues) {
|
||||
if(!StringUtils::is_float(svalue)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector_query.values.push_back(std::stof(svalue));
|
||||
}
|
||||
|
||||
if(i == vector_query_str.size()-1) {
|
||||
// missing params
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string param_str = vector_query_str.substr(i, (vector_query_str.size() - i));
|
||||
std::vector<std::string> param_kvs;
|
||||
StringUtils::split(param_str, param_kvs, ",");
|
||||
|
||||
for(auto& param_kv_str: param_kvs) {
|
||||
if(param_kv_str.back() == ')') {
|
||||
param_kv_str.pop_back();
|
||||
}
|
||||
|
||||
std::vector<std::string> param_kv;
|
||||
StringUtils::split(param_kv_str, param_kv, ":");
|
||||
if(param_kv.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if(param_kv[0] == "k") {
|
||||
if(!StringUtils::is_uint32_t(param_kv[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector_query.k = std::stoul(param_kv[1]);
|
||||
}
|
||||
|
||||
if(param_kv[0] == "flat_search_cutoff") {
|
||||
if(!StringUtils::is_uint32_t(param_kv[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector_query.flat_search_cutoff = std::stoi(param_kv[1]);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -2561,6 +2561,11 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
|
||||
for (const auto& dist_label : dist_labels) {
|
||||
uint32 seq_id = dist_label.second;
|
||||
|
||||
if(vector_query.query_doc_given && vector_query.seq_id == seq_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint64_t distinct_id = seq_id;
|
||||
if (group_limit != 0) {
|
||||
distinct_id = get_distinct_id(group_by_fields, seq_id);
|
||||
|
159
src/vector_query_ops.cpp
Normal file
159
src/vector_query_ops.cpp
Normal file
@ -0,0 +1,159 @@
|
||||
#include "vector_query_ops.h"
|
||||
#include "string_utils.h"
|
||||
#include "collection.h"
|
||||
|
||||
Option<bool> VectorQueryOps::parse_vector_query_str(std::string vector_query_str, vector_query_t& vector_query,
|
||||
const Collection* coll) {
|
||||
// FORMAT:
|
||||
// field_name([0.34, 0.66, 0.12, 0.68], exact: false, k: 10)
|
||||
size_t i = 0;
|
||||
while(i < vector_query_str.size()) {
|
||||
if(vector_query_str[i] != ':') {
|
||||
vector_query.field_name += vector_query_str[i];
|
||||
i++;
|
||||
} else {
|
||||
if(vector_query_str[i] != ':') {
|
||||
// missing ":"
|
||||
return Option<bool>(400, "Malformed vector query string: `:` is missing.");
|
||||
}
|
||||
|
||||
// field name is done
|
||||
i++;
|
||||
|
||||
StringUtils::trim(vector_query.field_name);
|
||||
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != '(') {
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != '(') {
|
||||
// missing "("
|
||||
return Option<bool>(400, "Malformed vector query string.");
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != '[') {
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != '[') {
|
||||
// missing opening "["
|
||||
return Option<bool>(400, "Malformed vector query string.");
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
std::string values_str;
|
||||
while(i < vector_query_str.size() && vector_query_str[i] != ']') {
|
||||
values_str += vector_query_str[i];
|
||||
i++;
|
||||
}
|
||||
|
||||
if(vector_query_str[i] != ']') {
|
||||
// missing closing "]"
|
||||
return Option<bool>(400, "Malformed vector query string.");
|
||||
}
|
||||
|
||||
i++;
|
||||
|
||||
std::vector<std::string> svalues;
|
||||
StringUtils::split(values_str, svalues, ",");
|
||||
|
||||
for(auto& svalue: svalues) {
|
||||
if(!StringUtils::is_float(svalue)) {
|
||||
return Option<bool>(400, "Malformed vector query string: one of the vector values is not a float.");
|
||||
}
|
||||
|
||||
vector_query.values.push_back(std::stof(svalue));
|
||||
}
|
||||
|
||||
if(i == vector_query_str.size()-1) {
|
||||
// missing params
|
||||
if(vector_query.values.empty()) {
|
||||
// when query values are missing, atleast the `id` parameter must be present
|
||||
return Option<bool>(400, "When a vector query value is empty, an `id` parameter must be present.");
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
std::string param_str = vector_query_str.substr(i, (vector_query_str.size() - i));
|
||||
std::vector<std::string> param_kvs;
|
||||
StringUtils::split(param_str, param_kvs, ",");
|
||||
|
||||
for(auto& param_kv_str: param_kvs) {
|
||||
if(param_kv_str.back() == ')') {
|
||||
param_kv_str.pop_back();
|
||||
}
|
||||
|
||||
std::vector<std::string> param_kv;
|
||||
StringUtils::split(param_kv_str, param_kv, ":");
|
||||
if(param_kv.size() != 2) {
|
||||
return Option<bool>(400, "Malformed vector query string.");
|
||||
}
|
||||
|
||||
if(param_kv[0] == "id") {
|
||||
if(!vector_query.values.empty()) {
|
||||
// cannot pass both vector values and id
|
||||
return Option<bool>(400, "Malformed vector query string: cannot pass both vector query "
|
||||
"and `id` parameter.");
|
||||
}
|
||||
|
||||
Option<uint32_t> id_op = coll->doc_id_to_seq_id(param_kv[1]);
|
||||
if(!id_op.ok()) {
|
||||
return Option<bool>(400, "Document id referenced in vector query is not found.");
|
||||
}
|
||||
|
||||
nlohmann::json document;
|
||||
auto doc_op = coll->get_document_from_store(id_op.get(), document);
|
||||
if(!doc_op.ok()) {
|
||||
return Option<bool>(400, "Document id referenced in vector query is not found.");
|
||||
}
|
||||
|
||||
if(!document.contains(vector_query.field_name) || !document[vector_query.field_name].is_array()) {
|
||||
return Option<bool>(400, "Document referenced in vector query does not contain a valid "
|
||||
"vector field.");
|
||||
}
|
||||
|
||||
for(auto& fvalue: document[vector_query.field_name]) {
|
||||
if(!fvalue.is_number_float()) {
|
||||
return Option<bool>(400, "Document referenced in vector query does not contain a valid "
|
||||
"vector field.");
|
||||
}
|
||||
|
||||
vector_query.values.push_back(fvalue.get<float>());
|
||||
}
|
||||
|
||||
vector_query.query_doc_given = true;
|
||||
vector_query.seq_id = id_op.get();
|
||||
}
|
||||
|
||||
if(param_kv[0] == "k") {
|
||||
if(!StringUtils::is_uint32_t(param_kv[1])) {
|
||||
return Option<bool>(400, "Malformed vector query string: `k` parameter must be an integer.");
|
||||
}
|
||||
|
||||
vector_query.k = std::stoul(param_kv[1]);
|
||||
}
|
||||
|
||||
if(param_kv[0] == "flat_search_cutoff") {
|
||||
if(!StringUtils::is_uint32_t(param_kv[1])) {
|
||||
return Option<bool>(400, "Malformed vector query string: "
|
||||
"`flat_search_cutoff` parameter must be an integer.");
|
||||
}
|
||||
|
||||
vector_query.flat_search_cutoff = std::stoi(param_kv[1]);
|
||||
}
|
||||
}
|
||||
|
||||
if(!vector_query.query_doc_given && vector_query.values.empty()) {
|
||||
return Option<bool>(400, "When a vector query value is empty, an `id` parameter must be present.");
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(400, "Malformed vector query string.");
|
||||
}
|
@ -989,43 +989,6 @@ TEST_F(CollectionManagerTest, ParseSortByClause) {
|
||||
ASSERT_FALSE(sort_by_parsed);
|
||||
}
|
||||
|
||||
TEST_F(CollectionManagerTest, ParseVectorQueryString) {
|
||||
vector_query_t vector_query;
|
||||
bool parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query);
|
||||
ASSERT_TRUE(parsed);
|
||||
ASSERT_EQ("vec", vector_query.field_name);
|
||||
ASSERT_EQ(10, vector_query.k);
|
||||
std::vector<float> fvs = {0.34, 0.66, 0.12, 0.68};
|
||||
ASSERT_EQ(fvs.size(), vector_query.values.size());
|
||||
for(size_t i = 0; i < fvs.size(); i++) {
|
||||
ASSERT_EQ(fvs[i], vector_query.values[i]);
|
||||
}
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query);
|
||||
ASSERT_TRUE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query);
|
||||
ASSERT_TRUE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
|
||||
vector_query._reset();
|
||||
parsed = CollectionManager::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query);
|
||||
ASSERT_FALSE(parsed);
|
||||
}
|
||||
|
||||
TEST_F(CollectionManagerTest, Presets) {
|
||||
// try getting on a blank slate
|
||||
auto presets = collectionManager.get_presets();
|
||||
|
@ -144,6 +144,33 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) {
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Field `zec` does not have a vector query index.", res_op.error());
|
||||
|
||||
// pass `id` of existing doc instead of vector, query doc should be omitted from results
|
||||
results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([], id: 1)").get();
|
||||
|
||||
ASSERT_EQ(2, results["found"].get<size_t>());
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// when `id` does not exist, return appropriate error
|
||||
res_op = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([], id: 100)");
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Document id referenced in vector query is not found.", res_op.error());
|
||||
|
||||
// only supported with wildcard queries
|
||||
res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
|
73
test/vector_query_ops_test.cpp
Normal file
73
test/vector_query_ops_test.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "vector_query_ops.h"
|
||||
|
||||
class VectorQueryOpsTest : public ::testing::Test {
|
||||
protected:
|
||||
void setupCollection() {
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(VectorQueryOpsTest, ParseVectorQueryString) {
|
||||
vector_query_t vector_query;
|
||||
auto parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
ASSERT_EQ("vec", vector_query.field_name);
|
||||
ASSERT_EQ(10, vector_query.k);
|
||||
std::vector<float> fvs = {0.34, 0.66, 0.12, 0.68};
|
||||
ASSERT_EQ(fvs.size(), vector_query.values.size());
|
||||
for (size_t i = 0; i < fvs.size(); i++) {
|
||||
ASSERT_EQ(fvs[i], vector_query.values[i]);
|
||||
}
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([])", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error());
|
||||
|
||||
// cannot pass both vector and id
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], id: 10)", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string: cannot pass both vector query and `id` parameter.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([], k: 10)", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("When a vector query value is empty, an `id` parameter must be present.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:[0.34, 0.66, 0.12, 0.68], k: 10)", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], k: 10", vector_query, nullptr);
|
||||
ASSERT_TRUE(parsed.ok());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:(0.34, 0.66, 0.12, 0.68, k: 10)", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec:([0.34, 0.66, 0.12, 0.68], )", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
|
||||
vector_query._reset();
|
||||
parsed = VectorQueryOps::parse_vector_query_str("vec([0.34, 0.66, 0.12, 0.68])", vector_query, nullptr);
|
||||
ASSERT_FALSE(parsed.ok());
|
||||
ASSERT_EQ("Malformed vector query string.", parsed.error());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user