From 85de14c8c345e81240df1d3c753f8dc120d28ee8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 1 May 2023 14:53:19 +0530 Subject: [PATCH] Refactor `art_fuzzy_search_i`. --- src/art.cpp | 263 ++++++++++++++++++++++++++++------------------------ 1 file changed, 144 insertions(+), 119 deletions(-) diff --git a/src/art.cpp b/src/art.cpp index 7a7d5d5b..8c3b99ad 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -973,36 +973,6 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, return prev_token_doc_ids; } -const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, - size_t& prev_token_doc_ids_len) { - - art_leaf* prev_leaf = static_cast( - art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) - ); - - uint32_t* prev_token_doc_ids = nullptr; - - if(prev_token.empty() || !prev_leaf) { - prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids); - return prev_token_doc_ids; - } - - std::vector prev_leaf_ids; - posting_t::merge({prev_leaf->values}, prev_leaf_ids); - - if(filter_result_iterator.is_valid) { - prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), - prev_token_doc_ids); - } else { - prev_token_doc_ids_len = prev_leaf_ids.size(); - prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; - std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); - } - - return prev_token_doc_ids; -} - bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, std::set& exclude_leaves, const art_leaf* exact_leaf, @@ -1030,6 +1000,52 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str return true; } +bool validate_and_add_leaf(art_leaf* leaf, + const std::string& prev_token, const art_leaf* prev_leaf, + const art_leaf* exact_leaf, + filter_result_iterator_t& filter_result_iterator, + std::set& exclude_leaves, + std::vector& results) { + if(leaf == exact_leaf) { + return false; + } + + std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); + if(exclude_leaves.count(tok) != 0) { + return false; + } + + if(prev_token.empty() || !prev_leaf) { + if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) { + return false; + } + } else if (!filter_result_iterator.is_valid) { + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + if (!posting_t::contains_atleast_one(leaf->values, prev_leaf_ids.data(), prev_leaf_ids.size())) { + return false; + } + } else { + std::vector leaf_ids; + posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); + + bool found = false; + for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) { + found = (filter_result_iterator.valid(leaf_ids[i]) == 1); + } + + if (!found) { + return false; + } + } + + exclude_leaves.emplace(tok); + results.push_back(leaf); + + return true; +} + int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, const art_leaf* exact_leaf, const bool last_token, const std::string& prev_token, @@ -1126,6 +1142,101 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r return 0; } +int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, + const art_leaf* exact_leaf, + const bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + const art_tree* t, std::set& exclude_leaves, std::vector& results) { + + printf("INSIDE art_topk_iter: root->type: %d\n", root->type); + + auto prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + std::priority_queue, + decltype(&compare_art_node_score_pq)> q(compare_art_node_score_pq); + + if(token_order == FREQUENCY) { + q = std::priority_queue, + decltype(&compare_art_node_frequency_pq)>(compare_art_node_frequency_pq); + } + + q.push(root); + + size_t num_processed = 0; + + while(!q.empty() && results.size() < max_results*4) { + art_node *n = (art_node *) q.top(); + q.pop(); + + if (!n) continue; + if (IS_LEAF(n)) { + art_leaf *l = (art_leaf *) LEAF_RAW(n); + //LOG(INFO) << "END LEAF SCORE: " << l->max_score; + + validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator, + exclude_leaves, results); + filter_result_iterator.reset(); + + if (++num_processed % 1024 == 0 && (microseconds( + std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { + search_cutoff = true; + break; + } + + continue; + } + + int idx; + switch (n->type) { + case NODE4: + //LOG(INFO) << "NODE4, SCORE: " << n->max_score; + for (int i=0; i < n->num_children; i++) { + art_node* child = ((art_node4*)n)->children[i]; + q.push(child); + } + break; + + case NODE16: + //LOG(INFO) << "NODE16, SCORE: " << n->max_score; + for (int i=0; i < n->num_children; i++) { + q.push(((art_node16*)n)->children[i]); + } + break; + + case NODE48: + //LOG(INFO) << "NODE48, SCORE: " << n->max_score; + for (int i=0; i < 256; i++) { + idx = ((art_node48*)n)->keys[i]; + if (!idx) continue; + art_node *child = ((art_node48*)n)->children[idx - 1]; + q.push(child); + } + break; + + case NODE256: + //LOG(INFO) << "NODE256, SCORE: " << n->max_score; + for (int i=0; i < 256; i++) { + if (!((art_node256*)n)->children[i]) continue; + q.push(((art_node256*)n)->children[i]); + } + break; + + default: + printf("ABORTING BECAUSE OF UNKNOWN NODE TYPE: %d\n", n->type); + abort(); + } + } + + /*LOG(INFO) << "leaf results.size: " << results.size() + << ", filter_ids_length: " << filter_ids_length + << ", num_large_lists: " << num_large_lists;*/ + + printf("OUTSIDE art_topk_iter: results size: %d\n", results.size()); + return 0; +} + // Recursively iterates over the tree static int recursive_iter(art_node *n, art_callback cb, void *data) { // Handle base cases @@ -1689,14 +1800,10 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; - // documents that contain the previous token and/or filter ids - size_t allowed_doc_ids_len = 0; - const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); - filter_result_iterator.reset(); - for(auto node: nodes) { art_topk_iter(node, token_order, max_words, exact_leaf, - last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + last_token, prev_token, + filter_result_iterator, t, exclude_leaves, results); } @@ -1722,91 +1829,9 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le if(time_micro > 1000) { LOG(INFO) << "Time taken for art_topk_iter: " << time_micro << "us, size of nodes: " << nodes.size() - << ", filter_ids_length: " << filter_ids_length; + << ", filter_ids_length: " << filter_result_iterator.approx_filter_ids_length; }*/ - delete [] allowed_doc_ids; - - return 0; -} - -int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const size_t max_words, const token_ordering token_order, const bool prefix, - bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, - std::vector &results, std::set& exclude_leaves) { - - std::vector nodes; - int irow[term_len + 1]; - int jrow[term_len + 1]; - for (int i = 0; i <= term_len; i++){ - irow[i] = jrow[i] = i; - } - - //auto begin = std::chrono::high_resolution_clock::now(); - - if(IS_LEAF(t->root)) { - art_leaf *l = (art_leaf *) LEAF_RAW(t->root); - art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); - } else { - if(t->root == nullptr) { - return 0; - } - - // send depth as -1 to indicate that this is a root node - art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); - } - - //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); - //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); - - //auto begin = std::chrono::high_resolution_clock::now(); - - size_t key_len = prefix ? term_len + 1 : term_len; - art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); - //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; - - // documents that contain the previous token and/or filter ids - size_t allowed_doc_ids_len = 0; - const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); - - for(auto node: nodes) { - art_topk_iter(node, token_order, max_words, exact_leaf, - last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, - t, exclude_leaves, results); - } - - if(token_order == FREQUENCY) { - std::sort(results.begin(), results.end(), compare_art_leaf_frequency); - } else { - std::sort(results.begin(), results.end(), compare_art_leaf_score); - } - - if(exact_leaf && min_cost == 0) { - std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); - if(exclude_leaves.count(tok) == 0) { - results.insert(results.begin(), exact_leaf); - exclude_leaves.emplace(tok); - } - } - - if(results.size() > max_words) { - results.resize(max_words); - } - - /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); - - if(time_micro > 1000) { - LOG(INFO) << "Time taken for art_topk_iter: " << time_micro - << "us, size of nodes: " << nodes.size() - << ", filter_ids_length: " << filter_ids_length; - }*/ - -// TODO: Figure out this edge case. -// if(allowed_doc_ids != filter_ids) { -// delete [] allowed_doc_ids; -// } - return 0; }