From 4ec2e960d0fc0f8dbe635914dcdae67c83537986 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Wed, 12 Jul 2023 20:49:10 +0530 Subject: [PATCH] Allow vector_query_hits to be passed via vector_query.k --- include/collection.h | 1 - src/collection.cpp | 16 ++------- src/collection_manager.cpp | 4 --- src/index.cpp | 5 +-- test/collection_vector_search_test.cpp | 45 +++++++++++++++++--------- 5 files changed, 35 insertions(+), 36 deletions(-) diff --git a/include/collection.h b/include/collection.h index 93e826ff..641692f9 100644 --- a/include/collection.h +++ b/include/collection.h @@ -464,7 +464,6 @@ public: const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, const size_t page_offset = 0, - const size_t vector_query_hits = 250, const size_t remote_embedding_timeout_ms = 30000, const size_t remote_embedding_num_try = 2) const; diff --git a/src/collection.cpp b/src/collection.cpp index 7778bf0c..738157fe 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1108,8 +1108,7 @@ Option Collection::search(std::string raw_query, const size_t facet_sample_percent, const size_t facet_sample_threshold, const size_t page_offset, - const size_t vector_query_hits, - const size_t remote_embedding_timeout_ms, + const size_t remote_embedding_timeout_ms, const size_t remote_embedding_num_try) const { std::shared_lock lock(mutex); @@ -1258,13 +1257,9 @@ Option Collection::search(std::string raw_query, } } std::vector embedding = embedding_op.embedding; - // distance could have been set for an embed field, so we take a backup and restore - auto dist = vector_query.distance_threshold; - vector_query._reset(); + // params could have been set for an embed field, so we take a backup and restore vector_query.values = embedding; vector_query.field_name = field_name; - vector_query.k = vector_query_hits; - vector_query.distance_threshold = dist; continue; } @@ -1280,11 +1275,6 @@ Option Collection::search(std::string raw_query, return Option(400, error); } - std::string real_raw_query = raw_query; - if(!vector_query.field_name.empty() && processed_search_fields.size() == 0) { - raw_query = "*"; - } - if(!query_by_weights.empty() && processed_search_fields.size() != query_by_weights.size()) { std::string error = "Error, query_by_weights.size != query_by.size."; return Option(400, error); @@ -2165,7 +2155,7 @@ Option Collection::search(std::string raw_query, result["request_params"] = nlohmann::json::object(); result["request_params"]["collection_name"] = name; result["request_params"]["per_page"] = per_page; - result["request_params"]["q"] = real_raw_query; + result["request_params"]["q"] = raw_query; //long long int timeMillis = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - begin).count(); //!LOG(INFO) << "Time taken for result calc: " << timeMillis << "us"; diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index ac740da8..059af6a8 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -680,7 +680,6 @@ Option CollectionManager::do_search(std::map& re const char *MAX_FACET_VALUES = "max_facet_values"; const char *VECTOR_QUERY = "vector_query"; - const char *VECTOR_QUERY_HITS = "vector_query_hits"; const char* REMOTE_EMBEDDING_TIMEOUT_MS = "remote_embedding_timeout_ms"; const char* REMOTE_EMBEDDING_NUM_TRY = "remote_embedding_num_try"; @@ -836,7 +835,6 @@ Option CollectionManager::do_search(std::map& re size_t max_extra_suffix = INT16_MAX; bool enable_highlight_v1 = true; text_match_type_t match_type = max_score; - size_t vector_query_hits = 250; size_t remote_embedding_timeout_ms = 5000; size_t remote_embedding_num_try = 2; @@ -866,7 +864,6 @@ Option CollectionManager::do_search(std::map& re {FILTER_CURATED_HITS, &filter_curated_hits_option}, {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, - {VECTOR_QUERY_HITS, &vector_query_hits}, {REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms}, {REMOTE_EMBEDDING_NUM_TRY, &remote_embedding_num_try}, }; @@ -1082,7 +1079,6 @@ Option CollectionManager::do_search(std::map& re facet_sample_percent, facet_sample_threshold, offset, - vector_query_hits, remote_embedding_timeout_ms, remote_embedding_num_try ); diff --git a/src/index.cpp b/src/index.cpp index f77bf06a..6cf559f0 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2857,7 +2857,8 @@ Option Index::search(std::vector& field_query_tokens, cons collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { - auto k = std::max(vector_query.k, fetch_size); + auto k = vector_query.k == 0 ? std::max(vector_query.k, fetch_size) : vector_query.k; + if(vector_query.query_doc_given) { // since we will omit the query doc from results k++; @@ -3147,7 +3148,7 @@ Option Index::search(std::vector& field_query_tokens, cons std::vector> dist_labels; // use k as 100 by default for ensuring results stability in pagination size_t default_k = 100; - auto k = std::max(vector_query.k, default_k); + auto k = vector_query.k == 0 ? std::max(fetch_size, default_k) : vector_query.k; if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index e12b58ad..493847c7 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -161,17 +161,28 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) { ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); ASSERT_STREQ("2", results["hits"][1]["document"]["id"].get().c_str()); - // `k` value should work correctly - results = coll1->search("*", {}, "", {}, {}, {0}, 1, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + // `k` value should overrides per_page + results = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 5, "", 10, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, 4, {off}, 32767, 32767, 2, - false, true, "vec:([], id: 1, k: 1)").get(); + false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488], k: 1)").get(); ASSERT_EQ(1, results["hits"].size()); + // when k is not set, should use per_page + results = coll1->search("*", {}, "", {}, {}, {0}, 2, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get(); + + ASSERT_EQ(2, results["hits"].size()); + // when `id` does not exist, return appropriate error res_op = coll1->search("*", {}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(), @@ -184,19 +195,6 @@ TEST_F(CollectionVectorTest, BasicVectorQuerying) { ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Document id referenced in vector query is not found.", res_op.error()); - // DEPRECATED: vector query is also supported on non-wildcard queries with hybrid search - // only supported with wildcard queries - // res_op = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, - // spp::sparse_hash_set(), - // spp::sparse_hash_set(), 10, "", 30, 5, - // "", 10, {}, {}, {}, 0, - // "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, - // 4, {off}, 32767, 32767, 2, - // false, true, "zec:([0.96826, 0.94, 0.39557, 0.4542])"); - - // ASSERT_FALSE(res_op.ok()); - // ASSERT_EQ("Vector query is supported only on wildcard (q=*) searches.", res_op.error()); - // support num_dim on only float array fields schema = R"({ "name": "coll2", @@ -764,6 +762,21 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { ASSERT_FLOAT_EQ(0.0462081432, search_res["hits"][0]["vector_distance"].get()); ASSERT_FLOAT_EQ(0.1213316321, search_res["hits"][1]["vector_distance"].get()); + // to pass k param + vec_query = "embedding:([], k: 1)"; + search_res_op = coll->search("butter", {"embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2, + false, true, vec_query); + ASSERT_TRUE(search_res_op.ok()); + search_res = search_res_op.get(); + ASSERT_EQ(1, search_res["found"].get()); + ASSERT_EQ(1, search_res["hits"].size()); + // when no embedding field is passed, it should not be allowed search_res_op = coll->search("butter", {"name"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, spp::sparse_hash_set(),