From d0edc91cbb219051daf04c20f695ae8e14d71cdd Mon Sep 17 00:00:00 2001 From: ozanarmagan Date: Mon, 26 Jun 2023 14:19:19 +0300 Subject: [PATCH 1/5] vector_query_hits & sort by vector_distance --- include/collection.h | 6 +- include/field.h | 2 + include/index.h | 3 +- src/collection.cpp | 33 +++++-- src/collection_manager.cpp | 3 + src/index.cpp | 42 +++++++-- test/collection_sorting_test.cpp | 142 +++++++++++++++++++++++++++++++ 7 files changed, 214 insertions(+), 17 deletions(-) diff --git a/include/collection.h b/include/collection.h index 8387b9e6..2f79bc68 100644 --- a/include/collection.h +++ b/include/collection.h @@ -207,7 +207,8 @@ private: Option validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, - bool is_wildcard_query, bool is_group_by_query = false) const; + bool is_wildcard_query,const bool is_vector_query, + bool is_group_by_query = false) const; Option persist_collection_meta(); @@ -462,7 +463,8 @@ public: const text_match_type_t match_type = max_score, const size_t facet_sample_percent = 100, const size_t facet_sample_threshold = 0, - const size_t page_offset = UINT32_MAX) const; + const size_t page_offset = UINT32_MAX, + const size_t vector_query_hits = 250) const; Option get_filter_ids(const std::string & filter_query, filter_result_t& filter_result) const; diff --git a/include/field.h b/include/field.h index 8d1606e3..ccf456e2 100644 --- a/include/field.h +++ b/include/field.h @@ -726,6 +726,8 @@ namespace sort_field_const { static const std::string precision = "precision"; static const std::string missing_values = "missing_values"; + + static const std::string vector_distance = "_vector_distance"; } struct sort_by { diff --git a/include/index.h b/include/index.h index 34a85749..3f93e182 100644 --- a/include/index.h +++ b/include/index.h @@ -344,6 +344,7 @@ private: static spp::sparse_hash_map eval_sentinel_value; static spp::sparse_hash_map geo_sentinel_value; static spp::sparse_hash_map str_sentinel_value; + static spp::sparse_hash_map vector_distance_sentinel_value; // Internal utility functions @@ -947,7 +948,7 @@ public: size_t filter_index, int64_t max_field_match_score, int64_t* scores, - int64_t& match_score_index) const; + int64_t& match_score_index, float vector_distance = 0) const; void process_curated_ids(const std::vector>& included_ids, diff --git a/src/collection.cpp b/src/collection.cpp index 1884004b..d52ffb3e 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -740,6 +740,7 @@ void Collection::curate_results(string& actual_query, const string& filter_query Option Collection::validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, const bool is_wildcard_query, + const bool is_vector_query, const bool is_group_by_query) const { size_t num_sort_expressions = 0; @@ -914,7 +915,7 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && - sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found) { + sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance) { const auto field_it = search_schema.find(sort_field_std.name); if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { @@ -928,6 +929,11 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< std::string error = "group_by parameters should not be empty when using sort_by group_found"; return Option(404, error); } + + if(sort_field_std.name == sort_field_const::vector_distance && !is_vector_query) { + std::string error = "sort_by vector_distance is only supported for vector queries, semantic search and hybrid search."; + return Option(404, error); + } StringUtils::toupper(sort_field_std.order); @@ -950,6 +956,10 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc); } + if(is_vector_query) { + sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc); + } + if(!default_sorting_field.empty()) { sort_fields_std.emplace_back(default_sorting_field, sort_field_const::desc); } else { @@ -958,9 +968,15 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } bool found_match_score = false; + bool found_vector_distance = false; for(const auto & sort_field : sort_fields_std) { if(sort_field.name == sort_field_const::text_match) { found_match_score = true; + } + if(sort_field.name == sort_field_const::vector_distance) { + found_vector_distance = true; + } + if(found_match_score && found_vector_distance) { break; } } @@ -969,6 +985,10 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_fields_std.emplace_back(sort_field_const::text_match, sort_field_const::desc); } + if(!found_vector_distance && is_vector_query && sort_fields.size() < 3) { + sort_fields_std.emplace_back(sort_field_const::vector_distance, sort_field_const::asc); + } + if(sort_fields_std.size() > 3) { std::string message = "Only upto 3 sort_by fields can be specified."; return Option(422, message); @@ -1085,7 +1105,8 @@ Option Collection::search(std::string raw_query, const text_match_type_t match_type, const size_t facet_sample_percent, const size_t facet_sample_threshold, - const size_t page_offset) const { + const size_t page_offset, + const size_t vector_query_hits) const { std::shared_lock lock(mutex); @@ -1226,6 +1247,7 @@ Option Collection::search(std::string raw_query, vector_query._reset(); vector_query.values = embedding; vector_query.field_name = field_name; + vector_query.k = vector_query_hits; continue; } @@ -1455,10 +1477,11 @@ Option Collection::search(std::string raw_query, bool is_wildcard_query = (query == "*"); bool is_group_by_query = group_by_fields.size() > 0; + bool is_vector_query = !vector_query.field_name.empty(); if(curated_sort_by.empty()) { auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, - sort_fields_std, is_wildcard_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1470,7 +1493,7 @@ Option Collection::search(std::string raw_query, } auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, - sort_fields_std, is_wildcard_query, is_group_by_query); + sort_fields_std, is_wildcard_query, is_vector_query, is_group_by_query); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1916,7 +1939,7 @@ Option Collection::search(std::string raw_query, } if(!vector_query.field_name.empty() && query == "*") { - wrapper_doc["vector_distance"] = Index::int64_t_to_float(-field_order_kv->scores[0]); + wrapper_doc["vector_distance"] = field_order_kv->vector_distance; } hits_array.push_back(wrapper_doc); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index eb2fcc71..93d905e5 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -669,6 +669,7 @@ Option CollectionManager::do_search(std::map& re const char *MAX_FACET_VALUES = "max_facet_values"; const char *VECTOR_QUERY = "vector_query"; + const char *VECTOR_QUERY_HITS = "vector_query_hits"; const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; @@ -800,6 +801,7 @@ Option CollectionManager::do_search(std::map& re size_t max_extra_suffix = INT16_MAX; bool enable_highlight_v1 = true; text_match_type_t match_type = max_score; + size_t vector_query_hits = 250; size_t facet_sample_percent = 100; size_t facet_sample_threshold = 0; @@ -826,6 +828,7 @@ Option CollectionManager::do_search(std::map& re {FILTER_CURATED_HITS, &filter_curated_hits_option}, {FACET_SAMPLE_PERCENT, &facet_sample_percent}, {FACET_SAMPLE_THRESHOLD, &facet_sample_threshold}, + {VECTOR_QUERY_HITS, &vector_query_hits}, }; std::unordered_map str_values = { diff --git a/src/index.cpp b/src/index.cpp index ce24b8e1..65f8f7fc 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -42,6 +42,7 @@ spp::sparse_hash_map Index::seq_id_sentinel_value; spp::sparse_hash_map Index::eval_sentinel_value; spp::sparse_hash_map Index::geo_sentinel_value; spp::sparse_hash_map Index::str_sentinel_value; +spp::sparse_hash_map Index::vector_distance_sentinel_value; struct token_posting_t { uint32_t token_id; @@ -2854,7 +2855,9 @@ Option Index::search(std::vector& field_query_tokens, cons collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { - auto k = std::max(vector_query.k, fetch_size); + // use k as 250 by default for ensuring results stability in pagination + size_t default_k = 250; + auto k = std::max(vector_query.k, default_k); if(vector_query.query_doc_given) { // since we will omit the query doc from results k++; @@ -2925,12 +2928,12 @@ Option Index::search(std::vector& field_query_tokens, cons } int64_t scores[3] = {0}; - scores[0] = -float_to_int64_t(vec_dist_score); int64_t match_score_index = -1; - //LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first; + compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0, 0, scores, match_score_index, vec_dist_score); KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, nullptr); + kv.vector_distance = vec_dist_score; int ret = topster->add(&kv); if(group_limit != 0 && ret < 2) { @@ -3142,7 +3145,9 @@ Option Index::search(std::vector& field_query_tokens, cons VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; - auto k = std::max(vector_query.k, fetch_size); + // use k as 250 by default for ensuring results stability in pagination + size_t default_k = 250; + auto k = std::max(vector_query.k, default_k); if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); @@ -3177,8 +3182,7 @@ Option Index::search(std::vector& field_query_tokens, cons continue; } // (1 / rank_of_document) * WEIGHT) - result->text_match_score = result->scores[result->match_score_index]; - LOG(INFO) << "SEQ_ID: " << result->key << ", score: " << result->text_match_score; + result->text_match_score = result->scores[result->match_score_index]; result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT); } @@ -3193,11 +3197,23 @@ Option Index::search(std::vector& field_query_tokens, cons // old_score + (1 / rank_of_document) * WEIGHT) result->vector_distance = vec_result.second; result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT)); + + for(size_t i = 0;i < 3; i++) { + if(field_values[i] == &vector_distance_sentinel_value) { + result->scores[i] = float_to_int64_t(vec_result.second); + } + + if(sort_order[i] == -1) { + result->scores[i] = -result->scores[i]; + } + } + } else { int64_t scores[3] = {0}; // (1 / rank_of_document) * WEIGHT) - scores[0] = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT); - int64_t match_score_index = 0; + int64_t match_score = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT); + int64_t match_score_index = -1; + compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second); KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores); kv.vector_distance = vec_result.second; topster->add(&kv); @@ -4164,7 +4180,7 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i std::array*, 3> field_values, const std::vector& geopoint_indices, uint32_t seq_id, size_t filter_index, int64_t max_field_match_score, - int64_t* scores, int64_t& match_score_index) const { + int64_t* scores, int64_t& match_score_index, float vector_distance) const { int64_t geopoint_distances[3]; @@ -4259,6 +4275,8 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } scores[0] = int64_t(found); + } else if(field_values[0] == &vector_distance_sentinel_value) { + scores[0] = float_to_int64_t(vector_distance); } else { auto it = field_values[0]->find(seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; @@ -4315,6 +4333,8 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } scores[1] = int64_t(found); + } else if(field_values[1] == &vector_distance_sentinel_value) { + scores[1] = float_to_int64_t(vector_distance); } else { auto it = field_values[1]->find(seq_id); scores[1] = (it == field_values[1]->end()) ? default_score : it->second; @@ -4367,6 +4387,8 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } scores[2] = int64_t(found); + } else if(field_values[2] == &vector_distance_sentinel_value) { + scores[2] = float_to_int64_t(vector_distance); } else { auto it = field_values[2]->find(seq_id); scores[2] = (it == field_values[2]->end()) ? default_score : it->second; @@ -5085,6 +5107,8 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint sort_fields_std[i].eval.ids = result.docs; sort_fields_std[i].eval.size = result.count; result.docs = nullptr; + } else if(sort_fields_std[i].name == sort_field_const::vector_distance) { + field_values[i] = &vector_distance_sentinel_value; } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 7598a589..95c255e9 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -2246,3 +2246,145 @@ TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) { ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); } } + + +TEST_F(CollectionSortingTest, AscendingVectorDistance) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + std::vector> points = { + {3.0, 4.0}, + {9.0, 21.0}, + {8.0, 15.0}, + {1.0, 1.0}, + {5.0, 7.0} + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector sort_fields = { + sort_by("_vector_distance", "asc"), + }; + + auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "points:([8.0, 15.0])").get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"2", "1", "4", "0", "3"}; + for(size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } +} + + +TEST_F(CollectionSortingTest, DescendingVectorDistance) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + std::vector> points = { + {3.0, 4.0}, + {9.0, 21.0}, + {8.0, 15.0}, + {1.0, 1.0}, + {5.0, 7.0} + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector sort_fields = { + sort_by("_vector_distance", "DESC"), + }; + + auto results = coll1->search("*", {}, "", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback, + 4, {off}, 32767, 32767, 2, + false, true, "points:([8.0, 15.0])").get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"3", "0", "4", "1", "2"}; + + for(size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } +} + + +TEST_F(CollectionSortingTest, InvalidVectorDistanceSorting) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "points", "type": "float[]", "num_dim": 2} + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + std::vector> points = { + {1.0, 1.0}, + {2.0, 2.0}, + {3.0, 3.0}, + {4.0, 4.0}, + {5.0, 5.0}, + }; + + for(size_t i = 0; i < points.size(); i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = points[i]; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector sort_fields = { + sort_by("_vector_distance", "desc"), + }; + + + + auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + + ASSERT_FALSE(results.ok()); + + ASSERT_EQ("sort_by vector_distance is only supported for vector queries, semantic search and hybrid search.", results.error()); +} \ No newline at end of file From 3a8a3997835a4d8f094515c82571ddb3172be346 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 27 Jun 2023 14:12:19 +0530 Subject: [PATCH 2/5] Don't send proxy logs through raft log. --- src/http_server.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/http_server.cpp b/src/http_server.cpp index f72de8da..813ee450 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -547,7 +547,9 @@ bool HttpServer::is_write_request(const std::string& root_resource, const std::s return false; } - bool write_free_request = (root_resource == "multi_search" || root_resource == "operations"); + bool write_free_request = (root_resource == "multi_search" || root_resource == "proxy" || + root_resource == "operations"); + if(!write_free_request && (http_method == "POST" || http_method == "PUT" || http_method == "DELETE" || http_method == "PATCH")) { From 425611fe0169fc36dc5852e697a4cbbcd7e07c2c Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 27 Jun 2023 14:21:36 +0530 Subject: [PATCH 3/5] Improve logging message. --- src/text_embedder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/text_embedder.cpp b/src/text_embedder.cpp index ae2297a2..b2a67607 100644 --- a/src/text_embedder.cpp +++ b/src/text_embedder.cpp @@ -46,7 +46,7 @@ TextEmbedder::TextEmbedder(const std::string& model_name) { TextEmbedder::TextEmbedder(const nlohmann::json& model_config) { auto model_name = model_config["model_name"].get(); - LOG(INFO) << "Loading model from remote: " << model_name; + LOG(INFO) << "Initializing remote embedding model: " << model_name; auto model_namespace = TextEmbedderManager::get_model_namespace(model_name); if(model_namespace == "openai") { From c0686a59365a225751a164cce4bb49ca29943f9f Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Wed, 28 Jun 2023 13:21:08 +0530 Subject: [PATCH 4/5] Fixed bug when only vector search produced results in hybrid. The vector search produced IDs must be merged back to all_result_ids in addition to incrementing the all_result_ids_len. --- src/index.cpp | 34 ++++++++++++++++++++------ test/collection_vector_search_test.cpp | 34 ++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 65f8f7fc..285b2af8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3186,17 +3186,26 @@ Option Index::search(std::vector& field_query_tokens, cons result->scores[result->match_score_index] = float_to_int64_t((1.0 / (i + 1)) * TEXT_MATCH_WEIGHT); } - for(int i = 0; i < vec_results.size(); i++) { - auto& vec_result = vec_results[i]; - auto doc_id = vec_result.first; + std::vector vec_search_ids; // list of IDs found only in vector search + for(size_t res_index = 0; res_index < vec_results.size(); res_index++) { + auto& vec_result = vec_results[res_index]; + auto doc_id = vec_result.first; auto result_it = topster->kv_map.find(doc_id); - if(result_it != topster->kv_map.end()&& result_it->second->match_score_index >= 0 && result_it->second->match_score_index <= 2) { + if(result_it != topster->kv_map.end()) { + if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) { + continue; + } + + // result overlaps with keyword search: we have to combine the scores + auto result = result_it->second; // old_score + (1 / rank_of_document) * WEIGHT) result->vector_distance = vec_result.second; - result->scores[result->match_score_index] = float_to_int64_t((int64_t_to_float(result->scores[result->match_score_index])) + ((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT)); + result->scores[result->match_score_index] = float_to_int64_t( + (int64_t_to_float(result->scores[result->match_score_index])) + + ((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT)); for(size_t i = 0;i < 3; i++) { if(field_values[i] == &vector_distance_sentinel_value) { @@ -3209,17 +3218,26 @@ Option Index::search(std::vector& field_query_tokens, cons } } else { - int64_t scores[3] = {0}; + // Result has been found only in vector search: we have to add it to both KV and result_ids // (1 / rank_of_document) * WEIGHT) - int64_t match_score = float_to_int64_t((1.0 / (i + 1)) * VECTOR_SEARCH_WEIGHT); + int64_t scores[3] = {0}; + int64_t match_score = float_to_int64_t((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT); int64_t match_score_index = -1; compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second); KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores); kv.vector_distance = vec_result.second; topster->add(&kv); - ++all_result_ids_len; + vec_search_ids.push_back(doc_id); } } + + if(!vec_search_ids.empty()) { + uint32_t* new_all_result_ids = nullptr; + all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, &vec_search_ids[0], + vec_search_ids.size(), &new_all_result_ids); + delete[] all_result_ids; + all_result_ids = new_all_result_ids; + } } } diff --git a/test/collection_vector_search_test.cpp b/test/collection_vector_search_test.cpp index 995bddd7..b6b598dc 100644 --- a/test/collection_vector_search_test.cpp +++ b/test/collection_vector_search_test.cpp @@ -711,12 +711,42 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) { 4, {off}, 32767, 32767, 2, false, true, "vec:(" + dummy_vec_string +")"); ASSERT_EQ(true, results_op.ok()); - - ASSERT_EQ(1, results_op.get()["found"].get()); ASSERT_EQ(1, results_op.get()["hits"].size()); } +TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) { + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "name", "type": "string", "facet": true}, + {"name": "vec", "type": "float[]", "embed":{"from": ["name"], "model_config": {"model_name": "ts/e5-small"}}} + ] + })"_json; + + TextEmbedderManager::set_model_dir("/tmp/typesense_test/models"); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + nlohmann::json doc; + doc["name"] = "john doe"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results_op = coll1->search("zzz", {"name", "vec"}, "", {"name"}, {}, {0}, 20, 1, FREQUENCY, {true}, + Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, {}, {}, {}, 0, + "", "", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, + fallback, + 4, {off}, 32767, 32767, 2); + ASSERT_EQ(true, results_op.ok()); + ASSERT_EQ(1, results_op.get()["found"].get()); + ASSERT_EQ(1, results_op.get()["hits"].size()); + ASSERT_EQ(1, results_op.get()["facet_counts"].size()); + ASSERT_EQ(4, results_op.get()["facet_counts"][0].size()); + ASSERT_EQ("name", results_op.get()["facet_counts"][0]["field_name"]); +} + TEST_F(CollectionVectorTest, DistanceThresholdTest) { nlohmann::json schema = R"({ "name": "test", From a7d549d0aeea26310e1a02889bfaaf9539e5a085 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Wed, 28 Jun 2023 13:24:17 +0530 Subject: [PATCH 5/5] Tweak default values for k in vector search. --- src/index.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 285b2af8..5cfb2287 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2855,9 +2855,7 @@ Option Index::search(std::vector& field_query_tokens, cons collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { - // use k as 250 by default for ensuring results stability in pagination - size_t default_k = 250; - auto k = std::max(vector_query.k, default_k); + auto k = std::max(vector_query.k, fetch_size); if(vector_query.query_doc_given) { // since we will omit the query doc from results k++; @@ -3145,8 +3143,8 @@ Option Index::search(std::vector& field_query_tokens, cons VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; - // use k as 250 by default for ensuring results stability in pagination - size_t default_k = 250; + // use k as 100 by default for ensuring results stability in pagination + size_t default_k = 100; auto k = std::max(vector_query.k, default_k); if(field_vector_index->distance_type == cosine) {