Refactor art_fuzzy_search_i.

This commit is contained in:
Harpreet Sangar 2023-05-01 14:53:19 +05:30
parent 56de9b1265
commit 85de14c8c3

View File

@ -973,36 +973,6 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token,
return prev_token_doc_ids;
}
const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token,
filter_result_iterator_t& filter_result_iterator,
size_t& prev_token_doc_ids_len) {
art_leaf* prev_leaf = static_cast<art_leaf*>(
art_search(t, reinterpret_cast<const unsigned char*>(prev_token.c_str()), prev_token.size() + 1)
);
uint32_t* prev_token_doc_ids = nullptr;
if(prev_token.empty() || !prev_leaf) {
prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids);
return prev_token_doc_ids;
}
std::vector<uint32_t> prev_leaf_ids;
posting_t::merge({prev_leaf->values}, prev_leaf_ids);
if(filter_result_iterator.is_valid) {
prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(),
prev_token_doc_ids);
} else {
prev_token_doc_ids_len = prev_leaf_ids.size();
prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len];
std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids);
}
return prev_token_doc_ids;
}
bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token,
const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len,
std::set<std::string>& exclude_leaves, const art_leaf* exact_leaf,
@ -1030,6 +1000,52 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str
return true;
}
bool validate_and_add_leaf(art_leaf* leaf,
const std::string& prev_token, const art_leaf* prev_leaf,
const art_leaf* exact_leaf,
filter_result_iterator_t& filter_result_iterator,
std::set<std::string>& exclude_leaves,
std::vector<art_leaf *>& results) {
if(leaf == exact_leaf) {
return false;
}
std::string tok(reinterpret_cast<char*>(leaf->key), leaf->key_len - 1);
if(exclude_leaves.count(tok) != 0) {
return false;
}
if(prev_token.empty() || !prev_leaf) {
if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) {
return false;
}
} else if (!filter_result_iterator.is_valid) {
std::vector<uint32_t> prev_leaf_ids;
posting_t::merge({prev_leaf->values}, prev_leaf_ids);
if (!posting_t::contains_atleast_one(leaf->values, prev_leaf_ids.data(), prev_leaf_ids.size())) {
return false;
}
} else {
std::vector<uint32_t> leaf_ids;
posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids);
bool found = false;
for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) {
found = (filter_result_iterator.valid(leaf_ids[i]) == 1);
}
if (!found) {
return false;
}
}
exclude_leaves.emplace(tok);
results.push_back(leaf);
return true;
}
int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results,
const art_leaf* exact_leaf,
const bool last_token, const std::string& prev_token,
@ -1126,6 +1142,101 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r
return 0;
}
int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results,
const art_leaf* exact_leaf,
const bool last_token, const std::string& prev_token,
filter_result_iterator_t& filter_result_iterator,
const art_tree* t, std::set<std::string>& exclude_leaves, std::vector<art_leaf *>& results) {
printf("INSIDE art_topk_iter: root->type: %d\n", root->type);
auto prev_leaf = static_cast<art_leaf*>(
art_search(t, reinterpret_cast<const unsigned char*>(prev_token.c_str()), prev_token.size() + 1)
);
std::priority_queue<const art_node *, std::vector<const art_node *>,
decltype(&compare_art_node_score_pq)> q(compare_art_node_score_pq);
if(token_order == FREQUENCY) {
q = std::priority_queue<const art_node *, std::vector<const art_node *>,
decltype(&compare_art_node_frequency_pq)>(compare_art_node_frequency_pq);
}
q.push(root);
size_t num_processed = 0;
while(!q.empty() && results.size() < max_results*4) {
art_node *n = (art_node *) q.top();
q.pop();
if (!n) continue;
if (IS_LEAF(n)) {
art_leaf *l = (art_leaf *) LEAF_RAW(n);
//LOG(INFO) << "END LEAF SCORE: " << l->max_score;
validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator,
exclude_leaves, results);
filter_result_iterator.reset();
if (++num_processed % 1024 == 0 && (microseconds(
std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) {
search_cutoff = true;
break;
}
continue;
}
int idx;
switch (n->type) {
case NODE4:
//LOG(INFO) << "NODE4, SCORE: " << n->max_score;
for (int i=0; i < n->num_children; i++) {
art_node* child = ((art_node4*)n)->children[i];
q.push(child);
}
break;
case NODE16:
//LOG(INFO) << "NODE16, SCORE: " << n->max_score;
for (int i=0; i < n->num_children; i++) {
q.push(((art_node16*)n)->children[i]);
}
break;
case NODE48:
//LOG(INFO) << "NODE48, SCORE: " << n->max_score;
for (int i=0; i < 256; i++) {
idx = ((art_node48*)n)->keys[i];
if (!idx) continue;
art_node *child = ((art_node48*)n)->children[idx - 1];
q.push(child);
}
break;
case NODE256:
//LOG(INFO) << "NODE256, SCORE: " << n->max_score;
for (int i=0; i < 256; i++) {
if (!((art_node256*)n)->children[i]) continue;
q.push(((art_node256*)n)->children[i]);
}
break;
default:
printf("ABORTING BECAUSE OF UNKNOWN NODE TYPE: %d\n", n->type);
abort();
}
}
/*LOG(INFO) << "leaf results.size: " << results.size()
<< ", filter_ids_length: " << filter_ids_length
<< ", num_large_lists: " << num_large_lists;*/
printf("OUTSIDE art_topk_iter: results size: %d\n", results.size());
return 0;
}
// Recursively iterates over the tree
static int recursive_iter(art_node *n, art_callback cb, void *data) {
// Handle base cases
@ -1689,14 +1800,10 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le
art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len);
//LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len;
// documents that contain the previous token and/or filter ids
size_t allowed_doc_ids_len = 0;
const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len);
filter_result_iterator.reset();
for(auto node: nodes) {
art_topk_iter(node, token_order, max_words, exact_leaf,
last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len,
last_token, prev_token,
filter_result_iterator,
t, exclude_leaves, results);
}
@ -1722,91 +1829,9 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le
if(time_micro > 1000) {
LOG(INFO) << "Time taken for art_topk_iter: " << time_micro
<< "us, size of nodes: " << nodes.size()
<< ", filter_ids_length: " << filter_ids_length;
<< ", filter_ids_length: " << filter_result_iterator.approx_filter_ids_length;
}*/
delete [] allowed_doc_ids;
return 0;
}
int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost,
const size_t max_words, const token_ordering token_order, const bool prefix,
bool last_token, const std::string& prev_token,
filter_result_iterator_t& filter_result_iterator,
std::vector<art_leaf *> &results, std::set<std::string>& exclude_leaves) {
std::vector<const art_node*> nodes;
int irow[term_len + 1];
int jrow[term_len + 1];
for (int i = 0; i <= term_len; i++){
irow[i] = jrow[i] = i;
}
//auto begin = std::chrono::high_resolution_clock::now();
if(IS_LEAF(t->root)) {
art_leaf *l = (art_leaf *) LEAF_RAW(t->root);
art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes);
} else {
if(t->root == nullptr) {
return 0;
}
// send depth as -1 to indicate that this is a root node
art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes);
}
//long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count();
//!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size();
//auto begin = std::chrono::high_resolution_clock::now();
size_t key_len = prefix ? term_len + 1 : term_len;
art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len);
//LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len;
// documents that contain the previous token and/or filter ids
size_t allowed_doc_ids_len = 0;
const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len);
for(auto node: nodes) {
art_topk_iter(node, token_order, max_words, exact_leaf,
last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len,
t, exclude_leaves, results);
}
if(token_order == FREQUENCY) {
std::sort(results.begin(), results.end(), compare_art_leaf_frequency);
} else {
std::sort(results.begin(), results.end(), compare_art_leaf_score);
}
if(exact_leaf && min_cost == 0) {
std::string tok(reinterpret_cast<char*>(exact_leaf->key), exact_leaf->key_len - 1);
if(exclude_leaves.count(tok) == 0) {
results.insert(results.begin(), exact_leaf);
exclude_leaves.emplace(tok);
}
}
if(results.size() > max_words) {
results.resize(max_words);
}
/*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count();
if(time_micro > 1000) {
LOG(INFO) << "Time taken for art_topk_iter: " << time_micro
<< "us, size of nodes: " << nodes.size()
<< ", filter_ids_length: " << filter_ids_length;
}*/
// TODO: Figure out this edge case.
// if(allowed_doc_ids != filter_ids) {
// delete [] allowed_doc_ids;
// }
return 0;
}