mirror of
https://github.com/typesense/typesense.git
synced 2025-05-23 07:09:44 +08:00
Refactor art_fuzzy_search_i
.
This commit is contained in:
parent
56de9b1265
commit
85de14c8c3
263
src/art.cpp
263
src/art.cpp
@ -973,36 +973,6 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token,
|
||||
return prev_token_doc_ids;
|
||||
}
|
||||
|
||||
const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token,
|
||||
filter_result_iterator_t& filter_result_iterator,
|
||||
size_t& prev_token_doc_ids_len) {
|
||||
|
||||
art_leaf* prev_leaf = static_cast<art_leaf*>(
|
||||
art_search(t, reinterpret_cast<const unsigned char*>(prev_token.c_str()), prev_token.size() + 1)
|
||||
);
|
||||
|
||||
uint32_t* prev_token_doc_ids = nullptr;
|
||||
|
||||
if(prev_token.empty() || !prev_leaf) {
|
||||
prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids);
|
||||
return prev_token_doc_ids;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> prev_leaf_ids;
|
||||
posting_t::merge({prev_leaf->values}, prev_leaf_ids);
|
||||
|
||||
if(filter_result_iterator.is_valid) {
|
||||
prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(),
|
||||
prev_token_doc_ids);
|
||||
} else {
|
||||
prev_token_doc_ids_len = prev_leaf_ids.size();
|
||||
prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len];
|
||||
std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids);
|
||||
}
|
||||
|
||||
return prev_token_doc_ids;
|
||||
}
|
||||
|
||||
bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token,
|
||||
const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len,
|
||||
std::set<std::string>& exclude_leaves, const art_leaf* exact_leaf,
|
||||
@ -1030,6 +1000,52 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str
|
||||
return true;
|
||||
}
|
||||
|
||||
bool validate_and_add_leaf(art_leaf* leaf,
|
||||
const std::string& prev_token, const art_leaf* prev_leaf,
|
||||
const art_leaf* exact_leaf,
|
||||
filter_result_iterator_t& filter_result_iterator,
|
||||
std::set<std::string>& exclude_leaves,
|
||||
std::vector<art_leaf *>& results) {
|
||||
if(leaf == exact_leaf) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string tok(reinterpret_cast<char*>(leaf->key), leaf->key_len - 1);
|
||||
if(exclude_leaves.count(tok) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if(prev_token.empty() || !prev_leaf) {
|
||||
if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) {
|
||||
return false;
|
||||
}
|
||||
} else if (!filter_result_iterator.is_valid) {
|
||||
std::vector<uint32_t> prev_leaf_ids;
|
||||
posting_t::merge({prev_leaf->values}, prev_leaf_ids);
|
||||
|
||||
if (!posting_t::contains_atleast_one(leaf->values, prev_leaf_ids.data(), prev_leaf_ids.size())) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
std::vector<uint32_t> leaf_ids;
|
||||
posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids);
|
||||
|
||||
bool found = false;
|
||||
for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) {
|
||||
found = (filter_result_iterator.valid(leaf_ids[i]) == 1);
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
exclude_leaves.emplace(tok);
|
||||
results.push_back(leaf);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results,
|
||||
const art_leaf* exact_leaf,
|
||||
const bool last_token, const std::string& prev_token,
|
||||
@ -1126,6 +1142,101 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r
|
||||
return 0;
|
||||
}
|
||||
|
||||
int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results,
|
||||
const art_leaf* exact_leaf,
|
||||
const bool last_token, const std::string& prev_token,
|
||||
filter_result_iterator_t& filter_result_iterator,
|
||||
const art_tree* t, std::set<std::string>& exclude_leaves, std::vector<art_leaf *>& results) {
|
||||
|
||||
printf("INSIDE art_topk_iter: root->type: %d\n", root->type);
|
||||
|
||||
auto prev_leaf = static_cast<art_leaf*>(
|
||||
art_search(t, reinterpret_cast<const unsigned char*>(prev_token.c_str()), prev_token.size() + 1)
|
||||
);
|
||||
|
||||
std::priority_queue<const art_node *, std::vector<const art_node *>,
|
||||
decltype(&compare_art_node_score_pq)> q(compare_art_node_score_pq);
|
||||
|
||||
if(token_order == FREQUENCY) {
|
||||
q = std::priority_queue<const art_node *, std::vector<const art_node *>,
|
||||
decltype(&compare_art_node_frequency_pq)>(compare_art_node_frequency_pq);
|
||||
}
|
||||
|
||||
q.push(root);
|
||||
|
||||
size_t num_processed = 0;
|
||||
|
||||
while(!q.empty() && results.size() < max_results*4) {
|
||||
art_node *n = (art_node *) q.top();
|
||||
q.pop();
|
||||
|
||||
if (!n) continue;
|
||||
if (IS_LEAF(n)) {
|
||||
art_leaf *l = (art_leaf *) LEAF_RAW(n);
|
||||
//LOG(INFO) << "END LEAF SCORE: " << l->max_score;
|
||||
|
||||
validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator,
|
||||
exclude_leaves, results);
|
||||
filter_result_iterator.reset();
|
||||
|
||||
if (++num_processed % 1024 == 0 && (microseconds(
|
||||
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) {
|
||||
search_cutoff = true;
|
||||
break;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
int idx;
|
||||
switch (n->type) {
|
||||
case NODE4:
|
||||
//LOG(INFO) << "NODE4, SCORE: " << n->max_score;
|
||||
for (int i=0; i < n->num_children; i++) {
|
||||
art_node* child = ((art_node4*)n)->children[i];
|
||||
q.push(child);
|
||||
}
|
||||
break;
|
||||
|
||||
case NODE16:
|
||||
//LOG(INFO) << "NODE16, SCORE: " << n->max_score;
|
||||
for (int i=0; i < n->num_children; i++) {
|
||||
q.push(((art_node16*)n)->children[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case NODE48:
|
||||
//LOG(INFO) << "NODE48, SCORE: " << n->max_score;
|
||||
for (int i=0; i < 256; i++) {
|
||||
idx = ((art_node48*)n)->keys[i];
|
||||
if (!idx) continue;
|
||||
art_node *child = ((art_node48*)n)->children[idx - 1];
|
||||
q.push(child);
|
||||
}
|
||||
break;
|
||||
|
||||
case NODE256:
|
||||
//LOG(INFO) << "NODE256, SCORE: " << n->max_score;
|
||||
for (int i=0; i < 256; i++) {
|
||||
if (!((art_node256*)n)->children[i]) continue;
|
||||
q.push(((art_node256*)n)->children[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
printf("ABORTING BECAUSE OF UNKNOWN NODE TYPE: %d\n", n->type);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
/*LOG(INFO) << "leaf results.size: " << results.size()
|
||||
<< ", filter_ids_length: " << filter_ids_length
|
||||
<< ", num_large_lists: " << num_large_lists;*/
|
||||
|
||||
printf("OUTSIDE art_topk_iter: results size: %d\n", results.size());
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Recursively iterates over the tree
|
||||
static int recursive_iter(art_node *n, art_callback cb, void *data) {
|
||||
// Handle base cases
|
||||
@ -1689,14 +1800,10 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le
|
||||
art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len);
|
||||
//LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len;
|
||||
|
||||
// documents that contain the previous token and/or filter ids
|
||||
size_t allowed_doc_ids_len = 0;
|
||||
const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len);
|
||||
filter_result_iterator.reset();
|
||||
|
||||
for(auto node: nodes) {
|
||||
art_topk_iter(node, token_order, max_words, exact_leaf,
|
||||
last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len,
|
||||
last_token, prev_token,
|
||||
filter_result_iterator,
|
||||
t, exclude_leaves, results);
|
||||
}
|
||||
|
||||
@ -1722,91 +1829,9 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le
|
||||
if(time_micro > 1000) {
|
||||
LOG(INFO) << "Time taken for art_topk_iter: " << time_micro
|
||||
<< "us, size of nodes: " << nodes.size()
|
||||
<< ", filter_ids_length: " << filter_ids_length;
|
||||
<< ", filter_ids_length: " << filter_result_iterator.approx_filter_ids_length;
|
||||
}*/
|
||||
|
||||
delete [] allowed_doc_ids;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost,
|
||||
const size_t max_words, const token_ordering token_order, const bool prefix,
|
||||
bool last_token, const std::string& prev_token,
|
||||
filter_result_iterator_t& filter_result_iterator,
|
||||
std::vector<art_leaf *> &results, std::set<std::string>& exclude_leaves) {
|
||||
|
||||
std::vector<const art_node*> nodes;
|
||||
int irow[term_len + 1];
|
||||
int jrow[term_len + 1];
|
||||
for (int i = 0; i <= term_len; i++){
|
||||
irow[i] = jrow[i] = i;
|
||||
}
|
||||
|
||||
//auto begin = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if(IS_LEAF(t->root)) {
|
||||
art_leaf *l = (art_leaf *) LEAF_RAW(t->root);
|
||||
art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes);
|
||||
} else {
|
||||
if(t->root == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// send depth as -1 to indicate that this is a root node
|
||||
art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes);
|
||||
}
|
||||
|
||||
//long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count();
|
||||
//!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size();
|
||||
|
||||
//auto begin = std::chrono::high_resolution_clock::now();
|
||||
|
||||
size_t key_len = prefix ? term_len + 1 : term_len;
|
||||
art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len);
|
||||
//LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len;
|
||||
|
||||
// documents that contain the previous token and/or filter ids
|
||||
size_t allowed_doc_ids_len = 0;
|
||||
const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len);
|
||||
|
||||
for(auto node: nodes) {
|
||||
art_topk_iter(node, token_order, max_words, exact_leaf,
|
||||
last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len,
|
||||
t, exclude_leaves, results);
|
||||
}
|
||||
|
||||
if(token_order == FREQUENCY) {
|
||||
std::sort(results.begin(), results.end(), compare_art_leaf_frequency);
|
||||
} else {
|
||||
std::sort(results.begin(), results.end(), compare_art_leaf_score);
|
||||
}
|
||||
|
||||
if(exact_leaf && min_cost == 0) {
|
||||
std::string tok(reinterpret_cast<char*>(exact_leaf->key), exact_leaf->key_len - 1);
|
||||
if(exclude_leaves.count(tok) == 0) {
|
||||
results.insert(results.begin(), exact_leaf);
|
||||
exclude_leaves.emplace(tok);
|
||||
}
|
||||
}
|
||||
|
||||
if(results.size() > max_words) {
|
||||
results.resize(max_words);
|
||||
}
|
||||
|
||||
/*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count();
|
||||
|
||||
if(time_micro > 1000) {
|
||||
LOG(INFO) << "Time taken for art_topk_iter: " << time_micro
|
||||
<< "us, size of nodes: " << nodes.size()
|
||||
<< ", filter_ids_length: " << filter_ids_length;
|
||||
}*/
|
||||
|
||||
// TODO: Figure out this edge case.
|
||||
// if(allowed_doc_ids != filter_ids) {
|
||||
// delete [] allowed_doc_ids;
|
||||
// }
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user