mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 14:12:27 +08:00
Allow vector_query_hits to be passed via vector_query.k
This commit is contained in:
parent
c2db7436a2
commit
4ec2e960d0
@ -464,7 +464,6 @@ public:
|
||||
const size_t facet_sample_percent = 100,
|
||||
const size_t facet_sample_threshold = 0,
|
||||
const size_t page_offset = 0,
|
||||
const size_t vector_query_hits = 250,
|
||||
const size_t remote_embedding_timeout_ms = 30000,
|
||||
const size_t remote_embedding_num_try = 2) const;
|
||||
|
||||
|
@ -1108,8 +1108,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
const size_t facet_sample_percent,
|
||||
const size_t facet_sample_threshold,
|
||||
const size_t page_offset,
|
||||
const size_t vector_query_hits,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_timeout_ms,
|
||||
const size_t remote_embedding_num_try) const {
|
||||
|
||||
std::shared_lock lock(mutex);
|
||||
@ -1258,13 +1257,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
}
|
||||
}
|
||||
std::vector<float> embedding = embedding_op.embedding;
|
||||
// distance could have been set for an embed field, so we take a backup and restore
|
||||
auto dist = vector_query.distance_threshold;
|
||||
vector_query._reset();
|
||||
// params could have been set for an embed field, so we take a backup and restore
|
||||
vector_query.values = embedding;
|
||||
vector_query.field_name = field_name;
|
||||
vector_query.k = vector_query_hits;
|
||||
vector_query.distance_threshold = dist;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1280,11 +1275,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
std::string real_raw_query = raw_query;
|
||||
if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) {
|
||||
raw_query = "*";
|
||||
}
|
||||
|
||||
if(!query_by_weights.empty() && processed_search_fields.size() != query_by_weights.size()) {
|
||||
std::string error = "Error, query_by_weights.size != query_by.size.";
|
||||
return Option<nlohmann::json>(400, error);
|
||||
@ -2165,7 +2155,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
|
||||
result["request_params"] = nlohmann::json::object();
|
||||
result["request_params"]["collection_name"] = name;
|
||||
result["request_params"]["per_page"] = per_page;
|
||||
result["request_params"]["q"] = real_raw_query;
|
||||
result["request_params"]["q"] = raw_query;
|
||||
|
||||
//long long int timeMillis = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - begin).count();
|
||||
//!LOG(INFO) << "Time taken for result calc: " << timeMillis << "us";
|
||||
|
@ -680,7 +680,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
const char *MAX_FACET_VALUES = "max_facet_values";
|
||||
|
||||
const char *VECTOR_QUERY = "vector_query";
|
||||
const char *VECTOR_QUERY_HITS = "vector_query_hits";
|
||||
|
||||
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
|
||||
const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try";
|
||||
@ -836,7 +835,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
size_t max_extra_suffix = INT16_MAX;
|
||||
bool enable_highlight_v1 = true;
|
||||
text_match_type_t match_type = max_score;
|
||||
size_t vector_query_hits = 250;
|
||||
|
||||
size_t remote_embedding_timeout_ms = 5000;
|
||||
size_t remote_embedding_num_try = 2;
|
||||
@ -866,7 +864,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
{FILTER_CURATED_HITS, &filter_curated_hits_option},
|
||||
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
|
||||
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
|
||||
{VECTOR_QUERY_HITS, &vector_query_hits},
|
||||
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
|
||||
{REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try},
|
||||
};
|
||||
@ -1082,7 +1079,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
|
||||
facet_sample_percent,
|
||||
facet_sample_threshold,
|
||||
offset,
|
||||
vector_query_hits,
|
||||
remote_embedding_timeout_ms,
|
||||
remote_embedding_num_try
|
||||
);
|
||||
|
@ -2857,7 +2857,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
|
||||
|
||||
if (!vector_query.field_name.empty()) {
|
||||
auto k = std::max<size_t>(vector_query.k, fetch_size);
|
||||
auto k = vector_query.k == 0 ? std::max<size_t>(vector_query.k, fetch_size) : vector_query.k;
|
||||
|
||||
if(vector_query.query_doc_given) {
|
||||
// since we will omit the query doc from results
|
||||
k++;
|
||||
@ -3147,7 +3148,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
// use k as 100 by default for ensuring results stability in pagination
|
||||
size_t default_k = 100;
|
||||
auto k = std::max<size_t>(vector_query.k, default_k);
|
||||
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k) : vector_query.k;
|
||||
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
|
@ -161,17 +161,28 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) {
|
||||
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// `k` value should work correctly
|
||||
results = coll1->search("*", {}, "", {}, {}, {0}, 1, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
// `k` value should overrides per_page
|
||||
results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([], id: 1, k: 1)").get();
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488], k: 1)").get();
|
||||
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
|
||||
// when k is not set, should use per_page
|
||||
results = coll1->search("*", {}, "", {}, {}, {0}, 2, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
|
||||
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
// when `id` does not exist, return appropriate error
|
||||
res_op = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
@ -184,19 +195,6 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) {
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Document id referenced in vector query is not found.", res_op.error());
|
||||
|
||||
// DEPRECATED: vector query is also supported on non-wildcard queries with hybrid search
|
||||
// only supported with wildcard queries
|
||||
// res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
// spp::sparse_hash_set<std::string>(),
|
||||
// spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
// "", 10, {}, {}, {}, 0,
|
||||
// "<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
|
||||
// 4, {off}, 32767, 32767, 2,
|
||||
// false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])");
|
||||
|
||||
// ASSERT_FALSE(res_op.ok());
|
||||
// ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error());
|
||||
|
||||
// support num_dim on only float array fields
|
||||
schema = R"({
|
||||
"name": "coll2",
|
||||
@ -764,6 +762,21 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
|
||||
ASSERT_FLOAT_EQ(0.0462081432, search_res["hits"][0]["vector_distance"].get<float>());
|
||||
ASSERT_FLOAT_EQ(0.1213316321, search_res["hits"][1]["vector_distance"].get<float>());
|
||||
|
||||
// to pass k param
|
||||
vec_query = "embedding:([], k: 1)";
|
||||
search_res_op = coll->search("butter", {"embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10, {}, {}, {}, 0,
|
||||
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
|
||||
fallback,
|
||||
4, {off}, 32767, 32767, 2,
|
||||
false, true, vec_query);
|
||||
ASSERT_TRUE(search_res_op.ok());
|
||||
search_res = search_res_op.get();
|
||||
ASSERT_EQ(1, search_res["found"].get<size_t>());
|
||||
ASSERT_EQ(1, search_res["hits"].size());
|
||||
|
||||
// when no embedding field is passed, it should not be allowed
|
||||
search_res_op = coll->search("butter", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user