From 7a5c0eafda18d3cf9d8df76381faa04b16ef5b7d Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 26 Jan 2023 18:24:43 +0530 Subject: [PATCH 1/5] Refactor fuzzy search state transition. Handle extra chars in the middle of a query. --- src/art.cpp | 46 ++++++++++++++++----------------- test/art_test.cpp | 28 +++++++++++++++++--- test/collection_locale_test.cpp | 22 ++++++++++++++++ test/collection_test.cpp | 16 ++++++------ test/documents.jsonl | 2 +- 5 files changed, 79 insertions(+), 35 deletions(-) diff --git a/src/art.cpp b/src/art.cpp index 2a6c99e9..f778eace 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1298,48 +1298,48 @@ static inline void rotate(int &i, int &j, int &k) { // -1: return without adding, 0 : continue iteration, 1: return after adding static inline int fuzzy_search_state(const bool prefix, int key_index, bool last_key_char, - int term_len, const int* cost_row, int min_cost, int max_cost) { + const int query_len, const int* cost_row, int min_cost, int max_cost) { - // a) iter_len < term_len: "pltninum" (term) on "pst" (key) - // b) term_len < iter_len: "pst" (term) on "pltninum" (key) + // There are 2 scenarios: + // a) key_len < query_len: "pltninum" (query) on "pst" (key) + // b) query_len < key_len: "pst" (query) on "pltninum" (key) - int cost = 0; + int key_len = last_key_char ? key_index : key_index + 1; - // a) because key's null character will appear first if(last_key_char) { - int key_len = key_index; - cost = cost_row[term_len]; - - if(cost >= min_cost && cost <= max_cost) { + // Last char, so have to return 1 or -1 + if(cost_row[query_len] >= min_cost && cost_row[query_len] <= max_cost) { return 1; } - cost = cost_row[key_len]; - - // used to match q=strawberries on key=strawberry, but limit to larger keys to prevent eager matches - if(key_len > 5 && term_len > key_len && (term_len - key_len) <= max_cost && - cost >= min_cost-1 && cost <= max_cost-1) { + // Special case used to match q=strawberries on key=strawberry (query_len > key_len) + // but limit to larger keys to prevent eager matches + if(key_len > 5 && query_len > key_len && (query_len - key_len) <= max_cost && + cost_row[key_len] >= min_cost && cost_row[key_len] <= max_cost-1) { return 1; } return -1; } - int key_len = key_index + 1; + // `key_len` can't exceed `query_len` since length of `cost_row` is `query_len + 1` + int cost = cost_row[std::min(key_len, query_len)]; - // b) we might iterate past term_len to catch trailing typos - if(key_len >= term_len && prefix) { - cost = cost_row[term_len]; + if(key_len >= query_len && prefix) { + // Case b) + // For prefix queries + // - we can return early if key_len reaches query_len and cost is within bounds. + // - might have to iterate past prefix query length to catch trailing typos. if(cost >= min_cost && cost <= max_cost) { return 1; } - } else { - // `key_len` can't exceed `term_len` since length of `cost_row` is `term_len + 1` - cost = cost_row[std::min(key_len, term_len)]; } - int bounded_cost = (max_cost == 0) ? max_cost : (max_cost + 1); - return (cost > bounded_cost) ? -1 : 0; + // Terminate the search early or continue iterating on the key? + // We have to account for the case that `cost` could momentarily exceed max_cost but resolve later. + // e.g. key=example, query=exZZample, after 5 chars, cost is 3 but drops to 2 at the end. + // But we will limit this for longer keys for performance. + return cost > max_cost && (key_len > 3 ? cost > (max_cost * 2) : true) ? -1 : 0; } static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node *n, int depth, const unsigned char *term, diff --git a/test/art_test.cpp b/test/art_test.cpp index 38c962db..c1030972 100644 --- a/test/art_test.cpp +++ b/test/art_test.cpp @@ -744,10 +744,9 @@ TEST(ArtTest, test_art_fuzzy_search) { } std::vector leaves; - - leaves.clear(); auto begin = std::chrono::high_resolution_clock::now(); + leaves.clear(); art_fuzzy_search(&t, (const unsigned char *) "pltinum", strlen("pltinum"), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); ASSERT_EQ(2, leaves.size()); ASSERT_STREQ("platinumsmith", (const char *)leaves.at(0)->key); @@ -800,7 +799,7 @@ TEST(ArtTest, test_art_fuzzy_search) { art_fuzzy_search(&t, (const unsigned char *) "hown", strlen("hown") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); ASSERT_EQ(10, leaves.size()); - std::set expected_words = {"town", "sown", "shown", "own", "mown", "lown", "howl", "howk", "howe", "how"}; + std::set expected_words = {"town", "sown", "mown", "lown", "howl", "howk", "howe", "how", "horn", "hoon"}; for(size_t leaf_index = 0; leaf_index < leaves.size(); leaf_index++) { art_leaf*& leaf = leaves.at(leaf_index); @@ -864,6 +863,29 @@ TEST(ArtTest, test_art_fuzzy_search_unicode_chars) { ASSERT_TRUE(res == 0); } +TEST(ArtTest, test_art_fuzzy_search_extra_chars) { + art_tree t; + int res = art_tree_init(&t); + ASSERT_TRUE(res == 0); + + std::vector keys = { + "abbviation" + }; + + for(const char* key: keys) { + art_document doc = get_document((uint32_t) 1); + ASSERT_TRUE(NULL == art_insert(&t, (unsigned char*)key, strlen(key)+1, &doc)); + } + + const char* query = "abbreviation"; + std::vector leaves; + art_fuzzy_search(&t, (unsigned char *)query, strlen(query), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + ASSERT_EQ(1, leaves.size()); + + res = art_tree_destroy(&t); + ASSERT_TRUE(res == 0); +} + TEST(ArtTest, test_art_search_sku_like_tokens) { art_tree t; int res = art_tree_init(&t); diff --git a/test/collection_locale_test.cpp b/test/collection_locale_test.cpp index 11ce2ed8..b12e68cb 100644 --- a/test/collection_locale_test.cpp +++ b/test/collection_locale_test.cpp @@ -774,6 +774,28 @@ TEST_F(CollectionLocaleTest, SearchOnCyrillicLargeText) { results["hits"][0]["highlights"][0]["snippet"].get().c_str()); } +TEST_F(CollectionLocaleTest, SearchOnArabicText) { + std::vector fields = {field("title", field_types::STRING, true, false, true, ""),}; + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields).get(); + + std::string data = "جهينة"; + std::string q = "جوهينة"; + + auto dchars = data.c_str(); + auto qchars = q.c_str(); + + nlohmann::json doc; + doc["title"] = "جهينة"; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("جوهينة", {"title"}, "", {}, {}, {2}, 10, 1, FREQUENCY, {true}).get(); + LOG(INFO) << results; + + ASSERT_STREQ("جهينة", + results["hits"][0]["highlights"][0]["snippet"].get().c_str()); +} + /* TEST_F(CollectionLocaleTest, TranslitPad) { UErrorCode translit_status = U_ZERO_ERROR; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 93bf05a4..b72573a3 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -392,7 +392,7 @@ TEST_F(CollectionTest, QueryWithTypo) { spp::sparse_hash_set(), 10, "", 30, 5, "", 10).get(); - ids = {"8", "1", "17"}; + ids = {"1", "13", "8"}; ASSERT_EQ(3, results["hits"].size()); @@ -667,20 +667,20 @@ TEST_F(CollectionTest, PrefixSearching) { } TEST_F(CollectionTest, TypoTokensThreshold) { - // Query expansion should happen only based on the `typo_tokens_threshold` value - auto results = collection->search("launch", {"title"}, "", {}, sort_fields, {2}, 10, 1, + // Typo correction should happen only based on the `typo_tokens_threshold` value + auto results = collection->search("redundant", {"title"}, "", {}, sort_fields, {2}, 10, 1, token_ordering::FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 5, 5, "", 0).get(); - ASSERT_EQ(5, results["hits"].size()); - ASSERT_EQ(5, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ(1, results["found"].get()); - results = collection->search("launch", {"title"}, "", {}, sort_fields, {2}, 10, 1, + results = collection->search("redundant", {"title"}, "", {}, sort_fields, {2}, 10, 1, token_ordering::FREQUENCY, {true}, 10, spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 5, 5, "", 10).get(); - ASSERT_EQ(7, results["hits"].size()); - ASSERT_EQ(7, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ(2, results["found"].get()); } TEST_F(CollectionTest, MultiOccurrenceString) { diff --git a/test/documents.jsonl b/test/documents.jsonl index 8caaa289..2c68f2d2 100644 --- a/test/documents.jsonl +++ b/test/documents.jsonl @@ -21,4 +21,4 @@ {"points":7,"title":"What kinds of things have been tossed out of ISSS in space?"} {"points":17,"title":"What does triple redundant closed loop digital avionics system mean?"} {"points":11,"title":"How are rockets guided to follow specific loop trajectory?"} -{"points":8,"title":"What do remotely controlled bolts look like?"} \ No newline at end of file +{"points":8,"title":"What do remotely controlled redundent bolts look like?"} \ No newline at end of file From de2de028b7e079bbd2dc1f8315e5fd0f2b5ba245 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 14 Feb 2023 14:28:39 +0530 Subject: [PATCH 2/5] Refactor fuzzy search restrictions. --- include/art.h | 7 +- src/art.cpp | 113 ++++++++++++++++++++++------- src/index.cpp | 85 +++++----------------- test/art_test.cpp | 116 +++++++++++++++++++----------- test/collection_specific_test.cpp | 2 + test/collection_test.cpp | 2 +- 6 files changed, 185 insertions(+), 140 deletions(-) diff --git a/include/art.h b/include/art.h index 0502641c..a9715fac 100644 --- a/include/art.h +++ b/include/art.h @@ -276,9 +276,10 @@ int art_iter_prefix(art_tree *t, const unsigned char *prefix, int prefix_len, ar * Returns leaves that match a given string within a fuzzy distance of max_cost. */ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const int max_words, const token_ordering token_order, const bool prefix, - const uint32_t *filter_ids, size_t filter_ids_length, - std::vector &results, const std::set& exclude_leaves = {}); + const size_t max_words, const token_ordering token_order, + const bool prefix, bool last_token, const std::string& prev_token, + const uint32_t *filter_ids, const size_t filter_ids_length, + std::vector &results, std::set& exclude_leaves); void encode_int32(int32_t n, unsigned char *chars); diff --git a/src/art.cpp b/src/art.cpp index f778eace..835f77b9 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -21,6 +21,7 @@ #include #include "art.h" #include "logger.h" +#include "array_utils.h" /** * Macros to manipulate pointer tags @@ -940,10 +941,69 @@ void* art_delete(art_tree *t, const unsigned char *key, int key_len) { return child->max_token_count; }*/ +const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, + const uint32_t* filter_ids, const size_t filter_ids_length, + size_t& prev_token_doc_ids_len) { + + art_leaf* prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + if(prev_token.empty() || !prev_leaf) { + prev_token_doc_ids_len = filter_ids_length; + return filter_ids; + } + + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + uint32_t* prev_token_doc_ids = nullptr; + + if(filter_ids_length != 0) { + prev_token_doc_ids_len = ArrayUtils::and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), + filter_ids, filter_ids_length, + &prev_token_doc_ids); + } else { + prev_token_doc_ids_len = prev_leaf_ids.size(); + prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; + std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); + } + + return prev_token_doc_ids; +} + +bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, + const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, + std::set& exclude_leaves, const art_leaf* exact_leaf, + std::vector& results) { + + if(leaf == exact_leaf) { + return false; + } + + std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); + if(exclude_leaves.count(tok) != 0) { + return false; + } + + if(allowed_doc_ids_len != 0) { + if(!posting_t::contains_atleast_one(leaf->values, allowed_doc_ids, + allowed_doc_ids_len)) { + return false; + } + } + + exclude_leaves.emplace(tok); + results.push_back(leaf); + + return true; +} + int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, - const uint32_t* filter_ids, size_t filter_ids_length, - const std::set& exclude_leaves, const art_leaf* exact_leaf, - std::vector& results) { + const art_leaf* exact_leaf, + const bool last_token, const std::string& prev_token, + const uint32_t* allowed_doc_ids, size_t allowed_doc_ids_len, + const art_tree* t, std::set& exclude_leaves, std::vector& results) { printf("INSIDE art_topk_iter: root->type: %d\n", root->type); @@ -972,25 +1032,8 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r if (IS_LEAF(n)) { art_leaf *l = (art_leaf *) LEAF_RAW(n); //LOG(INFO) << "END LEAF SCORE: " << l->max_score; - - if(filter_ids_length == 0) { - std::string tok(reinterpret_cast(l->key), l->key_len - 1); - if(exclude_leaves.count(tok) != 0 || l == exact_leaf) { - continue; - } - results.push_back(l); - } else { - // we will push leaf only if filter matches with leaf IDs - bool found_atleast_one = posting_t::contains_atleast_one(l->values, filter_ids, filter_ids_length); - if(found_atleast_one) { - std::string tok(reinterpret_cast(l->key), l->key_len - 1); - if(exclude_leaves.count(tok) != 0 || l == exact_leaf) { - continue; - } - results.push_back(l); - } - } - + validate_and_add_leaf(l, last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + exclude_leaves, exact_leaf, results); continue; } @@ -1491,9 +1534,10 @@ static void art_fuzzy_recurse(unsigned char p, unsigned char c, const art_node * * Returns leaves that match a given string within a fuzzy distance of max_cost. */ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const int max_words, const token_ordering token_order, const bool prefix, - const uint32_t *filter_ids, size_t filter_ids_length, - std::vector &results, const std::set& exclude_leaves) { + const size_t max_words, const token_ordering token_order, const bool prefix, + bool last_token, const std::string& prev_token, + const uint32_t *filter_ids, const size_t filter_ids_length, + std::vector &results, std::set& exclude_leaves) { std::vector nodes; int irow[term_len + 1]; @@ -1525,8 +1569,15 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; + // documents that contain the previous token and/or filter ids + size_t allowed_doc_ids_len = 0; + const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_ids, filter_ids_length, + allowed_doc_ids_len); + for(auto node: nodes) { - art_topk_iter(node, token_order, max_words, filter_ids, filter_ids_length, exclude_leaves, exact_leaf, results); + art_topk_iter(node, token_order, max_words, exact_leaf, + last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + t, exclude_leaves, results); } if(token_order == FREQUENCY) { @@ -1536,7 +1587,11 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, } if(exact_leaf && min_cost == 0) { - results.insert(results.begin(), exact_leaf); + std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); + if(exclude_leaves.count(tok) == 0) { + results.insert(results.begin(), exact_leaf); + exclude_leaves.emplace(tok); + } } if(results.size() > max_words) { @@ -1551,6 +1606,10 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, << ", filter_ids_length: " << filter_ids_length; }*/ + if(allowed_doc_ids != filter_ids) { + delete [] allowed_doc_ids; + } + return 0; } diff --git a/src/index.cpp b/src/index.cpp index 749ca44d..fab2a644 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3137,12 +3137,12 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } //LOG(INFO) << "Searching for field: " << the_field.name << ", found token:" << token; + const auto& prev_token = last_token ? token_candidates_vec.back().candidates[0] : ""; std::vector field_leaves; - int max_words = 100000; art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, - costs[token_index], costs[token_index], max_words, token_order, prefix_search, - filter_ids, filter_ids_length, field_leaves, unique_tokens); + costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, + last_token, prev_token, filter_ids, filter_ids_length, field_leaves, unique_tokens); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3153,60 +3153,17 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, continue; } - uint32_t* prev_token_doc_ids = nullptr; // documents that contain the previous token - size_t prev_token_doc_ids_len = 0; - - if(last_token) { - auto& prev_token = token_candidates_vec.back().candidates[0]; - art_leaf* prev_leaf = static_cast( - art_search(search_index.at(the_field.name), - reinterpret_cast(prev_token.c_str()), - prev_token.size() + 1)); - - if(!prev_leaf) { - continue; - } - - std::vector prev_leaf_ids; - posting_t::merge({prev_leaf->values}, prev_leaf_ids); - - if(filter_ids_length != 0) { - prev_token_doc_ids_len = ArrayUtils::and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), - filter_ids, filter_ids_length, - &prev_token_doc_ids); - } else { - prev_token_doc_ids_len = prev_leaf_ids.size(); - prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; - std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); - } - } - for(size_t i = 0; i < field_leaves.size(); i++) { auto leaf = field_leaves[i]; std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); - if(unique_tokens.count(tok) == 0) { - if(last_token) { - if(!posting_t::contains_atleast_one(leaf->values, prev_token_doc_ids, - prev_token_doc_ids_len)) { - continue; - } - } - - unique_tokens.emplace(tok); - leaf_tokens.push_back(tok); - } - - if(leaf_tokens.size() >= max_candidates) { - token_cost_cache.emplace(token_cost_hash, leaf_tokens); - delete [] prev_token_doc_ids; - prev_token_doc_ids = nullptr; - goto token_done; - } + leaf_tokens.push_back(tok); } token_cost_cache.emplace(token_cost_hash, leaf_tokens); - delete [] prev_token_doc_ids; - prev_token_doc_ids = nullptr; + + if(leaf_tokens.size() >= max_candidates) { + goto token_done; + } } if(last_token && leaf_tokens.size() < max_candidates) { @@ -3235,10 +3192,9 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } std::vector field_leaves; - int max_words = 100000; art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, - costs[token_index], costs[token_index], max_words, token_order, prefix_search, - filter_ids, filter_ids_length, field_leaves, unique_tokens); + costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, + false, "", filter_ids, filter_ids_length, field_leaves, unique_tokens); if(field_leaves.empty()) { // look at the next field @@ -3248,23 +3204,14 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, for(size_t i = 0; i < field_leaves.size(); i++) { auto leaf = field_leaves[i]; std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); - if(unique_tokens.count(tok) == 0) { - if(!posting_t::contains_atleast_one(leaf->values, &prev_token_doc_ids[0], - prev_token_doc_ids.size())) { - continue; - } - - unique_tokens.emplace(tok); - leaf_tokens.push_back(tok); - } - - if(leaf_tokens.size() >= max_candidates) { - token_cost_cache.emplace(token_cost_hash, leaf_tokens); - goto token_done; - } + leaf_tokens.push_back(tok); } token_cost_cache.emplace(token_cost_hash, leaf_tokens); + + if(leaf_tokens.size() >= max_candidates) { + goto token_done; + } } } } @@ -4635,7 +4582,7 @@ void Index::search_field(const uint8_t & field_id, // need less candidates for filtered searches since we already only pick tokens with results art_fuzzy_search(search_index.at(field_name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, - filter_ids, filter_ids_length, leaves, unique_tokens); + false, "", filter_ids, filter_ids_length, leaves, unique_tokens); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); diff --git a/test/art_test.cpp b/test/art_test.cpp index c1030972..df27cf6f 100644 --- a/test/art_test.cpp +++ b/test/art_test.cpp @@ -18,6 +18,8 @@ art_document get_document(uint32_t id) { return document; } +std::set exclude_leaves; + TEST(ArtTest, test_art_init_and_destroy) { art_tree t; int res = art_tree_init(&t); @@ -587,22 +589,25 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf) { EXPECT_EQ(1, posting_t::first_id(l->values)); std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *) implement_key, strlen(implement_key) + 1, 0, 0, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *) implement_key, strlen(implement_key) + 1, 0, 0, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); const char* implement_key_typo1 = "implment"; const char* implement_key_typo2 = "implwnent"; leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 0, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 0, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) implement_key_typo1, strlen(implement_key_typo1) + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) implement_key_typo2, strlen(implement_key_typo2) + 1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) implement_key_typo2, strlen(implement_key_typo2) + 1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -623,11 +628,12 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf_prefix) { std::vector leaves; std::string term = "aplication"; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -645,7 +651,7 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf_qlen_greater_than_key) { std::string term = "starkbin"; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); } @@ -660,11 +666,12 @@ TEST(ArtTest, test_art_fuzzy_search_single_leaf_non_prefix) { std::string term = "spz"; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size(), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -682,7 +689,7 @@ TEST(ArtTest, test_art_prefix_larger_than_key) { std::string term = "earrings"; std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *)(term.c_str()), term.size()+1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); res = art_tree_destroy(&t); @@ -706,7 +713,7 @@ TEST(ArtTest, test_art_fuzzy_search_prefix_token_ordering) { } std::vector leaves; - art_fuzzy_search(&t, (const unsigned char *) "e", 1, 0, 0, 3, MAX_SCORE, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *) "e", 1, 0, 0, 3, MAX_SCORE, true, false, "", nullptr, 0, leaves, exclude_leaves); std::string first_key(reinterpret_cast(leaves[0]->key), leaves[0]->key_len - 1); ASSERT_EQ("e", first_key); @@ -718,7 +725,8 @@ TEST(ArtTest, test_art_fuzzy_search_prefix_token_ordering) { ASSERT_EQ("elephant", third_key); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "enter", 5, 1, 1, 3, MAX_SCORE, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "enter", 5, 1, 1, 3, MAX_SCORE, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_TRUE(leaves.empty()); res = art_tree_destroy(&t); @@ -747,56 +755,65 @@ TEST(ArtTest, test_art_fuzzy_search) { auto begin = std::chrono::high_resolution_clock::now(); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "pltinum", strlen("pltinum"), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "pltinum", strlen("pltinum"), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); ASSERT_STREQ("platinumsmith", (const char *)leaves.at(0)->key); ASSERT_STREQ("platinum", (const char *)leaves.at(1)->key); leaves.clear(); + exclude_leaves.clear(); // extra char - art_fuzzy_search(&t, (const unsigned char *) "higghliving", strlen("higghliving") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + art_fuzzy_search(&t, (const unsigned char *) "higghliving", strlen("higghliving") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("highliving", (const char *)leaves.at(0)->key); // transpose leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "zymosthneic", strlen("zymosthneic") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "zymosthneic", strlen("zymosthneic") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("zymosthenic", (const char *)leaves.at(0)->key); // transpose + missing leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("dacrycystalgia", (const char *)leaves.at(0)->key); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 1, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "dacrcyystlgia", strlen("dacrcyystlgia") + 1, 1, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("dacrycystalgia", (const char *)leaves.at(0)->key); // missing char leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "gaberlunze", strlen("gaberlunze") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "gaberlunze", strlen("gaberlunze") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("gaberlunzie", (const char *)leaves.at(0)->key); // substituted char leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "eacemiferous", strlen("eacemiferous") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "eacemiferous", strlen("eacemiferous") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("racemiferous", (const char *)leaves.at(0)->key); // missing char + extra char leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "Sarbruckken", strlen("Sarbruckken") + 1, 0, 2, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "Sarbruckken", strlen("Sarbruckken") + 1, 0, 2, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ("Saarbrucken", (const char *)leaves.at(0)->key); // multiple matching results leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "hown", strlen("hown") + 1, 0, 1, 10, FREQUENCY, false, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "hown", strlen("hown") + 1, 0, 1, 10, FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(10, leaves.size()); std::set expected_words = {"town", "sown", "mown", "lown", "howl", "howk", "howe", "how", "horn", "hoon"}; @@ -809,23 +826,28 @@ TEST(ArtTest, test_art_fuzzy_search) { // fuzzy prefix search leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "lionhear", strlen("lionhear"), 0, 0, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "lionhear", strlen("lionhear"), 0, 0, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(3, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "lineage", strlen("lineage"), 0, 0, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "lineage", strlen("lineage"), 0, 0, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "liq", strlen("liq"), 0, 0, 50, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "liq", strlen("liq"), 0, 0, 50, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(39, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "antitraditiana", strlen("antitraditiana"), 0, 1, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "antitraditiana", strlen("antitraditiana"), 0, 1, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); - art_fuzzy_search(&t, (const unsigned char *) "antisocao", strlen("antisocao"), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + exclude_leaves.clear(); + art_fuzzy_search(&t, (const unsigned char *) "antisocao", strlen("antisocao"), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(6, leaves.size()); long long int timeMillis = std::chrono::duration_cast( @@ -855,7 +877,7 @@ TEST(ArtTest, test_art_fuzzy_search_unicode_chars) { EXPECT_EQ(1, posting_t::first_id(l->values)); std::vector leaves; - art_fuzzy_search(&t, (unsigned char *)key, strlen(key), 0, 0, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (unsigned char *)key, strlen(key), 0, 0, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); } @@ -879,7 +901,7 @@ TEST(ArtTest, test_art_fuzzy_search_extra_chars) { const char* query = "abbreviation"; std::vector leaves; - art_fuzzy_search(&t, (unsigned char *)query, strlen(query), 0, 2, 10, FREQUENCY, true, nullptr, 0, leaves); + art_fuzzy_search(&t, (unsigned char *)query, strlen(query), 0, 2, 10, FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); @@ -918,15 +940,16 @@ TEST(ArtTest, test_art_search_sku_like_tokens) { for (const auto &key : keys) { std::vector leaves; art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size(), 0, 0, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); leaves.clear(); + exclude_leaves.clear(); // non prefix art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size()+1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); } @@ -970,14 +993,17 @@ TEST(ArtTest, test_art_search_ill_like_tokens) { std::make_pair("ice", 2), }; + std::string key = "input"; + for (const auto &key : keys) { art_leaf* l = (art_leaf *) art_search(&t, (const unsigned char *)key.c_str(), key.size()+1); ASSERT_FALSE(l == nullptr); EXPECT_EQ(1, posting_t::num_ids(l->values)); std::vector leaves; + exclude_leaves.clear(); art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size(), 0, 0, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); if(key_to_count.count(key) != 0) { ASSERT_EQ(key_to_count[key], leaves.size()); @@ -987,10 +1013,14 @@ TEST(ArtTest, test_art_search_ill_like_tokens) { } leaves.clear(); + exclude_leaves.clear(); // non prefix art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size()+1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); + if(leaves.size() != 1) { + LOG(INFO) << key; + } ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); } @@ -1022,8 +1052,9 @@ TEST(ArtTest, test_art_search_ill_like_tokens2) { EXPECT_EQ(1, posting_t::num_ids(l->values)); std::vector leaves; + exclude_leaves.clear(); art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size(), 0, 0, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); if(key == "illustration") { ASSERT_EQ(2, leaves.size()); @@ -1033,10 +1064,11 @@ TEST(ArtTest, test_art_search_ill_like_tokens2) { } leaves.clear(); + exclude_leaves.clear(); // non prefix art_fuzzy_search(&t, (const unsigned char*)key.c_str(), key.size() + 1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); ASSERT_STREQ(key.c_str(), (const char *) leaves.at(0)->key); } @@ -1059,12 +1091,12 @@ TEST(ArtTest, test_art_search_roche_chews) { std::string term = "chews"; std::vector leaves; art_fuzzy_search(&t, (const unsigned char*)term.c_str(), term.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(0, leaves.size()); art_fuzzy_search(&t, (const unsigned char*)keys[0].c_str(), keys[0].size() + 1, 0, 0, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); @@ -1091,14 +1123,15 @@ TEST(ArtTest, test_art_search_raspberry) { std::string q_raspberries = "raspberries"; art_fuzzy_search(&t, (const unsigned char*)q_raspberries.c_str(), q_raspberries.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); leaves.clear(); + exclude_leaves.clear(); std::string q_raspberry = "raspberry"; art_fuzzy_search(&t, (const unsigned char*)q_raspberry.c_str(), q_raspberry.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(2, leaves.size()); res = art_tree_destroy(&t); @@ -1124,13 +1157,16 @@ TEST(ArtTest, test_art_search_highliving) { std::string query = "higghliving"; art_fuzzy_search(&t, (const unsigned char*)query.c_str(), query.size() + 1, 0, 1, 10, - FREQUENCY, false, nullptr, 0, leaves); + FREQUENCY, false, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); leaves.clear(); + exclude_leaves.clear(); + exclude_leaves.clear(); + exclude_leaves.clear(); art_fuzzy_search(&t, (const unsigned char*)query.c_str(), query.size(), 0, 2, 10, - FREQUENCY, true, nullptr, 0, leaves); + FREQUENCY, true, false, "", nullptr, 0, leaves, exclude_leaves); ASSERT_EQ(1, leaves.size()); res = art_tree_destroy(&t); diff --git a/test/collection_specific_test.cpp b/test/collection_specific_test.cpp index a6ea83e8..d7a4ad76 100644 --- a/test/collection_specific_test.cpp +++ b/test/collection_specific_test.cpp @@ -203,6 +203,8 @@ TEST_F(CollectionSpecificTest, ExactSingleFieldMatch) { spp::sparse_hash_set(), spp::sparse_hash_set(), 10, "", 30, 4, "title", 10).get(); + LOG(INFO) << results; + ASSERT_EQ(2, results["hits"].size()); ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); diff --git a/test/collection_test.cpp b/test/collection_test.cpp index b72573a3..109cc8d4 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -473,7 +473,7 @@ TEST_F(CollectionTest, TextContainingAnActualTypo) { ASSERT_EQ(4, results["hits"].size()); ASSERT_EQ(11, results["found"].get()); - std::vector ids = {"19", "22", "6", "13"}; + std::vector ids = {"19", "6", "21", "22"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); From fc8f0d72a78d202ad151e625410d1fa616772b7c Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 14 Feb 2023 16:07:20 +0530 Subject: [PATCH 3/5] Enable search cutoff for art search. --- src/art.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/art.cpp b/src/art.cpp index 835f77b9..40b028a3 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1017,6 +1017,8 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r q.push(root); + size_t num_processed = 0; + while(!q.empty() && results.size() < max_results*4) { art_node *n = (art_node *) q.top(); q.pop(); @@ -1034,6 +1036,13 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r //LOG(INFO) << "END LEAF SCORE: " << l->max_score; validate_and_add_leaf(l, last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, exclude_leaves, exact_leaf, results); + + if (++num_processed % 1024 == 0 && (microseconds( + std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { + search_cutoff = true; + break; + } + continue; } From 47879ff35c9452ef30fed6ab26dab5af8d533c42 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 26 Feb 2023 13:31:23 +0530 Subject: [PATCH 4/5] Handle bad filter query in override. --- include/field.h | 4 +- src/index.cpp | 4 +- test/collection_override_test.cpp | 75 +++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/include/field.h b/include/field.h index c8ba94b0..0f39ef69 100644 --- a/include/field.h +++ b/include/field.h @@ -516,8 +516,8 @@ struct filter_node_t { filter filter_exp; FILTER_OPERATOR filter_operator; bool isOperator; - filter_node_t* left; - filter_node_t* right; + filter_node_t* left = nullptr; + filter_node_t* right = nullptr; filter_node_t(filter filter_exp) : filter_exp(std::move(filter_exp)), diff --git a/src/index.cpp b/src/index.cpp index fab2a644..dc42a305 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2041,7 +2041,7 @@ bool Index::static_filter_query_eval(const override_t* override, if ((override->rule.match == override_t::MATCH_EXACT && override->rule.query == query) || (override->rule.match == override_t::MATCH_CONTAINS && StringUtils::contains_word(query, override->rule.query))) { - filter_node_t* new_filter_tree_root; + filter_node_t* new_filter_tree_root = nullptr; Option filter_op = filter::parse_filter_query(override->filter_by, search_schema, store, "", new_filter_tree_root); if (filter_op.ok()) { @@ -2196,7 +2196,7 @@ void Index::process_filter_overrides(const std::vector& filte token_order, absorbed_tokens, filter_by_clause); if (resolved_override) { - filter_node_t* new_filter_tree_root; + filter_node_t* new_filter_tree_root = nullptr; Option filter_op = filter::parse_filter_query(filter_by_clause, search_schema, store, "", new_filter_tree_root); if (filter_op.ok()) { diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 21af3405..0d283c1c 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -779,6 +779,35 @@ TEST_F(CollectionOverrideTest, IncludeOverrideWithFilterBy) { ASSERT_EQ(2, results["hits"].size()); ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + + // when bad filter by clause is used in override + override_json_include = { + {"id", "include-rule-2"}, + { + "rule", { + {"query", "test"}, + {"match", override_t::MATCH_EXACT} + } + }, + {"filter_curated_hits", false}, + {"stop_processing", false}, + {"remove_matched_tokens", false}, + {"filter_by", "price >55"} + }; + + override_json_include["includes"] = nlohmann::json::array(); + override_json_include["includes"][0] = nlohmann::json::object(); + override_json_include["includes"][0]["id"] = "2"; + override_json_include["includes"][0]["position"] = 1; + + override_t override_include2; + op = override_t::parse(override_json_include, "include-rule-2", override_include2); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_include2); + + results = coll1->search("random-name", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + ASSERT_EQ(0, results["hits"].size()); } TEST_F(CollectionOverrideTest, ReplaceQuery) { @@ -1673,6 +1702,52 @@ TEST_F(CollectionOverrideTest, DynamicFilteringMissingField) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionOverrideTest, DynamicFilteringBadFilterBy) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("category", field_types::STRING, true), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["category"] = "shoes"; + doc1["points"] = 3; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json = { + {"id", "dynamic-cat-filter"}, + { + "rule", { + {"query", "{category}"}, // this field does NOT exist + {"match", override_t::MATCH_EXACT} + } + }, + {"remove_matched_tokens", true}, + {"filter_by", "category: {category} && foo"} + }; + + override_t override; + auto op = override_t::parse(override_json, "dynamic-cat-filter", override); + ASSERT_TRUE(op.ok()); + coll1->add_override(override); + + auto results = coll1->search("shoes", {"name", "category"}, "", + {}, sort_fields, {2, 2}, 10).get(); + + ASSERT_EQ(1, results["hits"].size()); + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionOverrideTest, DynamicFilteringMultiplePlaceholders) { Collection* coll1; From 26e3407d1dc4ed7dc0ebdbf7c83bff25b74493a3 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 26 Feb 2023 19:58:24 +0530 Subject: [PATCH 5/5] Add guards for filter node. --- src/collection.cpp | 8 +++----- src/field.cpp | 8 +++++++- test/collection_filtering_test.cpp | 10 ++++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index 35a7a697..e0087da6 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1035,6 +1035,8 @@ Option Collection::search(const std::string & raw_query, filter_node_t* filter_tree_root = nullptr; Option parse_filter_op = filter::parse_filter_query(filter_query, search_schema, store, doc_id_prefix, filter_tree_root); + std::unique_ptr filter_tree_root_guard(filter_tree_root); + if(!parse_filter_op.ok()) { return Option(parse_filter_op.code(), parse_filter_op.error()); } @@ -1277,6 +1279,7 @@ Option Collection::search(const std::string & raw_query, min_len_1typo, min_len_2typo, max_candidates, infixes, max_extra_prefix, max_extra_suffix, facet_query_num_typos, filter_curated_hits, split_join_tokens, vector_query); + std::unique_ptr search_params_guard(search_params); index->run_search(search_params); @@ -1804,11 +1807,6 @@ Option Collection::search(const std::string & raw_query, result["facet_counts"].push_back(facet_result); } - // free search params - delete search_params; - - delete filter_tree_root; - result["search_cutoff"] = search_cutoff; result["request_params"] = nlohmann::json::object(); diff --git a/src/field.cpp b/src/field.cpp index 18ac472a..ba9c67f2 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -372,7 +372,7 @@ Option toParseTree(std::queue& postfix, filter_node_t*& root, const std::string expression = postfix.front(); postfix.pop(); - filter_node_t* filter_node; + filter_node_t* filter_node = nullptr; if (isOperator(expression)) { auto message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; @@ -383,6 +383,7 @@ Option toParseTree(std::queue& postfix, filter_node_t*& root, nodeStack.pop(); if (nodeStack.empty()) { + delete operandB; return Option(400, message); } auto operandA = nodeStack.top(); @@ -393,6 +394,11 @@ Option toParseTree(std::queue& postfix, filter_node_t*& root, filter filter_exp; Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); if (!toFilter_op.ok()) { + while(!nodeStack.empty()) { + auto filterNode = nodeStack.top(); + delete filterNode; + nodeStack.pop(); + } return toFilter_op; } diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index e8587544..c28ab159 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -353,6 +353,16 @@ TEST_F(CollectionFilteringTest, HandleBadlyFormedFilterQuery) { nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "tagzz: gold", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(0, results["hits"].size()); + // compound filter expression containing an unknown field + results = coll_array_fields->search("Jeremy", query_fields, + "(age:>0 || timestamps:> 0) || tagzz: gold", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["hits"].size()); + + // unbalanced paranthesis + results = coll_array_fields->search("Jeremy", query_fields, + "(age:>0 || timestamps:> 0) || ", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(0, results["hits"].size()); + // searching using a string for a numeric field results = coll_array_fields->search("Jeremy", query_fields, "age: abcdef", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(0, results["hits"].size());