Allow vector_query_hits to be passed via vector_query.k

This commit is contained in:
Kishore Nallan 2023-07-12 20:49:10 +05:30
parent c2db7436a2
commit 4ec2e960d0
5 changed files with 35 additions and 36 deletions

View File

@ -464,7 +464,6 @@ public:
const size_t facet_sample_percent = 100,
const size_t facet_sample_threshold = 0,
const size_t page_offset = 0,
const size_t vector_query_hits = 250,
const size_t remote_embedding_timeout_ms = 30000,
const size_t remote_embedding_num_try = 2) const;

View File

@ -1108,8 +1108,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
const size_t facet_sample_percent,
const size_t facet_sample_threshold,
const size_t page_offset,
const size_t vector_query_hits,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_timeout_ms,
const size_t remote_embedding_num_try) const {
std::shared_lock lock(mutex);
@ -1258,13 +1257,9 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
}
}
std::vector<float> embedding = embedding_op.embedding;
// distance could have been set for an embed field, so we take a backup and restore
auto dist = vector_query.distance_threshold;
vector_query._reset();
// params could have been set for an embed field, so we take a backup and restore
vector_query.values = embedding;
vector_query.field_name = field_name;
vector_query.k = vector_query_hits;
vector_query.distance_threshold = dist;
continue;
}
@ -1280,11 +1275,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
return Option<nlohmann::json>(400, error);
}
std::string real_raw_query = raw_query;
if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) {
raw_query = "*";
}
if(!query_by_weights.empty() && processed_search_fields.size() != query_by_weights.size()) {
std::string error = "Error, query_by_weights.size != query_by.size.";
return Option<nlohmann::json>(400, error);
@ -2165,7 +2155,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
result["request_params"] = nlohmann::json::object();
result["request_params"]["collection_name"] = name;
result["request_params"]["per_page"] = per_page;
result["request_params"]["q"] = real_raw_query;
result["request_params"]["q"] = raw_query;
//long long int timeMillis = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - begin).count();
//!LOG(INFO) << "Time taken for result calc: " << timeMillis << "us";

View File

@ -680,7 +680,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *MAX_FACET_VALUES = "max_facet_values";
const char *VECTOR_QUERY = "vector_query";
const char *VECTOR_QUERY_HITS = "vector_query_hits";
const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms";
const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try";
@ -836,7 +835,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
size_t max_extra_suffix = INT16_MAX;
bool enable_highlight_v1 = true;
text_match_type_t match_type = max_score;
size_t vector_query_hits = 250;
size_t remote_embedding_timeout_ms = 5000;
size_t remote_embedding_num_try = 2;
@ -866,7 +864,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{FILTER_CURATED_HITS, &filter_curated_hits_option},
{FACET_SAMPLE_PERCENT, &facet_sample_percent},
{FACET_SAMPLE_THRESHOLD, &facet_sample_threshold},
{VECTOR_QUERY_HITS, &vector_query_hits},
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
{REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try},
};
@ -1082,7 +1079,6 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
facet_sample_percent,
facet_sample_threshold,
offset,
vector_query_hits,
remote_embedding_timeout_ms,
remote_embedding_num_try
);

View File

@ -2857,7 +2857,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
if (!vector_query.field_name.empty()) {
auto k = std::max<size_t>(vector_query.k, fetch_size);
auto k = vector_query.k == 0 ? std::max<size_t>(vector_query.k, fetch_size) : vector_query.k;
if(vector_query.query_doc_given) {
// since we will omit the query doc from results
k++;
@ -3147,7 +3148,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = std::max<size_t>(vector_query.k, default_k);
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k) : vector_query.k;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());

View File

@ -161,17 +161,28 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) {
ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get<std::string>().c_str());
// `k` value should work correctly
results = coll1->search("*", {}, "", {}, {}, {0}, 1, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
// `k` value should overrides per_page
results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([], id: 1, k: 1)").get();
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488], k: 1)").get();
ASSERT_EQ(1, results["hits"].size());
// when k is not set, should use per_page
results = coll1->search("*", {}, "", {}, {}, {0}, 2, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(2, results["hits"].size());
// when `id` does not exist, return appropriate error
res_op = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
@ -184,19 +195,6 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) {
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Document id referenced in vector query is not found.", res_op.error());
// DEPRECATED: vector query is also supported on non-wildcard queries with hybrid search
// only supported with wildcard queries
// res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
// spp::sparse_hash_set<std::string>(),
// spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
// "", 10, {}, {}, {}, 0,
// "<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
// 4, {off}, 32767, 32767, 2,
// false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])");
// ASSERT_FALSE(res_op.ok());
// ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error());
// support num_dim on only float array fields
schema = R"({
"name": "coll2",
@ -764,6 +762,21 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
ASSERT_FLOAT_EQ(0.0462081432, search_res["hits"][0]["vector_distance"].get<float>());
ASSERT_FLOAT_EQ(0.1213316321, search_res["hits"][1]["vector_distance"].get<float>());
// to pass k param
vec_query = "embedding:([], k: 1)";
search_res_op = coll->search("butter", {"embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, vec_query);
ASSERT_TRUE(search_res_op.ok());
search_res = search_res_op.get();
ASSERT_EQ(1, search_res["found"].get<size_t>());
ASSERT_EQ(1, search_res["hits"].size());
// when no embedding field is passed, it should not be allowed
search_res_op = coll->search("butter", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),