hybrid search flat_search_cutoff (#2099)
Some checks failed
tests / test (push) Has been cancelled

* hybrid search flat_search_cutoff

* refactor repeatative code

* make both approach exclusive
This commit is contained in:
Krunal Gandhi 2024-12-10 19:03:20 +05:30 committed by GitHub
parent 225ed057cd
commit 9923a5455e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2870,6 +2870,108 @@ Option<bool> Index::search_infix(const std::string& query, const std::string& fi
return Option<bool>(true);
}
void process_results_bruteforce(filter_result_iterator_t* filter_result_iterator, const vector_query_t& vector_query,
hnsw_index_t* field_vector_index, std::vector<std::pair<float, single_filter_result_t>>& dist_results) {
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
auto seq_id = filter_result_iterator->seq_id;
auto filter_result = single_filter_result_t(seq_id, std::move(filter_result_iterator->reference));
filter_result_iterator->next();
std::vector<float> values;
try {
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
} catch (...) {
// likely not found
continue;
}
float dist;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(),
&field_vector_index->num_dim);
} else {
dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(),
&field_vector_index->num_dim);
}
dist_results.emplace_back(dist, filter_result);
}
}
void process_results_hnsw_index(filter_result_iterator_t* filter_result_iterator, const vector_query_t& vector_query,
hnsw_index_t* field_vector_index, VectorFilterFunctor& filterFunctor, size_t k,
std::vector<std::pair<float, single_filter_result_t>>& dist_results, bool is_wildcard_non_phrase_query = false) {
std::vector<std::pair<float, size_t>> pairs;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
pairs = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
}
std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
return x.second < y.second;
});
filter_result_iterator->reset();
if (!filter_result_iterator->reference.empty() && is_wildcard_non_phrase_query) {
// We'll have to get the references of each document.
for (auto pair: pairs) {
if (filter_result_iterator->validity == filter_result_iterator_t::timed_out) {
// Overriding timeout since we need to get the references of matched docs.
filter_result_iterator->reset(true);
search_cutoff = true;
}
auto const& seq_id = pair.second;
if (filter_result_iterator->is_valid(seq_id, search_cutoff) != 1) {
continue;
}
// The seq_id must be valid otherwise it would've been filtered out upstream.
auto filter_result = single_filter_result_t(seq_id,
std::move(filter_result_iterator->reference));
dist_results.emplace_back(pair.first, filter_result);
}
} else {
search_cutoff = search_cutoff || filter_result_iterator->validity ==
filter_result_iterator_t::timed_out;
if(!is_wildcard_non_phrase_query) {
std::vector<std::pair<float, size_t>> vec_results;
for(const auto& pair: pairs) {
auto vec_dist_score = (field_vector_index->distance_type == cosine)
? std::abs(pair.first) :
pair.first;
if (vec_dist_score > vector_query.distance_threshold) {
continue;
}
vec_results.push_back(pair);
}
// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
std::sort(vec_results.begin(), vec_results.end(),
[](const auto &a, const auto &b) {
return a.first < b.first;
});
pairs = std::move(vec_results);
}
for (const auto &pair: pairs) {
auto filter_result = single_filter_result_t(pair.second, {});
dist_results.emplace_back(pair.first, filter_result);
}
}
}
Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
filter_node_t*& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
@ -3101,82 +3203,17 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
filter_result_iterator->compute_iterators();
uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
auto seq_id = filter_result_iterator->seq_id;
auto filter_result = single_filter_result_t(seq_id, std::move(filter_result_iterator->reference));
filter_result_iterator->next();
std::vector<float> values;
try {
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
} catch (...) {
// likely not found
continue;
}
float dist;
if (field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(),
&field_vector_index->num_dim);
} else {
dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(),
&field_vector_index->num_dim);
}
dist_results.emplace_back(dist, filter_result);
}
}
filter_result_iterator->reset();
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;
if(!filter_by_provided ||
process_results_bruteforce(filter_result_iterator, vector_query, field_vector_index, dist_results);
} else if(!filter_by_provided ||
(filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->validity == filter_result_iterator_t::valid)) {
dist_results.clear();
std::vector<std::pair<float, size_t>> pairs;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
pairs = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
}
std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
return x.second < y.second;
});
filter_result_iterator->reset();
if (!filter_result_iterator->reference.empty()) {
// We'll have to get the references of each document.
for (auto pair: pairs) {
if (filter_result_iterator->validity == filter_result_iterator_t::timed_out) {
// Overriding timeout since we need to get the references of matched docs.
filter_result_iterator->reset(true);
search_cutoff = true;
}
auto const& seq_id = pair.second;
if (filter_result_iterator->is_valid(seq_id, search_cutoff) != 1) {
continue;
}
// The seq_id must be valid otherwise it would've been filtered out upstream.
auto filter_result = single_filter_result_t(seq_id,
std::move(filter_result_iterator->reference));
dist_results.emplace_back(pair.first, filter_result);
}
} else {
for (const auto &pair: pairs) {
auto filter_result = single_filter_result_t(pair.second, {});
dist_results.emplace_back(pair.first, filter_result);
}
}
process_results_hnsw_index(filter_result_iterator, vector_query, field_vector_index, filterFunctor, k, dist_results, true);
}
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;
std::vector<uint32_t> nearest_ids;
std::vector<uint32_t> eval_filter_indexes;
@ -3542,46 +3579,32 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
VectorFilterFunctor filterFunctor(filter_result_iterator, excluded_result_ids, excluded_result_ids_size);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k) : vector_query.k;
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
std::vector<std::pair<float, single_filter_result_t>> dist_results;
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
process_results_bruteforce(filter_result_iterator, vector_query, field_vector_index, dist_results);
} else if (!filter_by_provided || (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->validity == filter_result_iterator_t::valid)) {
dist_results.clear();
// use k as 100 by default for ensuring results stability in pagination
size_t default_k = 100;
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k)
: vector_query.k;
process_results_hnsw_index(filter_result_iterator, vector_query, field_vector_index, filterFunctor, k, dist_results);
}
filter_result_iterator->reset();
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;
std::vector<std::pair<uint32_t,float>> vec_results;
for (const auto& dist_label : dist_labels) {
uint32_t seq_id = dist_label.second;
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
dist_label.first;
if(vec_dist_score > vector_query.distance_threshold) {
continue;
}
vec_results.emplace_back(seq_id, vec_dist_score);
}
// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::unordered_map<uint32_t, uint32_t> seq_id_to_rank;
for(size_t vec_index = 0; vec_index < vec_results.size(); vec_index++) {
seq_id_to_rank.emplace(vec_results[vec_index].first, vec_index);
for (size_t vec_index = 0; vec_index < dist_results.size(); vec_index++) {
seq_id_to_rank.emplace(dist_results[vec_index].second.seq_id, vec_index);
}
std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
return a.first < b.first;
});
std::sort(dist_results.begin(), dist_results.end(),
[](const auto &a, const auto &b) {
return a.second.seq_id < b.second.seq_id;
});
std::vector<KV*> kvs;
if(group_limit != 0) {
@ -3618,10 +3641,10 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
}
for(size_t res_index = 0; res_index < vec_results.size() &&
for(size_t res_index = 0; res_index < dist_results.size() &&
filter_result_iterator->validity != filter_result_iterator_t::timed_out; res_index++) {
auto& vec_result = vec_results[res_index];
auto seq_id = vec_result.first;
auto& dist_result = dist_results[res_index];
auto seq_id = dist_result.second.seq_id;
if (filter_by_provided && filter_result_iterator->is_valid(seq_id) != 1) {
continue;
@ -3651,7 +3674,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
// result overlaps with keyword search: we have to combine the scores
// old_score + (1 / rank_of_document) * WEIGHT)
found_kv->vector_distance = vec_result.second;
found_kv->vector_distance = dist_result.first;
int64_t match_score = float_to_int64_t(
(int64_t_to_float(found_kv->scores[found_kv->match_score_index])) +
((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT));
@ -3662,7 +3685,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
dist_result.first, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}
@ -3688,7 +3711,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
dist_result.first, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}
@ -3711,7 +3734,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, std::move(references));
kv.text_match_score = 0;
kv.vector_distance = vec_result.second;
kv.vector_distance = dist_result.first;
auto ret = topster->add(&kv);
vec_search_ids.push_back(seq_id);