Fix group by not happening on vector search.

This commit is contained in:
Kishore Nallan 2023-08-26 17:44:27 +05:30
parent 6087ff30d4
commit 6cbd4306e0
2 changed files with 74 additions and 12 deletions

View File

@ -3204,8 +3204,8 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
for(size_t res_index = 0; res_index < vec_results.size(); res_index++) {
auto& vec_result = vec_results[res_index];
auto doc_id = vec_result.first;
auto result_it = topster->kv_map.find(doc_id);
auto seq_id = vec_result.first;
auto result_it = topster->kv_map.find(seq_id);
if(result_it != topster->kv_map.end()) {
if(result_it->second->match_score_index < 0 || result_it->second->match_score_index > 2) {
@ -3214,22 +3214,23 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
// result overlaps with keyword search: we have to combine the scores
auto result = result_it->second;
KV* kv = result_it->second;
// old_score + (1 / rank_of_document) * WEIGHT)
result->vector_distance = vec_result.second;
result->text_match_score = result->scores[result->match_score_index];
kv->vector_distance = vec_result.second;
kv->text_match_score = kv->scores[kv->match_score_index];
int64_t match_score = float_to_int64_t(
(int64_t_to_float(result->scores[result->match_score_index])) +
(int64_t_to_float(kv->scores[kv->match_score_index])) +
((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT));
int64_t match_score_index = -1;
int64_t scores[3] = {0};
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second);
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0,
match_score, scores, match_score_index, vec_result.second);
for(int i = 0; i < 3; i++) {
result->scores[i] = scores[i];
kv->scores[i] = scores[i];
}
result->match_score_index = match_score_index;
kv->match_score_index = match_score_index;
} else {
// Result has been found only in vector search: we have to add it to both KV and result_ids
@ -3237,12 +3238,21 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
int64_t scores[3] = {0};
int64_t match_score = float_to_int64_t((1.0 / (res_index + 1)) * VECTOR_SEARCH_WEIGHT);
int64_t match_score_index = -1;
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, doc_id, 0, match_score, scores, match_score_index, vec_result.second);
KV kv(searched_queries.size(), doc_id, doc_id, match_score_index, scores);
compute_sort_scores(sort_fields_std, sort_order, field_values, geopoint_indices, seq_id, 0, match_score, scores, match_score_index, vec_result.second);
uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
if(excluded_group_ids.count(distinct_id) != 0) {
continue;
}
}
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
kv.text_match_score = 0;
kv.vector_distance = vec_result.second;
topster->add(&kv);
vec_search_ids.push_back(doc_id);
vec_search_ids.push_back(seq_id);
}
}

View File

@ -1327,6 +1327,58 @@ TEST_F(CollectionVectorTest, KeywordSearchReturnOnlyTextMatchInfo) {
ASSERT_EQ(1, results["hits"][0].count("text_match_info"));
}
TEST_F(CollectionVectorTest, GroupByWithVectorSearch) {
nlohmann::json schema = R"({
"name": "coll1",
"fields": [
{"name": "title", "type": "string"},
{"name": "group", "type": "string", "facet": true},
{"name": "vec", "type": "float[]", "num_dim": 4}
]
})"_json;
Collection* coll1 = collectionManager.create_collection(schema).get();
std::vector<std::vector<float>> values = {
{0.851758, 0.909671, 0.823431, 0.372063},
{0.97826, 0.933157, 0.39557, 0.306488},
{0.230606, 0.634397, 0.514009, 0.399594}
};
for (size_t i = 0; i < values.size(); i++) {
nlohmann::json doc;
doc["id"] = std::to_string(i);
doc["title"] = std::to_string(i) + " title";
doc["group"] = "0";
doc["vec"] = values[i];
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto res = coll1->search("title", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {"group"}, 1,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(1, res["grouped_hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"][0].count("vector_distance"));
res = coll1->search("*", {"title"}, "", {}, {}, {0}, 10, 1, FREQUENCY, {true}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10, {}, {}, {"group"}, 1,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 32767, 32767, 2,
false, true, "vec:([0.96826, 0.94, 0.39557, 0.306488])").get();
ASSERT_EQ(1, res["grouped_hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"][0].count("vector_distance"));
}
TEST_F(CollectionVectorTest, HybridSearchReturnAllInfo) {
auto schema_json =
R"({