Don't return vector_distance for keyword only matches in hybrid search

This commit is contained in:
ozanarmagan 2023-09-24 18:52:43 +03:00
parent e54f680b22
commit d42590c638
3 changed files with 90 additions and 18 deletions

View File

@ -16,7 +16,7 @@ struct KV {
int64_t scores[3]{}; // match score + 2 custom attributes
// only to be used in hybrid search
float vector_distance = 2.0f;
float vector_distance = -1.0f;
int64_t text_match_score = 0;
reference_filter_result_t* reference_filter_result = nullptr;
@ -44,7 +44,7 @@ struct KV {
KV(KV&& kv) noexcept : match_score_index(kv.match_score_index),
query_index(kv.query_index), array_index(kv.array_index),
key(kv.key), distinct_key(kv.distinct_key) {
scores[0] = kv.scores[0];
scores[1] = kv.scores[1];
scores[2] = kv.scores[2];

View File

@ -3205,21 +3205,34 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
for(size_t res_index = 0; res_index < vec_results.size(); res_index++) {
auto& vec_result = vec_results[res_index];
auto seq_id = vec_result.first;
auto result_it = topster->kv_map.find(seq_id);
if(result_it != topster->kv_map.end()) {
if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) {
KV* found_kv = nullptr;
if(group_limit == 0) {
auto result_it = topster->kv_map.find(seq_id);
if(result_it != topster->kv_map.end()) {
found_kv = result_it->second;
}
} else {
auto g_topster_it = topster->group_kv_map.find(get_distinct_id(group_by_fields, seq_id));
if(g_topster_it != topster->group_kv_map.end()) {
auto g_topster = g_topster_it->second;
auto result_it = g_topster->kv_map.find(seq_id);
if(result_it != g_topster->kv_map.end()) {
found_kv = result_it->second;
}
}
}
if(found_kv) {
if(found_kv->match_score_index < 0 || found_kv->match_score_index > 2) {
continue;
}
// result overlaps with keyword search: we have to combine the scores
KV* kv = result_it->second;
// old_score + (1 / rank_of_document) * WEIGHT)
kv->vector_distance = vec_result.second;
kv->text_match_score = kv->scores[kv->match_score_index];
found_kv->vector_distance = vec_result.second;
found_kv->text_match_score = found_kv->scores[found_kv->match_score_index];
int64_t match_score = float_to_int64_t(
(int64_t_to_float(kv->scores[kv->match_score_index])) +
(int64_t_to_float(found_kv->scores[found_kv->match_score_index])) +
((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT));
int64_t match_score_index = -1;
int64_t scores[3] = {0};
@ -3228,9 +3241,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
match_score, scores, match_score_index, vec_result.second);
for(int i = 0; i < 3; i++) {
kv->scores[i] = scores[i];
found_kv->scores[i] = scores[i];
}
kv->match_score_index = match_score_index;
found_kv->match_score_index = match_score_index;
} else {
// Result has been found only in vector search: we have to add it to both KV and result_ids
@ -3251,8 +3264,12 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
kv.text_match_score = 0;
kv.vector_distance = vec_result.second;
topster->add(&kv);
auto ret = topster->add(&kv);
vec_search_ids.push_back(seq_id);
if(group_limit != 0 && ret < 2) {
groups_processed[distinct_id]++;
}
}
}

View File

@ -1154,7 +1154,7 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
ASSERT_FLOAT_EQ((1.0/3.0 * 0.7) + (1.0/2.0 * 0.3), search_res["hits"][2]["hybrid_search_info"]["rank_fusion_score"].get<float>());
// hybrid search with empty vector (to pass distance threshold param)
std::string vec_query = "embedding:([], distance_threshold: 0.20)";
std::string vec_query = "embedding:([], distance_threshold: 0.13)";
search_res_op = coll->search("butter", {"embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
@ -1217,9 +1217,9 @@ TEST_F(CollectionVectorTest, HybridSearchWithExplicitVector) {
ASSERT_EQ(3, search_res["found"].get<size_t>());
ASSERT_EQ(3, search_res["hits"].size());
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][0]["vector_distance"].get<float>());
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][1]["vector_distance"].get<float>());
ASSERT_FLOAT_EQ(2.0f, search_res["hits"][2]["vector_distance"].get<float>());
ASSERT_TRUE(search_res["hits"][0].count("vector_distance") == 0);
ASSERT_TRUE(search_res["hits"][1].count("vector_distance") == 0);
ASSERT_TRUE(search_res["hits"][2].count("vector_distance") == 0);
}
TEST_F(CollectionVectorTest, HybridSearchOnlyVectorMatches) {
@ -1846,7 +1846,7 @@ TEST_F(CollectionVectorTest, GroupByWithVectorSearch) {
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(1, res["grouped_hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"][0].count("vector_distance"));
@ -2122,4 +2122,59 @@ TEST_F(CollectionVectorTest, TestOneEmbeddingOneKeywordFieldsHaveSamePrefix) {
0, spp::sparse_hash_set<std::string>());
ASSERT_TRUE(keyword_results.ok());
}
TEST_F(CollectionVectorTest, HybridSearchOnlyKeyworMatchDoNotHaveVectorDistance) {
nlohmann::json schema = R"({
"name": "test",
"fields": [
{
"name": "title",
"type": "string"
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": [
"title"
],
"model_config": {
"model_name": "ts/e5-small"
}
}
}
]
})"_json;
TextEmbedderManager::set_model_dir("/tmp/typesense_test/models");
auto collection_create_op = collectionManager.create_collection(schema);
ASSERT_TRUE(collection_create_op.ok());
auto coll1 = collection_create_op.get();
auto add_op = coll1->add(R"({
"title": "john doe"
})"_json.dump());
ASSERT_TRUE(add_op.ok());
// hybrid search with empty vector (to pass distance threshold param)
std::string vec_query = "embedding:([], distance_threshold: 0.05)";
auto hybrid_results = coll1->search("john", {"title", "embedding"}, "", {}, {}, {0}, 20, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7,
fallback,
4, {off}, 32767, 32767, 2,
false, true, vec_query);
ASSERT_TRUE(hybrid_results.ok());
ASSERT_EQ(1, hybrid_results.get()["hits"].size());
ASSERT_EQ(0, hybrid_results.get()["hits"][0].count("vector_distance"));
}