mirror of
https://github.com/typesense/typesense.git
synced 2025-05-15 10:42:29 +08:00
hybrid search flat_search_cutoff (#2099)
Some checks failed
tests / test (push) Has been cancelled
Some checks failed
tests / test (push) Has been cancelled
* hybrid search flat_search_cutoff * refactor repeatative code * make both approach exclusive
This commit is contained in:
parent
225ed057cd
commit
9923a5455e
247
src/index.cpp
247
src/index.cpp
@ -2870,6 +2870,108 @@ Option<bool> Index::search_infix(const std::string& query, const std::string& fi
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
void process_results_bruteforce(filter_result_iterator_t* filter_result_iterator, const vector_query_t& vector_query,
|
||||
hnsw_index_t* field_vector_index, std::vector<std::pair<float, single_filter_result_t>>& dist_results) {
|
||||
|
||||
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
|
||||
auto seq_id = filter_result_iterator->seq_id;
|
||||
auto filter_result = single_filter_result_t(seq_id, std::move(filter_result_iterator->reference));
|
||||
filter_result_iterator->next();
|
||||
std::vector<float> values;
|
||||
|
||||
try {
|
||||
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
|
||||
} catch (...) {
|
||||
// likely not found
|
||||
continue;
|
||||
}
|
||||
|
||||
float dist;
|
||||
if (field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(),
|
||||
&field_vector_index->num_dim);
|
||||
} else {
|
||||
dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(),
|
||||
&field_vector_index->num_dim);
|
||||
}
|
||||
|
||||
dist_results.emplace_back(dist, filter_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_results_hnsw_index(filter_result_iterator_t* filter_result_iterator, const vector_query_t& vector_query,
|
||||
hnsw_index_t* field_vector_index, VectorFilterFunctor& filterFunctor, size_t k,
|
||||
std::vector<std::pair<float, single_filter_result_t>>& dist_results, bool is_wildcard_non_phrase_query = false) {
|
||||
|
||||
std::vector<std::pair<float, size_t>> pairs;
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
pairs = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
|
||||
} else {
|
||||
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
|
||||
}
|
||||
|
||||
std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
|
||||
return x.second < y.second;
|
||||
});
|
||||
|
||||
filter_result_iterator->reset();
|
||||
|
||||
if (!filter_result_iterator->reference.empty() && is_wildcard_non_phrase_query) {
|
||||
// We'll have to get the references of each document.
|
||||
for (auto pair: pairs) {
|
||||
if (filter_result_iterator->validity == filter_result_iterator_t::timed_out) {
|
||||
// Overriding timeout since we need to get the references of matched docs.
|
||||
filter_result_iterator->reset(true);
|
||||
search_cutoff = true;
|
||||
}
|
||||
|
||||
auto const& seq_id = pair.second;
|
||||
if (filter_result_iterator->is_valid(seq_id, search_cutoff) != 1) {
|
||||
continue;
|
||||
}
|
||||
// The seq_id must be valid otherwise it would've been filtered out upstream.
|
||||
auto filter_result = single_filter_result_t(seq_id,
|
||||
std::move(filter_result_iterator->reference));
|
||||
dist_results.emplace_back(pair.first, filter_result);
|
||||
}
|
||||
} else {
|
||||
|
||||
search_cutoff = search_cutoff || filter_result_iterator->validity ==
|
||||
filter_result_iterator_t::timed_out;
|
||||
|
||||
if(!is_wildcard_non_phrase_query) {
|
||||
std::vector<std::pair<float, size_t>> vec_results;
|
||||
|
||||
for(const auto& pair: pairs) {
|
||||
auto vec_dist_score = (field_vector_index->distance_type == cosine)
|
||||
? std::abs(pair.first) :
|
||||
pair.first;
|
||||
if (vec_dist_score > vector_query.distance_threshold) {
|
||||
continue;
|
||||
}
|
||||
vec_results.push_back(pair);
|
||||
}
|
||||
|
||||
// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
|
||||
std::sort(vec_results.begin(), vec_results.end(),
|
||||
[](const auto &a, const auto &b) {
|
||||
return a.first < b.first;
|
||||
});
|
||||
|
||||
pairs = std::move(vec_results);
|
||||
}
|
||||
|
||||
for (const auto &pair: pairs) {
|
||||
auto filter_result = single_filter_result_t(pair.second, {});
|
||||
dist_results.emplace_back(pair.first, filter_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
|
||||
const text_match_type_t match_type,
|
||||
filter_node_t*& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
@ -3101,82 +3203,17 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
filter_result_iterator->compute_iterators();
|
||||
|
||||
uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
|
||||
|
||||
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
|
||||
while (filter_result_iterator->validity == filter_result_iterator_t::valid) {
|
||||
auto seq_id = filter_result_iterator->seq_id;
|
||||
auto filter_result = single_filter_result_t(seq_id, std::move(filter_result_iterator->reference));
|
||||
filter_result_iterator->next();
|
||||
std::vector<float> values;
|
||||
|
||||
try {
|
||||
values = field_vector_index->vecdex->getDataByLabel<float>(seq_id);
|
||||
} catch (...) {
|
||||
// likely not found
|
||||
continue;
|
||||
}
|
||||
|
||||
float dist;
|
||||
if (field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(),
|
||||
&field_vector_index->num_dim);
|
||||
} else {
|
||||
dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(),
|
||||
&field_vector_index->num_dim);
|
||||
}
|
||||
|
||||
dist_results.emplace_back(dist, filter_result);
|
||||
}
|
||||
}
|
||||
filter_result_iterator->reset();
|
||||
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;
|
||||
|
||||
if(!filter_by_provided ||
|
||||
process_results_bruteforce(filter_result_iterator, vector_query, field_vector_index, dist_results);
|
||||
} else if(!filter_by_provided ||
|
||||
(filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->validity == filter_result_iterator_t::valid)) {
|
||||
dist_results.clear();
|
||||
|
||||
std::vector<std::pair<float, size_t>> pairs;
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
pairs = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
|
||||
} else {
|
||||
pairs = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
|
||||
}
|
||||
|
||||
std::sort(pairs.begin(), pairs.end(), [](auto& x, auto& y) {
|
||||
return x.second < y.second;
|
||||
});
|
||||
|
||||
filter_result_iterator->reset();
|
||||
|
||||
if (!filter_result_iterator->reference.empty()) {
|
||||
// We'll have to get the references of each document.
|
||||
for (auto pair: pairs) {
|
||||
if (filter_result_iterator->validity == filter_result_iterator_t::timed_out) {
|
||||
// Overriding timeout since we need to get the references of matched docs.
|
||||
filter_result_iterator->reset(true);
|
||||
search_cutoff = true;
|
||||
}
|
||||
|
||||
auto const& seq_id = pair.second;
|
||||
if (filter_result_iterator->is_valid(seq_id, search_cutoff) != 1) {
|
||||
continue;
|
||||
}
|
||||
// The seq_id must be valid otherwise it would've been filtered out upstream.
|
||||
auto filter_result = single_filter_result_t(seq_id,
|
||||
std::move(filter_result_iterator->reference));
|
||||
dist_results.emplace_back(pair.first, filter_result);
|
||||
}
|
||||
} else {
|
||||
for (const auto &pair: pairs) {
|
||||
auto filter_result = single_filter_result_t(pair.second, {});
|
||||
dist_results.emplace_back(pair.first, filter_result);
|
||||
}
|
||||
}
|
||||
process_results_hnsw_index(filter_result_iterator, vector_query, field_vector_index, filterFunctor, k, dist_results, true);
|
||||
}
|
||||
|
||||
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;
|
||||
|
||||
std::vector<uint32_t> nearest_ids;
|
||||
std::vector<uint32_t> eval_filter_indexes;
|
||||
|
||||
@ -3542,46 +3579,32 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
VectorFilterFunctor filterFunctor(filter_result_iterator, excluded_result_ids, excluded_result_ids_size);
|
||||
auto& field_vector_index = vector_index.at(vector_query.field_name);
|
||||
|
||||
std::vector<std::pair<float, size_t>> dist_labels;
|
||||
// use k as 100 by default for ensuring results stability in pagination
|
||||
size_t default_k = 100;
|
||||
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k) : vector_query.k;
|
||||
if(field_vector_index->distance_type == cosine) {
|
||||
std::vector<float> normalized_q(vector_query.values.size());
|
||||
hnsw_index_t::normalize_vector(vector_query.values, normalized_q);
|
||||
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(normalized_q.data(), k, vector_query.ef, &filterFunctor);
|
||||
} else {
|
||||
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, vector_query.ef, &filterFunctor);
|
||||
uint32_t filter_id_count = filter_result_iterator->approx_filter_ids_length;
|
||||
std::vector<std::pair<float, single_filter_result_t>> dist_results;
|
||||
|
||||
if (filter_by_provided && filter_id_count < vector_query.flat_search_cutoff) {
|
||||
process_results_bruteforce(filter_result_iterator, vector_query, field_vector_index, dist_results);
|
||||
} else if (!filter_by_provided || (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->validity == filter_result_iterator_t::valid)) {
|
||||
dist_results.clear();
|
||||
// use k as 100 by default for ensuring results stability in pagination
|
||||
size_t default_k = 100;
|
||||
auto k = vector_query.k == 0 ? std::max<size_t>(fetch_size, default_k)
|
||||
: vector_query.k;
|
||||
|
||||
process_results_hnsw_index(filter_result_iterator, vector_query, field_vector_index, filterFunctor, k, dist_results);
|
||||
}
|
||||
filter_result_iterator->reset();
|
||||
search_cutoff = search_cutoff || filter_result_iterator->validity == filter_result_iterator_t::timed_out;
|
||||
|
||||
std::vector<std::pair<uint32_t,float>> vec_results;
|
||||
for (const auto& dist_label : dist_labels) {
|
||||
uint32_t seq_id = dist_label.second;
|
||||
|
||||
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
|
||||
dist_label.first;
|
||||
if(vec_dist_score > vector_query.distance_threshold) {
|
||||
continue;
|
||||
}
|
||||
vec_results.emplace_back(seq_id, vec_dist_score);
|
||||
}
|
||||
|
||||
// iteration needs to happen on sorted sequence ID but score wise sort needed for compute rank fusion
|
||||
std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
std::unordered_map<uint32_t, uint32_t> seq_id_to_rank;
|
||||
|
||||
for(size_t vec_index = 0; vec_index < vec_results.size(); vec_index++) {
|
||||
seq_id_to_rank.emplace(vec_results[vec_index].first, vec_index);
|
||||
for (size_t vec_index = 0; vec_index < dist_results.size(); vec_index++) {
|
||||
seq_id_to_rank.emplace(dist_results[vec_index].second.seq_id, vec_index);
|
||||
}
|
||||
|
||||
std::sort(vec_results.begin(), vec_results.end(), [](const auto& a, const auto& b) {
|
||||
return a.first < b.first;
|
||||
});
|
||||
|
||||
std::sort(dist_results.begin(), dist_results.end(),
|
||||
[](const auto &a, const auto &b) {
|
||||
return a.second.seq_id < b.second.seq_id;
|
||||
});
|
||||
|
||||
std::vector<KV*> kvs;
|
||||
if(group_limit != 0) {
|
||||
@ -3618,10 +3641,10 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
group_by_field_it_vec = get_group_by_field_iterators(group_by_fields);
|
||||
}
|
||||
|
||||
for(size_t res_index = 0; res_index < vec_results.size() &&
|
||||
for(size_t res_index = 0; res_index < dist_results.size() &&
|
||||
filter_result_iterator->validity != filter_result_iterator_t::timed_out; res_index++) {
|
||||
auto& vec_result = vec_results[res_index];
|
||||
auto seq_id = vec_result.first;
|
||||
auto& dist_result = dist_results[res_index];
|
||||
auto seq_id = dist_result.second.seq_id;
|
||||
|
||||
if (filter_by_provided && filter_result_iterator->is_valid(seq_id) != 1) {
|
||||
continue;
|
||||
@ -3651,7 +3674,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
// result overlaps with keyword search: we have to combine the scores
|
||||
|
||||
// old_score + (1 / rank_of_document) * WEIGHT)
|
||||
found_kv->vector_distance = vec_result.second;
|
||||
found_kv->vector_distance = dist_result.first;
|
||||
int64_t match_score = float_to_int64_t(
|
||||
(int64_t_to_float(found_kv->scores[found_kv->match_score_index])) +
|
||||
((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT));
|
||||
@ -3662,7 +3685,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
|
||||
geopoint_indices, seq_id, references, eval_filter_indexes,
|
||||
match_score, scores, match_score_index, should_skip,
|
||||
vec_result.second, collection_name);
|
||||
dist_result.first, collection_name);
|
||||
if (!compute_sort_scores_op.ok()) {
|
||||
return compute_sort_scores_op;
|
||||
}
|
||||
@ -3688,7 +3711,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
|
||||
geopoint_indices, seq_id, references, eval_filter_indexes,
|
||||
match_score, scores, match_score_index, should_skip,
|
||||
vec_result.second, collection_name);
|
||||
dist_result.first, collection_name);
|
||||
if (!compute_sort_scores_op.ok()) {
|
||||
return compute_sort_scores_op;
|
||||
}
|
||||
@ -3711,7 +3734,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
|
||||
}
|
||||
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, std::move(references));
|
||||
kv.text_match_score = 0;
|
||||
kv.vector_distance = vec_result.second;
|
||||
kv.vector_distance = dist_result.first;
|
||||
|
||||
auto ret = topster->add(&kv);
|
||||
vec_search_ids.push_back(seq_id);
|
||||
|
Loading…
x
Reference in New Issue
Block a user