Add max_filter_by_candidates option. (#1969)

* Add `max_filter_by_candidates` option.

* Add tests.

* Fix tests.
This commit is contained in:
Harpreet Sangar 2024-09-30 15:13:17 +05:30 committed by GitHub
parent 087c6e2082
commit a891ff1a15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 115 additions and 22 deletions

View File

@ -584,7 +584,8 @@ public:
bool synonym_prefix = false,
uint32_t synonym_num_typos = 0,
bool enable_lazy_filter = false,
bool enable_typos_for_alpha_numerical_tokens = true) const;
bool enable_typos_for_alpha_numerical_tokens = true,
const size_t& max_filter_by_candidates = DEFAULT_FILTER_BY_CANDIDATES) const;
Option<bool> get_filter_ids(const std::string & filter_query, filter_result_t& filter_result,
const bool& should_timeout = true) const;

View File

@ -7,6 +7,7 @@
#include "store.h"
constexpr uint32_t COMPUTE_FILTER_ITERATOR_THRESHOLD = 25'000;
constexpr size_t DEFAULT_FILTER_BY_CANDIDATES = 4;
enum NUM_COMPARATOR {
LESS_THAN,

View File

@ -261,6 +261,8 @@ private:
std::vector<std::vector<posting_list_t*>> posting_lists;
std::vector<std::vector<posting_list_t::iterator_t>> posting_list_iterators;
std::vector<posting_list_t*> expanded_plists;
/// Controls the number of similar words that Typesense considers during fuzzy search for filter_by values.
size_t max_filter_by_candidates;
bool is_not_equals_iterator = false;
uint32_t equals_iterator_id = 0;
@ -347,11 +349,13 @@ public:
filter_result_iterator_t() = default;
explicit filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count,
const size_t& max_candidates = DEFAULT_FILTER_BY_CANDIDATES,
uint64_t search_begin_us = 0, uint64_t search_stop_us = UINT64_MAX);
explicit filter_result_iterator_t(const std::string& collection_name,
Index const* const index, filter_node_t const* const filter_node,
const bool& enable_lazy_evaluation = false,
const size_t& max_candidates = DEFAULT_FILTER_BY_CANDIDATES,
uint64_t search_begin_us = 0, uint64_t search_stop_us = UINT64_MAX);
~filter_result_iterator_t();

View File

@ -177,6 +177,7 @@ struct search_args {
drop_tokens_param_t drop_tokens_mode;
bool enable_lazy_filter;
size_t max_filter_by_candidates;
search_args(std::vector<query_tokens_t> field_query_tokens, std::vector<search_field_t> search_fields,
const text_match_type_t match_type,
@ -194,7 +195,7 @@ struct search_args {
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits, const enable_t split_join_tokens, vector_query_t& vector_query,
size_t facet_sample_percent, size_t facet_sample_threshold, drop_tokens_param_t drop_tokens_mode,
bool enable_lazy_filter) :
bool enable_lazy_filter, const size_t max_filter_by_candidates) :
field_query_tokens(field_query_tokens),
search_fields(search_fields), match_type(match_type), filter_tree_root(filter_tree_root), facets(facets),
included_ids(included_ids), excluded_ids(excluded_ids), sort_fields_std(sort_fields_std),
@ -213,7 +214,8 @@ struct search_args {
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits),
split_join_tokens(split_join_tokens), vector_query(vector_query),
facet_sample_percent(facet_sample_percent), facet_sample_threshold(facet_sample_threshold),
drop_tokens_mode(drop_tokens_mode), enable_lazy_filter(enable_lazy_filter) {
drop_tokens_mode(drop_tokens_mode), enable_lazy_filter(enable_lazy_filter),
max_filter_by_candidates(max_filter_by_candidates) {
const size_t topster_size = std::max((size_t)1, max_hits); // needs to be atleast 1 since scoring is mandatory
topster = new Topster(topster_size, group_limit);
@ -712,8 +714,8 @@ public:
bool synonym_prefix = false,
uint32_t synonym_num_typos = 0,
bool enable_lazy_filter = false,
bool enable_typos_for_alpha_numerical_tokens = true
) const;
bool enable_typos_for_alpha_numerical_tokens = true,
const size_t& max_filter_by_candidates = DEFAULT_FILTER_BY_CANDIDATES) const;
void remove_field(uint32_t seq_id, nlohmann::json& document, const std::string& field_name,
const bool is_update);

View File

@ -1684,7 +1684,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
bool synonym_prefix,
uint32_t synonyms_num_typos,
bool enable_lazy_filter,
bool enable_typos_for_alpha_numerical_tokens) const {
bool enable_typos_for_alpha_numerical_tokens,
const size_t& max_filter_by_candidates) const {
std::shared_lock lock(mutex);
// setup thread local vars
@ -2314,7 +2315,7 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
max_extra_prefix, max_extra_suffix, facet_query_num_typos,
filter_curated_hits, split_join_tokens, vector_query,
facet_sample_percent, facet_sample_threshold, drop_tokens_param,
enable_lazy_filter);
enable_lazy_filter, max_filter_by_candidates);
std::unique_ptr<search_args> search_params_guard(search_params);

View File

@ -1259,6 +1259,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
const char *ENABLE_TYPOS_FOR_NUMERICAL_TOKENS = "enable_typos_for_numerical_tokens";
const char *ENABLE_TYPOS_FOR_ALPHA_NUMERICAL_TOKENS = "enable_typos_for_alpha_numerical_tokens";
const char *ENABLE_LAZY_FILTER = "enable_lazy_filter";
const char *MAX_FILTER_BY_CANDIDATES = "max_filter_by_candidates";
const char *SYNONYM_PREFIX = "synonym_prefix";
const char *SYNONYM_NUM_TYPOS = "synonym_num_typos";
@ -1393,6 +1394,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
bool enable_typos_for_numerical_tokens = true;
bool enable_typos_for_alpha_numerical_tokens = true;
bool enable_lazy_filter = Config::get_instance().get_enable_lazy_filter();
size_t max_filter_by_candidates = DEFAULT_FILTER_BY_CANDIDATES;
std::string facet_strategy = "automatic";
@ -1437,6 +1439,7 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
{REMOTE_EMBEDDING_TIMEOUT_MS, &remote_embedding_timeout_ms},
{REMOTE_EMBEDDING_NUM_TRIES, &remote_embedding_num_tries},
{SYNONYM_NUM_TYPOS, &synonym_num_typos},
{MAX_FILTER_BY_CANDIDATES, &max_filter_by_candidates}
};
std::unordered_map<std::string, std::string*> str_values = {
@ -1693,7 +1696,8 @@ Option<bool> CollectionManager::do_search(std::map<std::string, std::string>& re
synonym_prefix,
synonym_num_typos,
enable_lazy_filter,
enable_typos_for_alpha_numerical_tokens);
enable_typos_for_alpha_numerical_tokens,
max_filter_by_candidates);
uint64_t timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin).count();

View File

@ -1499,7 +1499,7 @@ void filter_result_iterator_t::init(const bool& enable_lazy_evaluation) {
}
value_tokens.back().is_prefix_searched = true;
filter_result_iterator_t filter_result_it(nullptr, 0);
filter_result_iterator_t dummy_it(nullptr, 0);
std::vector<sort_by> sort_fields;
std::vector<std::vector<art_leaf*>> searched_filters;
tsl::htrie_map<char, token_leaf> qtoken_set;
@ -1510,19 +1510,18 @@ void filter_result_iterator_t::init(const bool& enable_lazy_evaluation) {
std::vector<std::string> group_by_fields;
std::set<uint64> query_hashes;
size_t typo_tokens_threshold = 0;
size_t max_candidates = 4;
size_t min_len_1typo = 0;
size_t min_len_2typo = 0;
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values{};
const std::vector<size_t> geopoint_indices;
auto fuzzy_search_fields_op = index->fuzzy_search_fields(fq_fields, value_tokens, {}, text_match_type_t::max_score,
nullptr, 0, &filter_result_it, {}, {}, sort_fields,
nullptr, 0, &dummy_it, {}, {}, sort_fields,
{0}, searched_filters, qtoken_set, topster,
groups_processed, all_result_ids, all_result_ids_len,
0, group_by_fields, false, false, false, false,
query_hashes, MAX_SCORE, {true}, typo_tokens_threshold,
false, max_candidates, min_len_1typo, min_len_2typo,
false, max_filter_by_candidates, min_len_1typo, min_len_2typo,
0, nullptr, field_values, geopoint_indices, "", false);
delete[] all_result_ids;
if(!fuzzy_search_fields_op.ok()) {
@ -2194,7 +2193,7 @@ void filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& len
filter_result_iterator_t::filter_result_iterator_t(const std::string& collection_name, const Index *const index,
const filter_node_t *const filter_node,
const bool& enable_lazy_evaluation,
const bool& enable_lazy_evaluation, const size_t& max_candidates,
uint64_t search_begin, uint64_t search_stop) :
collection_name(collection_name),
index(index),
@ -2211,7 +2210,8 @@ filter_result_iterator_t::filter_result_iterator_t(const std::string& collection
// Generate the iterator tree and then initialize each node.
if (filter_node->isOperator) {
left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, enable_lazy_evaluation);
left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, enable_lazy_evaluation,
max_candidates);
// If left subtree of && operator is invalid, we don't have to evaluate its right subtree.
if (filter_node->filter_operator == AND && left_it->validity == invalid) {
validity = invalid;
@ -2221,9 +2221,12 @@ filter_result_iterator_t::filter_result_iterator_t(const std::string& collection
return;
}
right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, enable_lazy_evaluation);
right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, enable_lazy_evaluation,
max_candidates);
}
max_filter_by_candidates = max_candidates;
init(enable_lazy_evaluation);
if (!validity) {
@ -2389,7 +2392,7 @@ filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_le
delete_filter_node = true;
}
filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count,
filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count, const size_t& max_candidates,
uint64_t search_begin, uint64_t search_stop) {
filter_result.count = approx_filter_ids_length = ids_count;
filter_result.docs = ids;
@ -2405,6 +2408,8 @@ filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t
timeout_info = std::make_unique<filter_result_iterator_timeout_info>(search_begin, search_stop);
}
}
max_filter_by_candidates = max_candidates;
}
void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& fit,
@ -2412,7 +2417,8 @@ void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& fit,
fit->reset();
auto root_iterator = new filter_result_iterator_t(std::min(phrase_result_count, fit->approx_filter_ids_length));
root_iterator->left_it = new filter_result_iterator_t(phrase_result_ids, phrase_result_count);
root_iterator->left_it = new filter_result_iterator_t(phrase_result_ids, phrase_result_count,
fit->max_filter_by_candidates);
root_iterator->right_it = fit;
root_iterator->timeout_info = std::move(fit->timeout_info);

View File

@ -1806,6 +1806,7 @@ Option<bool> Index::do_filtering_with_lock(filter_node_t* const filter_tree_root
std::shared_lock lock(mutex);
auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, false,
DEFAULT_FILTER_BY_CANDIDATES,
search_begin_us, should_timeout ? search_stop_us : UINT64_MAX);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
@ -1866,6 +1867,7 @@ Option<bool> Index::do_reference_filtering_with_lock(filter_node_t* const filter
std::shared_lock lock(mutex);
auto ref_filter_result_iterator = filter_result_iterator_t(ref_collection_name, this, filter_tree_root, false,
DEFAULT_FILTER_BY_CANDIDATES,
search_begin_us, search_stop_us);
auto filter_init_op = ref_filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
@ -2386,7 +2388,8 @@ Option<bool> Index::run_search(search_args* search_params, const std::string& co
synonym_prefix,
synonym_num_typos,
search_params->enable_lazy_filter,
enable_typos_for_alpha_numerical_tokens
enable_typos_for_alpha_numerical_tokens,
search_params->max_filter_by_candidates
);
return res;
@ -2885,11 +2888,12 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
bool enable_synonyms, bool synonym_prefix,
uint32_t synonym_num_typos,
bool enable_lazy_filter,
bool enable_typos_for_alpha_numerical_tokens) const {
bool enable_typos_for_alpha_numerical_tokens, const size_t& max_filter_by_candidates) const {
std::shared_lock lock(mutex);
auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root,
enable_lazy_filter, search_begin_us, search_stop_us);
enable_lazy_filter, max_filter_by_candidates,
search_begin_us, search_stop_us);
std::unique_ptr<filter_result_iterator_t> filter_iterator_guard(filter_result_iterator);
auto filter_init_op = filter_result_iterator->init_status();
@ -3211,6 +3215,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
// if filters were not provided, use the seq_ids index to generate the list of all document ids
if (!filter_by_provided) {
filter_result_iterator = new filter_result_iterator_t(seq_ids->uncompress(), seq_ids->num_ids(),
max_filter_by_candidates,
search_begin_us, search_stop_us);
filter_iterator_guard.reset(filter_result_iterator);
}
@ -6416,6 +6421,7 @@ Option<bool> Index::populate_sort_mapping(int* sort_order, std::vector<size_t>&
auto count = sort_fields_std[i].eval_expressions.size();
for (uint32_t j = 0; j < count; j++) {
auto filter_result_iterator = filter_result_iterator_t("", this, eval_exp.filter_trees[j], false,
DEFAULT_FILTER_BY_CANDIDATES,
search_begin_us, search_stop_us);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {

View File

@ -3020,4 +3020,72 @@ TEST_F(CollectionFilteringTest, FilterOnStemmedField) {
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ("125", results["hits"][0]["document"]["id"].get<std::string>());
}
}
TEST_F(CollectionFilteringTest, MaxFilterByCandidates) {
Collection *coll1;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
for(size_t i = 0; i < 20; i++) {
nlohmann::json doc;
doc["title"] = "Independent" + std::to_string(i);
doc["points"] = i;
coll1->add(doc.dump());
}
std::map<std::string, std::string> req_params = {
{"collection", "coll1"},
{"q", "*"},
{"filter_by", "title:independent*"},
};
nlohmann::json embedded_params;
std::string json_res;
auto now_ts = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count();
auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_TRUE(search_op.ok());
auto res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(4, res_obj["found"].get<size_t>());
ASSERT_EQ(4, res_obj["hits"].size());
ASSERT_EQ("Independent19", res_obj["hits"][0]["document"]["title"]);
ASSERT_EQ("Independent18", res_obj["hits"][1]["document"]["title"]);
ASSERT_EQ("Independent17", res_obj["hits"][2]["document"]["title"]);
ASSERT_EQ("Independent16", res_obj["hits"][3]["document"]["title"]);
req_params = {
{"collection", "coll1"},
{"q", "*"},
{"filter_by", "title:independent*"},
{"max_filter_by_candidates", "0"}
};
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_TRUE(search_op.ok());
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(0, res_obj["found"].get<size_t>());
ASSERT_EQ(0, res_obj["hits"].size());
req_params = {
{"collection", "coll1"},
{"q", "*"},
{"filter_by", "title:independent*"},
{"max_filter_by_candidates", "1"}
};
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_TRUE(search_op.ok());
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(1, res_obj["found"].get<size_t>());
ASSERT_EQ(1, res_obj["hits"].size());
ASSERT_EQ("Independent19", res_obj["hits"][0]["document"]["title"]);
}

View File

@ -708,7 +708,7 @@ TEST_F(FilterTest, FilterTreeIteratorTimeout) {
for (auto i = 0; i < count; i++) {
filter_ids[i] = i;
}
auto filter_iterator = new filter_result_iterator_t(filter_ids, count,
auto filter_iterator = new filter_result_iterator_t(filter_ids, count, DEFAULT_FILTER_BY_CANDIDATES,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count(),
10000000); // Timeout after 10 seconds