Fix phrase search.

This commit is contained in:
Harpreet Sangar 2023-05-03 18:42:31 +05:30
parent 91c1c321dc
commit 9362c5a5e0
7 changed files with 155 additions and 80 deletions

View File

@ -279,7 +279,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len,
int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost,
const size_t max_words, const token_ordering token_order,
const bool prefix, bool last_token, const std::string& prev_token,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
std::vector<art_leaf *> &results, std::set<std::string>& exclude_leaves);
void encode_int32(int32_t n, unsigned char *chars);

View File

@ -109,6 +109,8 @@ private:
std::vector<std::vector<posting_list_t::iterator_t>> posting_list_iterators;
std::vector<posting_list_t*> expanded_plists;
bool delete_filter_node = false;
/// Initializes the state of iterator node after it's creation.
void init();
@ -127,6 +129,8 @@ private:
/// Finds the next match for a filter on string field.
void get_string_filter_next_match(const bool& field_is_array);
explicit filter_result_iterator_t(uint32_t approx_filter_ids_length);
public:
uint32_t seq_id = 0;
/// Collection name -> references
@ -143,6 +147,8 @@ public:
/// iterator reaching it's end. (is_valid would be false in both these cases)
uint32_t approx_filter_ids_length;
explicit filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count);
explicit filter_result_iterator_t(const std::string collection_name,
Index const* const index, filter_node_t const* const filter_node,
uint32_t approx_filter_ids_length = UINT32_MAX);
@ -193,4 +199,7 @@ public:
/// Performs AND with the contents of A and allocates a new array of results.
/// \return size of the results array
uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results);
static void add_phrase_ids(filter_result_iterator_t*& filter_result_iterator,
uint32_t* phrase_result_ids, const uint32_t& phrase_result_count);
};

View File

@ -408,7 +408,7 @@ private:
void search_all_candidates(const size_t num_search_fields,
const text_match_type_t match_type,
const std::vector<search_field_t>& the_fields,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::unordered_set<uint32_t>& excluded_group_ids,
const std::vector<sort_by>& sort_fields,
@ -723,7 +723,7 @@ public:
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size, const std::unordered_set<uint32_t>& excluded_group_ids,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length,
filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length,
const size_t concurrency,
const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
@ -764,7 +764,7 @@ public:
std::vector<std::vector<art_leaf*>>& searched_queries, const size_t group_limit,
const std::vector<std::string>& group_by_fields, const size_t max_extra_prefix,
const size_t max_extra_suffix, const std::vector<token_t>& query_tokens, Topster* actual_topster,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const int sort_order[3],
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
@ -795,7 +795,7 @@ public:
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
std::set<uint64>& query_hashes,
const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
@ -812,6 +812,7 @@ public:
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
const std::vector<uint32_t>& curated_ids_sorted,
filter_result_iterator_t*& filter_result_iterator,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
const std::set<uint32_t>& curated_ids,
@ -819,8 +820,7 @@ public:
const std::unordered_set<uint32_t>& excluded_group_ids,
Topster* curated_topster,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
bool is_wildcard_query,
uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const;
bool is_wildcard_query) const;
void fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const std::vector<token_t>& query_tokens,
@ -828,7 +828,7 @@ public:
const text_match_type_t match_type,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const std::vector<uint32_t>& curated_ids,
const std::unordered_set<uint32_t>& excluded_group_ids,
const std::vector<sort_by>& sort_fields,
@ -857,7 +857,7 @@ public:
const std::string& previous_token_str,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
std::vector<uint32_t>& prev_token_doc_ids,
@ -879,7 +879,7 @@ public:
const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,
const bool search_all_candidates,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const uint32_t total_cost,
const int syn_orig_num_tokens,
const uint32_t* exclude_token_ids,
@ -931,7 +931,7 @@ public:
void process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<std::string>& group_by_fields,
const size_t group_limit, const bool filter_curated_hits,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
std::set<uint32_t>& curated_ids,
std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
std::vector<uint32_t>& included_ids_vec,

View File

@ -1003,7 +1003,7 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str
bool validate_and_add_leaf(art_leaf* leaf,
const std::string& prev_token, const art_leaf* prev_leaf,
const art_leaf* exact_leaf,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
std::set<std::string>& exclude_leaves,
std::vector<art_leaf *>& results) {
if(leaf == exact_leaf) {
@ -1016,10 +1016,10 @@ bool validate_and_add_leaf(art_leaf* leaf,
}
if(prev_token.empty() || !prev_leaf) {
if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) {
if (filter_result_iterator->is_valid && !filter_result_iterator->contains_atleast_one(leaf->values)) {
return false;
}
} else if (!filter_result_iterator.is_valid) {
} else if (!filter_result_iterator->is_valid) {
std::vector<uint32_t> prev_leaf_ids;
posting_t::merge({prev_leaf->values}, prev_leaf_ids);
@ -1031,8 +1031,8 @@ bool validate_and_add_leaf(art_leaf* leaf,
posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids);
bool found = false;
for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator.is_valid && !found; i++) {
found = (filter_result_iterator.valid(leaf_ids[i]) == 1);
for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator->is_valid && !found; i++) {
found = (filter_result_iterator->valid(leaf_ids[i]) == 1);
}
if (!found) {
@ -1145,7 +1145,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r
int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results,
const art_leaf* exact_leaf,
const bool last_token, const std::string& prev_token,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const art_tree* t, std::set<std::string>& exclude_leaves, std::vector<art_leaf *>& results) {
printf("INSIDE art_topk_iter: root->type: %d\n", root->type);
@ -1177,7 +1177,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r
validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator,
exclude_leaves, results);
filter_result_iterator.reset();
filter_result_iterator->reset();
if (++num_processed % 1024 == 0 && (microseconds(
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) {
@ -1767,7 +1767,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len,
int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost,
const size_t max_words, const token_ordering token_order,
const bool prefix, bool last_token, const std::string& prev_token,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
std::vector<art_leaf *> &results, std::set<std::string>& exclude_leaves) {
std::vector<const art_node*> nodes;

View File

@ -1252,6 +1252,10 @@ filter_result_iterator_t::~filter_result_iterator_t() {
delete expanded_plist;
}
if (delete_filter_node) {
delete filter_node;
}
delete left_it;
delete right_it;
}
@ -1343,3 +1347,44 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n,
next();
}
}
filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_length) :
approx_filter_ids_length(approx_filter_ids_length) {
filter_node = new filter_node_t(AND, nullptr, nullptr);
delete_filter_node = true;
}
filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count) {
filter_result.count = approx_filter_ids_length = ids_count;
filter_result.docs = ids;
is_valid = ids_count > 0;
if (is_valid) {
seq_id = filter_result.docs[result_index];
is_filter_result_initialized = true;
filter_node = new filter_node_t({"dummy", {}, {}});
delete_filter_node = true;
}
}
void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_result_iterator,
uint32_t* phrase_result_ids, const uint32_t& phrase_result_count) {
auto root_iterator = new filter_result_iterator_t(std::min(phrase_result_count, filter_result_iterator->approx_filter_ids_length));
root_iterator->left_it = new filter_result_iterator_t(phrase_result_ids, phrase_result_count);
root_iterator->right_it = filter_result_iterator;
auto& left_it = root_iterator->left_it;
auto& right_it = root_iterator->right_it;
while (left_it->is_valid && right_it->is_valid && left_it->seq_id != right_it->seq_id) {
if (left_it->seq_id < right_it->seq_id) {
left_it->skip_to(right_it->seq_id);
} else {
right_it->skip_to(left_it->seq_id);
}
}
root_iterator->is_valid = left_it->is_valid && right_it->is_valid;
root_iterator->seq_id = left_it->seq_id;
filter_result_iterator = root_iterator;
}

View File

@ -1286,7 +1286,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) {
void Index::search_all_candidates(const size_t num_search_fields,
const text_match_type_t match_type,
const std::vector<search_field_t>& the_fields,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::unordered_set<uint32_t>& excluded_group_ids,
const std::vector<sort_by>& sort_fields,
@ -2270,14 +2270,16 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
return rearrange_op;
}
auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root,
auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root,
approx_filter_ids_length);
auto filter_init_op = filter_result_iterator.init_status();
std::unique_ptr<filter_result_iterator_t> filter_iterator_guard(filter_result_iterator);
auto filter_init_op = filter_result_iterator->init_status();
if (!filter_init_op.ok()) {
return filter_init_op;
}
if (filter_tree_root != nullptr && !filter_result_iterator.is_valid) {
if (filter_tree_root != nullptr && !filter_result_iterator->is_valid) {
return Option(true);
}
@ -2291,7 +2293,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits,
filter_result_iterator, curated_ids, included_ids_map,
included_ids_vec, excluded_group_ids);
filter_result_iterator.reset();
filter_result_iterator->reset();
std::vector<uint32_t> curated_ids_sorted(curated_ids.begin(), curated_ids.end());
std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end());
@ -2322,24 +2324,19 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
field_query_tokens[0].q_include_tokens[0].value == "*";
// TODO: Do AND with phrase ids at last
// handle phrase searches
uint32_t* phrase_result_ids = nullptr;
uint32_t phrase_result_count = 0;
std::unique_ptr<uint32_t> phrase_result_ids_guard;
if (!field_query_tokens[0].q_phrases.empty()) {
do_phrase_search(num_search_fields, the_fields, field_query_tokens,
sort_fields_std, searched_queries, group_limit, group_by_fields,
topster, sort_order, field_values, geopoint_indices, curated_ids_sorted,
all_result_ids, all_result_ids_len, groups_processed, curated_ids,
filter_result_iterator, all_result_ids, all_result_ids_len, groups_processed, curated_ids,
excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster,
included_ids_map, is_wildcard_query,
phrase_result_ids, phrase_result_count);
included_ids_map, is_wildcard_query);
phrase_result_ids_guard.reset(phrase_result_ids);
filter_iterator_guard.release();
filter_iterator_guard.reset(filter_result_iterator);
if (phrase_result_count == 0) {
if (filter_result_iterator->approx_filter_ids_length == 0) {
goto process_search_results;
}
}
@ -2347,7 +2344,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
// for phrase query, parser will set field_query_tokens to "*", need to handle that
if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) {
const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0);
bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.is_valid);
bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator->is_valid);
if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() &&
sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id &&
@ -2395,8 +2392,10 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
Option<bool> parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema,
store, doc_id_prefix, filter_tree_root);
filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root);
approx_filter_ids_length = filter_result_iterator.approx_filter_ids_length;
filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root);
filter_iterator_guard.reset(filter_result_iterator);
approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length;
}
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
@ -2414,9 +2413,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
uint32_t filter_id_count = 0;
while (!no_filters_provided &&
filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) {
auto seq_id = filter_result_iterator.seq_id;
filter_result_iterator.next();
filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator->is_valid) {
auto seq_id = filter_result_iterator->seq_id;
filter_result_iterator->next();
std::vector<float> values;
try {
@ -2440,12 +2439,13 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
dist_labels.emplace_back(dist, seq_id);
filter_id_count++;
}
filter_result_iterator->reset();
if(no_filters_provided ||
(filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.is_valid)) {
(filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->is_valid)) {
dist_labels.clear();
VectorFilterFunctor filterFunctor(&filter_result_iterator);
VectorFilterFunctor filterFunctor(filter_result_iterator);
if(field_vector_index->distance_type == cosine) {
std::vector<float> normalized_q(vector_query.values.size());
@ -2455,8 +2455,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor);
}
}
filter_result_iterator.reset();
filter_result_iterator->reset();
std::vector<uint32_t> nearest_ids;
@ -2511,7 +2510,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
all_result_ids, all_result_ids_len,
filter_result_iterator, approx_filter_ids_length, concurrency,
sort_order, field_values, geopoint_indices);
filter_result_iterator.reset();
filter_result_iterator->reset();
}
// filter tree was initialized to have all sequence ids in this flow.
@ -2572,7 +2571,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
typo_tokens_threshold, exhaustive_search,
max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order,
field_values, geopoint_indices);
filter_result_iterator.reset();
filter_result_iterator->reset();
// try split/joining tokens if no results are found
if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) {
@ -2609,7 +2608,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match,
prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search,
max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices);
filter_result_iterator.reset();
filter_result_iterator->reset();
}
}
@ -2625,7 +2624,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
filter_result_iterator, query_hashes,
sort_order, field_values, geopoint_indices,
qtoken_set);
filter_result_iterator.reset();
filter_result_iterator->reset();
// gather up both original query and synonym queries and do drop tokens
@ -2682,7 +2681,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
token_order, prefixes, typo_tokens_threshold,
exhaustive_search, max_candidates, min_len_1typo,
min_len_2typo, -1, sort_order, field_values, geopoint_indices);
filter_result_iterator.reset();
filter_result_iterator->reset();
} else {
break;
@ -2699,7 +2698,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
sort_order, field_values, geopoint_indices,
curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed);
filter_result_iterator.reset();
filter_result_iterator->reset();
if(!vector_query.field_name.empty()) {
// check at least one of sort fields is text match
@ -2716,7 +2715,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
constexpr float TEXT_MATCH_WEIGHT = 0.7;
constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT;
VectorFilterFunctor filterFunctor(&filter_result_iterator);
VectorFilterFunctor filterFunctor(filter_result_iterator);
auto& field_vector_index = vector_index.at(vector_query.field_name);
std::vector<std::pair<float, size_t>> dist_labels;
auto k = std::max<size_t>(vector_query.k, fetch_size);
@ -2728,7 +2727,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
} else {
dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor);
}
filter_result_iterator.reset();
filter_result_iterator->reset();
std::vector<std::pair<uint32_t,float>> vec_results;
for (const auto& dist_label : dist_labels) {
@ -2938,7 +2937,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
void Index::process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids,
const std::vector<std::string>& group_by_fields, const size_t group_limit,
const bool filter_curated_hits, filter_result_iterator_t& filter_result_iterator,
const bool filter_curated_hits, filter_result_iterator_t* const filter_result_iterator,
std::set<uint32_t>& curated_ids,
std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
std::vector<uint32_t>& included_ids_vec,
@ -2961,9 +2960,9 @@ void Index::process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>
// if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition
std::set<uint32_t> included_ids_set;
if(filter_result_iterator.is_valid && filter_curated_hits) {
if(filter_result_iterator->is_valid && filter_curated_hits) {
for (const auto &included_id: included_ids_vec) {
auto result = filter_result_iterator.valid(included_id);
auto result = filter_result_iterator->valid(included_id);
if (result == -1) {
break;
@ -3030,7 +3029,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const std::vector<uint32_t>& curated_ids,
const std::unordered_set<uint32_t>& excluded_group_ids,
const std::vector<sort_by> & sort_fields,
@ -3176,7 +3175,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len,
costs[token_index], costs[token_index], max_candidates, token_order, prefix_search,
last_token, prev_token, filter_result_iterator, field_leaves, unique_tokens);
filter_result_iterator.reset();
filter_result_iterator->reset();
/*auto timeMillis = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - begin).count();
@ -3207,7 +3206,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
token_candidates_vec.back().candidates[0],
the_fields, num_search_fields, filter_result_iterator, exclude_token_ids,
exclude_token_ids_size, prev_token_doc_ids, popular_field_ids);
filter_result_iterator.reset();
filter_result_iterator->reset();
for(size_t field_id: query_field_ids) {
auto& the_field = the_fields[field_id];
@ -3230,7 +3229,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len,
costs[token_index], costs[token_index], max_candidates, token_order, prefix_search,
false, "", filter_result_iterator, field_leaves, unique_tokens);
filter_result_iterator.reset();
filter_result_iterator->reset();
if(field_leaves.empty()) {
// look at the next field
@ -3294,7 +3293,6 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
exhaustive_search, max_candidates,
syn_orig_num_tokens, sort_order, field_values, geopoint_indices,
query_hashes, id_buff);
filter_result_iterator.reset();
if(id_buff.size() > 1) {
gfx::timsort(id_buff.begin(), id_buff.end());
@ -3355,7 +3353,7 @@ void Index::find_across_fields(const token_t& previous_token,
const std::string& previous_token_str,
const std::vector<search_field_t>& the_fields,
const size_t num_search_fields,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
std::vector<uint32_t>& prev_token_doc_ids,
std::vector<size_t>& top_prefix_field_ids) const {
@ -3366,7 +3364,7 @@ void Index::find_across_fields(const token_t& previous_token,
// used to track plists that must be destructed once done
std::vector<posting_list_t*> expanded_plists;
result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator);
result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator);
const bool prefix_search = previous_token.is_prefix_searched;
const uint32_t token_num_typos = previous_token.num_typos;
@ -3447,7 +3445,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
const std::vector<std::string>& group_by_fields,
const bool prioritize_exact_match,
const bool prioritize_token_position,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const uint32_t total_cost, const int syn_orig_num_tokens,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::unordered_set<uint32_t>& excluded_group_ids,
@ -3508,7 +3506,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
// used to track plists that must be destructed once done
std::vector<posting_list_t*> expanded_plists;
result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator);
result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator);
// for each token, find the posting lists across all query_by fields
for(size_t ti = 0; ti < query_tokens.size(); ti++) {
@ -3970,6 +3968,7 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector<s
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
const std::vector<uint32_t>& curated_ids_sorted,
filter_result_iterator_t*& filter_result_iterator,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
const std::set<uint32_t>& curated_ids,
@ -3977,9 +3976,10 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector<s
const std::unordered_set<uint32_t>& excluded_group_ids,
Topster* curated_topster,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
bool is_wildcard_query,
uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const {
bool is_wildcard_query) const {
uint32_t* phrase_result_ids = nullptr;
uint32_t phrase_result_count = 0;
std::map<uint32_t, size_t> phrase_match_id_scores;
for(size_t i = 0; i < num_search_fields; i++) {
@ -4068,12 +4068,19 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector<s
excluded_result_ids_size, phrase_result_ids, phrase_result_count, curated_ids_sorted);
collate_included_ids({}, included_ids_map, curated_topster, searched_queries);
// AND phrase id matches with filter ids
if(filter_result_iterator->is_valid) {
filter_result_iterator_t::add_phrase_ids(filter_result_iterator, phrase_result_ids, phrase_result_count);
} else {
delete filter_result_iterator;
filter_result_iterator = new filter_result_iterator_t(phrase_result_ids, phrase_result_count);
}
size_t filter_index = 0;
if(is_wildcard_query) {
all_result_ids = new uint32_t[phrase_result_count];
std::copy(phrase_result_ids, phrase_result_ids + phrase_result_count, all_result_ids);
all_result_ids_len = phrase_result_count;
all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids);
filter_result_iterator->reset();
} else {
// this means that the there are non-phrase tokens in the query
// so we cannot directly copy to the all_result_ids array
@ -4081,8 +4088,8 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector<s
}
// populate topster
for(size_t i = 0; i < std::min<size_t>(10000, phrase_result_count); i++) {
auto seq_id = phrase_result_ids[i];
for(size_t i = 0; i < std::min<size_t>(10000, all_result_ids_len); i++) {
auto seq_id = all_result_ids[i];
int64_t match_score = phrase_match_id_scores[seq_id];
int64_t scores[3] = {0};
@ -4135,7 +4142,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
std::set<uint64>& query_hashes,
const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
@ -4163,7 +4170,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
const std::vector<std::string>& group_by_fields, const size_t max_extra_prefix,
const size_t max_extra_suffix,
const std::vector<token_t>& query_tokens, Topster* actual_topster,
filter_result_iterator_t& filter_result_iterator,
filter_result_iterator_t* const filter_result_iterator,
const int sort_order[3],
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
@ -4197,10 +4204,10 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
raw_infix_ids_length = infix_ids.size();
}
if(filter_result_iterator.is_valid) {
if(filter_result_iterator->is_valid) {
uint32_t *filtered_raw_infix_ids = nullptr;
raw_infix_ids_length = filter_result_iterator.and_scalar(raw_infix_ids, raw_infix_ids_length,
raw_infix_ids_length = filter_result_iterator->and_scalar(raw_infix_ids, raw_infix_ids_length,
filtered_raw_infix_ids);
if(raw_infix_ids != &infix_ids[0]) {
delete [] raw_infix_ids;
@ -4495,7 +4502,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
size_t exclude_token_ids_size, const std::unordered_set<uint32_t>& excluded_group_ids,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length,
filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length,
const size_t concurrency,
const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values,
@ -4525,11 +4532,11 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
auto parent_search_cutoff = search_cutoff;
uint32_t excluded_result_index = 0;
for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) {
for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator->is_valid; thread_id++) {
std::vector<uint32_t> batch_result_ids;
batch_result_ids.reserve(window_size);
filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size,
filter_result_iterator->get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size,
batch_result_ids);
num_queued++;
@ -4611,8 +4618,8 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
std::chrono::high_resolution_clock::now() - beginF).count();
LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/
filter_result_iterator.reset();
all_result_ids_len = filter_result_iterator.to_filter_id_array(all_result_ids);
filter_result_iterator->reset();
all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids);
}
void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,

View File

@ -482,5 +482,19 @@ TEST_F(FilterTest, FilterTreeIterator) {
ASSERT_EQ(6, iter_skip_test4.seq_id);
ASSERT_TRUE(iter_skip_test4.is_valid);
auto iter_add_phrase_ids_test = new filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root);
std::unique_ptr<filter_result_iterator_t> filter_iter_guard(iter_add_phrase_ids_test);
ASSERT_TRUE(iter_add_phrase_ids_test->init_status().ok());
auto phrase_ids = new uint32_t[4];
for (uint32_t i = 0; i < 4; i++) {
phrase_ids[i] = i * 2;
}
filter_result_iterator_t::add_phrase_ids(iter_add_phrase_ids_test, phrase_ids, 4);
filter_iter_guard.reset(iter_add_phrase_ids_test);
ASSERT_TRUE(iter_add_phrase_ids_test->is_valid);
ASSERT_EQ(6, iter_add_phrase_ids_test->seq_id);
delete filter_tree_root;
}