From 25318fd3ac894da76194ef8b9229ba75061039d5 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Sat, 18 Mar 2023 17:25:25 +0530 Subject: [PATCH 01/64] Add `filter_result_iterator_t`. --- include/filter.h | 81 ++++++++ include/index.h | 2 + include/option.h | 12 ++ include/posting_list.h | 2 + src/filter.cpp | 430 +++++++++++++++++++++++++++++++++++++++++ src/posting_list.cpp | 36 ++++ test/filter_test.cpp | 182 +++++++++++++++++ 7 files changed, 745 insertions(+) create mode 100644 include/filter.h create mode 100644 src/filter.cpp create mode 100644 test/filter_test.cpp diff --git a/include/filter.h b/include/filter.h new file mode 100644 index 00000000..239a7e77 --- /dev/null +++ b/include/filter.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include "index.h" + +class filter_result_iterator_t { +private: + std::string collection_name; + const Index* index; + filter_node_t* filter_node; + filter_result_iterator_t* left_it = nullptr; + filter_result_iterator_t* right_it = nullptr; + + // Used in case of id and reference filter. + uint32_t result_index = 0; + + // Stores the result of the filters that cannot be iterated. + filter_result_t filter_result; + + // Initialized in case of filter on string field. + // Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + // for each token. + // + // Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_list_iterators; + std::vector expanded_plists; + + // Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initializes the state of iterator node after it's creation. + void init(); + + /// Performs AND on the subtrees of operator. + void and_filter_iterators(); + + /// Performs OR on the subtrees of operator. + void or_filter_iterators(); + + /// Finds the next match for a filter on string field. + void doc_matching_string_filter(); + +public: + uint32_t doc; + // Collection name -> references + std::map reference; + Option status; + + explicit filter_result_iterator_t(const std::string& collection_name, + const Index* index, filter_node_t* filter_node, + Option& status) : + collection_name(collection_name), + index(index), + filter_node(filter_node), + status(status) { + // Generate the iterator tree and then initialize each node. + if (filter_node->isOperator) { + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, status); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, status); + } + + init(); + } + + ~filter_result_iterator_t() { + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + } + + [[nodiscard]] bool valid(); + + void next(); + + void skip_to(uint32_t id); +}; diff --git a/include/index.h b/include/index.h index 38fdc576..f2430e97 100644 --- a/include/index.h +++ b/include/index.h @@ -956,6 +956,8 @@ public: std::map>& included_ids_map, std::vector& included_ids_vec, std::unordered_set& excluded_group_ids) const; + + friend class filter_result_iterator_t; }; template diff --git a/include/option.h b/include/option.h index 0f8c49e0..ced54ae9 100644 --- a/include/option.h +++ b/include/option.h @@ -31,6 +31,18 @@ public: error_code = obj.error_code; } + Option& operator=(Option&& obj) noexcept { + if (&obj == this) + return *this; + + value = obj.value; + is_ok = obj.is_ok; + error_msg = obj.error_msg; + error_code = obj.error_code; + + return *this; + } + bool ok() const { return is_ok; } diff --git a/include/posting_list.h b/include/posting_list.h index dea742f7..16f42ea7 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -164,6 +164,8 @@ public: static void intersect(const std::vector& posting_lists, std::vector& result_ids); + static void intersect(std::vector& posting_list_iterators, bool& is_valid); + template static bool block_intersect( std::vector& its, diff --git a/src/filter.cpp b/src/filter.cpp new file mode 100644 index 00000000..804e19a4 --- /dev/null +++ b/src/filter.cpp @@ -0,0 +1,430 @@ +#include +#include +#include +#include "filter.h" + +void filter_result_iterator_t::and_filter_iterators() { + while (left_it->valid() && right_it->valid()) { + while (left_it->doc < right_it->doc) { + left_it->next(); + if (!left_it->valid()) { + is_valid = false; + return; + } + } + + while (left_it->doc > right_it->doc) { + right_it->next(); + if (!right_it->valid()) { + is_valid = false; + return; + } + } + + if (left_it->doc == right_it->doc) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + } + + is_valid = false; +} + +void filter_result_iterator_t::or_filter_iterators() { + if (left_it->valid() && right_it->valid()) { + if (left_it->doc < right_it->doc) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->doc > right_it->doc) { + doc = right_it->doc; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->valid()) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (right_it->valid()) { + doc = right_it->doc; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + is_valid = false; +} + +void filter_result_iterator_t::doc_matching_string_filter() { + // If none of the filter value iterators are valid, mark this node as invalid. + bool one_is_valid = false; + + // Since we do OR between filter values, the lowest doc id from all is selected. + uint32_t lowest_id = UINT32_MAX; + + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + + if (one_is_valid) { + doc = lowest_id; + } + + is_valid = one_is_valid; +} + +void filter_result_iterator_t::next() { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Advance all the filter values that are at doc. Then find the next one. + std::vector doc_matching_indexes; + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + const auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { + doc_matching_indexes.push_back(i); + } + } + + for (const auto &lowest_id_index: doc_matching_indexes) { + for (auto &iter: posting_list_iterators[lowest_id_index]) { + iter.next(); + } + } + + doc_matching_string_filter(); + return; + } +} + +void filter_result_iterator_t::init() { + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + if (collection == nullptr) { + status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + is_valid = false; + return; + } + + auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, + filter_result, + collection_name); + if (!reference_filter_op.ok()) { + status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name + + "` collection: " + reference_filter_op.error()); + is_valid = false; + return; + } + + is_valid = filter_result.count > 0; + return; + } + + if (a_filter.field_name == "id") { + if (a_filter.values.empty()) { + is_valid = false; + return; + } + + // we handle `ids` separately + std::vector result_ids; + for (const auto& id_str : a_filter.values) { + result_ids.push_back(std::stoul(id_str)); + } + + std::sort(result_ids.begin(), result_ids.end()); + + filter_result.count = result_ids.size(); + filter_result.docs = new uint32_t[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + art_tree* t = index->search_index.at(a_filter.field_name); + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + std::vector plists; + posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + + posting_list_iterators.emplace_back(std::vector()); + + for (auto const& plist: plists) { + posting_list_iterators.back().push_back(plist->new_iterator()); + } + } + + doc_matching_string_filter(); + return; + } +} + +bool filter_result_iterator_t::valid() { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + is_valid = left_it->valid() && right_it->valid(); + return is_valid; + } else { + is_valid = left_it->valid() || right_it->valid(); + return is_valid; + } + } + + const filter a_filter = filter_node->filter_exp; + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + is_valid = result_index < filter_result.count; + return is_valid; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return is_valid; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + bool one_is_valid = false; + for (auto& filter_value_tokens: posting_list_iterators) { + posting_list_t::intersect(filter_value_tokens, one_is_valid); + + if (one_is_valid) { + break; + } + } + + is_valid = one_is_valid; + return is_valid; + } + + return true; +} + +void filter_result_iterator_t::skip_to(uint32_t id) { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + if (filter_node->filter_operator == AND) { + left_it->skip_to(id); + and_filter_iterators(); + } else { + right_it->skip_to(id); + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Skip all the token iterators and find a new match. + for (auto& filter_value_tokens : posting_list_iterators) { + for (auto& token: filter_value_tokens) { + // We perform AND on tokens. Short-circuiting here. + if (!token.valid()) { + break; + } + + token.skip_to(id); + } + } + + doc_matching_string_filter(); + return; + } +} diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 23f17e8b..39f3ac00 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -754,6 +754,42 @@ void posting_list_t::intersect(const std::vector& posting_lists } } +void posting_list_t::intersect(std::vector& posting_list_iterators, bool& is_valid) { + if (posting_list_iterators.empty()) { + is_valid = false; + return; + } + + if (posting_list_iterators.size() == 1) { + is_valid = posting_list_iterators.front().valid(); + return; + } + + switch (posting_list_iterators.size()) { + case 2: + while(!at_end2(posting_list_iterators)) { + if(equals2(posting_list_iterators)) { + is_valid = true; + return; + } else { + advance_non_largest2(posting_list_iterators); + } + } + is_valid = false; + break; + default: + while(!at_end(posting_list_iterators)) { + if(equals(posting_list_iterators)) { + is_valid = true; + return; + } else { + advance_non_largest(posting_list_iterators); + } + } + is_valid = false; + } +} + bool posting_list_t::take_id(result_iter_state_t& istate, uint32_t id) { // decide if this result id should be excluded if(istate.excluded_result_ids_size != 0) { diff --git a/test/filter_test.cpp b/test/filter_test.cpp new file mode 100644 index 00000000..60cff669 --- /dev/null +++ b/test/filter_test.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class FilterTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_join"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(FilterTest, FilterTreeIterator) { + nlohmann::json schema = + R"({ + "name": "Collection", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int32"}, + {"name": "years", "type": "int32[]"}, + {"name": "rating", "type": "float"}, + {"name": "tags", "type": "string[]"} + ] + })"_json; + + Collection* coll = collectionManager.create_collection(schema).get(); + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::string json_line; + while (std::getline(infile, json_line)) { + auto add_op = coll->add(json_line); + ASSERT_TRUE(add_op.ok()); + } + infile.close(); + + const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_"; + filter_node_t* filter_tree_root = nullptr; + Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + Option iter_op(true); + auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_FALSE(iter_no_match_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: [foo bar, baz]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_FALSE(iter_no_match_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: Jeremy", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_contains_test.valid()); + ASSERT_EQ(i, iter_contains_test.doc); + iter_contains_test.next(); + } + ASSERT_FALSE(iter_contains_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: [Jeremy, Howard, Richard]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_contains_multi_test.valid()); + ASSERT_EQ(i, iter_contains_multi_test.doc); + iter_contains_multi_test.next(); + } + ASSERT_FALSE(iter_contains_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name:= Jeremy Howard", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_exact_match_test.valid()); + ASSERT_EQ(i, iter_exact_match_test.doc); + iter_exact_match_test.next(); + } + ASSERT_FALSE(iter_exact_match_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= [gold, silver]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + std::vector expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_exact_match_multi_test.valid()); + ASSERT_EQ(i, iter_exact_match_multi_test.doc); + iter_exact_match_multi_test.next(); + } + ASSERT_FALSE(iter_exact_match_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + +// delete filter_tree_root; +// filter_tree_root = nullptr; +// filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, +// filter_tree_root); +// ASSERT_TRUE(filter_op.ok()); +// +// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); +// +// std::vector expected = {1, 3}; +// for (auto const& i : expected) { +// ASSERT_TRUE(iter_not_equals_test.valid()); +// ASSERT_EQ(i, iter_not_equals_test.doc); +// iter_not_equals_test.next(); +// } +// +// ASSERT_FALSE(iter_not_equals_test.valid()); +// ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_skip_test.valid()); + iter_skip_test.skip_to(3); + ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_EQ(4, iter_skip_test.doc); + iter_skip_test.next(); + + ASSERT_FALSE(iter_skip_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; +} \ No newline at end of file From b024e0d4ecfdb2bc86184bf43d15c31882889965 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 10:33:01 +0530 Subject: [PATCH 02/64] Add `filter_result_iterator_t::valid(uint32_t id)`. --- include/filter.h | 23 +++++++++++++++-------- src/filter.cpp | 26 ++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/include/filter.h b/include/filter.h index 239a7e77..d11a8f9d 100644 --- a/include/filter.h +++ b/include/filter.h @@ -12,21 +12,21 @@ private: filter_result_iterator_t* left_it = nullptr; filter_result_iterator_t* right_it = nullptr; - // Used in case of id and reference filter. + /// Used in case of id and reference filter. uint32_t result_index = 0; - // Stores the result of the filters that cannot be iterated. + /// Stores the result of the filters that cannot be iterated. filter_result_t filter_result; - // Initialized in case of filter on string field. - // Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator - // for each token. - // - // Multiple filter values: Multiple tokens: posting list iterator + /// Initialized in case of filter on string field. + /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + /// for each token. + /// + /// Multiple filter values: Multiple tokens: posting list iterator std::vector> posting_list_iterators; std::vector expanded_plists; - // Set to false when this iterator or it's subtree becomes invalid. + /// Set to false when this iterator or it's subtree becomes invalid. bool is_valid = true; /// Initializes the state of iterator node after it's creation. @@ -73,9 +73,16 @@ public: delete right_it; } + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). [[nodiscard]] bool valid(); + /// Returns true when id is a match to the filter. Handles moving the individual iterators internally. + [[nodiscard]] bool valid(uint32_t id); + + /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this + /// operation. void next(); + /// Advances the iterator until the doc value is less than id. The iterator may become invalid during this operation. void skip_to(uint32_t id); }; diff --git a/src/filter.cpp b/src/filter.cpp index 804e19a4..a8fa1eeb 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -361,11 +361,12 @@ void filter_result_iterator_t::skip_to(uint32_t id) { if (filter_node->isOperator) { // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + left_it->skip_to(id); + right_it->skip_to(id); + if (filter_node->filter_operator == AND) { - left_it->skip_to(id); and_filter_iterators(); } else { - right_it->skip_to(id); or_filter_iterators(); } @@ -428,3 +429,24 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } } + +bool filter_result_iterator_t::valid(uint32_t id) { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + auto and_is_valid = left_it->valid(id) && right_it->valid(id); + is_valid = left_it->is_valid && right_it->is_valid; + return and_is_valid; + } else { + auto or_is_valid = left_it->valid(id) || right_it->valid(id); + is_valid = left_it->is_valid || right_it->is_valid; + return or_is_valid; + } + } + + skip_to(id); + return is_valid && doc == id; +} From 0a022382ce7fa62764be620ca4ea181cbb43d6d9 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 12:56:53 +0530 Subject: [PATCH 03/64] Add test cases for `AND` and `OR`. --- include/filter.h | 3 ++- src/filter.cpp | 24 ++++++++++++++++++------ test/filter_test.cpp | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/include/filter.h b/include/filter.h index d11a8f9d..643c5006 100644 --- a/include/filter.h +++ b/include/filter.h @@ -83,6 +83,7 @@ public: /// operation. void next(); - /// Advances the iterator until the doc value is less than id. The iterator may become invalid during this operation. + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during + /// this operation. void skip_to(uint32_t id); }; diff --git a/src/filter.cpp b/src/filter.cpp index a8fa1eeb..b42c3e5d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -4,10 +4,10 @@ #include "filter.h" void filter_result_iterator_t::and_filter_iterators() { - while (left_it->valid() && right_it->valid()) { + while (left_it->is_valid && right_it->is_valid) { while (left_it->doc < right_it->doc) { left_it->next(); - if (!left_it->valid()) { + if (!left_it->is_valid) { is_valid = false; return; } @@ -15,7 +15,7 @@ void filter_result_iterator_t::and_filter_iterators() { while (left_it->doc > right_it->doc) { right_it->next(); - if (!right_it->valid()) { + if (!right_it->is_valid) { is_valid = false; return; } @@ -40,7 +40,7 @@ void filter_result_iterator_t::and_filter_iterators() { } void filter_result_iterator_t::or_filter_iterators() { - if (left_it->valid() && right_it->valid()) { + if (left_it->is_valid && right_it->is_valid) { if (left_it->doc < right_it->doc) { doc = left_it->doc; reference.clear(); @@ -76,7 +76,7 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - if (left_it->valid()) { + if (left_it->is_valid) { doc = left_it->doc; reference.clear(); @@ -87,7 +87,7 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - if (right_it->valid()) { + if (right_it->is_valid) { doc = right_it->doc; reference.clear(); @@ -133,9 +133,21 @@ void filter_result_iterator_t::next() { } if (filter_node->isOperator) { + // Advance the subtrees and then apply operators to arrive at the next valid doc. if (filter_node->filter_operator == AND) { + left_it->next(); + right_it->next(); and_filter_iterators(); } else { + if (left_it->doc == doc && right_it->doc == doc) { + left_it->next(); + right_it->next(); + } else if (left_it->doc == doc) { + left_it->next(); + } else { + right_it->next(); + } + or_filter_iterators(); } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 60cff669..99edf27f 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -178,5 +178,49 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_skip_test.valid()); ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: jeremy && tags: fine platinum", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_and_test.valid()); + ASSERT_EQ(1, iter_and_test.doc); + iter_and_test.next(); + + ASSERT_FALSE(iter_and_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || tags: bronze", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto doc = + R"({ + "name": "James Rowdy", + "age": 36, + "years": [2005, 2022], + "rating": 6.03, + "tags": ["copper"] + })"_json; + auto add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + + auto iter_or_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + expected = {2, 4, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_or_test.valid()); + ASSERT_EQ(i, iter_or_test.doc); + iter_or_test.next(); + } + + ASSERT_FALSE(iter_or_test.valid()); + ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; } \ No newline at end of file From 41063e2e2a89b02c01306f0a4294b30e4a968d88 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 13:22:04 +0530 Subject: [PATCH 04/64] Add test case for complex filter. --- test/filter_test.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 99edf27f..ee8b7b00 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -222,5 +222,26 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_or_test.valid()); ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || (tags: gold && tags: silver)", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + iter_skip_complex_filter_test.skip_to(4); + + expected = {4, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + ASSERT_EQ(i, iter_skip_complex_filter_test.doc); + iter_skip_complex_filter_test.next(); + } + + ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; } \ No newline at end of file From dee367c7d2c49499ffe39dc53341814ba6e08f78 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 15:50:14 +0530 Subject: [PATCH 05/64] Add `filter_result_iterator_t::valid(uint32_t id)` test case. --- test/filter_test.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/filter_test.cpp b/test/filter_test.cpp index ee8b7b00..96121e92 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -243,5 +243,21 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_skip_complex_filter_test.valid()); ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || (tags: gold && tags: [silver, bronze])", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + expected = {0, 2, 4, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_validate_ids_test.valid(i)); + } + + ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; } \ No newline at end of file From ed719d19f3623e43462ee568d2a6a8e39cdb8d46 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 18:01:26 +0530 Subject: [PATCH 06/64] Handle `apply_not_equals`. --- src/filter.cpp | 17 +++++++++++++++++ test/filter_test.cpp | 17 ++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/filter.cpp b/src/filter.cpp index b42c3e5d..e7a46093 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -459,6 +459,23 @@ bool filter_result_iterator_t::valid(uint32_t id) { } } + if (filter_node->filter_exp.apply_not_equals) { + // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. + if (!valid()) { + is_valid = true; + return is_valid; + } + + skip_to(id); + + if (!is_valid) { + is_valid = true; + return is_valid; + } + + return doc != id; + } + skip_to(id); return is_valid && doc == id; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 96121e92..3c4b94a5 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -256,7 +256,22 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_validate_ids_test.valid(i)); } - ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || tags: != gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_not_equals_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root, iter_op); + + expected = {1, 3, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.valid(i)); + } + ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; From 3b0a2cbd55fa693118006618e2e5b27125a41681 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 22 Mar 2023 10:46:01 +0530 Subject: [PATCH 07/64] Refactor `filter_result_iterator_t::next()`. --- src/filter.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index e7a46093..2c289761 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -190,19 +190,14 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { - // Advance all the filter values that are at doc. Then find the next one. - std::vector doc_matching_indexes; + // Advance all the filter values that are at doc. Then find the next doc. for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { - const auto& filter_value_tokens = posting_list_iterators[i]; + auto& filter_value_tokens = posting_list_iterators[i]; if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { - doc_matching_indexes.push_back(i); - } - } - - for (const auto &lowest_id_index: doc_matching_indexes) { - for (auto &iter: posting_list_iterators[lowest_id_index]) { - iter.next(); + for (auto& iter: filter_value_tokens) { + iter.next(); + } } } @@ -363,7 +358,7 @@ bool filter_result_iterator_t::valid() { return is_valid; } - return true; + return false; } void filter_result_iterator_t::skip_to(uint32_t id) { From eb7a5b55e662f759a1825a4c0819ce1bc74cad5b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 22 Mar 2023 13:23:40 +0530 Subject: [PATCH 08/64] Add `filter_result_iterator_t::init_status()`. --- include/filter.h | 17 +++++----- src/filter.cpp | 54 ++++++++++++++++++------------- test/filter_test.cpp | 76 ++++++++++++++++++++++---------------------- 3 files changed, 79 insertions(+), 68 deletions(-) diff --git a/include/filter.h b/include/filter.h index 643c5006..c5a20f11 100644 --- a/include/filter.h +++ b/include/filter.h @@ -42,22 +42,20 @@ private: void doc_matching_string_filter(); public: - uint32_t doc; + uint32_t seq_id = 0; // Collection name -> references std::map reference; - Option status; + Option status = Option(true); explicit filter_result_iterator_t(const std::string& collection_name, - const Index* index, filter_node_t* filter_node, - Option& status) : + const Index* index, filter_node_t* filter_node) : collection_name(collection_name), index(index), - filter_node(filter_node), - status(status) { + filter_node(filter_node) { // Generate the iterator tree and then initialize each node. if (filter_node->isOperator) { - left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, status); - right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, status); + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right); } init(); @@ -73,6 +71,9 @@ public: delete right_it; } + /// Returns the status of the initialization of iterator tree. + Option init_status(); + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). [[nodiscard]] bool valid(); diff --git a/src/filter.cpp b/src/filter.cpp index 2c289761..16e52c49 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -5,7 +5,7 @@ void filter_result_iterator_t::and_filter_iterators() { while (left_it->is_valid && right_it->is_valid) { - while (left_it->doc < right_it->doc) { + while (left_it->seq_id < right_it->seq_id) { left_it->next(); if (!left_it->is_valid) { is_valid = false; @@ -13,7 +13,7 @@ void filter_result_iterator_t::and_filter_iterators() { } } - while (left_it->doc > right_it->doc) { + while (left_it->seq_id > right_it->seq_id) { right_it->next(); if (!right_it->is_valid) { is_valid = false; @@ -21,8 +21,8 @@ void filter_result_iterator_t::and_filter_iterators() { } } - if (left_it->doc == right_it->doc) { - doc = left_it->doc; + if (left_it->seq_id == right_it->seq_id) { + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -41,8 +41,8 @@ void filter_result_iterator_t::and_filter_iterators() { void filter_result_iterator_t::or_filter_iterators() { if (left_it->is_valid && right_it->is_valid) { - if (left_it->doc < right_it->doc) { - doc = left_it->doc; + if (left_it->seq_id < right_it->seq_id) { + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -52,8 +52,8 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - if (left_it->doc > right_it->doc) { - doc = right_it->doc; + if (left_it->seq_id > right_it->seq_id) { + seq_id = right_it->seq_id; reference.clear(); for (const auto& item: right_it->reference) { @@ -63,7 +63,7 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - doc = left_it->doc; + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -77,7 +77,7 @@ void filter_result_iterator_t::or_filter_iterators() { } if (left_it->is_valid) { - doc = left_it->doc; + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -88,7 +88,7 @@ void filter_result_iterator_t::or_filter_iterators() { } if (right_it->is_valid) { - doc = right_it->doc; + seq_id = right_it->seq_id; reference.clear(); for (const auto& item: right_it->reference) { @@ -105,7 +105,7 @@ void filter_result_iterator_t::doc_matching_string_filter() { // If none of the filter value iterators are valid, mark this node as invalid. bool one_is_valid = false; - // Since we do OR between filter values, the lowest doc id from all is selected. + // Since we do OR between filter values, the lowest seq_id id from all is selected. uint32_t lowest_id = UINT32_MAX; for (auto& filter_value_tokens : posting_list_iterators) { @@ -121,7 +121,7 @@ void filter_result_iterator_t::doc_matching_string_filter() { } if (one_is_valid) { - doc = lowest_id; + seq_id = lowest_id; } is_valid = one_is_valid; @@ -139,10 +139,10 @@ void filter_result_iterator_t::next() { right_it->next(); and_filter_iterators(); } else { - if (left_it->doc == doc && right_it->doc == doc) { + if (left_it->seq_id == seq_id && right_it->seq_id == seq_id) { left_it->next(); right_it->next(); - } else if (left_it->doc == doc) { + } else if (left_it->seq_id == seq_id) { left_it->next(); } else { right_it->next(); @@ -163,7 +163,7 @@ void filter_result_iterator_t::next() { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; reference.clear(); for (auto const& item: filter_result.reference_filter_results) { reference[item.first] = item.second[result_index]; @@ -178,7 +178,7 @@ void filter_result_iterator_t::next() { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; return; } @@ -194,7 +194,7 @@ void filter_result_iterator_t::next() { for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { auto& filter_value_tokens = posting_list_iterators[i]; - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { for (auto& iter: filter_value_tokens) { iter.next(); } @@ -391,7 +391,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; reference.clear(); for (auto const& item: filter_result.reference_filter_results) { reference[item.first] = item.second[result_index]; @@ -408,7 +408,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; return; } @@ -468,9 +468,19 @@ bool filter_result_iterator_t::valid(uint32_t id) { return is_valid; } - return doc != id; + return seq_id != id; } skip_to(id); - return is_valid && doc == id; + return is_valid && seq_id == id; +} + +Option filter_result_iterator_t::init_status() { + if (filter_node->isOperator) { + auto left_status = left_it->init_status(); + + return !left_status.ok() ? left_status : right_it->init_status(); + } + + return status; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 3c4b94a5..935efff8 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -64,11 +64,10 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - Option iter_op(true); - auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_no_match_test.init_status().ok()); ASSERT_FALSE(iter_no_match_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -76,10 +75,10 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_no_match_multi_test.init_status().ok()); ASSERT_FALSE(iter_no_match_multi_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -87,14 +86,15 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_contains_test.init_status().ok()); + for (uint32_t i = 0; i < 5; i++) { ASSERT_TRUE(iter_contains_test.valid()); - ASSERT_EQ(i, iter_contains_test.doc); + ASSERT_EQ(i, iter_contains_test.seq_id); iter_contains_test.next(); } ASSERT_FALSE(iter_contains_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -102,14 +102,15 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_contains_multi_test.init_status().ok()); + for (uint32_t i = 0; i < 5; i++) { ASSERT_TRUE(iter_contains_multi_test.valid()); - ASSERT_EQ(i, iter_contains_multi_test.doc); + ASSERT_EQ(i, iter_contains_multi_test.seq_id); iter_contains_multi_test.next(); } ASSERT_FALSE(iter_contains_multi_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -117,14 +118,15 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_test.init_status().ok()); + for (uint32_t i = 0; i < 5; i++) { ASSERT_TRUE(iter_exact_match_test.valid()); - ASSERT_EQ(i, iter_exact_match_test.doc); + ASSERT_EQ(i, iter_exact_match_test.seq_id); iter_exact_match_test.next(); } ASSERT_FALSE(iter_exact_match_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -132,16 +134,16 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_multi_test.init_status().ok()); std::vector expected = {0, 2, 3, 4}; for (auto const& i : expected) { ASSERT_TRUE(iter_exact_match_multi_test.valid()); - ASSERT_EQ(i, iter_exact_match_multi_test.doc); + ASSERT_EQ(i, iter_exact_match_multi_test.seq_id); iter_exact_match_multi_test.next(); } ASSERT_FALSE(iter_exact_match_multi_test.valid()); - ASSERT_TRUE(iter_op.ok()); // delete filter_tree_root; // filter_tree_root = nullptr; @@ -149,17 +151,17 @@ TEST_F(FilterTest, FilterTreeIterator) { // filter_tree_root); // ASSERT_TRUE(filter_op.ok()); // -// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); +// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); +// ASSERT_TRUE(iter_not_equals_test.init_status().ok()); // // std::vector expected = {1, 3}; // for (auto const& i : expected) { // ASSERT_TRUE(iter_not_equals_test.valid()); -// ASSERT_EQ(i, iter_not_equals_test.doc); +// ASSERT_EQ(i, iter_not_equals_test.seq_id); // iter_not_equals_test.next(); // } // // ASSERT_FALSE(iter_not_equals_test.valid()); -// ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -167,16 +169,16 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test.init_status().ok()); ASSERT_TRUE(iter_skip_test.valid()); iter_skip_test.skip_to(3); ASSERT_TRUE(iter_skip_test.valid()); - ASSERT_EQ(4, iter_skip_test.doc); + ASSERT_EQ(4, iter_skip_test.seq_id); iter_skip_test.next(); ASSERT_FALSE(iter_skip_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -184,14 +186,14 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_and_test.init_status().ok()); ASSERT_TRUE(iter_and_test.valid()); - ASSERT_EQ(1, iter_and_test.doc); + ASSERT_EQ(1, iter_and_test.seq_id); iter_and_test.next(); ASSERT_FALSE(iter_and_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -210,17 +212,17 @@ TEST_F(FilterTest, FilterTreeIterator) { auto add_op = coll->add(doc.dump()); ASSERT_TRUE(add_op.ok()); - auto iter_or_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_or_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_or_test.init_status().ok()); expected = {2, 4, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_or_test.valid()); - ASSERT_EQ(i, iter_or_test.doc); + ASSERT_EQ(i, iter_or_test.seq_id); iter_or_test.next(); } ASSERT_FALSE(iter_or_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -228,7 +230,8 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_complex_filter_test.init_status().ok()); ASSERT_TRUE(iter_skip_complex_filter_test.valid()); iter_skip_complex_filter_test.skip_to(4); @@ -236,12 +239,11 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {4, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_skip_complex_filter_test.valid()); - ASSERT_EQ(i, iter_skip_complex_filter_test.doc); + ASSERT_EQ(i, iter_skip_complex_filter_test.seq_id); iter_skip_complex_filter_test.next(); } ASSERT_FALSE(iter_skip_complex_filter_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -249,15 +251,14 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); expected = {0, 2, 4, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_validate_ids_test.valid(i)); } - ASSERT_TRUE(iter_op.ok()); - delete filter_tree_root; filter_tree_root = nullptr; filter_op = filter::parse_filter_query("name: James || tags: != gold", coll->get_schema(), store, doc_id_prefix, @@ -265,14 +266,13 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(filter_op.ok()); auto iter_validate_ids_not_equals_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), - filter_tree_root, iter_op); + filter_tree_root); + ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); expected = {1, 3, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.valid(i)); } - ASSERT_TRUE(iter_op.ok()); - delete filter_tree_root; } \ No newline at end of file From 4d0c7b5112bd7d448cd405fa109ed8c87cc1ed72 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 27 Mar 2023 08:25:15 +0530 Subject: [PATCH 09/64] Refactor `valid(id)`. --- include/filter.h | 9 +++++++-- src/filter.cpp | 38 ++++++++++++++++++++++++++++---------- test/filter_test.cpp | 16 +++++++++------- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/include/filter.h b/include/filter.h index c5a20f11..9bb78deb 100644 --- a/include/filter.h +++ b/include/filter.h @@ -77,8 +77,13 @@ public: /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). [[nodiscard]] bool valid(); - /// Returns true when id is a match to the filter. Handles moving the individual iterators internally. - [[nodiscard]] bool valid(uint32_t id); + /// Returns a tri-state: + /// 0: id is not valid + /// 1: id is valid + /// -1: end of iterator + /// + /// Handles moving the individual iterators internally. + [[nodiscard]] int valid(uint32_t id); /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this /// operation. diff --git a/src/filter.cpp b/src/filter.cpp index 16e52c49..44393d8d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -437,20 +437,38 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } } -bool filter_result_iterator_t::valid(uint32_t id) { +int filter_result_iterator_t::valid(uint32_t id) { if (!is_valid) { - return false; + return -1; } if (filter_node->isOperator) { + auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); + if (filter_node->filter_operator == AND) { - auto and_is_valid = left_it->valid(id) && right_it->valid(id); is_valid = left_it->is_valid && right_it->is_valid; - return and_is_valid; + + if (left_valid < 1 || right_valid < 1) { + if (left_valid == -1 || right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; } else { - auto or_is_valid = left_it->valid(id) || right_it->valid(id); is_valid = left_it->is_valid || right_it->is_valid; - return or_is_valid; + + if (left_valid < 1 && right_valid < 1) { + if (left_valid == -1 && right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; } } @@ -458,21 +476,21 @@ bool filter_result_iterator_t::valid(uint32_t id) { // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. if (!valid()) { is_valid = true; - return is_valid; + return 1; } skip_to(id); if (!is_valid) { is_valid = true; - return is_valid; + return 1; } - return seq_id != id; + return seq_id != id ? 1 : 0; } skip_to(id); - return is_valid && seq_id == id; + return is_valid ? (seq_id == id ? 1 : 0) : -1; } Option filter_result_iterator_t::init_status() { diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 935efff8..e607e03b 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -137,7 +137,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_exact_match_multi_test.init_status().ok()); - std::vector expected = {0, 2, 3, 4}; + std::vector expected = {0, 2, 3, 4}; for (auto const& i : expected) { ASSERT_TRUE(iter_exact_match_multi_test.valid()); ASSERT_EQ(i, iter_exact_match_multi_test.seq_id); @@ -254,9 +254,10 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); - expected = {0, 2, 4, 5}; - for (auto const& i : expected) { - ASSERT_TRUE(iter_validate_ids_test.valid(i)); + std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}; + expected = {1, 0, 1, 0, 1, 1, -1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); } delete filter_tree_root; @@ -269,9 +270,10 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); - expected = {1, 3, 5}; - for (auto const& i : expected) { - ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.valid(i)); + validate_ids = {0, 1, 2, 3, 4, 5, 6, 7, 100}; + expected = {0, 1, 0, 1, 0, 1, 1, 1, 1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); } delete filter_tree_root; From a7719d7b30f901041dd5585fbe8bbfae4fa1c0da Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 28 Mar 2023 10:00:47 +0530 Subject: [PATCH 10/64] Add `filter_result_iterator_t::contains_atleast_one`. --- include/filter.h | 3 +++ src/filter.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/include/filter.h b/include/filter.h index 9bb78deb..c8d4e3b0 100644 --- a/include/filter.h +++ b/include/filter.h @@ -2,6 +2,7 @@ #include #include +#include "posting_list.h" #include "index.h" class filter_result_iterator_t { @@ -92,4 +93,6 @@ public: /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); + + bool contains_atleast_one(const void* obj); }; diff --git a/src/filter.cpp b/src/filter.cpp index 44393d8d..304799b3 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -502,3 +502,45 @@ Option filter_result_iterator_t::init_status() { return status; } + +bool filter_result_iterator_t::contains_atleast_one(const void *obj) { + if(IS_COMPACT_POSTING(obj)) { + compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); + + size_t i = 0; + while(i < list->length && valid()) { + size_t num_existing_offsets = list->id_offsets[i]; + size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; + + if (existing_id == seq_id) { + return true; + } + + // advance smallest value + if (existing_id < seq_id) { + i += num_existing_offsets + 2; + } else { + skip_to(existing_id); + } + } + } else { + auto list = (posting_list_t*)(obj); + posting_list_t::iterator_t it = list->new_iterator(); + + while(it.valid() && valid()) { + uint32_t id = it.id(); + + if(id == seq_id) { + return true; + } + + if(id < seq_id) { + it.skip_to(seq_id); + } else { + skip_to(id); + } + } + } + + return false; +} From 4a1de71e81b74cb60596b45f693c60f3b25c6c88 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 28 Mar 2023 18:21:58 +0530 Subject: [PATCH 11/64] Add tests for `filter_result_iterator_t::contains_atleast_one`. --- test/filter_test.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/test/filter_test.cpp b/test/filter_test.cpp index e607e03b..0c3687db 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "collection.h" class FilterTest : public ::testing::Test { @@ -276,5 +277,59 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); } + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_compact_plist_contains_atleast_one_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_compact_plist_contains_atleast_one_test1.init_status().ok()); + + std::vector ids = {1, 3, 5}; + std::vector offset_index = {0, 3, 6}; + std::vector offsets = {0, 3, 4, 0, 3, 4, 0, 3, 4}; + + compact_posting_list_t* c_list1 = compact_posting_list_t::create(3, &ids[0], &offset_index[0], 9, &offsets[0]); + ASSERT_FALSE(iter_compact_plist_contains_atleast_one_test1.contains_atleast_one(SET_COMPACT_POSTING(c_list1))); + free(c_list1); + + auto iter_compact_plist_contains_atleast_one_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_compact_plist_contains_atleast_one_test2.init_status().ok()); + + ids = {1, 3, 4}; + offset_index = {0, 3, 6}; + offsets = {0, 3, 4, 0, 3, 4, 0, 3, 4}; + + compact_posting_list_t* c_list2 = compact_posting_list_t::create(3, &ids[0], &offset_index[0], 9, &offsets[0]); + ASSERT_TRUE(iter_compact_plist_contains_atleast_one_test2.contains_atleast_one(SET_COMPACT_POSTING(c_list2))); + free(c_list2); + + auto iter_plist_contains_atleast_one_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_plist_contains_atleast_one_test1.init_status().ok()); + + posting_list_t p_list1(2); + ids = {1, 3, 5}; + for (const auto &i: ids) { + p_list1.upsert(i, {1, 2, 3}); + } + + ASSERT_FALSE(iter_plist_contains_atleast_one_test1.contains_atleast_one(&p_list1)); + + auto iter_plist_contains_atleast_one_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_plist_contains_atleast_one_test2.init_status().ok()); + + posting_list_t p_list2(2); + ids = {1, 3, 4}; + for (const auto &i: ids) { + p_list1.upsert(i, {1, 2, 3}); + } + + ASSERT_TRUE(iter_plist_contains_atleast_one_test2.contains_atleast_one(&p_list1)); + delete filter_tree_root; } \ No newline at end of file From b1aaffe7d00e77e83bcb37dd2609ff3374257d08 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 29 Mar 2023 13:35:39 +0530 Subject: [PATCH 12/64] Add `filter_result_iterator_t::reset`. --- include/filter.h | 4 ++++ src/filter.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ test/filter_test.cpp | 26 ++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/include/filter.h b/include/filter.h index c8d4e3b0..41ccfac8 100644 --- a/include/filter.h +++ b/include/filter.h @@ -94,5 +94,9 @@ public: /// this operation. void skip_to(uint32_t id); + /// Returns true if at least one id from the posting list object matches the filter. bool contains_atleast_one(const void* obj); + + /// Returns to the initial state of the iterator. + void reset(); }; diff --git a/src/filter.cpp b/src/filter.cpp index 304799b3..d40a5564 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -544,3 +544,45 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { return false; } + +void filter_result_iterator_t::reset() { + if (filter_node->isOperator) { + // Reset the subtrees then apply operators to arrive at the first valid doc. + left_it->reset(); + right_it->reset(); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter || a_filter.field_name == "id") { + result_index = 0; + is_valid = filter_result.count > 0; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + posting_list_iterators.clear(); + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + expanded_plists.clear(); + + init(); + return; + } +} diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 0c3687db..af7145f2 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -331,5 +331,31 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_plist_contains_atleast_one_test2.contains_atleast_one(&p_list1)); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= [gold, silver]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_reset_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_reset_test.init_status().ok()); + + expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_EQ(i, iter_reset_test.seq_id); + iter_reset_test.next(); + } + ASSERT_FALSE(iter_reset_test.valid()); + + iter_reset_test.reset(); + + for (auto const& i : expected) { + ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_EQ(i, iter_reset_test.seq_id); + iter_reset_test.next(); + } + ASSERT_FALSE(iter_reset_test.valid()); + delete filter_tree_root; } \ No newline at end of file From 7471e45678a8fec1a4a69e4ffcc4ae9f8b1988e9 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 30 Mar 2023 09:52:10 +0530 Subject: [PATCH 13/64] Handle null filter tree. --- include/filter.h | 5 +++++ src/filter.cpp | 10 +++++++++- test/filter_test.cpp | 6 ++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/filter.h b/include/filter.h index 41ccfac8..babee59d 100644 --- a/include/filter.h +++ b/include/filter.h @@ -53,6 +53,11 @@ public: collection_name(collection_name), index(index), filter_node(filter_node) { + if (filter_node == nullptr) { + is_valid = false; + return; + } + // Generate the iterator tree and then initialize each node. if (filter_node->isOperator) { left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); diff --git a/src/filter.cpp b/src/filter.cpp index d40a5564..27c21198 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -207,6 +207,10 @@ void filter_result_iterator_t::next() { } void filter_result_iterator_t::init() { + if (filter_node == nullptr) { + return; + } + if (filter_node->isOperator) { if (filter_node->filter_operator == AND) { and_filter_iterators(); @@ -494,7 +498,7 @@ int filter_result_iterator_t::valid(uint32_t id) { } Option filter_result_iterator_t::init_status() { - if (filter_node->isOperator) { + if (filter_node != nullptr && filter_node->isOperator) { auto left_status = left_it->init_status(); return !left_status.ok() ? left_status : right_it->init_status(); @@ -546,6 +550,10 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { } void filter_result_iterator_t::reset() { + if (filter_node == nullptr) { + return; + } + if (filter_node->isOperator) { // Reset the subtrees then apply operators to arrive at the first valid doc. left_it->reset(); diff --git a/test/filter_test.cpp b/test/filter_test.cpp index af7145f2..6a952c57 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -61,6 +61,12 @@ TEST_F(FilterTest, FilterTreeIterator) { const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_"; filter_node_t* filter_tree_root = nullptr; + + auto iter_null_filter_tree_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + + ASSERT_TRUE(iter_null_filter_tree_test.init_status().ok()); + ASSERT_FALSE(iter_null_filter_tree_test.valid()); + Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, filter_tree_root); ASSERT_TRUE(filter_op.ok()); From a55fe2259bab1873cb19e820e94d9e01450459e3 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 3 Apr 2023 11:35:27 +0530 Subject: [PATCH 14/64] Refactor `filter_result_iterator_t`. --- include/filter.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/filter.h b/include/filter.h index babee59d..901e6e4f 100644 --- a/include/filter.h +++ b/include/filter.h @@ -5,11 +5,13 @@ #include "posting_list.h" #include "index.h" +class Index; + class filter_result_iterator_t { private: - std::string collection_name; - const Index* index; - filter_node_t* filter_node; + const std::string collection_name; + Index const* const index = nullptr; + filter_node_t const* const filter_node = nullptr; filter_result_iterator_t* left_it = nullptr; filter_result_iterator_t* right_it = nullptr; @@ -49,7 +51,7 @@ public: Option status = Option(true); explicit filter_result_iterator_t(const std::string& collection_name, - const Index* index, filter_node_t* filter_node) : + Index const* const index, filter_node_t const* const filter_node) : collection_name(collection_name), index(index), filter_node(filter_node) { From 9be1d8e841d4d35108cc6d4421951213b3ddbf7e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 4 Apr 2023 11:53:47 +0530 Subject: [PATCH 15/64] Add move assignment operator in `filter_result_iterator_t`. --- include/filter.h | 45 ++++++++++++++++++++++++++++++++++++++++---- test/filter_test.cpp | 13 +++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/include/filter.h b/include/filter.h index 901e6e4f..5deca7b5 100644 --- a/include/filter.h +++ b/include/filter.h @@ -9,9 +9,9 @@ class Index; class filter_result_iterator_t { private: - const std::string collection_name; - Index const* const index = nullptr; - filter_node_t const* const filter_node = nullptr; + std::string collection_name; + const Index* index = nullptr; + const filter_node_t* filter_node = nullptr; filter_result_iterator_t* left_it = nullptr; filter_result_iterator_t* right_it = nullptr; @@ -50,7 +50,7 @@ public: std::map reference; Option status = Option(true); - explicit filter_result_iterator_t(const std::string& collection_name, + explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node) : collection_name(collection_name), index(index), @@ -79,6 +79,43 @@ public: delete right_it; } + filter_result_iterator_t& operator=(filter_result_iterator_t&& obj) noexcept { + if (&obj == this) + return *this; + + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + + collection_name = obj.collection_name; + index = obj.index; + filter_node = obj.filter_node; + left_it = obj.left_it; + right_it = obj.right_it; + + obj.left_it = nullptr; + obj.right_it = nullptr; + + result_index = obj.result_index; + + filter_result = std::move(obj.filter_result); + + posting_list_iterators = std::move(obj.posting_list_iterators); + expanded_plists = std::move(obj.expanded_plists); + + is_valid = obj.is_valid; + + seq_id = obj.seq_id; + reference = std::move(obj.reference); + status = std::move(obj.status); + + return *this; + } + /// Returns the status of the initialization of iterator tree. Option init_status(); diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 6a952c57..91dd81c8 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -363,5 +363,18 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_reset_test.valid()); + auto iter_move_assignment_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + + iter_reset_test.reset(); + iter_move_assignment_test = std::move(iter_reset_test); + + expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_move_assignment_test.valid()); + ASSERT_EQ(i, iter_move_assignment_test.seq_id); + iter_move_assignment_test.next(); + } + ASSERT_FALSE(iter_move_assignment_test.valid()); + delete filter_tree_root; } \ No newline at end of file From 3a3814aba500d51594ebbd5cbd19a83d54751960 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 5 Apr 2023 10:10:13 +0530 Subject: [PATCH 16/64] Add `to_filter_id_array` and `and_scalar` methods. --- include/filter.h | 16 +++++++++++++--- src/filter.cpp | 45 ++++++++++++++++++++++++++++++++++++++++++++ test/filter_test.cpp | 39 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/include/filter.h b/include/filter.h index 5deca7b5..4103f49f 100644 --- a/include/filter.h +++ b/include/filter.h @@ -2,10 +2,14 @@ #include #include +#include #include "posting_list.h" #include "index.h" class Index; +struct filter_node_t; +struct reference_filter_result_t; +struct filter_result_t; class filter_result_iterator_t { private: @@ -52,9 +56,9 @@ public: explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node) : - collection_name(collection_name), - index(index), - filter_node(filter_node) { + collection_name(collection_name), + index(index), + filter_node(filter_node) { if (filter_node == nullptr) { is_valid = false; return; @@ -143,4 +147,10 @@ public: /// Returns to the initial state of the iterator. void reset(); + + /// Iterates and collects all the filter ids into filter_array. + /// \return size of the filter array + uint32_t to_filter_id_array(uint32_t*& filter_array); + + uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); }; diff --git a/src/filter.cpp b/src/filter.cpp index 27c21198..77ef5d5d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -594,3 +594,48 @@ void filter_result_iterator_t::reset() { return; } } + +uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + do { + filter_ids.push_back(seq_id); + next(); + } while (valid()); + + filter_array = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), filter_array); + + return filter_ids.size(); +} + +uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + for (uint32_t i = 0; i < lenA; i++) { + auto result = valid(A[i]); + + if (result == -1) { + break; + } + + if (result == 1) { + filter_ids.push_back(A[i]); + } + } + + if (filter_ids.empty()) { + return 0; + } + + results = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), results); + + return filter_ids.size(); +} diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 91dd81c8..b4f836f1 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -377,4 +377,43 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_move_assignment_test.valid()); delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_to_array_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_to_array_test.init_status().ok()); + + uint32_t* filter_ids = nullptr; + uint32_t filter_ids_length; + + filter_ids_length = iter_to_array_test.to_filter_id_array(filter_ids); + ASSERT_EQ(3, filter_ids_length); + + expected = {0, 2, 4}; + for (uint32_t i = 0; i < filter_ids_length; i++) { + ASSERT_EQ(expected[i], filter_ids[i]); + } + ASSERT_FALSE(iter_to_array_test.valid()); + + delete filter_ids; + + auto iter_and_scalar_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_and_scalar_test.init_status().ok()); + + uint32_t a_ids[6] = {0, 1, 3, 4, 5, 6}; + uint32_t* and_result = nullptr; + uint32_t and_result_length; + and_result_length = iter_and_scalar_test.and_scalar(a_ids, 6, and_result); + ASSERT_EQ(2, and_result_length); + + expected = {0, 4}; + for (uint32_t i = 0; i < and_result_length; i++) { + ASSERT_EQ(expected[i], and_result[i]); + } + ASSERT_FALSE(iter_and_test.valid()); + + delete and_result; + delete filter_tree_root; } \ No newline at end of file From 191013dc5d0e53459ee2985a0110d874ccc88ab8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 10 Apr 2023 17:43:39 +0530 Subject: [PATCH 17/64] Refactor filtering logic to overcome circular referencing. Handle exact string filtering in `filter_result_iterator`. --- include/art.h | 19 +- include/field.h | 205 +----- include/filter.h | 218 ++---- include/filter_result_iterator.h | 174 +++++ include/index.h | 56 +- include/posting_list.h | 11 + include/topster.h | 1 + include/validator.h | 15 - src/art.cpp | 190 +++++ src/collection.cpp | 1 + src/field.cpp | 669 ----------------- src/filter.cpp | 1038 +++++++++++++-------------- src/filter_result_iterator.cpp | 901 +++++++++++++++++++++++ src/index.cpp | 284 ++++---- src/or_iterator.cpp | 6 +- src/posting_list.cpp | 140 ++++ src/validator.cpp | 1 + test/collection_all_fields_test.cpp | 5 +- test/collection_test.cpp | 1 + test/filter_test.cpp | 22 +- 20 files changed, 2195 insertions(+), 1762 deletions(-) create mode 100644 include/filter_result_iterator.h create mode 100644 src/filter_result_iterator.cpp diff --git a/include/art.h b/include/art.h index a9715fac..11f57a68 100644 --- a/include/art.h +++ b/include/art.h @@ -1,5 +1,4 @@ -#ifndef ART_H -#define ART_H +#pragma once #include #include @@ -7,6 +6,7 @@ #include #include "array.h" #include "sorted_array.h" +#include "filter_result_iterator.h" #define IGNORE_PRINTF 1 @@ -111,7 +111,7 @@ struct token_leaf { uint32_t num_typos; token_leaf(art_leaf* leaf, uint32_t root_len, uint32_t num_typos, bool is_prefix) : - leaf(leaf), root_len(root_len), num_typos(num_typos), is_prefix(is_prefix) { + leaf(leaf), root_len(root_len), num_typos(num_typos), is_prefix(is_prefix) { } }; @@ -157,11 +157,6 @@ enum NUM_COMPARATOR { RANGE_INCLUSIVE }; -enum FILTER_OPERATOR { - AND, - OR -}; - /** * Initializes an ART tree * @return 0 on success. @@ -281,6 +276,12 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const uint32_t *filter_ids, const size_t filter_ids_length, std::vector &results, std::set& exclude_leaves); +int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, + const size_t max_words, const token_ordering token_order, + const bool prefix, bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + std::vector &results, std::set& exclude_leaves); + void encode_int32(int32_t n, unsigned char *chars); void encode_int64(int64_t n, unsigned char *chars); @@ -295,6 +296,4 @@ int art_float_search(art_tree *t, float value, NUM_COMPARATOR comparator, std::v #ifdef __cplusplus } -#endif - #endif \ No newline at end of file diff --git a/include/field.h b/include/field.h index 4fa6cc6e..03c7015a 100644 --- a/include/field.h +++ b/include/field.h @@ -2,13 +2,13 @@ #include #include -#include "art.h" #include "option.h" #include "string_utils.h" #include "logger.h" #include "store.h" #include #include +#include #include "json.hpp" #include "text_embedder_manager.h" @@ -505,200 +505,19 @@ struct field { static void compact_nested_fields(tsl::htrie_map& nested_fields); }; -struct filter_node_t; - -struct filter { - std::string field_name; - std::vector values; - std::vector comparators; - // Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the - // case of int and float fields. During filtering, all the results of matching the field against the values are - // aggregated and then this flag is checked if negation on the aggregated result is required. - bool apply_not_equals = false; - - // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` - std::string referenced_collection_name = ""; - - static const std::string RANGE_OPERATOR() { - return ".."; - } - - static Option validate_numerical_filter_value(field _field, const std::string& raw_value) { - if(_field.is_int32() && !StringUtils::is_int32_t(raw_value)) { - return Option(400, "Error with filter field `" + _field.name + "`: Not an int32."); - } - - else if(_field.is_int64() && !StringUtils::is_int64_t(raw_value)) { - return Option(400, "Error with filter field `" + _field.name + "`: Not an int64."); - } - - else if(_field.is_float() && !StringUtils::is_float(raw_value)) { - return Option(400, "Error with filter field `" + _field.name + "`: Not a float."); - } - - return Option(true); - } - - static Option extract_num_comparator(std::string & comp_and_value) { - auto num_comparator = EQUALS; - - if(StringUtils::is_integer(comp_and_value) || StringUtils::is_float(comp_and_value)) { - num_comparator = EQUALS; - } - - // the ordering is important - we have to compare 2-letter operators first - else if(comp_and_value.compare(0, 2, "<=") == 0) { - num_comparator = LESS_THAN_EQUALS; - } - - else if(comp_and_value.compare(0, 2, ">=") == 0) { - num_comparator = GREATER_THAN_EQUALS; - } - - else if(comp_and_value.compare(0, 2, "!=") == 0) { - num_comparator = NOT_EQUALS; - } - - else if(comp_and_value.compare(0, 1, "<") == 0) { - num_comparator = LESS_THAN; - } - - else if(comp_and_value.compare(0, 1, ">") == 0) { - num_comparator = GREATER_THAN; - } - - else if(comp_and_value.find("..") != std::string::npos) { - num_comparator = RANGE_INCLUSIVE; - } - - else { - return Option(400, "Numerical field has an invalid comparator."); - } - - if(num_comparator == LESS_THAN || num_comparator == GREATER_THAN) { - comp_and_value = comp_and_value.substr(1); - } else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS || num_comparator == NOT_EQUALS) { - comp_and_value = comp_and_value.substr(2); - } - - comp_and_value = StringUtils::trim(comp_and_value); - - return Option(num_comparator); - } - - static Option parse_geopoint_filter_value(std::string& raw_value, - const std::string& format_err_msg, - std::string& processed_filter_val, - NUM_COMPARATOR& num_comparator); - - static Option parse_filter_query(const std::string& filter_query, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix, - filter_node_t*& root); +enum index_operation_t { + CREATE, + UPSERT, + UPDATE, + EMPLACE, + DELETE }; -struct filter_node_t { - filter filter_exp; - FILTER_OPERATOR filter_operator; - bool isOperator; - filter_node_t* left = nullptr; - filter_node_t* right = nullptr; - - filter_node_t(filter filter_exp) - : filter_exp(std::move(filter_exp)), - isOperator(false), - left(nullptr), - right(nullptr) {} - - filter_node_t(FILTER_OPERATOR filter_operator, - filter_node_t* left, - filter_node_t* right) - : filter_operator(filter_operator), - isOperator(true), - left(left), - right(right) {} - - ~filter_node_t() { - delete left; - delete right; - } -}; - -struct reference_filter_result_t { - uint32_t count = 0; - uint32_t* docs = nullptr; - - reference_filter_result_t& operator=(const reference_filter_result_t& obj) noexcept { - if (&obj == this) - return *this; - - count = obj.count; - docs = new uint32_t[count]; - memcpy(docs, obj.docs, count * sizeof(uint32_t)); - - return *this; - } - - ~reference_filter_result_t() { - delete[] docs; - } -}; - -struct filter_result_t { - uint32_t count = 0; - uint32_t* docs = nullptr; - // Collection name -> Reference filter result - std::map reference_filter_results; - - filter_result_t() = default; - - filter_result_t(uint32_t count, uint32_t* docs) : count(count), docs(docs) {} - - filter_result_t& operator=(const filter_result_t& obj) noexcept { - if (&obj == this) - return *this; - - count = obj.count; - docs = new uint32_t[count]; - memcpy(docs, obj.docs, count * sizeof(uint32_t)); - - // Copy every collection's references. - for (const auto &item: obj.reference_filter_results) { - reference_filter_results[item.first] = new reference_filter_result_t[count]; - - for (uint32_t i = 0; i < count; i++) { - reference_filter_results[item.first][i] = item.second[i]; - } - } - - return *this; - } - - filter_result_t& operator=(filter_result_t&& obj) noexcept { - if (&obj == this) - return *this; - - count = obj.count; - docs = obj.docs; - reference_filter_results = std::map(obj.reference_filter_results); - - obj.docs = nullptr; - obj.reference_filter_results.clear(); - - return *this; - } - - ~filter_result_t() { - delete[] docs; - for (const auto &item: reference_filter_results) { - delete[] item.second; - } - } - - static void and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); - - static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); +enum class DIRTY_VALUES { + REJECT = 1, + DROP = 2, + COERCE_OR_REJECT = 3, + COERCE_OR_DROP = 4, }; namespace sort_field_const { diff --git a/include/filter.h b/include/filter.h index 4103f49f..1961b0ac 100644 --- a/include/filter.h +++ b/include/filter.h @@ -2,155 +2,73 @@ #include #include -#include -#include "posting_list.h" -#include "index.h" +#include +#include +#include "store.h" -class Index; -struct filter_node_t; -struct reference_filter_result_t; -struct filter_result_t; - -class filter_result_iterator_t { -private: - std::string collection_name; - const Index* index = nullptr; - const filter_node_t* filter_node = nullptr; - filter_result_iterator_t* left_it = nullptr; - filter_result_iterator_t* right_it = nullptr; - - /// Used in case of id and reference filter. - uint32_t result_index = 0; - - /// Stores the result of the filters that cannot be iterated. - filter_result_t filter_result; - - /// Initialized in case of filter on string field. - /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator - /// for each token. - /// - /// Multiple filter values: Multiple tokens: posting list iterator - std::vector> posting_list_iterators; - std::vector expanded_plists; - - /// Set to false when this iterator or it's subtree becomes invalid. - bool is_valid = true; - - /// Initializes the state of iterator node after it's creation. - void init(); - - /// Performs AND on the subtrees of operator. - void and_filter_iterators(); - - /// Performs OR on the subtrees of operator. - void or_filter_iterators(); - - /// Finds the next match for a filter on string field. - void doc_matching_string_filter(); - -public: - uint32_t seq_id = 0; - // Collection name -> references - std::map reference; - Option status = Option(true); - - explicit filter_result_iterator_t(const std::string collection_name, - Index const* const index, filter_node_t const* const filter_node) : - collection_name(collection_name), - index(index), - filter_node(filter_node) { - if (filter_node == nullptr) { - is_valid = false; - return; - } - - // Generate the iterator tree and then initialize each node. - if (filter_node->isOperator) { - left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); - right_it = new filter_result_iterator_t(collection_name, index, filter_node->right); - } - - init(); - } - - ~filter_result_iterator_t() { - // In case the filter was on string field. - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - - delete left_it; - delete right_it; - } - - filter_result_iterator_t& operator=(filter_result_iterator_t&& obj) noexcept { - if (&obj == this) - return *this; - - // In case the filter was on string field. - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - - delete left_it; - delete right_it; - - collection_name = obj.collection_name; - index = obj.index; - filter_node = obj.filter_node; - left_it = obj.left_it; - right_it = obj.right_it; - - obj.left_it = nullptr; - obj.right_it = nullptr; - - result_index = obj.result_index; - - filter_result = std::move(obj.filter_result); - - posting_list_iterators = std::move(obj.posting_list_iterators); - expanded_plists = std::move(obj.expanded_plists); - - is_valid = obj.is_valid; - - seq_id = obj.seq_id; - reference = std::move(obj.reference); - status = std::move(obj.status); - - return *this; - } - - /// Returns the status of the initialization of iterator tree. - Option init_status(); - - /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). - [[nodiscard]] bool valid(); - - /// Returns a tri-state: - /// 0: id is not valid - /// 1: id is valid - /// -1: end of iterator - /// - /// Handles moving the individual iterators internally. - [[nodiscard]] int valid(uint32_t id); - - /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this - /// operation. - void next(); - - /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during - /// this operation. - void skip_to(uint32_t id); - - /// Returns true if at least one id from the posting list object matches the filter. - bool contains_atleast_one(const void* obj); - - /// Returns to the initial state of the iterator. - void reset(); - - /// Iterates and collects all the filter ids into filter_array. - /// \return size of the filter array - uint32_t to_filter_id_array(uint32_t*& filter_array); - - uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); +enum FILTER_OPERATOR { + AND, + OR +}; + +struct filter_node_t; +struct field; + +struct filter { + std::string field_name; + std::vector values; + std::vector comparators; + // Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the + // case of int and float fields. During filtering, all the results of matching the field against the values are + // aggregated and then this flag is checked if negation on the aggregated result is required. + bool apply_not_equals = false; + + // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` + std::string referenced_collection_name = ""; + + static const std::string RANGE_OPERATOR() { + return ".."; + } + + static Option validate_numerical_filter_value(field _field, const std::string& raw_value); + + static Option extract_num_comparator(std::string & comp_and_value); + + static Option parse_geopoint_filter_value(std::string& raw_value, + const std::string& format_err_msg, + std::string& processed_filter_val, + NUM_COMPARATOR& num_comparator); + + static Option parse_filter_query(const std::string& filter_query, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix, + filter_node_t*& root); +}; + +struct filter_node_t { + filter filter_exp; + FILTER_OPERATOR filter_operator; + bool isOperator; + filter_node_t* left = nullptr; + filter_node_t* right = nullptr; + + filter_node_t(filter filter_exp) + : filter_exp(std::move(filter_exp)), + isOperator(false), + left(nullptr), + right(nullptr) {} + + filter_node_t(FILTER_OPERATOR filter_operator, + filter_node_t* left, + filter_node_t* right) + : filter_operator(filter_operator), + isOperator(true), + left(left), + right(right) {} + + ~filter_node_t() { + delete left; + delete right; + } }; diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h new file mode 100644 index 00000000..b67d67ca --- /dev/null +++ b/include/filter_result_iterator.h @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include +#include "option.h" +#include "posting_list.h" + +class Index; +struct filter_node_t; + +struct reference_filter_result_t { + uint32_t count = 0; + uint32_t* docs = nullptr; + + reference_filter_result_t& operator=(const reference_filter_result_t& obj) noexcept { + if (&obj == this) + return *this; + + count = obj.count; + docs = new uint32_t[count]; + memcpy(docs, obj.docs, count * sizeof(uint32_t)); + + return *this; + } + + ~reference_filter_result_t() { + delete[] docs; + } +}; + +struct filter_result_t { + uint32_t count = 0; + uint32_t* docs = nullptr; + // Collection name -> Reference filter result + std::map reference_filter_results; + + filter_result_t() = default; + + filter_result_t(uint32_t count, uint32_t* docs) : count(count), docs(docs) {} + + filter_result_t& operator=(const filter_result_t& obj) noexcept { + if (&obj == this) + return *this; + + count = obj.count; + docs = new uint32_t[count]; + memcpy(docs, obj.docs, count * sizeof(uint32_t)); + + // Copy every collection's references. + for (const auto &item: obj.reference_filter_results) { + reference_filter_results[item.first] = new reference_filter_result_t[count]; + + for (uint32_t i = 0; i < count; i++) { + reference_filter_results[item.first][i] = item.second[i]; + } + } + + return *this; + } + + filter_result_t& operator=(filter_result_t&& obj) noexcept { + if (&obj == this) + return *this; + + count = obj.count; + docs = obj.docs; + reference_filter_results = std::map(obj.reference_filter_results); + + obj.docs = nullptr; + obj.reference_filter_results.clear(); + + return *this; + } + + ~filter_result_t() { + delete[] docs; + for (const auto &item: reference_filter_results) { + delete[] item.second; + } + } + + static void and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); + + static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); +}; + + +class filter_result_iterator_t { +private: + std::string collection_name; + const Index* index = nullptr; + const filter_node_t* filter_node = nullptr; + filter_result_iterator_t* left_it = nullptr; + filter_result_iterator_t* right_it = nullptr; + + /// Used in case of id and reference filter. + uint32_t result_index = 0; + + /// Stores the result of the filters that cannot be iterated. + filter_result_t filter_result; + + /// Initialized in case of filter on string field. + /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + /// for each token. + /// + /// Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_list_iterators; + std::vector expanded_plists; + + /// Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initializes the state of iterator node after it's creation. + void init(); + + /// Performs AND on the subtrees of operator. + void and_filter_iterators(); + + /// Performs OR on the subtrees of operator. + void or_filter_iterators(); + + /// Finds the next match for a filter on string field. + void doc_matching_string_filter(bool field_is_array); + +public: + uint32_t seq_id = 0; + // Collection name -> references + std::map reference; + Option status = Option(true); + + explicit filter_result_iterator_t(const std::string collection_name, + Index const* const index, filter_node_t const* const filter_node); + + ~filter_result_iterator_t(); + + filter_result_iterator_t& operator=(filter_result_iterator_t&& obj) noexcept; + + /// Returns the status of the initialization of iterator tree. + Option init_status(); + + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). + [[nodiscard]] bool valid(); + + /// Returns a tri-state: + /// 0: id is not valid + /// 1: id is valid + /// -1: end of iterator + /// + /// Handles moving the individual iterators internally. + [[nodiscard]] int valid(uint32_t id); + + /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this + /// operation. + void next(); + + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during + /// this operation. + void skip_to(uint32_t id); + + /// Returns true if at least one id from the posting list object matches the filter. + bool contains_atleast_one(const void* obj); + + /// Returns to the initial state of the iterator. + void reset(); + + /// Iterates and collects all the filter ids into filter_array. + /// \return size of the filter array + uint32_t to_filter_id_array(uint32_t*& filter_array); + + /// Performs AND with the contents of A and allocates a new array of results. + /// \return size of the results array + uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); +}; diff --git a/include/index.h b/include/index.h index f2430e97..1795ecd6 100644 --- a/include/index.h +++ b/include/index.h @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include "string_utils.h" @@ -30,6 +29,7 @@ #include "override.h" #include "vector_query_ops.h" #include "hnswlib/hnswlib.h" +#include "filter.h" static constexpr size_t ARRAY_FACET_DIM = 4; using facet_map_t = spp::sparse_hash_map; @@ -233,19 +233,15 @@ struct index_record { }; class VectorFilterFunctor: public hnswlib::BaseFilterFunctor { - const uint32_t* filter_ids = nullptr; - const uint32_t filter_ids_length = 0; + filter_result_iterator_t* const filter_result_iterator; public: - explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) : - filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} + explicit VectorFilterFunctor(filter_result_iterator_t* const filter_result_iterator) : + filter_result_iterator(filter_result_iterator) {} - bool operator()(hnswlib::labeltype id) override { - if(filter_ids_length == 0) { - return true; - } - - return std::binary_search(filter_ids, filter_ids + filter_ids_length, id); + bool operator()(unsigned int id) { + filter_result_iterator->reset(); + return filter_result_iterator->valid(id) == 1; } }; @@ -412,7 +408,7 @@ private: void search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -565,6 +561,9 @@ public: // in the query that have the least individual hits one by one until enough results are found. static const int DROP_TOKENS_THRESHOLD = 1; + // "_all_" is a special field that maps to all the ids in the index. + static constexpr const char* SEQ_IDS_FILTER = "_all_: 1"; + Index() = delete; Index(const std::string& name, @@ -743,8 +742,9 @@ public: const std::vector& group_by_fields, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, - uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, - uint32_t filter_ids_length, const size_t concurrency, + uint32_t*& all_result_ids, size_t& all_result_ids_len, + filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices) const; @@ -784,7 +784,7 @@ public: std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - const uint32_t *filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -815,7 +815,7 @@ public: spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -840,7 +840,7 @@ public: Topster* curated_topster, const std::map>& included_ids_map, bool is_wildcard_query, - uint32_t*& filter_ids, uint32_t& filter_ids_length) const; + uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const; void fuzzy_search_fields(const std::vector& the_fields, const std::vector& query_tokens, @@ -848,7 +848,7 @@ public: const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -877,7 +877,7 @@ public: const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, @@ -899,7 +899,7 @@ public: const std::vector& group_by_fields, bool prioritize_exact_match, const bool search_all_candidates, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, @@ -948,14 +948,14 @@ public: int64_t* scores, int64_t& match_score_index) const; - void - process_curated_ids(const std::vector>& included_ids, - const std::vector& excluded_ids, const std::vector& group_by_fields, - const size_t group_limit, const bool filter_curated_hits, const uint32_t* filter_ids, - uint32_t filter_ids_length, std::set& curated_ids, - std::map>& included_ids_map, - std::vector& included_ids_vec, - std::unordered_set& excluded_group_ids) const; + void process_curated_ids(const std::vector>& included_ids, + const std::vector& excluded_ids, const std::vector& group_by_fields, + const size_t group_limit, const bool filter_curated_hits, + filter_result_iterator_t& filter_result_iterator, + std::set& curated_ids, + std::map>& included_ids_map, + std::vector& included_ids_vec, + std::unordered_set& excluded_group_ids) const; friend class filter_result_iterator_t; }; diff --git a/include/posting_list.h b/include/posting_list.h index 16f42ea7..11ed9c91 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -8,6 +8,7 @@ #include "thread_local_vars.h" typedef uint32_t last_id_t; +class filter_result_iterator_t; struct result_iter_state_t { const uint32_t* excluded_result_ids = nullptr; @@ -19,12 +20,19 @@ struct result_iter_state_t { size_t excluded_result_ids_index = 0; size_t filter_ids_index = 0; + filter_result_iterator_t* fit = nullptr; + result_iter_state_t() = default; result_iter_state_t(const uint32_t* excluded_result_ids, size_t excluded_result_ids_size, const uint32_t* filter_ids, const size_t filter_ids_length) : excluded_result_ids(excluded_result_ids), excluded_result_ids_size(excluded_result_ids_size), filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} + + result_iter_state_t(const uint32_t* excluded_result_ids, size_t excluded_result_ids_size, + filter_result_iterator_t* fit) : excluded_result_ids(excluded_result_ids), + excluded_result_ids_size(excluded_result_ids_size), + fit(fit){} }; /* @@ -186,6 +194,9 @@ public: const uint32_t* ids, const uint32_t num_ids, uint32_t*& exact_ids, size_t& num_exact_ids); + static bool has_exact_match(std::vector& posting_list_iterators, + const bool field_is_array); + static void get_phrase_matches(std::vector& its, bool field_is_array, const uint32_t* ids, const uint32_t num_ids, uint32_t*& phrase_ids, size_t& num_phrase_ids); diff --git a/include/topster.h b/include/topster.h index e59ae74c..8e10abcb 100644 --- a/include/topster.h +++ b/include/topster.h @@ -6,6 +6,7 @@ #include #include #include +#include "filter_result_iterator.h" struct KV { int8_t match_score_index{}; diff --git a/include/validator.h b/include/validator.h index 3998f5dc..a8a4a8f9 100644 --- a/include/validator.h +++ b/include/validator.h @@ -6,21 +6,6 @@ #include "tsl/htrie_map.h" #include "field.h" -enum index_operation_t { - CREATE, - UPSERT, - UPDATE, - EMPLACE, - DELETE -}; - -enum class DIRTY_VALUES { - REJECT = 1, - DROP = 2, - COERCE_OR_REJECT = 3, - COERCE_OR_DROP = 4, -}; - class validator_t { public: diff --git a/src/art.cpp b/src/art.cpp index 40b028a3..48470551 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -22,6 +22,7 @@ #include "art.h" #include "logger.h" #include "array_utils.h" +#include "filter_result_iterator.h" /** * Macros to manipulate pointer tags @@ -972,6 +973,36 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, return prev_token_doc_ids; } +const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + size_t& prev_token_doc_ids_len) { + + art_leaf* prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + uint32_t* prev_token_doc_ids = nullptr; + + if(prev_token.empty() || !prev_leaf) { + prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids); + return prev_token_doc_ids; + } + + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + if(filter_result_iterator.valid()) { + prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), + prev_token_doc_ids); + } else { + prev_token_doc_ids_len = prev_leaf_ids.size(); + prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; + std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); + } + + return prev_token_doc_ids; +} + bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, std::set& exclude_leaves, const art_leaf* exact_leaf, @@ -1622,6 +1653,165 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, return 0; } +int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, + const size_t max_words, const token_ordering token_order, + const bool prefix, bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + std::vector &results, std::set& exclude_leaves) { + + std::vector nodes; + int irow[term_len + 1]; + int jrow[term_len + 1]; + for (int i = 0; i <= term_len; i++){ + irow[i] = jrow[i] = i; + } + + //auto begin = std::chrono::high_resolution_clock::now(); + + if(IS_LEAF(t->root)) { + art_leaf *l = (art_leaf *) LEAF_RAW(t->root); + art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } else { + if(t->root == nullptr) { + return 0; + } + + // send depth as -1 to indicate that this is a root node + art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } + + //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); + + //auto begin = std::chrono::high_resolution_clock::now(); + + size_t key_len = prefix ? term_len + 1 : term_len; + art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); + //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; + + // documents that contain the previous token and/or filter ids + size_t allowed_doc_ids_len = 0; + const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); + + for(auto node: nodes) { + art_topk_iter(node, token_order, max_words, exact_leaf, + last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + t, exclude_leaves, results); + } + + if(token_order == FREQUENCY) { + std::sort(results.begin(), results.end(), compare_art_leaf_frequency); + } else { + std::sort(results.begin(), results.end(), compare_art_leaf_score); + } + + if(exact_leaf && min_cost == 0) { + std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); + if(exclude_leaves.count(tok) == 0) { + results.insert(results.begin(), exact_leaf); + exclude_leaves.emplace(tok); + } + } + + if(results.size() > max_words) { + results.resize(max_words); + } + + /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + if(time_micro > 1000) { + LOG(INFO) << "Time taken for art_topk_iter: " << time_micro + << "us, size of nodes: " << nodes.size() + << ", filter_ids_length: " << filter_ids_length; + }*/ + +// TODO: Figure out this edge case. +// if(allowed_doc_ids != filter_ids) { +// delete [] allowed_doc_ids; +// } + + return 0; +} + +int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, + const size_t max_words, const token_ordering token_order, const bool prefix, + bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + std::vector &results, std::set& exclude_leaves) { + + std::vector nodes; + int irow[term_len + 1]; + int jrow[term_len + 1]; + for (int i = 0; i <= term_len; i++){ + irow[i] = jrow[i] = i; + } + + //auto begin = std::chrono::high_resolution_clock::now(); + + if(IS_LEAF(t->root)) { + art_leaf *l = (art_leaf *) LEAF_RAW(t->root); + art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } else { + if(t->root == nullptr) { + return 0; + } + + // send depth as -1 to indicate that this is a root node + art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } + + //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); + + //auto begin = std::chrono::high_resolution_clock::now(); + + size_t key_len = prefix ? term_len + 1 : term_len; + art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); + //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; + + // documents that contain the previous token and/or filter ids + size_t allowed_doc_ids_len = 0; + const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); + + for(auto node: nodes) { + art_topk_iter(node, token_order, max_words, exact_leaf, + last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + t, exclude_leaves, results); + } + + if(token_order == FREQUENCY) { + std::sort(results.begin(), results.end(), compare_art_leaf_frequency); + } else { + std::sort(results.begin(), results.end(), compare_art_leaf_score); + } + + if(exact_leaf && min_cost == 0) { + std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); + if(exclude_leaves.count(tok) == 0) { + results.insert(results.begin(), exact_leaf); + exclude_leaves.emplace(tok); + } + } + + if(results.size() > max_words) { + results.resize(max_words); + } + + /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + + if(time_micro > 1000) { + LOG(INFO) << "Time taken for art_topk_iter: " << time_micro + << "us, size of nodes: " << nodes.size() + << ", filter_ids_length: " << filter_ids_length; + }*/ + +// TODO: Figure out this edge case. +// if(allowed_doc_ids != filter_ids) { +// delete [] allowed_doc_ids; +// } + + return 0; +} + void encode_int32(int32_t n, unsigned char *chars) { unsigned char symbols[16] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 diff --git a/src/collection.cpp b/src/collection.cpp index 7d7d031b..dddbf66b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "validator.h" #include "topster.h" #include "logger.h" #include "thread_local_vars.h" diff --git a/src/field.cpp b/src/field.cpp index bcc8470a..69925b84 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -6,511 +6,6 @@ #include #include -Option filter::parse_geopoint_filter_value(std::string& raw_value, - const std::string& format_err_msg, - std::string& processed_filter_val, - NUM_COMPARATOR& num_comparator) { - - num_comparator = LESS_THAN_EQUALS; - - if(!(raw_value[0] == '(' && raw_value[raw_value.size() - 1] == ')')) { - return Option(400, format_err_msg); - } - - std::vector filter_values; - auto raw_val_without_paran = raw_value.substr(1, raw_value.size() - 2); - StringUtils::split(raw_val_without_paran, filter_values, ","); - - // we will end up with: "10.45 34.56 2 km" or "10.45 34.56 2mi" or a geo polygon - - if(filter_values.size() < 3) { - return Option(400, format_err_msg); - } - - // do validation: format should match either a point + radius or polygon - - size_t num_floats = 0; - for(const auto& fvalue: filter_values) { - if(StringUtils::is_float(fvalue)) { - num_floats++; - } - } - - bool is_polygon = (num_floats == filter_values.size()); - if(!is_polygon) { - // we have to ensure that this is a point + radius match - if(!StringUtils::is_float(filter_values[0]) || !StringUtils::is_float(filter_values[1])) { - return Option(400, format_err_msg); - } - - if(filter_values[0] == "nan" || filter_values[0] == "NaN" || - filter_values[1] == "nan" || filter_values[1] == "NaN") { - return Option(400, format_err_msg); - } - } - - if(is_polygon) { - processed_filter_val = raw_val_without_paran; - } else { - // point + radius - // filter_values[2] is distance, get the unit, validate it and split on that - if(filter_values[2].size() < 2) { - return Option(400, "Unit must be either `km` or `mi`."); - } - - std::string unit = filter_values[2].substr(filter_values[2].size()-2, 2); - - if(unit != "km" && unit != "mi") { - return Option(400, "Unit must be either `km` or `mi`."); - } - - std::vector dist_values; - StringUtils::split(filter_values[2], dist_values, unit); - - if(dist_values.size() != 1) { - return Option(400, format_err_msg); - } - - if(!StringUtils::is_float(dist_values[0])) { - return Option(400, format_err_msg); - } - - processed_filter_val = filter_values[0] + ", " + filter_values[1] + ", " + // co-ords - dist_values[0] + ", " + unit; // X km - } - - return Option(true); -} - -bool isOperator(const std::string& expression) { - return expression == "&&" || expression == "||"; -} - -// https://en.wikipedia.org/wiki/Shunting_yard_algorithm -Option toPostfix(std::queue& tokens, std::queue& postfix) { - std::stack operatorStack; - - while (!tokens.empty()) { - auto expression = tokens.front(); - tokens.pop(); - - if (isOperator(expression)) { - // We only have two operators &&, || having the same precedence and both being left associative. - while (!operatorStack.empty() && operatorStack.top() != "(") { - postfix.push(operatorStack.top()); - operatorStack.pop(); - } - - operatorStack.push(expression); - } else if (expression == "(") { - operatorStack.push(expression); - } else if (expression == ")") { - while (!operatorStack.empty() && operatorStack.top() != "(") { - postfix.push(operatorStack.top()); - operatorStack.pop(); - } - - if (operatorStack.empty() || operatorStack.top() != "(") { - return Option(400, "Could not parse the filter query: unbalanced parentheses."); - } - operatorStack.pop(); - } else { - postfix.push(expression); - } - } - - while (!operatorStack.empty()) { - if (operatorStack.top() == "(") { - return Option(400, "Could not parse the filter query: unbalanced parentheses."); - } - postfix.push(operatorStack.top()); - operatorStack.pop(); - } - - return Option(true); -} - -Option toMultiValueNumericFilter(std::string& raw_value, filter& filter_exp, const field& _field) { - std::vector filter_values; - StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); - filter_exp = {_field.name, {}, {}}; - for (std::string& filter_value: filter_values) { - Option op_comparator = filter::extract_num_comparator(filter_value); - if (!op_comparator.ok()) { - return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); - } - if (op_comparator.get() == RANGE_INCLUSIVE) { - // split the value around range operator to extract bounds - std::vector range_values; - StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR()); - for (const std::string& range_value: range_values) { - auto validate_op = filter::validate_numerical_filter_value(_field, range_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp.values.push_back(range_value); - filter_exp.comparators.push_back(op_comparator.get()); - } - } else { - auto validate_op = filter::validate_numerical_filter_value(_field, filter_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp.values.push_back(filter_value); - filter_exp.comparators.push_back(op_comparator.get()); - } - } - - return Option(true); -} - -Option toFilter(const std::string expression, - filter& filter_exp, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix) { - // split into [field_name, value] - size_t found_index = expression.find(':'); - if (found_index == std::string::npos) { - return Option(400, "Could not parse the filter query."); - } - std::string&& field_name = expression.substr(0, found_index); - StringUtils::trim(field_name); - if (field_name == "id") { - std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); - StringUtils::trim(raw_value); - std::string empty_filter_err = "Error with filter field `id`: Filter value cannot be empty."; - if (raw_value.empty()) { - return Option(400, empty_filter_err); - } - filter_exp = {field_name, {}, {}}; - NUM_COMPARATOR id_comparator = EQUALS; - size_t filter_value_index = 0; - if (raw_value[0] == '=') { - id_comparator = EQUALS; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - return Option(400, "Not equals filtering is not supported on the `id` field."); - } - if (filter_value_index != 0) { - raw_value = raw_value.substr(filter_value_index); - } - if (raw_value.empty()) { - return Option(400, empty_filter_err); - } - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector doc_ids; - StringUtils::split_to_values(raw_value.substr(1, raw_value.size() - 2), doc_ids); - for (std::string& doc_id: doc_ids) { - // we have to convert the doc_id to seq id - std::string seq_id_str; - StoreStatus seq_id_status = store->get(doc_id_prefix + doc_id, seq_id_str); - if (seq_id_status != StoreStatus::FOUND) { - continue; - } - filter_exp.values.push_back(seq_id_str); - filter_exp.comparators.push_back(id_comparator); - } - } else { - std::vector doc_ids; - StringUtils::split_to_values(raw_value, doc_ids); // to handle backticks - std::string seq_id_str; - StoreStatus seq_id_status = store->get(doc_id_prefix + doc_ids[0], seq_id_str); - if (seq_id_status == StoreStatus::FOUND) { - filter_exp.values.push_back(seq_id_str); - filter_exp.comparators.push_back(id_comparator); - } - } - return Option(true); - } - - auto field_it = search_schema.find(field_name); - - if (field_it == search_schema.end()) { - return Option(404, "Could not find a filter field named `" + field_name + "` in the schema."); - } - - if (field_it->num_dim > 0) { - return Option(404, "Cannot filter on vector field `" + field_name + "`."); - } - - const field& _field = field_it.value(); - std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); - StringUtils::trim(raw_value); - // skip past optional `:=` operator, which has no meaning for non-string fields - if (!_field.is_string() && raw_value[0] == '=') { - size_t filter_value_index = 0; - while (raw_value[++filter_value_index] == ' '); - raw_value = raw_value.substr(filter_value_index); - } - if (_field.is_integer() || _field.is_float()) { - // could be a single value or a list - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); - if (!op.ok()) { - return op; - } - } else { - Option op_comparator = filter::extract_num_comparator(raw_value); - if (!op_comparator.ok()) { - return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); - } - if (op_comparator.get() == RANGE_INCLUSIVE) { - // split the value around range operator to extract bounds - std::vector range_values; - StringUtils::split(raw_value, range_values, filter::RANGE_OPERATOR()); - filter_exp.field_name = field_name; - for (const std::string& range_value: range_values) { - auto validate_op = filter::validate_numerical_filter_value(_field, range_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp.values.push_back(range_value); - filter_exp.comparators.push_back(op_comparator.get()); - } - } else if (op_comparator.get() == NOT_EQUALS && raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); - if (!op.ok()) { - return op; - } - filter_exp.apply_not_equals = true; - } else { - auto validate_op = filter::validate_numerical_filter_value(_field, raw_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp = {field_name, {raw_value}, {op_comparator.get()}}; - } - } - } else if (_field.is_bool()) { - NUM_COMPARATOR bool_comparator = EQUALS; - size_t filter_value_index = 0; - if (raw_value[0] == '=') { - bool_comparator = EQUALS; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - bool_comparator = NOT_EQUALS; - filter_value_index++; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } - if (filter_value_index != 0) { - raw_value = raw_value.substr(filter_value_index); - } - if (filter_value_index == raw_value.size()) { - return Option(400, "Error with filter field `" + _field.name + - "`: Filter value cannot be empty."); - } - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector filter_values; - StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); - filter_exp = {field_name, {}, {}}; - for (std::string& filter_value: filter_values) { - if (filter_value != "true" && filter_value != "false") { - return Option(400, "Values of filter field `" + _field.name + - "`: must be `true` or `false`."); - } - filter_value = (filter_value == "true") ? "1" : "0"; - filter_exp.values.push_back(filter_value); - filter_exp.comparators.push_back(bool_comparator); - } - } else { - if (raw_value != "true" && raw_value != "false") { - return Option(400, "Value of filter field `" + _field.name + "` must be `true` or `false`."); - } - std::string bool_value = (raw_value == "true") ? "1" : "0"; - filter_exp = {field_name, {bool_value}, {bool_comparator}}; - } - } else if (_field.is_geopoint()) { - filter_exp = {field_name, {}, {}}; - const std::string& format_err_msg = "Value of filter field `" + _field.name + - "`: must be in the `(-44.50, 170.29, 0.75 km)` or " - "(56.33, -65.97, 23.82, -127.82) format."; - NUM_COMPARATOR num_comparator; - // could be a single value or a list - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector filter_values; - StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, "),"); - for (std::string& filter_value: filter_values) { - filter_value += ")"; - std::string processed_filter_val; - auto parse_op = filter::parse_geopoint_filter_value(filter_value, format_err_msg, processed_filter_val, - num_comparator); - if (!parse_op.ok()) { - return parse_op; - } - filter_exp.values.push_back(processed_filter_val); - filter_exp.comparators.push_back(num_comparator); - } - } else { - // single value, e.g. (10.45, 34.56, 2 km) - std::string processed_filter_val; - auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, processed_filter_val, - num_comparator); - if (!parse_op.ok()) { - return parse_op; - } - filter_exp.values.push_back(processed_filter_val); - filter_exp.comparators.push_back(num_comparator); - } - } else if (_field.is_string()) { - size_t filter_value_index = 0; - NUM_COMPARATOR str_comparator = CONTAINS; - if (raw_value[0] == '=') { - // string filter should be evaluated in strict "equals" mode - str_comparator = EQUALS; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - str_comparator = NOT_EQUALS; - filter_value_index++; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } - if (filter_value_index == raw_value.size()) { - return Option(400, "Error with filter field `" + _field.name + - "`: Filter value cannot be empty."); - } - if (raw_value[filter_value_index] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector filter_values; - StringUtils::split_to_values( - raw_value.substr(filter_value_index + 1, raw_value.size() - filter_value_index - 2), filter_values); - filter_exp = {field_name, filter_values, {str_comparator}}; - } else { - filter_exp = {field_name, {raw_value.substr(filter_value_index)}, {str_comparator}}; - } - - filter_exp.apply_not_equals = (str_comparator == NOT_EQUALS); - } else { - return Option(400, "Error with filter field `" + _field.name + - "`: Unidentified field data type, see docs for supported data types."); - } - - return Option(true); -} - -// https://stackoverflow.com/a/423914/11218270 -Option toParseTree(std::queue& postfix, filter_node_t*& root, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix) { - std::stack nodeStack; - bool is_successful = true; - std::string error_message; - - filter_node_t *filter_node = nullptr; - - while (!postfix.empty()) { - const std::string expression = postfix.front(); - postfix.pop(); - - if (isOperator(expression)) { - if (nodeStack.empty()) { - is_successful = false; - error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; - break; - } - auto operandB = nodeStack.top(); - nodeStack.pop(); - - if (nodeStack.empty()) { - delete operandB; - is_successful = false; - error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; - break; - } - auto operandA = nodeStack.top(); - nodeStack.pop(); - - filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); - } else { - filter filter_exp; - - // Expected value: $Collection(...) - bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')'); - if (is_referenced_filter) { - size_t parenthesis_index = expression.find('('); - - std::string collection_name = expression.substr(1, parenthesis_index - 1); - auto &cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(collection_name); - if (collection == nullptr) { - is_successful = false; - error_message = "Referenced collection `" + collection_name + "` not found."; - break; - } - - filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)}; - filter_exp.referenced_collection_name = collection_name; - } else { - Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); - if (!toFilter_op.ok()) { - is_successful = false; - error_message = toFilter_op.error(); - break; - } - } - - filter_node = new filter_node_t(filter_exp); - } - - nodeStack.push(filter_node); - } - - if (!is_successful) { - while (!nodeStack.empty()) { - auto filterNode = nodeStack.top(); - delete filterNode; - nodeStack.pop(); - } - - return Option(400, error_message); - } - - if (nodeStack.empty()) { - return Option(400, "Filter query cannot be empty."); - } - root = nodeStack.top(); - - return Option(true); -} - -Option filter::parse_filter_query(const std::string& filter_query, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix, - filter_node_t*& root) { - auto _filter_query = filter_query; - StringUtils::trim(_filter_query); - if (_filter_query.empty()) { - return Option(true); - } - - std::queue tokens; - Option tokenize_op = StringUtils::tokenize_filter_query(filter_query, tokens); - if (!tokenize_op.ok()) { - return tokenize_op; - } - - if (tokens.size() > 100) { - return Option(400, "Filter expression is not valid."); - } - - std::queue postfix; - Option toPostfix_op = toPostfix(tokens, postfix); - if (!toPostfix_op.ok()) { - return toPostfix_op; - } - - Option toParseTree_op = toParseTree(postfix, - root, - search_schema, - store, - doc_id_prefix); - if (!toParseTree_op.ok()) { - return toParseTree_op; - } - - return Option(true); -} - Option field::json_field_to_field(bool enable_nested_fields, nlohmann::json& field_json, std::vector& the_fields, string& fallback_field_type, size_t& num_auto_detect_fields) { @@ -1030,167 +525,3 @@ void field::compact_nested_fields(tsl::htrie_map& nested_fields) { nested_fields.erase_prefix(field_name + "."); } } - -void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { - auto lenA = a.count, lenB = b.count; - if (lenA == 0 || lenB == 0) { - return; - } - - result.docs = new uint32_t[std::min(lenA, lenB)]; - - auto A = a.docs, B = b.docs, out = result.docs; - const uint32_t *endA = A + lenA; - const uint32_t *endB = B + lenB; - - // Add an entry of references in the result for each unique collection in a and b. - for (auto const& item: a.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; - } - } - for (auto const& item: b.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; - } - } - - while (true) { - while (*A < *B) { - SKIP_FIRST_COMPARE: - if (++A == endA) { - result.count = out - result.docs; - return; - } - } - while (*A > *B) { - if (++B == endB) { - result.count = out - result.docs; - return; - } - } - if (*A == *B) { - *out = *A; - - // Copy the references of the document from every collection into result. - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][out - result.docs] = item.second[A - a.docs]; - } - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][out - result.docs] = item.second[B - b.docs]; - } - - out++; - - if (++A == endA || ++B == endB) { - result.count = out - result.docs; - return; - } - } else { - goto SKIP_FIRST_COMPARE; - } - } -} - -void filter_result_t::or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { - if (a.count == 0 && b.count == 0) { - return; - } - - // If either one of a or b does not have any matches, copy other into result. - if (a.count == 0) { - result = b; - return; - } - if (b.count == 0) { - result = a; - return; - } - - size_t indexA = 0, indexB = 0, res_index = 0, lenA = a.count, lenB = b.count; - result.docs = new uint32_t[lenA + lenB]; - - // Add an entry of references in the result for each unique collection in a and b. - for (auto const& item: a.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; - } - } - for (auto const& item: b.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; - } - } - - while (indexA < lenA && indexB < lenB) { - if (a.docs[indexA] < b.docs[indexB]) { - // check for duplicate - if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { - result.docs[res_index] = a.docs[indexA]; - res_index++; - } - - // Copy references of the last result document from every collection in a. - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; - } - - indexA++; - } else { - if (res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { - result.docs[res_index] = b.docs[indexB]; - res_index++; - } - - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; - } - - indexB++; - } - } - - while (indexA < lenA) { - if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { - result.docs[res_index] = a.docs[indexA]; - res_index++; - } - - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; - } - - indexA++; - } - - while (indexB < lenB) { - if(res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { - result.docs[res_index] = b.docs[indexB]; - res_index++; - } - - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; - } - - indexB++; - } - - result.count = res_index; - - // shrink fit - auto out = new uint32_t[res_index]; - memcpy(out, result.docs, res_index * sizeof(uint32_t)); - delete[] result.docs; - result.docs = out; - - for (auto &item: result.reference_filter_results) { - auto out_references = new reference_filter_result_t[res_index]; - - for (uint32_t i = 0; i < result.count; i++) { - out_references[i] = item.second[i]; - } - delete[] item.second; - item.second = out_references; - } -} diff --git a/src/filter.cpp b/src/filter.cpp index 77ef5d5d..95fbfefc 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -1,641 +1,573 @@ #include #include #include +#include #include "filter.h" -void filter_result_iterator_t::and_filter_iterators() { - while (left_it->is_valid && right_it->is_valid) { - while (left_it->seq_id < right_it->seq_id) { - left_it->next(); - if (!left_it->is_valid) { - is_valid = false; - return; - } - } - - while (left_it->seq_id > right_it->seq_id) { - right_it->next(); - if (!right_it->is_valid) { - is_valid = false; - return; - } - } - - if (left_it->seq_id == right_it->seq_id) { - seq_id = left_it->seq_id; - reference.clear(); - - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; - } +Option filter::validate_numerical_filter_value(field _field, const string &raw_value) { + if(_field.is_int32() && !StringUtils::is_int32_t(raw_value)) { + return Option(400, "Error with filter field `" + _field.name + "`: Not an int32."); } - is_valid = false; + else if(_field.is_int64() && !StringUtils::is_int64_t(raw_value)) { + return Option(400, "Error with filter field `" + _field.name + "`: Not an int64."); + } + + else if(_field.is_float() && !StringUtils::is_float(raw_value)) { + return Option(400, "Error with filter field `" + _field.name + "`: Not a float."); + } + + return Option(true); } -void filter_result_iterator_t::or_filter_iterators() { - if (left_it->is_valid && right_it->is_valid) { - if (left_it->seq_id < right_it->seq_id) { - seq_id = left_it->seq_id; - reference.clear(); +Option filter::extract_num_comparator(string &comp_and_value) { + auto num_comparator = EQUALS; - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - - return; - } - - if (left_it->seq_id > right_it->seq_id) { - seq_id = right_it->seq_id; - reference.clear(); - - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; - } - - seq_id = left_it->seq_id; - reference.clear(); - - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; + if(StringUtils::is_integer(comp_and_value) || StringUtils::is_float(comp_and_value)) { + num_comparator = EQUALS; } - if (left_it->is_valid) { - seq_id = left_it->seq_id; - reference.clear(); - - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - - return; + // the ordering is important - we have to compare 2-letter operators first + else if(comp_and_value.compare(0, 2, "<=") == 0) { + num_comparator = LESS_THAN_EQUALS; } - if (right_it->is_valid) { - seq_id = right_it->seq_id; - reference.clear(); - - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; + else if(comp_and_value.compare(0, 2, ">=") == 0) { + num_comparator = GREATER_THAN_EQUALS; } - is_valid = false; + else if(comp_and_value.compare(0, 2, "!=") == 0) { + num_comparator = NOT_EQUALS; + } + + else if(comp_and_value.compare(0, 1, "<") == 0) { + num_comparator = LESS_THAN; + } + + else if(comp_and_value.compare(0, 1, ">") == 0) { + num_comparator = GREATER_THAN; + } + + else if(comp_and_value.find("..") != std::string::npos) { + num_comparator = RANGE_INCLUSIVE; + } + + else { + return Option(400, "Numerical field has an invalid comparator."); + } + + if(num_comparator == LESS_THAN || num_comparator == GREATER_THAN) { + comp_and_value = comp_and_value.substr(1); + } else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS || num_comparator == NOT_EQUALS) { + comp_and_value = comp_and_value.substr(2); + } + + comp_and_value = StringUtils::trim(comp_and_value); + + return Option(num_comparator); } -void filter_result_iterator_t::doc_matching_string_filter() { - // If none of the filter value iterators are valid, mark this node as invalid. - bool one_is_valid = false; +Option filter::parse_geopoint_filter_value(std::string& raw_value, + const std::string& format_err_msg, + std::string& processed_filter_val, + NUM_COMPARATOR& num_comparator) { - // Since we do OR between filter values, the lowest seq_id id from all is selected. - uint32_t lowest_id = UINT32_MAX; + num_comparator = LESS_THAN_EQUALS; - for (auto& filter_value_tokens : posting_list_iterators) { - // Perform AND between tokens of a filter value. - bool tokens_iter_is_valid; - posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + if(!(raw_value[0] == '(' && raw_value[raw_value.size() - 1] == ')')) { + return Option(400, format_err_msg); + } - one_is_valid = tokens_iter_is_valid || one_is_valid; + std::vector filter_values; + auto raw_val_without_paran = raw_value.substr(1, raw_value.size() - 2); + StringUtils::split(raw_val_without_paran, filter_values, ","); - if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { - lowest_id = filter_value_tokens[0].id(); + // we will end up with: "10.45 34.56 2 km" or "10.45 34.56 2mi" or a geo polygon + + if(filter_values.size() < 3) { + return Option(400, format_err_msg); + } + + // do validation: format should match either a point + radius or polygon + + size_t num_floats = 0; + for(const auto& fvalue: filter_values) { + if(StringUtils::is_float(fvalue)) { + num_floats++; } } - if (one_is_valid) { - seq_id = lowest_id; + bool is_polygon = (num_floats == filter_values.size()); + if(!is_polygon) { + // we have to ensure that this is a point + radius match + if(!StringUtils::is_float(filter_values[0]) || !StringUtils::is_float(filter_values[1])) { + return Option(400, format_err_msg); + } + + if(filter_values[0] == "nan" || filter_values[0] == "NaN" || + filter_values[1] == "nan" || filter_values[1] == "NaN") { + return Option(400, format_err_msg); + } } - is_valid = one_is_valid; + if(is_polygon) { + processed_filter_val = raw_val_without_paran; + } else { + // point + radius + // filter_values[2] is distance, get the unit, validate it and split on that + if(filter_values[2].size() < 2) { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::string unit = filter_values[2].substr(filter_values[2].size()-2, 2); + + if(unit != "km" && unit != "mi") { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::vector dist_values; + StringUtils::split(filter_values[2], dist_values, unit); + + if(dist_values.size() != 1) { + return Option(400, format_err_msg); + } + + if(!StringUtils::is_float(dist_values[0])) { + return Option(400, format_err_msg); + } + + processed_filter_val = filter_values[0] + ", " + filter_values[1] + ", " + // co-ords + dist_values[0] + ", " + unit; // X km + } + + return Option(true); } -void filter_result_iterator_t::next() { - if (!is_valid) { - return; - } +bool isOperator(const std::string& expression) { + return expression == "&&" || expression == "||"; +} - if (filter_node->isOperator) { - // Advance the subtrees and then apply operators to arrive at the next valid doc. - if (filter_node->filter_operator == AND) { - left_it->next(); - right_it->next(); - and_filter_iterators(); +// https://en.wikipedia.org/wiki/Shunting_yard_algorithm +Option toPostfix(std::queue& tokens, std::queue& postfix) { + std::stack operatorStack; + + while (!tokens.empty()) { + auto expression = tokens.front(); + tokens.pop(); + + if (isOperator(expression)) { + // We only have two operators &&, || having the same precedence and both being left associative. + while (!operatorStack.empty() && operatorStack.top() != "(") { + postfix.push(operatorStack.top()); + operatorStack.pop(); + } + + operatorStack.push(expression); + } else if (expression == "(") { + operatorStack.push(expression); + } else if (expression == ")") { + while (!operatorStack.empty() && operatorStack.top() != "(") { + postfix.push(operatorStack.top()); + operatorStack.pop(); + } + + if (operatorStack.empty() || operatorStack.top() != "(") { + return Option(400, "Could not parse the filter query: unbalanced parentheses."); + } + operatorStack.pop(); } else { - if (left_it->seq_id == seq_id && right_it->seq_id == seq_id) { - left_it->next(); - right_it->next(); - } else if (left_it->seq_id == seq_id) { - left_it->next(); - } else { - right_it->next(); - } - - or_filter_iterators(); + postfix.push(expression); } - - return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - if (++result_index >= filter_result.count) { - is_valid = false; - return; + while (!operatorStack.empty()) { + if (operatorStack.top() == "(") { + return Option(400, "Could not parse the filter query: unbalanced parentheses."); } - - seq_id = filter_result.docs[result_index]; - reference.clear(); - for (auto const& item: filter_result.reference_filter_results) { - reference[item.first] = item.second[result_index]; - } - - return; + postfix.push(operatorStack.top()); + operatorStack.pop(); } - if (a_filter.field_name == "id") { - if (++result_index >= filter_result.count) { - is_valid = false; - return; + return Option(true); +} + +Option toMultiValueNumericFilter(std::string& raw_value, filter& filter_exp, const field& _field) { + std::vector filter_values; + StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); + filter_exp = {_field.name, {}, {}}; + for (std::string& filter_value: filter_values) { + Option op_comparator = filter::extract_num_comparator(filter_value); + if (!op_comparator.ok()) { + return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); } - - seq_id = filter_result.docs[result_index]; - return; - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - // Advance all the filter values that are at doc. Then find the next doc. - for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { - auto& filter_value_tokens = posting_list_iterators[i]; - - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { - for (auto& iter: filter_value_tokens) { - iter.next(); + if (op_comparator.get() == RANGE_INCLUSIVE) { + // split the value around range operator to extract bounds + std::vector range_values; + StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR()); + for (const std::string& range_value: range_values) { + auto validate_op = filter::validate_numerical_filter_value(_field, range_value); + if (!validate_op.ok()) { + return validate_op; } + filter_exp.values.push_back(range_value); + filter_exp.comparators.push_back(op_comparator.get()); } + } else { + auto validate_op = filter::validate_numerical_filter_value(_field, filter_value); + if (!validate_op.ok()) { + return validate_op; + } + filter_exp.values.push_back(filter_value); + filter_exp.comparators.push_back(op_comparator.get()); } - - doc_matching_string_filter(); - return; } + + return Option(true); } -void filter_result_iterator_t::init() { - if (filter_node == nullptr) { - return; +Option toFilter(const std::string expression, + filter& filter_exp, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix) { + // split into [field_name, value] + size_t found_index = expression.find(':'); + if (found_index == std::string::npos) { + return Option(400, "Could not parse the filter query."); } - - if (filter_node->isOperator) { - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); + std::string&& field_name = expression.substr(0, found_index); + StringUtils::trim(field_name); + if (field_name == "id") { + std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); + StringUtils::trim(raw_value); + std::string empty_filter_err = "Error with filter field `id`: Filter value cannot be empty."; + if (raw_value.empty()) { + return Option(400, empty_filter_err); } - - return; - } - - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. - auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(a_filter.referenced_collection_name); - if (collection == nullptr) { - status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); - is_valid = false; - return; + filter_exp = {field_name, {}, {}}; + NUM_COMPARATOR id_comparator = EQUALS; + size_t filter_value_index = 0; + if (raw_value[0] == '=') { + id_comparator = EQUALS; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + return Option(400, "Not equals filtering is not supported on the `id` field."); } - - auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, - filter_result, - collection_name); - if (!reference_filter_op.ok()) { - status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name - + "` collection: " + reference_filter_op.error()); - is_valid = false; - return; + if (filter_value_index != 0) { + raw_value = raw_value.substr(filter_value_index); } - - is_valid = filter_result.count > 0; - return; - } - - if (a_filter.field_name == "id") { - if (a_filter.values.empty()) { - is_valid = false; - return; + if (raw_value.empty()) { + return Option(400, empty_filter_err); } - - // we handle `ids` separately - std::vector result_ids; - for (const auto& id_str : a_filter.values) { - result_ids.push_back(std::stoul(id_str)); - } - - std::sort(result_ids.begin(), result_ids.end()); - - filter_result.count = result_ids.size(); - filter_result.docs = new uint32_t[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - art_tree* t = index->search_index.at(a_filter.field_name); - - for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; - - // there could be multiple tokens in a filter value, which we have to treat as ANDs - // e.g. country: South Africa - Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); - - std::string str_token; - size_t token_index = 0; - std::vector str_tokens; - - while (tokenizer.next(str_token, token_index)) { - str_tokens.push_back(str_token); - - art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), - str_token.length()+1); - if (leaf == nullptr) { + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector doc_ids; + StringUtils::split_to_values(raw_value.substr(1, raw_value.size() - 2), doc_ids); + for (std::string& doc_id: doc_ids) { + // we have to convert the doc_id to seq id + std::string seq_id_str; + StoreStatus seq_id_status = store->get(doc_id_prefix + doc_id, seq_id_str); + if (seq_id_status != StoreStatus::FOUND) { continue; } - - posting_lists.push_back(leaf->values); + filter_exp.values.push_back(seq_id_str); + filter_exp.comparators.push_back(id_comparator); } - - if (posting_lists.size() != str_tokens.size()) { - continue; - } - - std::vector plists; - posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); - - posting_list_iterators.emplace_back(std::vector()); - - for (auto const& plist: plists) { - posting_list_iterators.back().push_back(plist->new_iterator()); + } else { + std::vector doc_ids; + StringUtils::split_to_values(raw_value, doc_ids); // to handle backticks + std::string seq_id_str; + StoreStatus seq_id_status = store->get(doc_id_prefix + doc_ids[0], seq_id_str); + if (seq_id_status == StoreStatus::FOUND) { + filter_exp.values.push_back(seq_id_str); + filter_exp.comparators.push_back(id_comparator); } } - - doc_matching_string_filter(); - return; + return Option(true); } + + auto field_it = search_schema.find(field_name); + + if (field_it == search_schema.end()) { + return Option(404, "Could not find a filter field named `" + field_name + "` in the schema."); + } + + if (field_it->num_dim > 0) { + return Option(404, "Cannot filter on vector field `" + field_name + "`."); + } + + const field& _field = field_it.value(); + std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); + StringUtils::trim(raw_value); + // skip past optional `:=` operator, which has no meaning for non-string fields + if (!_field.is_string() && raw_value[0] == '=') { + size_t filter_value_index = 0; + while (raw_value[++filter_value_index] == ' '); + raw_value = raw_value.substr(filter_value_index); + } + if (_field.is_integer() || _field.is_float()) { + // could be a single value or a list + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); + if (!op.ok()) { + return op; + } + } else { + Option op_comparator = filter::extract_num_comparator(raw_value); + if (!op_comparator.ok()) { + return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); + } + if (op_comparator.get() == RANGE_INCLUSIVE) { + // split the value around range operator to extract bounds + std::vector range_values; + StringUtils::split(raw_value, range_values, filter::RANGE_OPERATOR()); + filter_exp.field_name = field_name; + for (const std::string& range_value: range_values) { + auto validate_op = filter::validate_numerical_filter_value(_field, range_value); + if (!validate_op.ok()) { + return validate_op; + } + filter_exp.values.push_back(range_value); + filter_exp.comparators.push_back(op_comparator.get()); + } + } else if (op_comparator.get() == NOT_EQUALS && raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); + if (!op.ok()) { + return op; + } + filter_exp.apply_not_equals = true; + } else { + auto validate_op = filter::validate_numerical_filter_value(_field, raw_value); + if (!validate_op.ok()) { + return validate_op; + } + filter_exp = {field_name, {raw_value}, {op_comparator.get()}}; + } + } + } else if (_field.is_bool()) { + NUM_COMPARATOR bool_comparator = EQUALS; + size_t filter_value_index = 0; + if (raw_value[0] == '=') { + bool_comparator = EQUALS; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + bool_comparator = NOT_EQUALS; + filter_value_index++; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } + if (filter_value_index != 0) { + raw_value = raw_value.substr(filter_value_index); + } + if (filter_value_index == raw_value.size()) { + return Option(400, "Error with filter field `" + _field.name + + "`: Filter value cannot be empty."); + } + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector filter_values; + StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); + filter_exp = {field_name, {}, {}}; + for (std::string& filter_value: filter_values) { + if (filter_value != "true" && filter_value != "false") { + return Option(400, "Values of filter field `" + _field.name + + "`: must be `true` or `false`."); + } + filter_value = (filter_value == "true") ? "1" : "0"; + filter_exp.values.push_back(filter_value); + filter_exp.comparators.push_back(bool_comparator); + } + } else { + if (raw_value != "true" && raw_value != "false") { + return Option(400, "Value of filter field `" + _field.name + "` must be `true` or `false`."); + } + std::string bool_value = (raw_value == "true") ? "1" : "0"; + filter_exp = {field_name, {bool_value}, {bool_comparator}}; + } + } else if (_field.is_geopoint()) { + filter_exp = {field_name, {}, {}}; + const std::string& format_err_msg = "Value of filter field `" + _field.name + + "`: must be in the `(-44.50, 170.29, 0.75 km)` or " + "(56.33, -65.97, 23.82, -127.82) format."; + NUM_COMPARATOR num_comparator; + // could be a single value or a list + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector filter_values; + StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, "),"); + for (std::string& filter_value: filter_values) { + filter_value += ")"; + std::string processed_filter_val; + auto parse_op = filter::parse_geopoint_filter_value(filter_value, format_err_msg, processed_filter_val, + num_comparator); + if (!parse_op.ok()) { + return parse_op; + } + filter_exp.values.push_back(processed_filter_val); + filter_exp.comparators.push_back(num_comparator); + } + } else { + // single value, e.g. (10.45, 34.56, 2 km) + std::string processed_filter_val; + auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, processed_filter_val, + num_comparator); + if (!parse_op.ok()) { + return parse_op; + } + filter_exp.values.push_back(processed_filter_val); + filter_exp.comparators.push_back(num_comparator); + } + } else if (_field.is_string()) { + size_t filter_value_index = 0; + NUM_COMPARATOR str_comparator = CONTAINS; + if (raw_value[0] == '=') { + // string filter should be evaluated in strict "equals" mode + str_comparator = EQUALS; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + str_comparator = NOT_EQUALS; + filter_value_index++; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } + if (filter_value_index == raw_value.size()) { + return Option(400, "Error with filter field `" + _field.name + + "`: Filter value cannot be empty."); + } + if (raw_value[filter_value_index] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector filter_values; + StringUtils::split_to_values( + raw_value.substr(filter_value_index + 1, raw_value.size() - filter_value_index - 2), filter_values); + filter_exp = {field_name, filter_values, {str_comparator}}; + } else { + filter_exp = {field_name, {raw_value.substr(filter_value_index)}, {str_comparator}}; + } + + filter_exp.apply_not_equals = (str_comparator == NOT_EQUALS); + } else { + return Option(400, "Error with filter field `" + _field.name + + "`: Unidentified field data type, see docs for supported data types."); + } + + return Option(true); } -bool filter_result_iterator_t::valid() { - if (!is_valid) { - return false; - } +// https://stackoverflow.com/a/423914/11218270 +Option toParseTree(std::queue& postfix, filter_node_t*& root, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix) { + std::stack nodeStack; + bool is_successful = true; + std::string error_message; - if (filter_node->isOperator) { - if (filter_node->filter_operator == AND) { - is_valid = left_it->valid() && right_it->valid(); - return is_valid; - } else { - is_valid = left_it->valid() || right_it->valid(); - return is_valid; - } - } + filter_node_t *filter_node = nullptr; - const filter a_filter = filter_node->filter_exp; + while (!postfix.empty()) { + const std::string expression = postfix.front(); + postfix.pop(); - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { - is_valid = result_index < filter_result.count; - return is_valid; - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return is_valid; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - bool one_is_valid = false; - for (auto& filter_value_tokens: posting_list_iterators) { - posting_list_t::intersect(filter_value_tokens, one_is_valid); - - if (one_is_valid) { + if (isOperator(expression)) { + if (nodeStack.empty()) { + is_successful = false; + error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; break; } - } + auto operandB = nodeStack.top(); + nodeStack.pop(); - is_valid = one_is_valid; - return is_valid; - } + if (nodeStack.empty()) { + delete operandB; + is_successful = false; + error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; + break; + } + auto operandA = nodeStack.top(); + nodeStack.pop(); - return false; -} - -void filter_result_iterator_t::skip_to(uint32_t id) { - if (!is_valid) { - return; - } - - if (filter_node->isOperator) { - // Skip the subtrees to id and then apply operators to arrive at the next valid doc. - left_it->skip_to(id); - right_it->skip_to(id); - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); + filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); } else { - or_filter_iterators(); - } + filter filter_exp; - return; - } + // Expected value: $Collection(...) + bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')'); + if (is_referenced_filter) { + size_t parenthesis_index = expression.find('('); - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); - - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - reference.clear(); - for (auto const& item: filter_result.reference_filter_results) { - reference[item.first] = item.second[result_index]; - } - - return; - } - - if (a_filter.field_name == "id") { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); - - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - // Skip all the token iterators and find a new match. - for (auto& filter_value_tokens : posting_list_iterators) { - for (auto& token: filter_value_tokens) { - // We perform AND on tokens. Short-circuiting here. - if (!token.valid()) { + std::string collection_name = expression.substr(1, parenthesis_index - 1); + auto &cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(collection_name); + if (collection == nullptr) { + is_successful = false; + error_message = "Referenced collection `" + collection_name + "` not found."; break; } - token.skip_to(id); - } - } - - doc_matching_string_filter(); - return; - } -} - -int filter_result_iterator_t::valid(uint32_t id) { - if (!is_valid) { - return -1; - } - - if (filter_node->isOperator) { - auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); - - if (filter_node->filter_operator == AND) { - is_valid = left_it->is_valid && right_it->is_valid; - - if (left_valid < 1 || right_valid < 1) { - if (left_valid == -1 || right_valid == -1) { - return -1; - } - - return 0; - } - - return 1; - } else { - is_valid = left_it->is_valid || right_it->is_valid; - - if (left_valid < 1 && right_valid < 1) { - if (left_valid == -1 && right_valid == -1) { - return -1; - } - - return 0; - } - - return 1; - } - } - - if (filter_node->filter_exp.apply_not_equals) { - // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. - if (!valid()) { - is_valid = true; - return 1; - } - - skip_to(id); - - if (!is_valid) { - is_valid = true; - return 1; - } - - return seq_id != id ? 1 : 0; - } - - skip_to(id); - return is_valid ? (seq_id == id ? 1 : 0) : -1; -} - -Option filter_result_iterator_t::init_status() { - if (filter_node != nullptr && filter_node->isOperator) { - auto left_status = left_it->init_status(); - - return !left_status.ok() ? left_status : right_it->init_status(); - } - - return status; -} - -bool filter_result_iterator_t::contains_atleast_one(const void *obj) { - if(IS_COMPACT_POSTING(obj)) { - compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); - - size_t i = 0; - while(i < list->length && valid()) { - size_t num_existing_offsets = list->id_offsets[i]; - size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; - - if (existing_id == seq_id) { - return true; - } - - // advance smallest value - if (existing_id < seq_id) { - i += num_existing_offsets + 2; + filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)}; + filter_exp.referenced_collection_name = collection_name; } else { - skip_to(existing_id); - } - } - } else { - auto list = (posting_list_t*)(obj); - posting_list_t::iterator_t it = list->new_iterator(); - - while(it.valid() && valid()) { - uint32_t id = it.id(); - - if(id == seq_id) { - return true; + Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); + if (!toFilter_op.ok()) { + is_successful = false; + error_message = toFilter_op.error(); + break; + } } - if(id < seq_id) { - it.skip_to(seq_id); - } else { - skip_to(id); - } + filter_node = new filter_node_t(filter_exp); } + + nodeStack.push(filter_node); } - return false; + if (!is_successful) { + while (!nodeStack.empty()) { + auto filterNode = nodeStack.top(); + delete filterNode; + nodeStack.pop(); + } + + return Option(400, error_message); + } + + if (nodeStack.empty()) { + return Option(400, "Filter query cannot be empty."); + } + root = nodeStack.top(); + + return Option(true); } -void filter_result_iterator_t::reset() { - if (filter_node == nullptr) { - return; +Option filter::parse_filter_query(const std::string& filter_query, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix, + filter_node_t*& root) { + auto _filter_query = filter_query; + StringUtils::trim(_filter_query); + if (_filter_query.empty()) { + return Option(true); } - if (filter_node->isOperator) { - // Reset the subtrees then apply operators to arrive at the first valid doc. - left_it->reset(); - right_it->reset(); - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); - } - - return; + std::queue tokens; + Option tokenize_op = StringUtils::tokenize_filter_query(filter_query, tokens); + if (!tokenize_op.ok()) { + return tokenize_op; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter || a_filter.field_name == "id") { - result_index = 0; - is_valid = filter_result.count > 0; - return; + if (tokens.size() > 100) { + return Option(400, "Filter expression is not valid."); } - if (!index->field_is_indexed(a_filter.field_name)) { - return; + std::queue postfix; + Option toPostfix_op = toPostfix(tokens, postfix); + if (!toPostfix_op.ok()) { + return toPostfix_op; } - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - posting_list_iterators.clear(); - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - expanded_plists.clear(); - - init(); - return; + Option toParseTree_op = toParseTree(postfix, + root, + search_schema, + store, + doc_id_prefix); + if (!toParseTree_op.ok()) { + return toParseTree_op; } -} - -uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { - if (!valid()) { - return 0; - } - - std::vector filter_ids; - do { - filter_ids.push_back(seq_id); - next(); - } while (valid()); - - filter_array = new uint32_t[filter_ids.size()]; - std::copy(filter_ids.begin(), filter_ids.end(), filter_array); - - return filter_ids.size(); -} - -uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { - if (!valid()) { - return 0; - } - - std::vector filter_ids; - for (uint32_t i = 0; i < lenA; i++) { - auto result = valid(A[i]); - - if (result == -1) { - break; - } - - if (result == 1) { - filter_ids.push_back(A[i]); - } - } - - if (filter_ids.empty()) { - return 0; - } - - results = new uint32_t[filter_ids.size()]; - std::copy(filter_ids.begin(), filter_ids.end(), results); - - return filter_ids.size(); + + return Option(true); } diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp new file mode 100644 index 00000000..21a4128f --- /dev/null +++ b/src/filter_result_iterator.cpp @@ -0,0 +1,901 @@ +#include "filter_result_iterator.h" +#include "index.h" +#include "posting.h" +#include "collection_manager.h" + +void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { + auto lenA = a.count, lenB = b.count; + if (lenA == 0 || lenB == 0) { + return; + } + + result.docs = new uint32_t[std::min(lenA, lenB)]; + + auto A = a.docs, B = b.docs, out = result.docs; + const uint32_t *endA = A + lenA; + const uint32_t *endB = B + lenB; + + // Add an entry of references in the result for each unique collection in a and b. + for (auto const& item: a.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; + } + } + for (auto const& item: b.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; + } + } + + while (true) { + while (*A < *B) { + SKIP_FIRST_COMPARE: + if (++A == endA) { + result.count = out - result.docs; + return; + } + } + while (*A > *B) { + if (++B == endB) { + result.count = out - result.docs; + return; + } + } + if (*A == *B) { + *out = *A; + + // Copy the references of the document from every collection into result. + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][out - result.docs] = item.second[A - a.docs]; + } + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][out - result.docs] = item.second[B - b.docs]; + } + + out++; + + if (++A == endA || ++B == endB) { + result.count = out - result.docs; + return; + } + } else { + goto SKIP_FIRST_COMPARE; + } + } +} + +void filter_result_t::or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { + if (a.count == 0 && b.count == 0) { + return; + } + + // If either one of a or b does not have any matches, copy other into result. + if (a.count == 0) { + result = b; + return; + } + if (b.count == 0) { + result = a; + return; + } + + size_t indexA = 0, indexB = 0, res_index = 0, lenA = a.count, lenB = b.count; + result.docs = new uint32_t[lenA + lenB]; + + // Add an entry of references in the result for each unique collection in a and b. + for (auto const& item: a.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; + } + } + for (auto const& item: b.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; + } + } + + while (indexA < lenA && indexB < lenB) { + if (a.docs[indexA] < b.docs[indexB]) { + // check for duplicate + if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { + result.docs[res_index] = a.docs[indexA]; + res_index++; + } + + // Copy references of the last result document from every collection in a. + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; + } + + indexA++; + } else { + if (res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { + result.docs[res_index] = b.docs[indexB]; + res_index++; + } + + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; + } + + indexB++; + } + } + + while (indexA < lenA) { + if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { + result.docs[res_index] = a.docs[indexA]; + res_index++; + } + + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; + } + + indexA++; + } + + while (indexB < lenB) { + if(res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { + result.docs[res_index] = b.docs[indexB]; + res_index++; + } + + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; + } + + indexB++; + } + + result.count = res_index; + + // shrink fit + auto out = new uint32_t[res_index]; + memcpy(out, result.docs, res_index * sizeof(uint32_t)); + delete[] result.docs; + result.docs = out; + + for (auto &item: result.reference_filter_results) { + auto out_references = new reference_filter_result_t[res_index]; + + for (uint32_t i = 0; i < result.count; i++) { + out_references[i] = item.second[i]; + } + delete[] item.second; + item.second = out_references; + } +} + +void filter_result_iterator_t::and_filter_iterators() { + while (left_it->is_valid && right_it->is_valid) { + while (left_it->seq_id < right_it->seq_id) { + left_it->next(); + if (!left_it->is_valid) { + is_valid = false; + return; + } + } + + while (left_it->seq_id > right_it->seq_id) { + right_it->next(); + if (!right_it->is_valid) { + is_valid = false; + return; + } + } + + if (left_it->seq_id == right_it->seq_id) { + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + } + + is_valid = false; +} + +void filter_result_iterator_t::or_filter_iterators() { + if (left_it->is_valid && right_it->is_valid) { + if (left_it->seq_id < right_it->seq_id) { + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->seq_id > right_it->seq_id) { + seq_id = right_it->seq_id; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->is_valid) { + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (right_it->is_valid) { + seq_id = right_it->seq_id; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + is_valid = false; +} + +void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { + // If none of the filter value iterators are valid, mark this node as invalid. + bool one_is_valid = false; + + // Since we do OR between filter values, the lowest seq_id id from all is selected. + uint32_t lowest_id = UINT32_MAX; + + if (filter_node->filter_exp.comparators[0] == EQUALS || filter_node->filter_exp.comparators[0] == NOT_EQUALS) { + for (auto& filter_value_tokens : posting_list_iterators) { + bool tokens_iter_is_valid, exact_match = false; + while(true) { + // Perform AND between tokens of a filter value. + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + if (!tokens_iter_is_valid) { + break; + } + + if (posting_list_t::has_exact_match(filter_value_tokens, field_is_array)) { + exact_match = true; + break; + } else { + // Keep advancing token iterators till exact match is not found. + for (auto &item: filter_value_tokens) { + item.next(); + } + } + } + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && exact_match && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + } else { + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + } + + if (one_is_valid) { + seq_id = lowest_id; + } + + is_valid = one_is_valid; +} + +void filter_result_iterator_t::next() { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Advance the subtrees and then apply operators to arrive at the next valid doc. + if (filter_node->filter_operator == AND) { + left_it->next(); + right_it->next(); + and_filter_iterators(); + } else { + if (left_it->seq_id == seq_id && right_it->seq_id == seq_id) { + left_it->next(); + right_it->next(); + } else if (left_it->seq_id == seq_id) { + left_it->next(); + } else { + right_it->next(); + } + + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Advance all the filter values that are at doc. Then find the next doc. + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { + for (auto& iter: filter_value_tokens) { + iter.next(); + } + } + } + + doc_matching_string_filter(f.is_array()); + return; + } +} + +void filter_result_iterator_t::init() { + if (filter_node == nullptr) { + return; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + if (collection == nullptr) { + status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + is_valid = false; + return; + } + + auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, + filter_result, + collection_name); + if (!reference_filter_op.ok()) { + status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name + + "` collection: " + reference_filter_op.error()); + is_valid = false; + return; + } + + is_valid = filter_result.count > 0; + return; + } + + if (a_filter.field_name == "id") { + if (a_filter.values.empty()) { + is_valid = false; + return; + } + + // we handle `ids` separately + std::vector result_ids; + for (const auto& id_str : a_filter.values) { + result_ids.push_back(std::stoul(id_str)); + } + + std::sort(result_ids.begin(), result_ids.end()); + + filter_result.count = result_ids.size(); + filter_result.docs = new uint32_t[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + art_tree* t = index->search_index.at(a_filter.field_name); + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + std::vector plists; + posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + + posting_list_iterators.emplace_back(std::vector()); + + for (auto const& plist: plists) { + posting_list_iterators.back().push_back(plist->new_iterator()); + } + } + + doc_matching_string_filter(f.is_array()); + return; + } +} + +bool filter_result_iterator_t::valid() { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + is_valid = left_it->valid() && right_it->valid(); + return is_valid; + } else { + is_valid = left_it->valid() || right_it->valid(); + return is_valid; + } + } + + const filter a_filter = filter_node->filter_exp; + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + is_valid = result_index < filter_result.count; + return is_valid; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return is_valid; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + bool one_is_valid = false; + for (auto& filter_value_tokens: posting_list_iterators) { + posting_list_t::intersect(filter_value_tokens, one_is_valid); + + if (one_is_valid) { + break; + } + } + + is_valid = one_is_valid; + return is_valid; + } + + return false; +} + +void filter_result_iterator_t::skip_to(uint32_t id) { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + left_it->skip_to(id); + right_it->skip_to(id); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Skip all the token iterators and find a new match. + for (auto& filter_value_tokens : posting_list_iterators) { + for (auto& token: filter_value_tokens) { + // We perform AND on tokens. Short-circuiting here. + if (!token.valid()) { + break; + } + + token.skip_to(id); + } + } + + doc_matching_string_filter(f.is_array()); + return; + } +} + +int filter_result_iterator_t::valid(uint32_t id) { + if (!is_valid) { + return -1; + } + + if (filter_node->isOperator) { + auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); + + if (filter_node->filter_operator == AND) { + is_valid = left_it->is_valid && right_it->is_valid; + + if (left_valid < 1 || right_valid < 1) { + if (left_valid == -1 || right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; + } else { + is_valid = left_it->is_valid || right_it->is_valid; + + if (left_valid < 1 && right_valid < 1) { + if (left_valid == -1 && right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; + } + } + + if (filter_node->filter_exp.apply_not_equals) { + // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. + if (!valid()) { + is_valid = true; + return 1; + } + + skip_to(id); + + if (!is_valid) { + is_valid = true; + return 1; + } + + return seq_id != id ? 1 : 0; + } + + skip_to(id); + return is_valid ? (seq_id == id ? 1 : 0) : -1; +} + +Option filter_result_iterator_t::init_status() { + if (filter_node != nullptr && filter_node->isOperator) { + auto left_status = left_it->init_status(); + + return !left_status.ok() ? left_status : right_it->init_status(); + } + + return status; +} + +bool filter_result_iterator_t::contains_atleast_one(const void *obj) { + if(IS_COMPACT_POSTING(obj)) { + compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); + + size_t i = 0; + while(i < list->length && valid()) { + size_t num_existing_offsets = list->id_offsets[i]; + size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; + + if (existing_id == seq_id) { + return true; + } + + // advance smallest value + if (existing_id < seq_id) { + i += num_existing_offsets + 2; + } else { + skip_to(existing_id); + } + } + } else { + auto list = (posting_list_t*)(obj); + posting_list_t::iterator_t it = list->new_iterator(); + + while(it.valid() && valid()) { + uint32_t id = it.id(); + + if(id == seq_id) { + return true; + } + + if(id < seq_id) { + it.skip_to(seq_id); + } else { + skip_to(id); + } + } + } + + return false; +} + +void filter_result_iterator_t::reset() { + if (filter_node == nullptr) { + return; + } + + if (filter_node->isOperator) { + // Reset the subtrees then apply operators to arrive at the first valid doc. + left_it->reset(); + right_it->reset(); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter || a_filter.field_name == "id") { + result_index = 0; + is_valid = filter_result.count > 0; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + posting_list_iterators.clear(); + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + expanded_plists.clear(); + + init(); + return; + } +} + +uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + do { + filter_ids.push_back(seq_id); + next(); + } while (valid()); + + filter_array = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), filter_array); + + return filter_ids.size(); +} + +uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + for (uint32_t i = 0; i < lenA; i++) { + auto result = valid(A[i]); + + if (result == -1) { + break; + } + + if (result == 1) { + filter_ids.push_back(A[i]); + } + } + + if (filter_ids.empty()) { + return 0; + } + + results = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), results); + + return filter_ids.size(); +} + +filter_result_iterator_t::filter_result_iterator_t(const std::string collection_name, const Index *const index, + const filter_node_t *const filter_node) : + collection_name(collection_name), + index(index), + filter_node(filter_node) { + if (filter_node == nullptr) { + is_valid = false; + return; + } + + // Generate the iterator tree and then initialize each node. + if (filter_node->isOperator) { + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right); + } + + init(); +} + +filter_result_iterator_t::~filter_result_iterator_t() { + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; +} + +filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iterator_t &&obj) noexcept { + if (&obj == this) + return *this; + + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + + collection_name = obj.collection_name; + index = obj.index; + filter_node = obj.filter_node; + left_it = obj.left_it; + right_it = obj.right_it; + + obj.left_it = nullptr; + obj.right_it = nullptr; + + result_index = obj.result_index; + + filter_result = std::move(obj.filter_result); + + posting_list_iterators = std::move(obj.posting_list_iterators); + expanded_plists = std::move(obj.expanded_plists); + + is_valid = obj.is_valid; + + seq_id = obj.seq_id; + reference = std::move(obj.reference); + status = std::move(obj.status); + + return *this; +} diff --git a/src/index.cpp b/src/index.cpp index b337bb4b..fb01ee09 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -5,10 +5,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include @@ -24,6 +24,7 @@ #include #include "logger.h" #include +#include "validator.h" #define RETURN_CIRCUIT_BREAKER if((std::chrono::duration_cast( \ std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { \ @@ -1262,7 +1263,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { void Index::search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -1331,7 +1332,8 @@ void Index::search_all_candidates(const size_t num_search_fields, sort_fields, topster,groups_processed, searched_queries, qtoken_set, dropped_tokens, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, - filter_ids, filter_ids_length, total_cost, syn_orig_num_tokens, + filter_result_iterator, + total_cost, syn_orig_num_tokens, exclude_token_ids, exclude_token_ids_size, excluded_group_ids, sort_order, field_values, geopoint_indices, id_buff, all_result_ids, all_result_ids_len); @@ -2705,19 +2707,19 @@ Option Index::search(std::vector& field_query_tokens, cons const std::string& collection_name) const { std::shared_lock lock(mutex); - uint32_t filter_ids_length = 0; - auto rearrange_op = rearrange_filter_tree(filter_tree_root, filter_ids_length, collection_name); + uint32_t approx_filter_ids_length = 0; + auto rearrange_op = rearrange_filter_tree(filter_tree_root, approx_filter_ids_length, collection_name); if (!rearrange_op.ok()) { return rearrange_op; } - filter_result_t filter_result; - auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } - if (filter_tree_root != nullptr && filter_result.count == 0) { + if (filter_tree_root != nullptr && !filter_result_iterator.valid()) { return Option(true); } @@ -2729,8 +2731,9 @@ Option Index::search(std::vector& field_query_tokens, cons std::unordered_set excluded_group_ids; process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits, - filter_result.docs, filter_result.count, curated_ids, included_ids_map, + filter_result_iterator, curated_ids, included_ids_map, included_ids_vec, excluded_group_ids); + filter_result_iterator.reset(); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2761,7 +2764,10 @@ Option Index::search(std::vector& field_query_tokens, cons field_query_tokens[0].q_include_tokens[0].value == "*"; + // TODO: Do AND with phrase ids at last // handle phrase searches + uint32_t* phrase_result_ids = nullptr; + uint32_t phrase_result_count = 0; if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, @@ -2769,8 +2775,8 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, groups_processed, curated_ids, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, included_ids_map, is_wildcard_query, - filter_result.docs, filter_result.count); - if (filter_result.count == 0) { + phrase_result_ids, phrase_result_count); + if (phrase_result_count == 0) { goto process_search_results; } } @@ -2778,7 +2784,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && filter_result.count == 0); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.valid()); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2820,15 +2826,18 @@ Option Index::search(std::vector& field_query_tokens, cons goto process_search_results; } - // if filters were not provided, use the seq_ids index to generate the - // list of all document ids + // if filters were not provided, use the seq_ids index to generate the list of all document ids if (no_filters_provided) { - filter_result.count = seq_ids->num_ids(); - filter_result.docs = seq_ids->uncompress(); + const std::string doc_id_prefix = std::to_string(collection_id) + "_" + Collection::DOC_ID_PREFIX + "_"; + Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, + store, doc_id_prefix, filter_tree_root); + + filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); } - curate_filtered_ids(curated_ids, excluded_result_ids, - excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); +// TODO: Curate ids at last +// curate_filtered_ids(curated_ids, excluded_result_ids, +// excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { @@ -2838,37 +2847,46 @@ Option Index::search(std::vector& field_query_tokens, cons k++; } - VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; - if(!no_filters_provided && filter_result.count < vector_query.flat_search_cutoff) { - for(size_t i = 0; i < filter_result.count; i++) { - auto seq_id = filter_result.docs[i]; - std::vector values; + uint32_t filter_id_count = 0; + while (!no_filters_provided && + filter_id_count < vector_query.flat_search_cutoff && + filter_result_iterator.valid()) { + auto seq_id = filter_result_iterator.seq_id; + std::vector values; - try { - values = field_vector_index->vecdex->getDataByLabel(seq_id); - } catch(...) { - // likely not found - continue; - } - - float dist; - if(field_vector_index->distance_type == cosine) { - std::vector normalized_q(vector_query.values.size()); - hnsw_index_t::normalize_vector(vector_query.values, normalized_q); - dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(), - &field_vector_index->num_dim); - } else { - dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(), - &field_vector_index->num_dim); - } - - dist_labels.emplace_back(dist, seq_id); + try { + values = field_vector_index->vecdex->getDataByLabel(seq_id); + } catch(...) { + // likely not found + continue; } - } else { + + float dist; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_q(vector_query.values.size()); + hnsw_index_t::normalize_vector(vector_query.values, normalized_q); + dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(), + &field_vector_index->num_dim); + } else { + dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(), + &field_vector_index->num_dim); + } + + dist_labels.emplace_back(dist, seq_id); + filter_result_iterator.next(); + filter_id_count++; + } + + if(no_filters_provided || + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.valid())) { + dist_labels.clear(); + + VectorFilterFunctor filterFunctor(&filter_result_iterator); + if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); hnsw_index_t::normalize_vector(vector_query.values, normalized_q); @@ -2878,6 +2896,8 @@ Option Index::search(std::vector& field_query_tokens, cons } } + filter_result_iterator.reset(); + std::vector nearest_ids; for (const auto& dist_label : dist_labels) { @@ -2929,7 +2949,7 @@ Option Index::search(std::vector& field_query_tokens, cons curated_ids, curated_ids_sorted, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, all_result_ids, all_result_ids_len, - filter_result.docs, filter_result.count, concurrency, + filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); } } else { @@ -2971,7 +2991,7 @@ Option Index::search(std::vector& field_query_tokens, cons } fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, {}, match_type, excluded_result_ids, - excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, + excluded_result_ids_size, filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, @@ -2979,6 +2999,7 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -3009,12 +3030,13 @@ Option Index::search(std::vector& field_query_tokens, cons } fuzzy_search_fields(the_fields, resolved_tokens, {}, match_type, excluded_result_ids, - excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, + excluded_result_ids_size, filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); } } @@ -3027,9 +3049,10 @@ Option Index::search(std::vector& field_query_tokens, cons excluded_result_ids, excluded_result_ids_size, excluded_group_ids, topster, q_pos_synonyms, syn_orig_num_tokens, groups_processed, searched_queries, all_result_ids, all_result_ids_len, - filter_result.docs, filter_result.count, query_hashes, + filter_result_iterator, query_hashes, sort_order, field_values, geopoint_indices, qtoken_set); + filter_result_iterator.reset(); // gather up both original query and synonym queries and do drop tokens @@ -3078,7 +3101,7 @@ Option Index::search(std::vector& field_query_tokens, cons fuzzy_search_fields(the_fields, truncated_tokens, dropped_tokens, match_type, excluded_result_ids, excluded_result_ids_size, - filter_result.docs, filter_result.count, + filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, @@ -3086,6 +3109,7 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); } else { break; @@ -3098,10 +3122,12 @@ Option Index::search(std::vector& field_query_tokens, cons group_limit, group_by_fields, max_extra_prefix, max_extra_suffix, field_query_tokens[0].q_include_tokens, - topster, filter_result.docs, filter_result.count, + topster, filter_result_iterator, sort_order, field_values, geopoint_indices, curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed); + filter_result_iterator.reset(); + if(!vector_query.field_name.empty()) { // check at least one of sort fields is text match bool has_text_match = false; @@ -3117,7 +3143,7 @@ Option Index::search(std::vector& field_query_tokens, cons constexpr float TEXT_MATCH_WEIGHT = 0.7; constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; - VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); + VectorFilterFunctor filterFunctor(&filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; auto k = std::max(vector_query.k, fetch_size); @@ -3129,6 +3155,7 @@ Option Index::search(std::vector& field_query_tokens, cons } else { dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } + filter_result_iterator.reset(); std::vector> vec_results; for (const auto& dist_label : dist_labels) { @@ -3338,7 +3365,7 @@ Option Index::search(std::vector& field_query_tokens, cons void Index::process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, - const bool filter_curated_hits, const uint32_t* filter_ids, uint32_t filter_ids_length, + const bool filter_curated_hits, filter_result_iterator_t& filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, @@ -3361,19 +3388,18 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_ids_length != 0 && filter_curated_hits) { - uint32_t* included_ids_arr = nullptr; - size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids, - filter_ids_length, &included_ids_arr); + if(filter_result_iterator.valid() && filter_curated_hits) { + for (const auto &included_id: included_ids_vec) { + auto result = filter_result_iterator.valid(included_id); - included_ids_vec.clear(); + if (result == -1) { + break; + } - for(size_t i = 0; i < included_ids_len; i++) { - included_ids_set.insert(included_ids_arr[i]); - included_ids_vec.push_back(included_ids_arr[i]); + if (result == 1) { + included_ids_set.insert(included_id); + } } - - delete [] included_ids_arr; } else { included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end()); } @@ -3431,7 +3457,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector & sort_fields, @@ -3574,9 +3600,10 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const auto& prev_token = last_token ? token_candidates_vec.back().candidates[0] : ""; std::vector field_leaves; - art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, + art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, - last_token, prev_token, filter_ids, filter_ids_length, field_leaves, unique_tokens); + last_token, prev_token, filter_result_iterator, field_leaves, unique_tokens); + filter_result_iterator.reset(); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3605,8 +3632,9 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, std::vector prev_token_doc_ids; find_across_fields(token_candidates_vec.back().token, token_candidates_vec.back().candidates[0], - the_fields, num_search_fields, filter_ids, filter_ids_length, exclude_token_ids, + the_fields, num_search_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, prev_token_doc_ids, popular_field_ids); + filter_result_iterator.reset(); for(size_t field_id: query_field_ids) { auto& the_field = the_fields[field_id]; @@ -3626,9 +3654,9 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } std::vector field_leaves; - art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, + art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, - false, "", filter_ids, filter_ids_length, field_leaves, unique_tokens); + false, "", filter_result_iterator, field_leaves, unique_tokens); if(field_leaves.empty()) { // look at the next field @@ -3681,7 +3709,8 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, if(token_candidates_vec.size() == query_tokens.size()) { std::vector id_buff; - search_all_candidates(num_search_fields, match_type, the_fields, filter_ids, filter_ids_length, + + search_all_candidates(num_search_fields, match_type, the_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, excluded_group_ids, sort_fields, token_candidates_vec, searched_queries, qtoken_set, dropped_tokens, topster, @@ -3691,6 +3720,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, exhaustive_search, max_candidates, syn_orig_num_tokens, sort_order, field_values, geopoint_indices, query_hashes, id_buff); + filter_result_iterator.reset(); if(id_buff.size() > 1) { gfx::timsort(id_buff.begin(), id_buff.end()); @@ -3751,7 +3781,7 @@ void Index::find_across_fields(const token_t& previous_token, const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, std::vector& top_prefix_field_ids) const { @@ -3762,7 +3792,7 @@ void Index::find_across_fields(const token_t& previous_token, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); const bool prefix_search = previous_token.is_prefix_searched; const uint32_t token_num_typos = previous_token.num_typos; @@ -3843,7 +3873,7 @@ void Index::search_across_fields(const std::vector& query_tokens, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool prioritize_token_position, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, @@ -3904,7 +3934,7 @@ void Index::search_across_fields(const std::vector& query_tokens, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); // for each token, find the posting lists across all query_by fields for(size_t ti = 0; ti < query_tokens.size(); ti++) { @@ -4374,13 +4404,10 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector>& included_ids_map, bool is_wildcard_query, - uint32_t*& filter_ids, uint32_t& filter_ids_length) const { + uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const { std::map phrase_match_id_scores; - uint32_t* phrase_match_ids = nullptr; - size_t phrase_match_ids_size = 0; - for(size_t i = 0; i < num_search_fields; i++) { const std::string& field_name = search_fields[i].name; const size_t field_weight = search_fields[i].weight; @@ -4449,50 +4476,30 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector(10000, filter_ids_length); i++) { - auto seq_id = filter_ids[i]; + for(size_t i = 0; i < std::min(10000, phrase_result_count); i++) { + auto seq_id = phrase_result_ids[i]; int64_t match_score = phrase_match_id_scores[seq_id]; int64_t scores[3] = {0}; @@ -4554,7 +4561,7 @@ void Index::do_synonym_search(const std::vector& the_fields, spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - const uint32_t* filter_ids, const uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -4564,7 +4571,7 @@ void Index::do_synonym_search(const std::vector& the_fields, for (const auto& syn_tokens : q_pos_synonyms) { query_hashes.clear(); fuzzy_search_fields(the_fields, syn_tokens, {}, match_type, exclude_token_ids, - exclude_token_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, excluded_group_ids, + exclude_token_ids_size, filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, {0}, searched_queries, qtoken_set, actual_topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, @@ -4582,7 +4589,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - const uint32_t *filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -4616,10 +4623,11 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, - uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, - uint32_t filter_ids_length, const size_t concurrency, + uint32_t*& all_result_ids, size_t& all_result_ids_len, + filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices) const { uint32_t token_bits = 0; - const bool check_for_circuit_break = (filter_ids_length > 1000000); + const bool check_for_circuit_break = (approx_filter_ids_length > 1000000); //auto beginF = std::chrono::high_resolution_clock::now(); - const size_t num_threads = std::min(concurrency, filter_ids_length); + const size_t num_threads = std::min(concurrency, approx_filter_ids_length); const size_t window_size = (num_threads == 0) ? 0 : - (filter_ids_length + num_threads - 1) / num_threads; // rounds up + (approx_filter_ids_length + num_threads - 1) / num_threads; // rounds up spp::sparse_hash_map tgroups_processed[num_threads]; Topster* topsters[num_threads]; @@ -4942,14 +4951,15 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - for(size_t thread_id = 0; thread_id < num_threads && filter_index < filter_ids_length; thread_id++) { - size_t batch_res_len = window_size; + for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.valid(); thread_id++) { + std::vector batch_result_ids; + batch_result_ids.reserve(window_size); - if(filter_index + window_size > filter_ids_length) { - batch_res_len = filter_ids_length - filter_index; - } + do { + batch_result_ids.push_back(filter_result_iterator.seq_id); + filter_result_iterator.next(); + } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); - const uint32_t* batch_result_ids = filter_ids + filter_index; num_queued++; searched_queries.push_back({}); @@ -4961,7 +4971,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, &group_limit, &group_by_fields, &topsters, &tgroups_processed, &excluded_group_ids, &sort_order, field_values, &geopoint_indices, &plists, check_for_circuit_break, - batch_result_ids, batch_res_len, + batch_result_ids, &num_processed, &m_process, &cv_process]() { search_begin_us = parent_search_begin; @@ -4970,7 +4980,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, size_t filter_index = 0; - for(size_t i = 0; i < batch_res_len; i++) { + for(size_t i = 0; i < batch_result_ids.size(); i++) { const uint32_t seq_id = batch_result_ids[i]; int64_t match_score = 0; @@ -5010,7 +5020,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, cv_process.notify_one(); }); - filter_index += batch_res_len; + filter_index += batch_result_ids.size(); } std::unique_lock lock_process(m_process); @@ -5022,7 +5032,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, //groups_processed.insert(tgroups_processed[thread_id].begin(), tgroups_processed[thread_id].end()); for(const auto& it : tgroups_processed[thread_id]) { groups_processed[it.first]+= it.second; - } + } aggregate_topster(topster, topsters[thread_id]); delete topsters[thread_id]; } @@ -5031,11 +5041,13 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - uint32_t* new_all_result_ids = nullptr; - all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids, - filter_ids_length, &new_all_result_ids); - delete [] all_result_ids; - all_result_ids = new_all_result_ids; +// TODO: OR filter ids with all_results_ids +// +// uint32_t* new_all_result_ids = nullptr; +// all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids, +// filter_ids_length, &new_all_result_ids); +// delete [] all_result_ids; +// all_result_ids = new_all_result_ids; } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index 13165370..df4404cb 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -1,5 +1,5 @@ #include "or_iterator.h" - +#include "filter.h" bool or_iterator_t::at_end(const std::vector& its) { // if any iterator is invalid, we stop @@ -208,6 +208,10 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e return false; } + if (istate.fit != nullptr) { + return (istate.fit->valid(id) == 1); + } + return true; } diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 39f3ac00..e983e98b 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -1260,6 +1260,146 @@ void posting_list_t::get_exact_matches(std::vector& its, const bool num_exact_ids = exact_id_index; } +bool posting_list_t::has_exact_match(std::vector& posting_list_iterators, + const bool field_is_array) { + if(posting_list_iterators.size() == 1) { + return is_single_token_verbatim_match(posting_list_iterators[0], field_is_array); + } else { + + if (!field_is_array) { + for (uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + posting_list_t::iterator_t& it = posting_list_iterators[i]; + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + return false; + } + + uint32_t* offsets = it.offsets; + + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + if(i == posting_list_iterators.size() - 1) { + // check if the last query token is the last offset + if( offsets[end_offset_index-1] != 0 || + (end_offset_index-2 >= 0 && offsets[end_offset_index-2] != posting_list_iterators.size())) { + // not the last token for the document, so skip + return false; + } + } + + // looping handles duplicate query tokens, e.g. "hip hip hurray hurray" + while(start_offset_index < end_offset_index) { + uint32_t offset = offsets[start_offset_index]; + start_offset_index++; + + if(offset == (i + 1)) { + // we have found a matching index, no need to look further + return true; + } + + if(offset > (i + 1)) { + return false; + } + } + } + } + + else { + // field is an array + + struct token_index_meta_t { + std::bitset<32> token_index; + bool has_last_token; + }; + + std::map array_index_to_token_index; + + for(uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + posting_list_t::iterator_t& it = posting_list_iterators[i]; + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + return false; + } + + uint32_t* offsets = it.offsets; + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + int prev_pos = -1; + bool has_atleast_one_last_token = false; + bool found_matching_index = false; + size_t num_matching_index = 0; + + while(start_offset_index < end_offset_index) { + int pos = offsets[start_offset_index]; + start_offset_index++; + + if(pos == prev_pos) { // indicates end of array index + size_t array_index = (size_t) offsets[start_offset_index]; + + if(start_offset_index+1 < end_offset_index) { + size_t next_offset = (size_t) offsets[start_offset_index + 1]; + if(next_offset == 0 && pos == posting_list_iterators.size()) { + // indicates that token is the last token on the doc + array_index_to_token_index[array_index].has_last_token = true; + has_atleast_one_last_token = true; + start_offset_index++; + } + } + + if(found_matching_index) { + array_index_to_token_index[array_index].token_index.set(i + 1); + } + + start_offset_index++; // skip current value which is the array index or flag for last index + prev_pos = -1; + found_matching_index = false; + continue; + } + + if(pos == (i + 1)) { + // we have found a matching index + found_matching_index = true; + num_matching_index++; + } + + prev_pos = pos; + } + + // check if the last query token is the last offset of ANY array element + if(i == posting_list_iterators.size() - 1 && !has_atleast_one_last_token) { + return false; + } + + if(num_matching_index == 0) { + // not even a single matching index found: can never be an exact match + return false; + } + } + + // iterate array index to token index to check if atleast 1 array position contains all tokens + for(auto& kv: array_index_to_token_index) { + if(kv.second.token_index.count() == posting_list_iterators.size() && kv.second.has_last_token) { + return true; + } + } + } + } + + return false; +} + bool posting_list_t::found_token_sequence(const std::vector& token_positions, const size_t token_index, const uint16_t target_pos) { diff --git a/src/validator.cpp b/src/validator.cpp index 844ef3bf..aab4a9ad 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -1,4 +1,5 @@ #include "validator.h" +#include "field.h" Option validator_t::coerce_element(const field& a_field, nlohmann::json& document, nlohmann::json& doc_ele, diff --git a/test/collection_all_fields_test.cpp b/test/collection_all_fields_test.cpp index 3acc794c..4e9a9a22 100644 --- a/test/collection_all_fields_test.cpp +++ b/test/collection_all_fields_test.cpp @@ -64,7 +64,10 @@ TEST_F(CollectionAllFieldsTest, IndexDocsWithoutSchema) { nlohmann::json document = nlohmann::json::parse(json_line); Option add_op = coll1->add(document.dump()); - LOG(INFO) << "Add op: " << add_op.error(); + if (!add_op.ok()) { + LOG(INFO) << "Add op: " << add_op.error(); + } + ASSERT_TRUE(add_op.ok()); } diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 43572401..ab0e951e 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "collection.h" #include "text_embedder_manager.h" #include "http_client.h" diff --git a/test/filter_test.cpp b/test/filter_test.cpp index b4f836f1..d125e814 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -125,15 +125,25 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); - ASSERT_TRUE(iter_exact_match_test.init_status().ok()); + auto iter_exact_match_1_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_1_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_exact_match_test.valid()); - ASSERT_EQ(i, iter_exact_match_test.seq_id); - iter_exact_match_test.next(); + ASSERT_TRUE(iter_exact_match_1_test.valid()); + ASSERT_EQ(i, iter_exact_match_1_test.seq_id); + iter_exact_match_1_test.next(); } - ASSERT_FALSE(iter_exact_match_test.valid()); + ASSERT_FALSE(iter_exact_match_1_test.valid()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= PLATINUM", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_2_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_2_test.init_status().ok()); + ASSERT_FALSE(iter_exact_match_2_test.valid()); delete filter_tree_root; filter_tree_root = nullptr; From 97316b8b4248e8cc41058d26c1597ddf15f06010 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 10 Apr 2023 19:13:45 +0530 Subject: [PATCH 18/64] Fix `FacetFieldStringFiltering` test. --- src/index.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index fb01ee09..3b2dc166 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4960,6 +4960,12 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, filter_result_iterator.next(); } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); + uint32_t* new_all_result_ids = nullptr; + all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, batch_result_ids.data(), + batch_result_ids.size(), &new_all_result_ids); + delete [] all_result_ids; + all_result_ids = new_all_result_ids; + num_queued++; searched_queries.push_back({}); @@ -5040,14 +5046,6 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, /*long long int timeMillisF = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - -// TODO: OR filter ids with all_results_ids -// -// uint32_t* new_all_result_ids = nullptr; -// all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids, -// filter_ids_length, &new_all_result_ids); -// delete [] all_result_ids; -// all_result_ids = new_all_result_ids; } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, From 37fcb53f6f27d7d006bdc8015f3003911b433091 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 11 Apr 2023 10:50:39 +0530 Subject: [PATCH 19/64] Fix memory leak. --- src/art.cpp | 5 +---- src/posting_list.cpp | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/art.cpp b/src/art.cpp index 48470551..d189d7eb 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1724,10 +1724,7 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le << ", filter_ids_length: " << filter_ids_length; }*/ -// TODO: Figure out this edge case. -// if(allowed_doc_ids != filter_ids) { -// delete [] allowed_doc_ids; -// } + delete [] allowed_doc_ids; return 0; } diff --git a/src/posting_list.cpp b/src/posting_list.cpp index e983e98b..1b5a3323 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -1267,7 +1267,7 @@ bool posting_list_t::has_exact_match(std::vector& po } else { if (!field_is_array) { - for (uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + for (int i = posting_list_iterators.size() - 1; i >= 0; i--) { posting_list_t::iterator_t& it = posting_list_iterators[i]; block_t* curr_block = it.block(); @@ -1320,7 +1320,7 @@ bool posting_list_t::has_exact_match(std::vector& po std::map array_index_to_token_index; - for(uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + for(int i = posting_list_iterators.size() - 1; i >= 0; i--) { posting_list_t::iterator_t& it = posting_list_iterators[i]; block_t* curr_block = it.block(); From 889c1759e283b6cb675c5613f20dc214b682e9bb Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 11 Apr 2023 11:27:31 +0530 Subject: [PATCH 20/64] Refactor `Index::search_wildcard`. --- src/index.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 3b2dc166..29938911 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4960,12 +4960,6 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, filter_result_iterator.next(); } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); - uint32_t* new_all_result_ids = nullptr; - all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, batch_result_ids.data(), - batch_result_ids.size(), &new_all_result_ids); - delete [] all_result_ids; - all_result_ids = new_all_result_ids; - num_queued++; searched_queries.push_back({}); @@ -5046,6 +5040,9 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, /*long long int timeMillisF = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ + + filter_result_iterator.reset(); + all_result_ids_len = filter_result_iterator.to_filter_id_array(all_result_ids); } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, From d796391464cc6e044c91cff1f9209220364370f6 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 11 Apr 2023 17:44:32 +0530 Subject: [PATCH 21/64] Add `approx_filter_ids_length` field. --- include/filter_result_iterator.h | 10 ++++++++-- src/filter_result_iterator.cpp | 6 ++++-- src/index.cpp | 3 ++- src/or_iterator.cpp | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index b67d67ca..f5ae22a6 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -125,12 +125,18 @@ private: public: uint32_t seq_id = 0; - // Collection name -> references + /// Collection name -> references std::map reference; Option status = Option(true); + /// Holds the upper-bound of the number of seq ids this iterator would match. + /// Useful in a scenario where we need to differentiate between filter iterator not matching any document v/s filter + /// iterator reaching it's end. (is_valid would be false in both these cases) + uint32_t approx_filter_ids_length; + explicit filter_result_iterator_t(const std::string collection_name, - Index const* const index, filter_node_t const* const filter_node); + Index const* const index, filter_node_t const* const filter_node, + uint32_t approx_filter_ids_length = UINT32_MAX); ~filter_result_iterator_t(); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 21a4128f..f5bd8b5f 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -835,10 +835,12 @@ uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& } filter_result_iterator_t::filter_result_iterator_t(const std::string collection_name, const Index *const index, - const filter_node_t *const filter_node) : + const filter_node_t *const filter_node, + uint32_t approx_filter_ids_length) : collection_name(collection_name), index(index), - filter_node(filter_node) { + filter_node(filter_node), + approx_filter_ids_length(approx_filter_ids_length) { if (filter_node == nullptr) { is_valid = false; return; diff --git a/src/index.cpp b/src/index.cpp index 29938911..ab8f8576 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2713,7 +2713,8 @@ Option Index::search(std::vector& field_query_tokens, cons return rearrange_op; } - auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, + approx_filter_ids_length); auto filter_init_op = filter_result_iterator.init_status(); if (!filter_init_op.ok()) { return filter_init_op; diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index df4404cb..8dd9d487 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -208,7 +208,7 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e return false; } - if (istate.fit != nullptr) { + if (istate.fit != nullptr && istate.fit->approx_filter_ids_length > 0) { return (istate.fit->valid(id) == 1); } From a9ca96a63e6aed83aa8a7750bdcf4e199d274a73 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 12 Apr 2023 18:47:55 +0530 Subject: [PATCH 22/64] Handle `!=` in `filter_result_iterator_t`. --- include/filter_result_iterator.h | 3 + include/id_list.h | 2 + src/filter_result_iterator.cpp | 145 +++++++++++++++++++++++++------ src/id_list.cpp | 8 ++ test/filter_test.cpp | 40 ++++----- 5 files changed, 153 insertions(+), 45 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index f5ae22a6..46ec11bc 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -120,6 +120,9 @@ private: /// Performs OR on the subtrees of operator. void or_filter_iterators(); + /// Advance all the token iterators that are at seq_id. + void advance_string_filter_token_iterators(); + /// Finds the next match for a filter on string field. void doc_matching_string_filter(bool field_is_array); diff --git a/include/id_list.h b/include/id_list.h index 3b0ef7ae..ad890119 100644 --- a/include/id_list.h +++ b/include/id_list.h @@ -126,6 +126,8 @@ public: uint32_t first_id(); + uint32_t last_id(); + block_t* block_of(uint32_t id); bool contains(uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f5bd8b5f..f7217c0a 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -265,6 +265,18 @@ void filter_result_iterator_t::or_filter_iterators() { is_valid = false; } +void filter_result_iterator_t::advance_string_filter_token_iterators() { + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { + for (auto& iter: filter_value_tokens) { + iter.next(); + } + } + } +} + void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { // If none of the filter value iterators are valid, mark this node as invalid. bool one_is_valid = false; @@ -384,18 +396,38 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { - // Advance all the filter values that are at doc. Then find the next doc. - for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { - auto& filter_value_tokens = posting_list_iterators[i]; - - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { - for (auto& iter: filter_value_tokens) { - iter.next(); - } + if (filter_node->filter_exp.apply_not_equals) { + if (++seq_id < result_index) { + return; } + + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id); + + if (!is_valid) { + // We've reached the end of the index, no possible matches pending. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + result_index = index->seq_ids->last_id() + 1; + seq_id = previous_match + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + return; } + advance_string_filter_token_iterators(); doc_matching_string_filter(f.is_array()); + return; } } @@ -509,6 +541,47 @@ void filter_result_iterator_t::init() { } doc_matching_string_filter(f.is_array()); + + if (filter_node->filter_exp.apply_not_equals) { + // filter didn't match any id. So by applying not equals, every id in the index is a match. + if (!is_valid) { + is_valid = true; + seq_id = 0; + result_index = index->seq_ids->last_id() + 1; + return; + } + + // [0, seq_id) are a match for not equals. + if (seq_id > 0) { + result_index = seq_id; + seq_id = 0; + return; + } + + // Keep ignoring the consecutive matches. + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id); + + if (!is_valid) { + // filter matched all the ids in the index. So for not equals, there's no match. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + result_index = index->seq_ids->last_id() + 1; + seq_id = previous_match + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + } + return; } } @@ -543,6 +616,10 @@ bool filter_result_iterator_t::valid() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { + if (filter_node->filter_exp.apply_not_equals) { + return seq_id < result_index; + } + bool one_is_valid = false; for (auto& filter_value_tokens: posting_list_iterators) { posting_list_t::intersect(filter_value_tokens, one_is_valid); @@ -618,6 +695,41 @@ void filter_result_iterator_t::skip_to(uint32_t id) { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { + if (filter_node->filter_exp.apply_not_equals) { + if (id < seq_id) { + return; + } + + if (id < result_index) { + seq_id = id; + return; + } + + seq_id = result_index; + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id && seq_id >= id); + + if (!is_valid) { + // filter matched all the ids in the index. So for not equals, there's no match. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + seq_id = previous_match + 1; + result_index = index->seq_ids->last_id() + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + return; + } + // Skip all the token iterators and find a new match. for (auto& filter_value_tokens : posting_list_iterators) { for (auto& token: filter_value_tokens) { @@ -670,23 +782,6 @@ int filter_result_iterator_t::valid(uint32_t id) { } } - if (filter_node->filter_exp.apply_not_equals) { - // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. - if (!valid()) { - is_valid = true; - return 1; - } - - skip_to(id); - - if (!is_valid) { - is_valid = true; - return 1; - } - - return seq_id != id ? 1 : 0; - } - skip_to(id); return is_valid ? (seq_id == id ? 1 : 0) : -1; } diff --git a/src/id_list.cpp b/src/id_list.cpp index 4b308603..82712368 100644 --- a/src/id_list.cpp +++ b/src/id_list.cpp @@ -338,6 +338,14 @@ uint32_t id_list_t::first_id() { return root_block.ids.at(0); } +uint32_t id_list_t::last_id() { + if(id_block_map.empty()) { + return 0; + } + + return id_block_map.rbegin()->first; +} + id_list_t::block_t* id_list_t::block_of(uint32_t id) { const auto it = id_block_map.lower_bound(id); if(it == id_block_map.end()) { diff --git a/test/filter_test.cpp b/test/filter_test.cpp index d125e814..6cab88b5 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -162,23 +162,23 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_exact_match_multi_test.valid()); -// delete filter_tree_root; -// filter_tree_root = nullptr; -// filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, -// filter_tree_root); -// ASSERT_TRUE(filter_op.ok()); -// -// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); -// ASSERT_TRUE(iter_not_equals_test.init_status().ok()); -// -// std::vector expected = {1, 3}; -// for (auto const& i : expected) { -// ASSERT_TRUE(iter_not_equals_test.valid()); -// ASSERT_EQ(i, iter_not_equals_test.seq_id); -// iter_not_equals_test.next(); -// } -// -// ASSERT_FALSE(iter_not_equals_test.valid()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_not_equals_test.init_status().ok()); + + expected = {1, 3}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_not_equals_test.valid()); + ASSERT_EQ(i, iter_not_equals_test.seq_id); + iter_not_equals_test.next(); + } + + ASSERT_FALSE(iter_not_equals_test.valid()); delete filter_tree_root; filter_tree_root = nullptr; @@ -287,8 +287,8 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); - validate_ids = {0, 1, 2, 3, 4, 5, 6, 7, 100}; - expected = {0, 1, 0, 1, 0, 1, 1, 1, 1}; + validate_ids = {0, 1, 2, 3, 4, 5, 6}; + expected = {0, 1, 0, 1, 0, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); } @@ -422,7 +422,7 @@ TEST_F(FilterTest, FilterTreeIterator) { for (uint32_t i = 0; i < and_result_length; i++) { ASSERT_EQ(expected[i], and_result[i]); } - ASSERT_FALSE(iter_and_test.valid()); + ASSERT_FALSE(iter_and_scalar_test.valid()); delete and_result; delete filter_tree_root; From f53f6635b7f4a435b98a9520a0aacf3aa98de360 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 13 Apr 2023 14:49:46 +0530 Subject: [PATCH 23/64] Fix approximation logic of filter matches in case of `!=`. --- src/index.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index ab8f8576..9b93709e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2072,9 +2072,8 @@ Option Index::_approximate_filter_ids(const filter& a_filter, } } - if (a_filter.apply_not_equals) { - auto all_ids_size = seq_ids->num_ids(); - filter_ids_length = (all_ids_size - filter_ids_length); + if (a_filter.apply_not_equals && filter_ids_length == 0) { + filter_ids_length = seq_ids->num_ids(); } return Option(true); From c571812696c6cdb4a719141505e6d15f33e822eb Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 17 Apr 2023 20:54:50 +0530 Subject: [PATCH 24/64] Add numeric field support in `filter_result_t`. --- include/ids_t.h | 8 +- include/num_tree.h | 10 ++ src/filter_result_iterator.cpp | 199 +++++++++++++++++++++++++++++++-- src/num_tree.cpp | 78 +++++++++++++ 4 files changed, 283 insertions(+), 12 deletions(-) diff --git a/include/ids_t.h b/include/ids_t.h index 15cf8c6e..949c71b8 100644 --- a/include/ids_t.h +++ b/include/ids_t.h @@ -39,11 +39,6 @@ struct compact_id_list_t { }; class ids_t { -private: - - static void to_expanded_id_lists(const std::vector& raw_id_lists, std::vector& id_lists, - std::vector& expanded_id_lists); - public: static constexpr size_t COMPACT_LIST_THRESHOLD_LENGTH = 64; static constexpr size_t MAX_BLOCK_ELEMENTS = 256; @@ -104,6 +99,9 @@ public: static uint32_t* uncompress(void*& obj); static void uncompress(void*& obj, std::vector& ids); + + static void to_expanded_id_lists(const std::vector& raw_id_lists, std::vector& id_lists, + std::vector& expanded_id_lists); }; template diff --git a/include/num_tree.h b/include/num_tree.h index 5406a109..de95307b 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -30,6 +30,11 @@ public: void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len); + void range_inclusive_search_iterators(int64_t start, + int64_t end, + std::vector& id_list_iterators, + std::vector& expanded_id_lists); + void approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len); void range_inclusive_contains(const int64_t& start, const int64_t& end, @@ -42,6 +47,11 @@ public: void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len); + void search_iterators(NUM_COMPARATOR comparator, + int64_t value, + std::vector& id_list_iterators, + std::vector& expanded_id_lists); + void approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len); void remove(uint64_t value, uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f7217c0a..6076807f 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1,3 +1,5 @@ +#include +#include #include "filter_result_iterator.h" #include "index.h" #include "posting.h" @@ -395,7 +397,16 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + result_index++; + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + return; + } else if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (++seq_id < result_index) { return; @@ -432,6 +443,40 @@ void filter_result_iterator_t::next() { } } +void merge_id_list_iterators(std::vector& id_list_iterators, + std::vector& result_ids) { + struct comp { + bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { + return lhs->id() > rhs->id(); + } + }; + + std::priority_queue, comp> iter_queue; + for (auto& id_list_iterator: id_list_iterators) { + if (id_list_iterator.valid()) { + iter_queue.push(&id_list_iterator); + } + } + + if (iter_queue.empty()) { + return; + } + + // TODO: Handle != + + do { + id_list_t::iterator_t* iter = iter_queue.top(); + iter_queue.pop(); + + result_ids.push_back(iter->id()); + iter->next(); + + if (iter->valid()) { + iter_queue.push(iter); + } + } while (!iter_queue.empty()); +} + void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -470,7 +515,12 @@ void filter_result_iterator_t::init() { return; } - is_valid = filter_result.count > 0; + if (filter_result.count == 0) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; return; } @@ -491,6 +541,7 @@ void filter_result_iterator_t::init() { filter_result.count = result_ids.size(); filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; } if (!index->field_is_indexed(a_filter.field_name)) { @@ -500,7 +551,112 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer()) { + auto num_tree = index->numerical_index.at(a_filter.field_name); + + std::vector ids; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + int64_t value = (int64_t)std::stol(filter_value); + std::vector id_list_iterators; + std::vector expanded_id_lists; + + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const range_end_value = (int64_t)std::stol(next_filter_value); + num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); + fi++; + } else { + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + value, id_list_iterators, expanded_id_lists); + } + + merge_id_list_iterators(id_list_iterators, ids); + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } + } + + if (ids.empty()) { + is_valid = false; + return; + } + + filter_result.count = ids.size(); + filter_result.docs = new uint32_t[ids.size()]; + std::copy(ids.begin(), ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + } else if (f.is_float()) { + auto num_tree = index->numerical_index.at(a_filter.field_name); + + std::vector ids; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); + std::vector id_list_iterators; + std::vector expanded_id_lists; + + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi+1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + num_tree->range_inclusive_search_iterators(float_int64, range_end_value, + id_list_iterators, expanded_id_lists); + fi++; + } else { + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + float_int64, id_list_iterators, expanded_id_lists); + } + + merge_id_list_iterators(id_list_iterators, ids); + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } + } + + if (ids.empty()) { + is_valid = false; + return; + } + + filter_result.count = ids.size(); + filter_result.docs = new uint32_t[ids.size()]; + std::copy(ids.begin(), ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + } else if (f.is_bool()) { + auto num_tree = index->numerical_index.at(a_filter.field_name); + + std::vector ids; + size_t value_index = 0; + for (const std::string& filter_value : a_filter.values) { + int64_t bool_int64 = (filter_value == "1") ? 1 : 0; + std::vector id_list_iterators; + std::vector expanded_id_lists; + + num_tree->search_iterators(a_filter.comparators[value_index] == NOT_EQUALS ? EQUALS : a_filter.comparators[value_index], + bool_int64, id_list_iterators, expanded_id_lists); + + merge_id_list_iterators(id_list_iterators, ids); + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } + + value_index++; + } + + if (ids.empty()) { + is_valid = false; + return; + } + + filter_result.count = ids.size(); + filter_result.docs = new uint32_t[ids.size()]; + std::copy(ids.begin(), ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); for (const std::string& filter_value : a_filter.values) { @@ -615,7 +771,10 @@ bool filter_result_iterator_t::valid() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + is_valid = result_index < filter_result.count; + return is_valid; + } else if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { return seq_id < result_index; } @@ -694,7 +853,17 @@ void filter_result_iterator_t::skip_to(uint32_t id) { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + while(result_index < filter_result.count && filter_result.docs[result_index] < id) { + result_index++; + } + + if (result_index >= filter_result.count) { + is_valid = false; + } + + return; + } else if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (id < seq_id) { return; @@ -861,8 +1030,14 @@ void filter_result_iterator_t::reset() { bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); if (is_referenced_filter || a_filter.field_name == "id") { + if (filter_result.count == 0) { + is_valid = false; + return; + } + result_index = 0; - is_valid = filter_result.count > 0; + seq_id = filter_result.docs[result_index]; + is_valid = true; return; } @@ -872,7 +1047,17 @@ void filter_result_iterator_t::reset() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + if (filter_result.count == 0) { + is_valid = false; + return; + } + + result_index = 0; + seq_id = filter_result.docs[result_index]; + is_valid = true; + return; + } else if (f.is_string()) { posting_list_iterators.clear(); for(auto expanded_plist: expanded_plists) { delete expanded_plist; diff --git a/src/num_tree.cpp b/src/num_tree.cpp index 1bcdbc9f..4214ced6 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -43,6 +43,30 @@ void num_tree_t::range_inclusive_search(int64_t start, int64_t end, uint32_t** i *ids = out; } +void num_tree_t::range_inclusive_search_iterators(int64_t start, + int64_t end, + std::vector& id_list_iterators, + std::vector& expanded_id_lists) { + if (int64map.empty()) { + return; + } + + auto it_start = int64map.lower_bound(start); // iter values will be >= start + + std::vector raw_id_lists; + while (it_start != int64map.end() && it_start->first <= end) { + raw_id_lists.push_back(it_start->second); + it_start++; + } + + std::vector id_lists; + ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); + + for (const auto &id_list: id_lists) { + id_list_iterators.emplace_back(id_list->new_iterator()); + } +} + void num_tree_t::approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len) { if (int64map.empty()) { return; @@ -187,6 +211,60 @@ void num_tree_t::search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids } } +void num_tree_t::search_iterators(NUM_COMPARATOR comparator, + int64_t value, + std::vector& id_list_iterators, + std::vector& expanded_id_lists) { + if (int64map.empty()) { + return ; + } + + std::vector raw_id_lists; + if (comparator == EQUALS) { + const auto& it = int64map.find(value); + if (it != int64map.end()) { + raw_id_lists.emplace_back(it->second); + } + } else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) { + // iter entries will be >= value, or end() if all entries are before value + auto iter_ge_value = int64map.lower_bound(value); + + if(iter_ge_value == int64map.end()) { + return ; + } + + if(comparator == GREATER_THAN && iter_ge_value->first == value) { + iter_ge_value++; + } + + while(iter_ge_value != int64map.end()) { + raw_id_lists.emplace_back(iter_ge_value->second); + iter_ge_value++; + } + } else if(comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) { + // iter entries will be >= value, or end() if all entries are before value + auto iter_ge_value = int64map.lower_bound(value); + + auto it = int64map.begin(); + while(it != iter_ge_value) { + raw_id_lists.emplace_back(it->second); + it++; + } + + // for LESS_THAN_EQUALS, check if last iter entry is equal to value + if(it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) { + raw_id_lists.emplace_back(it->second); + } + } + + std::vector id_lists; + ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); + + for (const auto &id_list: id_lists) { + id_list_iterators.emplace_back(id_list->new_iterator()); + } +} + void num_tree_t::approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len) { if (int64map.empty()) { return; From 7ef71b0fa8cc9543280146d687d0c539b67e8e4b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 18 Apr 2023 09:57:39 +0530 Subject: [PATCH 25/64] Optimize `filter_result_iterator_t::to_filter_id_array`. --- src/filter_result_iterator.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 6076807f..f2aeb7cd 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1074,6 +1074,19 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { return 0; } + if (!filter_node->isOperator) { + const filter a_filter = filter_node->filter_exp; + field f = index->search_schema.at(a_filter.field_name); + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || + (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { + filter_array = new uint32_t[filter_result.count]; + std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); + + return filter_result.count; + } + } + std::vector filter_ids; do { filter_ids.push_back(seq_id); From ac4cb544368d38d045c06da55ada356d5434292c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 18 Apr 2023 15:16:33 +0530 Subject: [PATCH 26/64] Refactor numeric filter initialization. --- src/filter_result_iterator.cpp | 105 +++++---------------------------- 1 file changed, 16 insertions(+), 89 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f2aeb7cd..d57c8895 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -443,40 +443,6 @@ void filter_result_iterator_t::next() { } } -void merge_id_list_iterators(std::vector& id_list_iterators, - std::vector& result_ids) { - struct comp { - bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { - return lhs->id() > rhs->id(); - } - }; - - std::priority_queue, comp> iter_queue; - for (auto& id_list_iterator: id_list_iterators) { - if (id_list_iterator.valid()) { - iter_queue.push(&id_list_iterator); - } - } - - if (iter_queue.empty()) { - return; - } - - // TODO: Handle != - - do { - id_list_t::iterator_t* iter = iter_queue.top(); - iter_queue.pop(); - - result_ids.push_back(iter->id()); - iter->next(); - - if (iter->valid()) { - iter_queue.push(iter); - } - } while (!iter_queue.empty()); -} - void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -554,108 +520,69 @@ void filter_result_iterator_t::init() { if (f.is_integer()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - std::vector ids; + // TODO: Handle not equals + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); - std::vector id_list_iterators; - std::vector expanded_id_lists; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, + reinterpret_cast(filter_result.count)); fi++; } else { - num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], - value, id_list_iterators, expanded_id_lists); - } - - merge_id_list_iterators(id_list_iterators, ids); - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; + num_tree->search(a_filter.comparators[fi], value, + &filter_result.docs, reinterpret_cast(filter_result.count)); } } - if (ids.empty()) { + if (filter_result.count == 0) { is_valid = false; return; } - - filter_result.count = ids.size(); - filter_result.docs = new uint32_t[ids.size()]; - std::copy(ids.begin(), ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; } else if (f.is_float()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - std::vector ids; for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; float value = (float)std::atof(filter_value.c_str()); int64_t float_int64 = Index::float_to_int64_t(value); - std::vector id_list_iterators; - std::vector expanded_id_lists; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi+1]; int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search_iterators(float_int64, range_end_value, - id_list_iterators, expanded_id_lists); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, + reinterpret_cast(filter_result.count)); fi++; } else { - num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], - float_int64, id_list_iterators, expanded_id_lists); - } - - merge_id_list_iterators(id_list_iterators, ids); - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; + num_tree->search(a_filter.comparators[fi], float_int64, + &filter_result.docs, reinterpret_cast(filter_result.count)); } } - if (ids.empty()) { + if (filter_result.count == 0) { is_valid = false; return; } - - filter_result.count = ids.size(); - filter_result.docs = new uint32_t[ids.size()]; - std::copy(ids.begin(), ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; } else if (f.is_bool()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - std::vector ids; size_t value_index = 0; for (const std::string& filter_value : a_filter.values) { int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - std::vector id_list_iterators; - std::vector expanded_id_lists; - num_tree->search_iterators(a_filter.comparators[value_index] == NOT_EQUALS ? EQUALS : a_filter.comparators[value_index], - bool_int64, id_list_iterators, expanded_id_lists); - - merge_id_list_iterators(id_list_iterators, ids); - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; - } + num_tree->search(a_filter.comparators[value_index], bool_int64, + &filter_result.docs, reinterpret_cast(filter_result.count)); value_index++; } - if (ids.empty()) { + if (filter_result.count == 0) { is_valid = false; return; } - - filter_result.count = ids.size(); - filter_result.docs = new uint32_t[ids.size()]; - std::copy(ids.begin(), ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); @@ -1080,9 +1007,9 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { + filter_array = new uint32_t[filter_result.count]; std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); - return filter_result.count; } } From d44e2e4c7a51159e46aebed6e0473aa47aac7c75 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 18 Apr 2023 16:53:24 +0530 Subject: [PATCH 27/64] Expose filter ids from iterator where possible. --- include/filter_result_iterator.h | 10 ++++++++++ src/filter_result_iterator.cpp | 33 +++++++++++++++++++++----------- src/index.cpp | 21 +++++++++++++------- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 46ec11bc..bd9e66f0 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -127,6 +127,14 @@ private: void doc_matching_string_filter(bool field_is_array); public: + uint32_t* get_ids() { + return filter_result.docs; + } + + uint32_t get_length() { + return filter_result.count; + } + uint32_t seq_id = 0; /// Collection name -> references std::map reference; @@ -180,4 +188,6 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); + + bool can_get_ids(); }; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d57c8895..787ad578 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1001,17 +1001,10 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { return 0; } - if (!filter_node->isOperator) { - const filter a_filter = filter_node->filter_exp; - field f = index->search_schema.at(a_filter.field_name); - - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || - (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { - - filter_array = new uint32_t[filter_result.count]; - std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); - return filter_result.count; - } + if (can_get_ids()) { + filter_array = new uint32_t[filter_result.count]; + std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); + return filter_result.count; } std::vector filter_ids; @@ -1031,6 +1024,10 @@ uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& return 0; } + if (can_get_ids()) { + return ArrayUtils::and_scalar(A, lenA, filter_result.docs, filter_result.count, &results); + } + std::vector filter_ids; for (uint32_t i = 0; i < lenA; i++) { auto result = valid(A[i]); @@ -1121,3 +1118,17 @@ filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iter return *this; } + +bool filter_result_iterator_t::can_get_ids() { + if (!filter_node->isOperator) { + const filter a_filter = filter_node->filter_exp; + field f = index->search_schema.at(a_filter.field_name); + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || + (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { + return true; + } + } + + return false; +} diff --git a/src/index.cpp b/src/index.cpp index 9b93709e..19d9c05a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4951,14 +4951,23 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.valid(); thread_id++) { + for(size_t thread_id = 0; thread_id < num_threads && + (filter_result_iterator.can_get_ids() ? + filter_index < filter_result_iterator.get_length() : + filter_result_iterator.valid()); thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - do { - batch_result_ids.push_back(filter_result_iterator.seq_id); - filter_result_iterator.next(); - } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); + if (filter_result_iterator.can_get_ids()) { + while (batch_result_ids.size() < window_size && filter_index < filter_result_iterator.get_length()) { + batch_result_ids.push_back(filter_result_iterator.get_ids()[filter_index++]); + } + } else { + do { + batch_result_ids.push_back(filter_result_iterator.seq_id); + filter_result_iterator.next(); + } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); + } num_queued++; @@ -5019,8 +5028,6 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, parent_search_cutoff = parent_search_cutoff || search_cutoff; cv_process.notify_one(); }); - - filter_index += batch_result_ids.size(); } std::unique_lock lock_process(m_process); From 6459681a0d0c3c452b6098af60fcba6dc5a76b58 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 20 Apr 2023 13:19:42 +0530 Subject: [PATCH 28/64] Add `filter_result_iterator_t::get_n_ids`. Use `is_valid` instead of `valid()`. Handle special `_all_` field name in filtering logic. --- include/filter_result_iterator.h | 28 ++-- include/index.h | 1 + src/art.cpp | 3 +- src/filter.cpp | 3 + src/filter_result_iterator.cpp | 256 +++++++++++++++++++------------ src/index.cpp | 47 +++--- test/filter_test.cpp | 64 ++++---- 7 files changed, 227 insertions(+), 175 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index bd9e66f0..1184b74a 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -99,6 +99,7 @@ private: /// Stores the result of the filters that cannot be iterated. filter_result_t filter_result; + bool is_filter_result_initialized = false; /// Initialized in case of filter on string field. /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator @@ -108,9 +109,6 @@ private: std::vector> posting_list_iterators; std::vector expanded_plists; - /// Set to false when this iterator or it's subtree becomes invalid. - bool is_valid = true; - /// Initializes the state of iterator node after it's creation. void init(); @@ -126,18 +124,18 @@ private: /// Finds the next match for a filter on string field. void doc_matching_string_filter(bool field_is_array); + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). + [[nodiscard]] bool valid(); + public: - uint32_t* get_ids() { - return filter_result.docs; - } - - uint32_t get_length() { - return filter_result.count; - } - uint32_t seq_id = 0; /// Collection name -> references std::map reference; + + /// Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initialization status of the iterator. Option status = Option(true); /// Holds the upper-bound of the number of seq ids this iterator would match. @@ -156,9 +154,6 @@ public: /// Returns the status of the initialization of iterator tree. Option init_status(); - /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). - [[nodiscard]] bool valid(); - /// Returns a tri-state: /// 0: id is not valid /// 1: id is valid @@ -171,6 +166,9 @@ public: /// operation. void next(); + /// Collects n doc ids while advancing the iterator. The iterator may become invalid during this operation. + void get_n_ids(const uint32_t& n, std::vector& results); + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); @@ -188,6 +186,4 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); - - bool can_get_ids(); }; diff --git a/include/index.h b/include/index.h index 1795ecd6..15be92e2 100644 --- a/include/index.h +++ b/include/index.h @@ -562,6 +562,7 @@ public: static const int DROP_TOKENS_THRESHOLD = 1; // "_all_" is a special field that maps to all the ids in the index. + static constexpr const char* SEQ_IDS_FIELD = "_all_"; static constexpr const char* SEQ_IDS_FILTER = "_all_: 1"; Index() = delete; diff --git a/src/art.cpp b/src/art.cpp index d189d7eb..7a7d5d5b 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -991,7 +991,7 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, std::vector prev_leaf_ids; posting_t::merge({prev_leaf->values}, prev_leaf_ids); - if(filter_result_iterator.valid()) { + if(filter_result_iterator.is_valid) { prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), prev_token_doc_ids); } else { @@ -1692,6 +1692,7 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le // documents that contain the previous token and/or filter ids size_t allowed_doc_ids_len = 0; const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); + filter_result_iterator.reset(); for(auto node: nodes) { art_topk_iter(node, token_order, max_words, exact_leaf, diff --git a/src/filter.cpp b/src/filter.cpp index 95fbfefc..18348ed6 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -283,6 +283,9 @@ Option toFilter(const std::string expression, } } return Option(true); + } else if (field_name == Index::SEQ_IDS_FIELD) { + filter_exp = {field_name, {}, {}}; + return Option(true); } auto field_it = search_schema.find(field_name); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 787ad578..c5098581 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -271,8 +271,12 @@ void filter_result_iterator_t::advance_string_filter_token_iterators() { for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { auto& filter_value_tokens = posting_list_iterators[i]; - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { - for (auto& iter: filter_value_tokens) { + if (!filter_value_tokens[0].valid() || filter_value_tokens[0].id() != seq_id) { + continue; + } + + for (auto& iter: filter_value_tokens) { + if (iter.valid()) { iter.next(); } } @@ -362,10 +366,7 @@ void filter_result_iterator_t::next() { return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { + if (is_filter_result_initialized) { if (++result_index >= filter_result.count) { is_valid = false; return; @@ -380,15 +381,7 @@ void filter_result_iterator_t::next() { return; } - if (a_filter.field_name == "id") { - if (++result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { is_valid = false; @@ -397,16 +390,7 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - result_index++; - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } else if (f.is_string()) { + if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (++seq_id < result_index) { return; @@ -443,6 +427,41 @@ void filter_result_iterator_t::next() { } } +void numeric_not_equals_filter(num_tree_t* const num_tree, + const int64_t value, + uint32_t*&& all_ids, + uint32_t&& all_ids_length, + uint32_t*& result_ids, + size_t& result_ids_len) { + uint32_t* to_exclude_ids = nullptr; + size_t to_exclude_ids_len = 0; + + num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len); + + result_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_length, to_exclude_ids, to_exclude_ids_len, &result_ids); + + delete[] all_ids; + delete[] to_exclude_ids; +} + +void apply_not_equals(uint32_t*&& all_ids, + uint32_t&& all_ids_length, + uint32_t*& result_ids, + uint32_t& result_ids_len) { + + uint32_t* to_include_ids = nullptr; + size_t to_include_ids_len = 0; + + to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_length, result_ids, + result_ids_len, &to_include_ids); + + delete[] all_ids; + delete[] result_ids; + + result_ids = to_include_ids; + result_ids_len = to_include_ids_len; +} + void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -487,6 +506,11 @@ void filter_result_iterator_t::init() { } seq_id = filter_result.docs[result_index]; + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + is_filter_result_initialized = true; return; } @@ -507,7 +531,22 @@ void filter_result_iterator_t::init() { filter_result.count = result_ids.size(); filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; + } else if (a_filter.field_name == Index::SEQ_IDS_FIELD) { + if (index->seq_ids->num_ids() == 0) { + is_valid = false; + return; + } + + filter_result.count = index->seq_ids->num_ids(); + filter_result.docs = index->seq_ids->uncompress(); + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } if (!index->field_is_indexed(a_filter.field_name)) { @@ -520,28 +559,40 @@ void filter_result_iterator_t::init() { if (f.is_integer()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - // TODO: Handle not equals - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); + size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, - reinterpret_cast(filter_result.count)); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], value, - &filter_result.docs, reinterpret_cast(filter_result.count)); + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); } + + filter_result.count = result_size; + } + + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); } if (filter_result.count == 0) { is_valid = false; return; } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } else if (f.is_float()) { auto num_tree = index->numerical_index.at(a_filter.field_name); @@ -550,22 +601,36 @@ void filter_result_iterator_t::init() { float value = (float)std::atof(filter_value.c_str()); int64_t float_int64 = Index::float_to_int64_t(value); + size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi+1]; int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, - reinterpret_cast(filter_result.count)); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, float_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], float_int64, - &filter_result.docs, reinterpret_cast(filter_result.count)); + num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); } + + filter_result.count = result_size; + } + + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); } if (filter_result.count == 0) { is_valid = false; return; } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } else if (f.is_bool()) { auto num_tree = index->numerical_index.at(a_filter.field_name); @@ -573,16 +638,32 @@ void filter_result_iterator_t::init() { for (const std::string& filter_value : a_filter.values) { int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - num_tree->search(a_filter.comparators[value_index], bool_int64, - &filter_result.docs, reinterpret_cast(filter_result.count)); + size_t result_size = filter_result.count; + if (a_filter.comparators[value_index] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, bool_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[value_index], bool_int64, &filter_result.docs, result_size); + } + filter_result.count = result_size; value_index++; } + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); + } + if (filter_result.count == 0) { is_valid = false; return; } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); @@ -684,13 +765,13 @@ bool filter_result_iterator_t::valid() { } } - const filter a_filter = filter_node->filter_exp; - - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + if (is_filter_result_initialized) { is_valid = result_index < filter_result.count; return is_valid; } + const filter a_filter = filter_node->filter_exp; + if (!index->field_is_indexed(a_filter.field_name)) { is_valid = false; return is_valid; @@ -698,10 +779,7 @@ bool filter_result_iterator_t::valid() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - is_valid = result_index < filter_result.count; - return is_valid; - } else if (f.is_string()) { + if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { return seq_id < result_index; } @@ -741,10 +819,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { + if (is_filter_result_initialized) { while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); if (result_index >= filter_result.count) { @@ -761,17 +836,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - if (a_filter.field_name == "id") { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); - - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { is_valid = false; @@ -780,17 +845,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - while(result_index < filter_result.count && filter_result.docs[result_index] < id) { - result_index++; - } - - if (result_index >= filter_result.count) { - is_valid = false; - } - - return; - } else if (f.is_string()) { + if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (id < seq_id) { return; @@ -897,7 +952,7 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); size_t i = 0; - while(i < list->length && valid()) { + while(i < list->length && is_valid) { size_t num_existing_offsets = list->id_offsets[i]; size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; @@ -916,7 +971,7 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { auto list = (posting_list_t*)(obj); posting_list_t::iterator_t it = list->new_iterator(); - while(it.valid() && valid()) { + while(it.valid() && is_valid) { uint32_t id = it.id(); if(id == seq_id) { @@ -943,6 +998,7 @@ void filter_result_iterator_t::reset() { // Reset the subtrees then apply operators to arrive at the first valid doc. left_it->reset(); right_it->reset(); + is_valid = true; if (filter_node->filter_operator == AND) { and_filter_iterators(); @@ -953,10 +1009,7 @@ void filter_result_iterator_t::reset() { return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter || a_filter.field_name == "id") { + if (is_filter_result_initialized) { if (filter_result.count == 0) { is_valid = false; return; @@ -964,27 +1017,25 @@ void filter_result_iterator_t::reset() { result_index = 0; seq_id = filter_result.docs[result_index]; + + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + is_valid = true; return; } + const filter a_filter = filter_node->filter_exp; + if (!index->field_is_indexed(a_filter.field_name)) { return; } field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - if (filter_result.count == 0) { - is_valid = false; - return; - } - - result_index = 0; - seq_id = filter_result.docs[result_index]; - is_valid = true; - return; - } else if (f.is_string()) { + if (f.is_string()) { posting_list_iterators.clear(); for(auto expanded_plist: expanded_plists) { delete expanded_plist; @@ -997,11 +1048,11 @@ void filter_result_iterator_t::reset() { } uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { - if (!valid()) { + if (!is_valid) { return 0; } - if (can_get_ids()) { + if (is_filter_result_initialized) { filter_array = new uint32_t[filter_result.count]; std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); return filter_result.count; @@ -1011,7 +1062,7 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { do { filter_ids.push_back(seq_id); next(); - } while (valid()); + } while (is_valid); filter_array = new uint32_t[filter_ids.size()]; std::copy(filter_ids.begin(), filter_ids.end(), filter_array); @@ -1020,11 +1071,11 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { } uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { - if (!valid()) { + if (!is_valid) { return 0; } - if (can_get_ids()) { + if (is_filter_result_initialized) { return ArrayUtils::and_scalar(A, lenA, filter_result.docs, filter_result.count, &results); } @@ -1115,20 +1166,23 @@ filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iter seq_id = obj.seq_id; reference = std::move(obj.reference); status = std::move(obj.status); + is_filter_result_initialized = obj.is_filter_result_initialized; return *this; } -bool filter_result_iterator_t::can_get_ids() { - if (!filter_node->isOperator) { - const filter a_filter = filter_node->filter_exp; - field f = index->search_schema.at(a_filter.field_name); - - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || - (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { - return true; +void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { + if (is_filter_result_initialized) { + for (uint32_t count = 0; count < n && result_index < filter_result.count; count++) { + results.push_back(filter_result.docs[result_index++]); } + + is_valid = result_index < filter_result.count; + return; } - return false; + for (uint32_t count = 0; count < n && is_valid; count++) { + results.push_back(seq_id); + next(); + } } diff --git a/src/index.cpp b/src/index.cpp index 19d9c05a..af8c2c3a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2719,7 +2719,7 @@ Option Index::search(std::vector& field_query_tokens, cons return filter_init_op; } - if (filter_tree_root != nullptr && !filter_result_iterator.valid()) { + if (filter_tree_root != nullptr && !filter_result_iterator.is_valid) { return Option(true); } @@ -2784,7 +2784,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.valid()); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.is_valid); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2833,11 +2833,9 @@ Option Index::search(std::vector& field_query_tokens, cons store, doc_id_prefix, filter_tree_root); filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + approx_filter_ids_length = filter_result_iterator.is_valid; } -// TODO: Curate ids at last -// curate_filtered_ids(curated_ids, excluded_result_ids, -// excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { @@ -2853,8 +2851,7 @@ Option Index::search(std::vector& field_query_tokens, cons uint32_t filter_id_count = 0; while (!no_filters_provided && - filter_id_count < vector_query.flat_search_cutoff && - filter_result_iterator.valid()) { + filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { auto seq_id = filter_result_iterator.seq_id; std::vector values; @@ -2882,7 +2879,7 @@ Option Index::search(std::vector& field_query_tokens, cons } if(no_filters_provided || - (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.valid())) { + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.is_valid)) { dist_labels.clear(); VectorFilterFunctor filterFunctor(&filter_result_iterator); @@ -2951,7 +2948,19 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); } + + // filter tree was initialized to have all sequence ids in this flow. + if (no_filters_provided) { + delete filter_tree_root; + filter_tree_root = nullptr; + } + + uint32_t _all_result_ids_len = all_result_ids_len; + curate_filtered_ids(curated_ids, excluded_result_ids, + excluded_result_ids_size, all_result_ids, _all_result_ids_len, curated_ids_sorted); + all_result_ids_len = _all_result_ids_len; } else { // Non-wildcard // In multi-field searches, a record can be matched across different fields, so we use this for aggregation @@ -3388,7 +3397,7 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_result_iterator.valid() && filter_curated_hits) { + if(filter_result_iterator.is_valid && filter_curated_hits) { for (const auto &included_id: included_ids_vec) { auto result = filter_result_iterator.valid(included_id); @@ -3657,6 +3666,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, false, "", filter_result_iterator, field_leaves, unique_tokens); + filter_result_iterator.reset(); if(field_leaves.empty()) { // look at the next field @@ -4623,7 +4633,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector batch_result_ids; batch_result_ids.reserve(window_size); - if (filter_result_iterator.can_get_ids()) { - while (batch_result_ids.size() < window_size && filter_index < filter_result_iterator.get_length()) { - batch_result_ids.push_back(filter_result_iterator.get_ids()[filter_index++]); - } - } else { - do { - batch_result_ids.push_back(filter_result_iterator.seq_id); - filter_result_iterator.next(); - } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); - } + filter_result_iterator.get_n_ids(window_size, batch_result_ids); num_queued++; diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 6cab88b5..86cdd0f5 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -65,7 +65,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_null_filter_tree_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_null_filter_tree_test.init_status().ok()); - ASSERT_FALSE(iter_null_filter_tree_test.valid()); + ASSERT_FALSE(iter_null_filter_tree_test.is_valid); Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, filter_tree_root); @@ -74,7 +74,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_no_match_test.init_status().ok()); - ASSERT_FALSE(iter_no_match_test.valid()); + ASSERT_FALSE(iter_no_match_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -85,7 +85,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_no_match_multi_test.init_status().ok()); - ASSERT_FALSE(iter_no_match_multi_test.valid()); + ASSERT_FALSE(iter_no_match_multi_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -97,11 +97,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_contains_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_contains_test.valid()); + ASSERT_TRUE(iter_contains_test.is_valid); ASSERT_EQ(i, iter_contains_test.seq_id); iter_contains_test.next(); } - ASSERT_FALSE(iter_contains_test.valid()); + ASSERT_FALSE(iter_contains_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -113,11 +113,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_contains_multi_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_contains_multi_test.valid()); + ASSERT_TRUE(iter_contains_multi_test.is_valid); ASSERT_EQ(i, iter_contains_multi_test.seq_id); iter_contains_multi_test.next(); } - ASSERT_FALSE(iter_contains_multi_test.valid()); + ASSERT_FALSE(iter_contains_multi_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -129,11 +129,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_exact_match_1_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_exact_match_1_test.valid()); + ASSERT_TRUE(iter_exact_match_1_test.is_valid); ASSERT_EQ(i, iter_exact_match_1_test.seq_id); iter_exact_match_1_test.next(); } - ASSERT_FALSE(iter_exact_match_1_test.valid()); + ASSERT_FALSE(iter_exact_match_1_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -143,7 +143,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_exact_match_2_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_exact_match_2_test.init_status().ok()); - ASSERT_FALSE(iter_exact_match_2_test.valid()); + ASSERT_FALSE(iter_exact_match_2_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -156,11 +156,11 @@ TEST_F(FilterTest, FilterTreeIterator) { std::vector expected = {0, 2, 3, 4}; for (auto const& i : expected) { - ASSERT_TRUE(iter_exact_match_multi_test.valid()); + ASSERT_TRUE(iter_exact_match_multi_test.is_valid); ASSERT_EQ(i, iter_exact_match_multi_test.seq_id); iter_exact_match_multi_test.next(); } - ASSERT_FALSE(iter_exact_match_multi_test.valid()); + ASSERT_FALSE(iter_exact_match_multi_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -173,12 +173,12 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {1, 3}; for (auto const& i : expected) { - ASSERT_TRUE(iter_not_equals_test.valid()); + ASSERT_TRUE(iter_not_equals_test.is_valid); ASSERT_EQ(i, iter_not_equals_test.seq_id); iter_not_equals_test.next(); } - ASSERT_FALSE(iter_not_equals_test.valid()); + ASSERT_FALSE(iter_not_equals_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -189,13 +189,13 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_skip_test.init_status().ok()); - ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_TRUE(iter_skip_test.is_valid); iter_skip_test.skip_to(3); - ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_TRUE(iter_skip_test.is_valid); ASSERT_EQ(4, iter_skip_test.seq_id); iter_skip_test.next(); - ASSERT_FALSE(iter_skip_test.valid()); + ASSERT_FALSE(iter_skip_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -206,11 +206,11 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_and_test.init_status().ok()); - ASSERT_TRUE(iter_and_test.valid()); + ASSERT_TRUE(iter_and_test.is_valid); ASSERT_EQ(1, iter_and_test.seq_id); iter_and_test.next(); - ASSERT_FALSE(iter_and_test.valid()); + ASSERT_FALSE(iter_and_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -234,12 +234,12 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {2, 4, 5}; for (auto const& i : expected) { - ASSERT_TRUE(iter_or_test.valid()); + ASSERT_TRUE(iter_or_test.is_valid); ASSERT_EQ(i, iter_or_test.seq_id); iter_or_test.next(); } - ASSERT_FALSE(iter_or_test.valid()); + ASSERT_FALSE(iter_or_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -250,17 +250,17 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_skip_complex_filter_test.init_status().ok()); - ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_skip_complex_filter_test.is_valid); iter_skip_complex_filter_test.skip_to(4); expected = {4, 5}; for (auto const& i : expected) { - ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_skip_complex_filter_test.is_valid); ASSERT_EQ(i, iter_skip_complex_filter_test.seq_id); iter_skip_complex_filter_test.next(); } - ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_FALSE(iter_skip_complex_filter_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -358,20 +358,20 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {0, 2, 3, 4}; for (auto const& i : expected) { - ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_TRUE(iter_reset_test.is_valid); ASSERT_EQ(i, iter_reset_test.seq_id); iter_reset_test.next(); } - ASSERT_FALSE(iter_reset_test.valid()); + ASSERT_FALSE(iter_reset_test.is_valid); iter_reset_test.reset(); for (auto const& i : expected) { - ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_TRUE(iter_reset_test.is_valid); ASSERT_EQ(i, iter_reset_test.seq_id); iter_reset_test.next(); } - ASSERT_FALSE(iter_reset_test.valid()); + ASSERT_FALSE(iter_reset_test.is_valid); auto iter_move_assignment_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); @@ -380,11 +380,11 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {0, 2, 3, 4}; for (auto const& i : expected) { - ASSERT_TRUE(iter_move_assignment_test.valid()); + ASSERT_TRUE(iter_move_assignment_test.is_valid); ASSERT_EQ(i, iter_move_assignment_test.seq_id); iter_move_assignment_test.next(); } - ASSERT_FALSE(iter_move_assignment_test.valid()); + ASSERT_FALSE(iter_move_assignment_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -405,7 +405,7 @@ TEST_F(FilterTest, FilterTreeIterator) { for (uint32_t i = 0; i < filter_ids_length; i++) { ASSERT_EQ(expected[i], filter_ids[i]); } - ASSERT_FALSE(iter_to_array_test.valid()); + ASSERT_FALSE(iter_to_array_test.is_valid); delete filter_ids; @@ -422,7 +422,7 @@ TEST_F(FilterTest, FilterTreeIterator) { for (uint32_t i = 0; i < and_result_length; i++) { ASSERT_EQ(expected[i], and_result[i]); } - ASSERT_FALSE(iter_and_scalar_test.valid()); + ASSERT_FALSE(iter_and_scalar_test.is_valid); delete and_result; delete filter_tree_root; From d85cf706b2a17c0224f09f71cd197a71ad86d32a Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 21 Apr 2023 09:48:19 +0530 Subject: [PATCH 29/64] Add tests. --- src/filter_result_iterator.cpp | 23 +++++++++-- test/filter_test.cpp | 70 ++++++++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index c5098581..2e144af8 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -858,11 +858,15 @@ void filter_result_iterator_t::skip_to(uint32_t id) { seq_id = result_index; uint32_t previous_match; + + // Keep ignoring the found gaps till they cannot contain id. do { - previous_match = seq_id; - advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); - } while (is_valid && previous_match + 1 == seq_id && seq_id >= id); + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id); + } while (is_valid && seq_id <= id); if (!is_valid) { // filter matched all the ids in the index. So for not equals, there's no match. @@ -873,11 +877,22 @@ void filter_result_iterator_t::skip_to(uint32_t id) { is_valid = true; seq_id = previous_match + 1; result_index = index->seq_ids->last_id() + 1; + + // Skip to id, if possible. + if (seq_id < id && id < result_index) { + seq_id = id; + } + return; } result_index = seq_id; seq_id = previous_match + 1; + + if (seq_id < id && id < result_index) { + seq_id = id; + } + return; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 86cdd0f5..50e7d6d6 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -186,16 +186,29 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); - ASSERT_TRUE(iter_skip_test.init_status().ok()); + auto iter_skip_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test1.init_status().ok()); - ASSERT_TRUE(iter_skip_test.is_valid); - iter_skip_test.skip_to(3); - ASSERT_TRUE(iter_skip_test.is_valid); - ASSERT_EQ(4, iter_skip_test.seq_id); - iter_skip_test.next(); + ASSERT_TRUE(iter_skip_test1.is_valid); + iter_skip_test1.skip_to(3); + ASSERT_TRUE(iter_skip_test1.is_valid); + ASSERT_EQ(4, iter_skip_test1.seq_id); + iter_skip_test1.next(); - ASSERT_FALSE(iter_skip_test.is_valid); + ASSERT_FALSE(iter_skip_test1.is_valid); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: != silver", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test2.init_status().ok()); + + ASSERT_TRUE(iter_skip_test2.is_valid); + iter_skip_test2.skip_to(3); + ASSERT_FALSE(iter_skip_test2.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -426,4 +439,45 @@ TEST_F(FilterTest, FilterTreeIterator) { delete and_result; delete filter_tree_root; + + doc = R"({ + "name": "James Rowdy", + "age": 36, + "years": [2005, 2022], + "rating": 6.03, + "tags": ["FINE PLATINUM"] + })"_json; + add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: != FINE PLATINUM", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test3 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test3.init_status().ok()); + + ASSERT_TRUE(iter_skip_test3.is_valid); + iter_skip_test3.skip_to(4); + ASSERT_EQ(4, iter_skip_test3.seq_id); + + ASSERT_TRUE(iter_skip_test3.is_valid); + + delete filter_tree_root; + + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: != gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test4 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test4.init_status().ok()); + + ASSERT_TRUE(iter_skip_test4.is_valid); + iter_skip_test4.skip_to(6); + ASSERT_EQ(6, iter_skip_test4.seq_id); + ASSERT_TRUE(iter_skip_test4.is_valid); + + delete filter_tree_root; } \ No newline at end of file From 688d8b6bcf816b5bdd7f691d58d6a517f6916481 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 21 Apr 2023 12:48:46 +0530 Subject: [PATCH 30/64] Support geo filtering. --- src/filter_result_iterator.cpp | 143 ++++++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 2 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 2e144af8..1c7cbc7a 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1,5 +1,13 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include #include "filter_result_iterator.h" #include "index.h" #include "posting.h" @@ -306,8 +314,12 @@ void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { break; } else { // Keep advancing token iterators till exact match is not found. - for (auto &item: filter_value_tokens) { - item.next(); + for (auto &iter: filter_value_tokens) { + if (!iter.valid()) { + break; + } + + iter.next(); } } } @@ -661,6 +673,133 @@ void filter_result_iterator_t::init() { return; } + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; + } else if (f.is_geopoint()) { + for (const std::string& filter_value : a_filter.values) { + std::vector geo_result_ids; + + std::vector filter_value_parts; + StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points + + bool is_polygon = StringUtils::is_float(filter_value_parts.back()); + S2Region* query_region; + + if (is_polygon) { + const int num_verts = int(filter_value_parts.size()) / 2; + std::vector vertices; + double sum = 0.0; + + for (size_t point_index = 0; point_index < size_t(num_verts); + point_index++) { + double lat = std::stod(filter_value_parts[point_index * 2]); + double lon = std::stod(filter_value_parts[point_index * 2 + 1]); + S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint(); + vertices.emplace_back(vertex); + } + + auto loop = new S2Loop(vertices, S2Debug::DISABLE); + loop->Normalize(); // if loop is not CCW but CW, change to CCW. + + S2Error error; + if (loop->FindValidationError(&error)) { + LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error; + delete loop; + continue; + } else { + query_region = loop; + } + } else { + double radius = std::stof(filter_value_parts[2]); + const auto& unit = filter_value_parts[3]; + + if (unit == "km") { + radius *= 1000; + } else { + // assume "mi" (validated upstream) + radius *= 1609.34; + } + + S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); + double query_lat = std::stod(filter_value_parts[0]); + double query_lng = std::stod(filter_value_parts[1]); + S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); + query_region = new S2Cap(center, query_radius); + } + + S2RegionTermIndexer::Options options; + options.set_index_contains_points_only(true); + S2RegionTermIndexer indexer(options); + + for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { + auto geo_index = index->geopoint_index.at(a_filter.field_name); + const auto& ids_it = geo_index->find(term); + if(ids_it != geo_index->end()) { + geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); + } + } + + gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); + geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); + + // `geo_result_ids` will contain all IDs that are within approximately within query radius + // we still need to do another round of exact filtering on them + + std::vector exact_geo_result_ids; + + if (f.is_single_geopoint()) { + spp::sparse_hash_map* sort_field_index = index->sort_index.at(f.name); + + for (auto result_id : geo_result_ids) { + // no need to check for existence of `result_id` because of indexer based pre-filtering above + int64_t lat_lng = sort_field_index->at(result_id); + S2LatLng s2_lat_lng; + GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); + if (query_region->Contains(s2_lat_lng.ToPoint())) { + exact_geo_result_ids.push_back(result_id); + } + } + } else { + spp::sparse_hash_map* geo_field_index = index->geo_array_index.at(f.name); + + for (auto result_id : geo_result_ids) { + int64_t* lat_lngs = geo_field_index->at(result_id); + + bool point_found = false; + + // any one point should exist + for (size_t li = 0; li < lat_lngs[0]; li++) { + int64_t lat_lng = lat_lngs[li + 1]; + S2LatLng s2_lat_lng; + GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); + if (query_region->Contains(s2_lat_lng.ToPoint())) { + point_found = true; + break; + } + } + + if (point_found) { + exact_geo_result_ids.push_back(result_id); + } + } + } + + uint32_t* out = nullptr; + filter_result.count = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(), + filter_result.docs, filter_result.count, &out); + + delete[] filter_result.docs; + filter_result.docs = out; + + delete query_region; + } + + if (filter_result.count == 0) { + is_valid = false; + return; + } + seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; return; From 8ba560e896e0835a87ba158b0fa728e9baeb05dc Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 21 Apr 2023 14:31:23 +0530 Subject: [PATCH 31/64] Fix failing tests. --- src/filter_result_iterator.cpp | 4 +++- src/index.cpp | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 1c7cbc7a..d0037512 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -553,7 +553,7 @@ void filter_result_iterator_t::init() { return; } - filter_result.count = index->seq_ids->num_ids(); + approx_filter_ids_length = filter_result.count = index->seq_ids->num_ids(); filter_result.docs = index->seq_ids->uncompress(); seq_id = filter_result.docs[result_index]; @@ -1322,6 +1322,8 @@ filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iter status = std::move(obj.status); is_filter_result_initialized = obj.is_filter_result_initialized; + approx_filter_ids_length = obj.approx_filter_ids_length; + return *this; } diff --git a/src/index.cpp b/src/index.cpp index af8c2c3a..9003dfac 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2833,7 +2833,7 @@ Option Index::search(std::vector& field_query_tokens, cons store, doc_id_prefix, filter_tree_root); filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); - approx_filter_ids_length = filter_result_iterator.is_valid; + approx_filter_ids_length = filter_result_iterator.approx_filter_ids_length; } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -2955,6 +2955,7 @@ Option Index::search(std::vector& field_query_tokens, cons if (no_filters_provided) { delete filter_tree_root; filter_tree_root = nullptr; + approx_filter_ids_length = 0; } uint32_t _all_result_ids_len = all_result_ids_len; From c7107a4f10e044cbed4e2e330f8f504b9f725c36 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 24 Apr 2023 13:54:16 +0530 Subject: [PATCH 32/64] Handle excluded ids in `filter_result_iterator_t::get_n_ids`. --- include/filter_result_iterator.h | 6 ++++++ src/filter_result_iterator.cpp | 33 ++++++++++++++++++++++++++++++++ src/index.cpp | 12 +++++++++--- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 1184b74a..784aa385 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -169,6 +169,12 @@ public: /// Collects n doc ids while advancing the iterator. The iterator may become invalid during this operation. void get_n_ids(const uint32_t& n, std::vector& results); + /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The + /// iterator may become invalid during this operation. + void get_n_ids(const uint32_t &n, + uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, + std::vector &results); + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d0037512..4ea630cd 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1275,6 +1275,10 @@ filter_result_iterator_t::filter_result_iterator_t(const std::string collection_ } init(); + + if (!is_valid) { + this->approx_filter_ids_length = 0; + } } filter_result_iterator_t::~filter_result_iterator_t() { @@ -1342,3 +1346,32 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { + if (excluded_result_ids == nullptr || excluded_result_ids_size == 0) { + return get_n_ids(n, results); + } + + if (is_filter_result_initialized) { + for (uint32_t count = 0; count < n && result_index < filter_result.count;) { + auto id = filter_result.docs[result_index++]; + if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, id)) { + results.push_back(id); + count++; + } + } + + is_valid = result_index < filter_result.count; + return; + } + + for (uint32_t count = 0; count < n && is_valid;) { + if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, seq_id)) { + results.push_back(seq_id); + count++; + } + next(); + } +} diff --git a/src/index.cpp b/src/index.cpp index 9003dfac..eef5cc67 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1970,8 +1970,14 @@ void Index::aproximate_numerical_match(num_tree_t* const num_tree, uint32_t to_exclude_ids_len = 0; num_tree->approx_search_count(EQUALS, value, to_exclude_ids_len); - auto all_ids_size = seq_ids->num_ids(); - filter_ids_length += (all_ids_size - to_exclude_ids_len); + if (to_exclude_ids_len == 0) { + filter_ids_length += seq_ids->num_ids(); + } else if (to_exclude_ids_len >= seq_ids->num_ids()) { + filter_ids_length += 0; + } else { + filter_ids_length += (seq_ids->num_ids() - to_exclude_ids_len); + } + return; } @@ -4965,7 +4971,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, batch_result_ids); + filter_result_iterator.get_n_ids(window_size, exclude_token_ids, exclude_token_ids_size, batch_result_ids); num_queued++; From b67655c45b3780bef5535208210b04ab0b44181f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 24 Apr 2023 14:21:15 +0530 Subject: [PATCH 33/64] Optimize exclusion in `filter_result_iterator_t::get_n_ids`. --- include/filter_result_iterator.h | 1 + src/filter_result_iterator.cpp | 17 ++++++++++++++--- src/index.cpp | 4 +++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 784aa385..2e30a2cb 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -172,6 +172,7 @@ public: /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The /// iterator may become invalid during this operation. void get_n_ids(const uint32_t &n, + size_t& excluded_result_index, uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, std::vector &results); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 4ea630cd..9977c6be 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1348,16 +1348,23 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { - if (excluded_result_ids == nullptr || excluded_result_ids_size == 0) { + if (excluded_result_ids == nullptr || excluded_result_ids_size == 0 || + excluded_result_index >= excluded_result_ids_size) { return get_n_ids(n, results); } if (is_filter_result_initialized) { for (uint32_t count = 0; count < n && result_index < filter_result.count;) { auto id = filter_result.docs[result_index++]; - if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, id)) { + + while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < id) { + excluded_result_index++; + } + + if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != id) { results.push_back(id); count++; } @@ -1368,7 +1375,11 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } for (uint32_t count = 0; count < n && is_valid;) { - if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, seq_id)) { + while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < seq_id) { + excluded_result_index++; + } + + if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != seq_id) { results.push_back(seq_id); count++; } diff --git a/src/index.cpp b/src/index.cpp index eef5cc67..ba405b53 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4966,12 +4966,14 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_begin = search_begin_us; const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; + size_t excluded_result_index = 0; for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, exclude_token_ids, exclude_token_ids_size, batch_result_ids); + filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, + batch_result_ids); num_queued++; From 33be7e6c6881545f9ce2c284b42788fc8c1ee79c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 09:39:51 +0530 Subject: [PATCH 34/64] Add `ArrayUtils::skip_index_to_id`. --- include/array_utils.h | 6 ++++++ include/filter_result_iterator.h | 2 +- src/array_utils.cpp | 25 +++++++++++++++++++++++++ src/filter_result_iterator.cpp | 16 ++++------------ src/index.cpp | 2 +- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/include/array_utils.h b/include/array_utils.h index 81a0576c..008a8a36 100644 --- a/include/array_utils.h +++ b/include/array_utils.h @@ -16,4 +16,10 @@ public: static size_t exclude_scalar(const uint32_t *src, const size_t lenSrc, const uint32_t *filter, const size_t lenFilter, uint32_t **out); + + /// Performs binary search to find the index of id. If id is not found, curr_index is set to the index of next bigger + /// number than id in the array. + /// \return Whether or not id was found in array. + static bool skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, + const uint32_t& id); }; \ No newline at end of file diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 2e30a2cb..259b93f0 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -172,7 +172,7 @@ public: /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The /// iterator may become invalid during this operation. void get_n_ids(const uint32_t &n, - size_t& excluded_result_index, + uint32_t& excluded_result_index, uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, std::vector &results); diff --git a/src/array_utils.cpp b/src/array_utils.cpp index 9f0c7f4a..ad22a85f 100644 --- a/src/array_utils.cpp +++ b/src/array_utils.cpp @@ -149,4 +149,29 @@ size_t ArrayUtils::exclude_scalar(const uint32_t *A, const size_t lenA, delete[] results; return res_index; +} + +bool ArrayUtils::skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, + const uint32_t& id) { + if (id <= array[curr_index]) { + return id == array[curr_index]; + } + + long start = curr_index, mid, end = array_len; + + while (start <= end) { + mid = start + (end - start) / 2; + + if (array[mid] == id) { + curr_index = mid; + return true; + } else if (array[mid] < id) { + start = mid + 1; + } else { + end = mid - 1; + } + } + + curr_index = start; + return false; } \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 9977c6be..63a76f73 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -959,7 +959,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } if (is_filter_result_initialized) { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id); if (result_index >= filter_result.count) { is_valid = false; @@ -1348,7 +1348,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { if (excluded_result_ids == nullptr || excluded_result_ids_size == 0 || @@ -1360,11 +1360,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, for (uint32_t count = 0; count < n && result_index < filter_result.count;) { auto id = filter_result.docs[result_index++]; - while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < id) { - excluded_result_index++; - } - - if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != id) { + if (!ArrayUtils::skip_index_to_id(excluded_result_index, excluded_result_ids, excluded_result_ids_size, id)) { results.push_back(id); count++; } @@ -1375,11 +1371,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } for (uint32_t count = 0; count < n && is_valid;) { - while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < seq_id) { - excluded_result_index++; - } - - if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != seq_id) { + if (!ArrayUtils::skip_index_to_id(excluded_result_index, excluded_result_ids, excluded_result_ids_size, seq_id)) { results.push_back(seq_id); count++; } diff --git a/src/index.cpp b/src/index.cpp index ba405b53..225a7b7b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4966,7 +4966,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_begin = search_begin_us; const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - size_t excluded_result_index = 0; + uint32_t excluded_result_index = 0; for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { std::vector batch_result_ids; From 167edaba6fa320881d33656fc390ed6c00a838c7 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 14:27:50 +0530 Subject: [PATCH 35/64] Add tests for `ArrayUtils::skip_index_to_id`. --- src/array_utils.cpp | 6 +++++- test/array_utils_test.cpp | 42 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/array_utils.cpp b/src/array_utils.cpp index ad22a85f..e034c4d6 100644 --- a/src/array_utils.cpp +++ b/src/array_utils.cpp @@ -153,11 +153,15 @@ size_t ArrayUtils::exclude_scalar(const uint32_t *A, const size_t lenA, bool ArrayUtils::skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, const uint32_t& id) { + if (curr_index >= array_len) { + return false; + } + if (id <= array[curr_index]) { return id == array[curr_index]; } - long start = curr_index, mid, end = array_len; + long start = curr_index, mid, end = array_len - 1; while (start <= end) { mid = start + (end - start) / 2; diff --git a/test/array_utils_test.cpp b/test/array_utils_test.cpp index 0fa4622a..2a961296 100644 --- a/test/array_utils_test.cpp +++ b/test/array_utils_test.cpp @@ -172,4 +172,46 @@ TEST(SortedArrayTest, FilterArray) { delete[] arr2; delete[] arr1; delete[] results; +} + +TEST(SortedArrayTest, SkipToID) { + std::vector array; + for (uint32_t i = 0; i < 10; i++) { + array.push_back(i * 3); + } + + uint32_t index = 0; + bool found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 15); + ASSERT_TRUE(found); + ASSERT_EQ(5, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 3); + ASSERT_FALSE(found); + ASSERT_EQ(4, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 12); + ASSERT_TRUE(found); + ASSERT_EQ(4, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 24); + ASSERT_TRUE(found); + ASSERT_EQ(8, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 25); + ASSERT_FALSE(found); + ASSERT_EQ(9, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 30); + ASSERT_FALSE(found); + ASSERT_EQ(10, index); + + index = 12; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 30); + ASSERT_FALSE(found); + ASSERT_EQ(12, index); } \ No newline at end of file From 5c79f200c37bfb5e11b4bbc1e4fa3a84a6b5f708 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 16:54:12 +0530 Subject: [PATCH 36/64] Fix failing tests. --- include/index.h | 2 +- src/index.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/index.h b/include/index.h index 15be92e2..9b9c1d2e 100644 --- a/include/index.h +++ b/include/index.h @@ -239,7 +239,7 @@ public: explicit VectorFilterFunctor(filter_result_iterator_t* const filter_result_iterator) : filter_result_iterator(filter_result_iterator) {} - bool operator()(unsigned int id) { + bool operator()(hnswlib::labeltype id) override { filter_result_iterator->reset(); return filter_result_iterator->valid(id) == 1; } diff --git a/src/index.cpp b/src/index.cpp index 225a7b7b..569eb7ec 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2859,6 +2859,7 @@ Option Index::search(std::vector& field_query_tokens, cons while (!no_filters_provided && filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { auto seq_id = filter_result_iterator.seq_id; + filter_result_iterator.next(); std::vector values; try { @@ -2880,7 +2881,6 @@ Option Index::search(std::vector& field_query_tokens, cons } dist_labels.emplace_back(dist, seq_id); - filter_result_iterator.next(); filter_id_count++; } From 2f615fe1ffe60c58d5d75de0c3dc50eb494413d3 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 26 Apr 2023 20:26:45 +0530 Subject: [PATCH 37/64] Fix memory leaks: * Handle deletion of `filter_tree_root` in `sort_fields_guard_t`. * Handle `filter_tree_root` being updated in `Index::static_filter_query_eval`. * Handle deletion of `phrase_result_ids` in `Index::search`. --- include/field.h | 2 +- include/index.h | 2 +- src/collection.cpp | 7 +++++++ src/index.cpp | 9 +++++++-- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/field.h b/include/field.h index 03c7015a..8b4fe04e 100644 --- a/include/field.h +++ b/include/field.h @@ -545,7 +545,7 @@ struct sort_by { }; struct eval_t { - filter_node_t* filter_tree_root; + filter_node_t* filter_tree_root = nullptr; uint32_t* ids = nullptr; uint32_t size = 0; }; diff --git a/include/index.h b/include/index.h index 9b9c1d2e..fb2dfd82 100644 --- a/include/index.h +++ b/include/index.h @@ -644,7 +644,7 @@ public: Option search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t*& filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, diff --git a/src/collection.cpp b/src/collection.cpp index dddbf66b..5eab45f0 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -28,6 +28,8 @@ struct sort_fields_guard_t { ~sort_fields_guard_t() { for(auto& sort_by_clause: sort_fields_std) { + delete sort_by_clause.eval.filter_tree_root; + if(sort_by_clause.eval.ids) { delete [] sort_by_clause.eval.ids; sort_by_clause.eval.ids = nullptr; @@ -1523,6 +1525,11 @@ Option Collection::search(std::string raw_query, std::unique_ptr search_params_guard(search_params); auto search_op = index->run_search(search_params, name); + + // filter_tree_root might be updated in Index::static_filter_query_eval. + filter_tree_root_guard.release(); + filter_tree_root_guard.reset(filter_tree_root); + if (!search_op.ok()) { return Option(search_op.code(), search_op.error()); } diff --git a/src/index.cpp b/src/index.cpp index 569eb7ec..ed3873b7 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2325,7 +2325,7 @@ bool Index::static_filter_query_eval(const override_t* override, if (filter_tree_root == nullptr) { filter_tree_root = new_filter_tree_root; } else { - filter_node_t* root = new filter_node_t(AND, filter_tree_root, + auto root = new filter_node_t(AND, filter_tree_root, new_filter_tree_root); filter_tree_root = root; } @@ -2688,7 +2688,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name Option Index::search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t*& filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, @@ -2774,6 +2774,8 @@ Option Index::search(std::vector& field_query_tokens, cons // handle phrase searches uint32_t* phrase_result_ids = nullptr; uint32_t phrase_result_count = 0; + std::unique_ptr phrase_result_ids_guard; + if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, @@ -2782,6 +2784,9 @@ Option Index::search(std::vector& field_query_tokens, cons excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, included_ids_map, is_wildcard_query, phrase_result_ids, phrase_result_count); + + phrase_result_ids_guard.reset(phrase_result_ids); + if (phrase_result_count == 0) { goto process_search_results; } From 5dbfb9df6396206e5070ef5cc47d41b4035f31ce Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 26 Apr 2023 20:36:09 +0530 Subject: [PATCH 38/64] Remove `Index::do_filtering`. Using `filter_result_t` instead. --- include/index.h | 19 -- src/index.cpp | 517 +++--------------------------------------------- 2 files changed, 25 insertions(+), 511 deletions(-) diff --git a/include/index.h b/include/index.h index fb2dfd82..e80d5696 100644 --- a/include/index.h +++ b/include/index.h @@ -472,31 +472,12 @@ private: bool field_is_indexed(const std::string& field_name) const; - Option do_filtering(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name = "", - const uint32_t& context_ids_length = 0, - uint32_t* const& context_ids = nullptr) const; - void aproximate_numerical_match(num_tree_t* const num_tree, const NUM_COMPARATOR& comparator, const int64_t& value, const int64_t& range_end_value, uint32_t& filter_ids_length) const; - /// Traverses through filter tree to get the filter_result. - /// - /// \param filter_tree_root - /// \param filter_result - /// \param collection_name Name of the collection to which current index belongs. Used to find the reference field in other collection. - /// \param context_ids_length Number of docs matching the search query. - /// \param context_ids Array of doc ids matching the search query. - Option recursive_filter(filter_node_t* const filter_tree_root, - filter_result_t& filter_result, - const std::string& collection_name = "", - const uint32_t& context_ids_length = 0, - uint32_t* const& context_ids = nullptr) const; - void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; diff --git a/src/index.cpp b/src/index.cpp index ed3873b7..552b3149 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1516,446 +1516,6 @@ bool Index::field_is_indexed(const std::string& field_name) const { geopoint_index.count(field_name) != 0; } -Option Index::do_filtering(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name, - const uint32_t& context_ids_length, - uint32_t* const& context_ids) const { - // auto begin = std::chrono::high_resolution_clock::now(); - const filter a_filter = root->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. - auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(a_filter.referenced_collection_name); - if (collection == nullptr) { - return Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); - } - - filter_result_t reference_filter_result; - auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, - reference_filter_result, - collection_name); - if (!reference_filter_op.ok()) { - return Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name - + "` collection: " + reference_filter_op.error()); - } - - if (context_ids_length != 0) { - std::vector include_indexes; - include_indexes.reserve(std::min(context_ids_length, reference_filter_result.count)); - - size_t context_index = 0, reference_result_index = 0; - while (context_index < context_ids_length && reference_result_index < reference_filter_result.count) { - if (context_ids[context_index] == reference_filter_result.docs[reference_result_index]) { - include_indexes.push_back(reference_result_index); - context_index++; - reference_result_index++; - } else if (context_ids[context_index] < reference_filter_result.docs[reference_result_index]) { - context_index++; - } else { - reference_result_index++; - } - } - - result.count = include_indexes.size(); - result.docs = new uint32_t[include_indexes.size()]; - auto& result_references = result.reference_filter_results[a_filter.referenced_collection_name]; - result_references = new reference_filter_result_t[include_indexes.size()]; - - for (uint32_t i = 0; i < include_indexes.size(); i++) { - result.docs[i] = reference_filter_result.docs[include_indexes[i]]; - result_references[i] = reference_filter_result.reference_filter_results[a_filter.referenced_collection_name][include_indexes[i]]; - } - - return Option(true); - } - - result = std::move(reference_filter_result); - return Option(true); - } - - if (a_filter.field_name == "id") { - // we handle `ids` separately - std::vector result_ids; - for (const auto& id_str : a_filter.values) { - result_ids.push_back(std::stoul(id_str)); - } - - std::sort(result_ids.begin(), result_ids.end()); - - auto result_array = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), result_array); - - if (context_ids_length != 0) { - uint32_t* out = nullptr; - result.count = ArrayUtils::and_scalar(context_ids, context_ids_length, - result_array, result_ids.size(), &out); - - delete[] result_array; - - result.docs = out; - return Option(true); - } - - result.docs = result_array; - result.count = result_ids.size(); - return Option(true); - } - - if (!field_is_indexed(a_filter.field_name)) { - return Option(true); - } - - field f = search_schema.at(a_filter.field_name); - - uint32_t* result_ids = nullptr; - size_t result_ids_len = 0; - - if (f.is_integer()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - int64_t value = (int64_t)std::stol(filter_value); - - if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi + 1]; - auto const range_end_value = (int64_t)std::stol(next_filter_value); - - if (context_ids_length != 0) { - num_tree->range_inclusive_contains(value, range_end_value, context_ids_length, context_ids, - result_ids_len, result_ids); - } else { - num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len); - } - - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[fi], value, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len); - } - } - } - } else if (f.is_float()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - float value = (float)std::atof(filter_value.c_str()); - int64_t float_int64 = float_to_int64_t(value); - - if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str())); - - if (context_ids_length != 0) { - num_tree->range_inclusive_contains(float_int64, range_end_value, context_ids_length, context_ids, - result_ids_len, result_ids); - } else { - num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len); - } - - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, float_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[fi], float_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len); - } - } - } - } else if (f.is_bool()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - size_t value_index = 0; - for (const std::string& filter_value : a_filter.values) { - int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - if (a_filter.comparators[value_index] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, bool_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[value_index], bool_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len); - } - } - - value_index++; - } - } else if (f.is_geopoint()) { - for (const std::string& filter_value : a_filter.values) { - std::vector geo_result_ids; - - std::vector filter_value_parts; - StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points - - bool is_polygon = StringUtils::is_float(filter_value_parts.back()); - S2Region* query_region; - - if (is_polygon) { - const int num_verts = int(filter_value_parts.size()) / 2; - std::vector vertices; - double sum = 0.0; - - for (size_t point_index = 0; point_index < size_t(num_verts); - point_index++) { - double lat = std::stod(filter_value_parts[point_index * 2]); - double lon = std::stod(filter_value_parts[point_index * 2 + 1]); - S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint(); - vertices.emplace_back(vertex); - } - - auto loop = new S2Loop(vertices, S2Debug::DISABLE); - loop->Normalize(); // if loop is not CCW but CW, change to CCW. - - S2Error error; - if (loop->FindValidationError(&error)) { - LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error; - delete loop; - continue; - } else { - query_region = loop; - } - } else { - double radius = std::stof(filter_value_parts[2]); - const auto& unit = filter_value_parts[3]; - - if (unit == "km") { - radius *= 1000; - } else { - // assume "mi" (validated upstream) - radius *= 1609.34; - } - - S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); - double query_lat = std::stod(filter_value_parts[0]); - double query_lng = std::stod(filter_value_parts[1]); - S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); - query_region = new S2Cap(center, query_radius); - } - - S2RegionTermIndexer::Options options; - options.set_index_contains_points_only(true); - S2RegionTermIndexer indexer(options); - - for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { - auto geo_index = geopoint_index.at(a_filter.field_name); - const auto& ids_it = geo_index->find(term); - if(ids_it != geo_index->end()) { - geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); - } - } - - gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); - geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); - - // `geo_result_ids` will contain all IDs that are within approximately within query radius - // we still need to do another round of exact filtering on them - - if (context_ids_length != 0) { - uint32_t *out = nullptr; - uint32_t count = ArrayUtils::and_scalar(context_ids, context_ids_length, - &geo_result_ids[0], geo_result_ids.size(), &out); - - geo_result_ids = std::vector(out, out + count); - } - - std::vector exact_geo_result_ids; - - if (f.is_single_geopoint()) { - spp::sparse_hash_map* sort_field_index = sort_index.at(f.name); - - for (auto result_id : geo_result_ids) { - // no need to check for existence of `result_id` because of indexer based pre-filtering above - int64_t lat_lng = sort_field_index->at(result_id); - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); - if (query_region->Contains(s2_lat_lng.ToPoint())) { - exact_geo_result_ids.push_back(result_id); - } - } - } else { - spp::sparse_hash_map* geo_field_index = geo_array_index.at(f.name); - - for (auto result_id : geo_result_ids) { - int64_t* lat_lngs = geo_field_index->at(result_id); - - bool point_found = false; - - // any one point should exist - for (size_t li = 0; li < lat_lngs[0]; li++) { - int64_t lat_lng = lat_lngs[li + 1]; - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); - if (query_region->Contains(s2_lat_lng.ToPoint())) { - point_found = true; - break; - } - } - - if (point_found) { - exact_geo_result_ids.push_back(result_id); - } - } - } - - uint32_t* out = nullptr; - result_ids_len = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(), - result_ids, result_ids_len, &out); - - delete[] result_ids; - result_ids = out; - - delete query_region; - } - } else if (f.is_string()) { - art_tree* t = search_index.at(a_filter.field_name); - - uint32_t* or_ids = nullptr; - size_t or_ids_size = 0; - - // aggregates IDs across array of filter values and reduces excessive ORing - std::vector f_id_buff; - - for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; - - // there could be multiple tokens in a filter value, which we have to treat as ANDs - // e.g. country: South Africa - Tokenizer tokenizer(filter_value, true, false, f.locale, symbols_to_index, token_separators); - - std::string str_token; - size_t token_index = 0; - std::vector str_tokens; - - while (tokenizer.next(str_token, token_index)) { - str_tokens.push_back(str_token); - - art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), - str_token.length()+1); - if (leaf == nullptr) { - continue; - } - - posting_lists.push_back(leaf->values); - } - - if (posting_lists.size() != str_tokens.size()) { - continue; - } - - if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) { - // needs intersection + exact matching (unlike CONTAINS) - std::vector result_id_vec; - posting_t::intersect(posting_lists, result_id_vec, context_ids_length, context_ids); - - if (result_id_vec.empty()) { - continue; - } - - // need to do exact match - uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()]; - size_t exact_str_ids_size = 0; - std::unique_ptr exact_str_ids_guard(exact_str_ids); - - posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(), - exact_str_ids, exact_str_ids_size); - - if (exact_str_ids_size == 0) { - continue; - } - - for (size_t ei = 0; ei < exact_str_ids_size; ei++) { - f_id_buff.push_back(exact_str_ids[ei]); - } - } else { - // CONTAINS - size_t before_size = f_id_buff.size(); - posting_t::intersect(posting_lists, f_id_buff, context_ids_length, context_ids); - if (f_id_buff.size() == before_size) { - continue; - } - } - - if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) { - gfx::timsort(f_id_buff.begin(), f_id_buff.end()); - f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); - - uint32_t* out = nullptr; - or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); - delete[] or_ids; - or_ids = out; - std::vector().swap(f_id_buff); // clears out memory - } - } - - if (!f_id_buff.empty()) { - gfx::timsort(f_id_buff.begin(), f_id_buff.end()); - f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); - - uint32_t* out = nullptr; - or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); - delete[] or_ids; - or_ids = out; - std::vector().swap(f_id_buff); // clears out memory - } - - result_ids = or_ids; - result_ids_len = or_ids_size; - } - - if (a_filter.apply_not_equals) { - auto all_ids = seq_ids->uncompress(); - auto all_ids_size = seq_ids->num_ids(); - - uint32_t* to_include_ids = nullptr; - size_t to_include_ids_len = 0; - - to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, result_ids, - result_ids_len, &to_include_ids); - - delete[] all_ids; - delete[] result_ids; - - result_ids = to_include_ids; - result_ids_len = to_include_ids_len; - - if (context_ids_length != 0) { - uint32_t *out = nullptr; - result.count = ArrayUtils::and_scalar(context_ids, context_ids_length, - result_ids, result_ids_len, &out); - - delete[] result_ids; - - result.docs = out; - return Option(true); - } - } - - result.docs = result_ids; - result.count = result_ids_len; - - return Option(true); - /*long long int timeMillis = - std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - - begin).count(); - - LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ -} - void Index::aproximate_numerical_match(num_tree_t* const num_tree, const NUM_COMPARATOR& comparator, const int64_t& value, @@ -2126,54 +1686,19 @@ Option Index::rearrange_filter_tree(filter_node_t* const root, return Option(true); } -Option Index::recursive_filter(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name, - const uint32_t& context_ids_length, - uint32_t* const& context_ids) const { - if (root == nullptr) { - return Option(true); - } - - if (root->isOperator) { - filter_result_t l_result; - if (root->left != nullptr) { - auto filter_op = recursive_filter(root->left, l_result , collection_name, context_ids_length, context_ids); - if (!filter_op.ok()) { - return filter_op; - } - } - - filter_result_t r_result; - if (root->right != nullptr) { - auto filter_op = recursive_filter(root->right, r_result , collection_name, context_ids_length, context_ids); - if (!filter_op.ok()) { - return filter_op; - } - } - - if (root->filter_operator == AND) { - filter_result_t::and_filter_results(l_result, r_result, result); - } else { - filter_result_t::or_filter_results(l_result, r_result, result); - } - - return Option(true); - } - - return do_filtering(root, result, collection_name, context_ids_length, context_ids); -} - Option Index::do_filtering_with_lock(filter_node_t* const filter_tree_root, filter_result_t& filter_result, const std::string& collection_name) const { std::shared_lock lock(mutex); - auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } + filter_result.count = filter_result_iterator.to_filter_id_array(filter_result.docs); + return Option(true); } @@ -2183,16 +1708,20 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); - filter_result_t reference_filter_result; - auto filter_op = recursive_filter(filter_tree_root, reference_filter_result); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } + uint32_t* reference_docs = nullptr; + uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs); + std::unique_ptr docs_guard(reference_docs); + // doc id -> reference doc ids std::map> reference_map; - for (uint32_t i = 0; i < reference_filter_result.count; i++) { - auto reference_doc_id = reference_filter_result.docs[i]; + for (uint32_t i = 0; i < count; i++) { + auto reference_doc_id = reference_docs[i]; auto doc_id = sort_index.at(reference_helper_field_name)->at(reference_doc_id); reference_map[doc_id].push_back(reference_doc_id); @@ -5079,11 +4608,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - filter_result_t result; - recursive_filter(sort_fields_std[i].eval.filter_tree_root, result); - sort_fields_std[i].eval.ids = result.docs; - sort_fields_std[i].eval.size = result.count; - result.docs = nullptr; + + auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return; + } + + sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids); + } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); From 387c3ede6f65839e07aebcb815dfbbbbfa7c4231 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 27 Apr 2023 14:40:53 +0530 Subject: [PATCH 39/64] Refactor string filter iteration. --- include/filter_result_iterator.h | 10 +- src/filter_result_iterator.cpp | 171 ++++++++++++------------------- 2 files changed, 69 insertions(+), 112 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 259b93f0..bc8c4c23 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -85,7 +85,6 @@ struct filter_result_t { static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); }; - class filter_result_iterator_t { private: std::string collection_name; @@ -106,6 +105,7 @@ private: /// for each token. /// /// Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_lists; std::vector> posting_list_iterators; std::vector expanded_plists; @@ -121,11 +121,11 @@ private: /// Advance all the token iterators that are at seq_id. void advance_string_filter_token_iterators(); - /// Finds the next match for a filter on string field. - void doc_matching_string_filter(bool field_is_array); + /// Finds the first match for a filter on string field. + void get_string_filter_first_match(const bool& field_is_array); - /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). - [[nodiscard]] bool valid(); + /// Finds the next match for a filter on string field. + void get_string_filter_next_match(const bool& field_is_array); public: uint32_t seq_id = 0; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 63a76f73..34228916 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -291,7 +291,7 @@ void filter_result_iterator_t::advance_string_filter_token_iterators() { } } -void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { +void filter_result_iterator_t::get_string_filter_next_match(const bool& field_is_array) { // If none of the filter value iterators are valid, mark this node as invalid. bool one_is_valid = false; @@ -412,7 +412,7 @@ void filter_result_iterator_t::next() { do { previous_match = seq_id; advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); } while (is_valid && previous_match + 1 == seq_id); if (!is_valid) { @@ -433,7 +433,7 @@ void filter_result_iterator_t::next() { } advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); return; } @@ -474,6 +474,50 @@ void apply_not_equals(uint32_t*&& all_ids, result_ids_len = to_include_ids_len; } +void filter_result_iterator_t::get_string_filter_first_match(const bool& field_is_array) { + get_string_filter_next_match(field_is_array); + + if (filter_node->filter_exp.apply_not_equals) { + // filter didn't match any id. So by applying not equals, every id in the index is a match. + if (!is_valid) { + is_valid = true; + seq_id = 0; + result_index = index->seq_ids->last_id() + 1; + return; + } + + // [0, seq_id) are a match for not equals. + if (seq_id > 0) { + result_index = seq_id; + seq_id = 0; + return; + } + + // Keep ignoring the consecutive matches. + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + get_string_filter_next_match(field_is_array); + } while (is_valid && previous_match + 1 == seq_id); + + if (!is_valid) { + // filter matched all the ids in the index. So for not equals, there's no match. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + result_index = index->seq_ids->last_id() + 1; + seq_id = previous_match + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + } +} + void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -807,7 +851,7 @@ void filter_result_iterator_t::init() { art_tree* t = index->search_index.at(a_filter.field_name); for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; + std::vector raw_posting_lists; // there could be multiple tokens in a filter value, which we have to treat as ANDs // e.g. country: South Africa @@ -826,119 +870,29 @@ void filter_result_iterator_t::init() { continue; } - posting_lists.push_back(leaf->values); + raw_posting_lists.push_back(leaf->values); } - if (posting_lists.size() != str_tokens.size()) { + if (raw_posting_lists.size() != str_tokens.size()) { continue; } std::vector plists; - posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + posting_t::to_expanded_plists(raw_posting_lists, plists, expanded_plists); + posting_lists.push_back(plists); posting_list_iterators.emplace_back(std::vector()); - for (auto const& plist: plists) { posting_list_iterators.back().push_back(plist->new_iterator()); } } - doc_matching_string_filter(f.is_array()); - - if (filter_node->filter_exp.apply_not_equals) { - // filter didn't match any id. So by applying not equals, every id in the index is a match. - if (!is_valid) { - is_valid = true; - seq_id = 0; - result_index = index->seq_ids->last_id() + 1; - return; - } - - // [0, seq_id) are a match for not equals. - if (seq_id > 0) { - result_index = seq_id; - seq_id = 0; - return; - } - - // Keep ignoring the consecutive matches. - uint32_t previous_match; - do { - previous_match = seq_id; - advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); - } while (is_valid && previous_match + 1 == seq_id); - - if (!is_valid) { - // filter matched all the ids in the index. So for not equals, there's no match. - if (previous_match >= index->seq_ids->last_id()) { - return; - } - - is_valid = true; - result_index = index->seq_ids->last_id() + 1; - seq_id = previous_match + 1; - return; - } - - result_index = seq_id; - seq_id = previous_match + 1; - } + get_string_filter_first_match(f.is_array()); return; } } -bool filter_result_iterator_t::valid() { - if (!is_valid) { - return false; - } - - if (filter_node->isOperator) { - if (filter_node->filter_operator == AND) { - is_valid = left_it->valid() && right_it->valid(); - return is_valid; - } else { - is_valid = left_it->valid() || right_it->valid(); - return is_valid; - } - } - - if (is_filter_result_initialized) { - is_valid = result_index < filter_result.count; - return is_valid; - } - - const filter a_filter = filter_node->filter_exp; - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return is_valid; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - if (filter_node->filter_exp.apply_not_equals) { - return seq_id < result_index; - } - - bool one_is_valid = false; - for (auto& filter_value_tokens: posting_list_iterators) { - posting_list_t::intersect(filter_value_tokens, one_is_valid); - - if (one_is_valid) { - break; - } - } - - is_valid = one_is_valid; - return is_valid; - } - - return false; -} - void filter_result_iterator_t::skip_to(uint32_t id) { if (!is_valid) { return; @@ -1003,7 +957,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { do { previous_match = seq_id; advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); } while (is_valid && previous_match + 1 == seq_id); } while (is_valid && seq_id <= id); @@ -1047,7 +1001,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } } - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); return; } } @@ -1190,13 +1144,16 @@ void filter_result_iterator_t::reset() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { - posting_list_iterators.clear(); - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - expanded_plists.clear(); + for (uint32_t i = 0; i < posting_lists.size(); i++) { + auto const& plists = posting_lists[i]; - init(); + posting_list_iterators[i].clear(); + for (auto const& plist: plists) { + posting_list_iterators[i].push_back(plist->new_iterator()); + } + } + + get_string_filter_first_match(f.is_array()); return; } } From 9baa51664d5bb5ceb042ce4434cf6a094e06e783 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 28 Apr 2023 13:31:25 +0530 Subject: [PATCH 40/64] Add convenience methods in `result_iter_state_t`. --- include/or_iterator.h | 36 ++++++++++++++++++------------------ include/posting_list.h | 6 ++++++ src/posting_list.cpp | 29 +++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/include/or_iterator.h b/include/or_iterator.h index 67fd5ddf..c4f27518 100644 --- a/include/or_iterator.h +++ b/include/or_iterator.h @@ -62,8 +62,8 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state case 0: break; case 1: - if(istate.filter_ids_length != 0) { - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); + if(istate.is_filter_provided() && istate.is_filter_valid()) { + its[0].skip_to(istate.get_filter_id()); } while(its.size() == it_size && its[0].valid()) { @@ -79,10 +79,10 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state func(id, its); } - if(istate.filter_ids_length != 0 && !is_excluded) { - if(istate.filter_ids_index < istate.filter_ids_length) { + if(istate.is_filter_provided() && !is_excluded) { + if(istate.is_filter_valid()) { // skip iterator till next id available in filter - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); + its[0].skip_to(istate.get_filter_id()); } else { break; } @@ -92,9 +92,9 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } break; case 2: - if(istate.filter_ids_length != 0) { - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); - its[1].skip_to(istate.filter_ids[istate.filter_ids_index]); + if(istate.is_filter_provided() && istate.is_filter_valid()) { + its[0].skip_to(istate.get_filter_id()); + its[1].skip_to(istate.get_filter_id()); } while(its.size() == it_size && !at_end2(its)) { @@ -111,11 +111,11 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state func(id, its); } - if(istate.filter_ids_length != 0 && !is_excluded) { - if(istate.filter_ids_index < istate.filter_ids_length) { + if(istate.is_filter_provided() != 0 && !is_excluded) { + if(istate.is_filter_valid()) { // skip iterator till next id available in filter - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); - its[1].skip_to(istate.filter_ids[istate.filter_ids_index]); + its[0].skip_to(istate.get_filter_id()); + its[1].skip_to(istate.get_filter_id()); } else { break; } @@ -128,9 +128,9 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } break; default: - if(istate.filter_ids_length != 0) { + if(istate.is_filter_provided() && istate.is_filter_valid()) { for(auto& it: its) { - it.skip_to(istate.filter_ids[istate.filter_ids_index]); + it.skip_to(istate.get_filter_id()); } } @@ -148,11 +148,11 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state func(id, its); } - if(istate.filter_ids_length != 0 && !is_excluded) { - if(istate.filter_ids_index < istate.filter_ids_length) { + if(istate.is_filter_provided() && !is_excluded) { + if(istate.is_filter_valid()) { // skip iterator till next id available in filter for(auto& it: its) { - it.skip_to(istate.filter_ids[istate.filter_ids_index]); + it.skip_to(istate.get_filter_id()); } } else { break; @@ -167,4 +167,4 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } return true; -} +} \ No newline at end of file diff --git a/include/posting_list.h b/include/posting_list.h index 11ed9c91..a0df81fa 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -33,6 +33,12 @@ struct result_iter_state_t { filter_result_iterator_t* fit) : excluded_result_ids(excluded_result_ids), excluded_result_ids_size(excluded_result_ids_size), fit(fit){} + + [[nodiscard]] bool is_filter_provided() const; + + [[nodiscard]] bool is_filter_valid() const; + + [[nodiscard]] uint32_t get_filter_id() const; }; /* diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 1b5a3323..dc82fb69 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -2,6 +2,7 @@ #include #include "for.h" #include "array_utils.h" +#include "filter_result_iterator.h" /* block_t operations */ @@ -1781,3 +1782,31 @@ posting_list_t::iterator_t posting_list_t::iterator_t::clone() const { uint32_t posting_list_t::iterator_t::get_field_id() const { return field_id; } + +bool result_iter_state_t::is_filter_provided() const { + return filter_ids_length > 0 || (fit != nullptr && fit->approx_filter_ids_length > 0); +} + +bool result_iter_state_t::is_filter_valid() const { + if (filter_ids_length > 0) { + return filter_ids_index < filter_ids_length; + } + + if (fit != nullptr) { + return fit->is_valid; + } + + return false; +} + +uint32_t result_iter_state_t::get_filter_id() const { + if (filter_ids_length > 0 && filter_ids_index < filter_ids_length) { + return filter_ids[filter_ids_index]; + } + + if (fit != nullptr && fit->is_valid) { + return fit->seq_id; + } + + return 0; +} From e9ce1c095521660c4d46a81d09bff6214210668c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 28 Apr 2023 13:35:07 +0530 Subject: [PATCH 41/64] Refactor `or_iterator_t::take_id`. --- src/or_iterator.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index 8dd9d487..6384cba2 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -209,7 +209,16 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e } if (istate.fit != nullptr && istate.fit->approx_filter_ids_length > 0) { - return (istate.fit->valid(id) == 1); + if (istate.fit->valid(id) == -1) { + return false; + } + + if (istate.fit->seq_id == id) { + istate.fit->next(); + return true; + } + + return false; } return true; @@ -245,5 +254,3 @@ or_iterator_t::~or_iterator_t() noexcept { it.reset_cache(); } } - - From b72e31bc5cb02454f4cd8798bcb114cd0aee728e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 28 Apr 2023 14:55:39 +0530 Subject: [PATCH 42/64] Fix `filter_result_iterator_t::valid(uint32_t id)` not updating `seq_id` in case of complex filter. --- src/filter_result_iterator.cpp | 8 ++++++++ test/filter_test.cpp | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 34228916..f155d955 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1014,6 +1014,14 @@ int filter_result_iterator_t::valid(uint32_t id) { if (filter_node->isOperator) { auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); + if (left_it->is_valid && right_it->is_valid) { + seq_id = std::min(left_it->seq_id, right_it->seq_id); + } else if (left_it->is_valid) { + seq_id = left_it->seq_id; + } else if (right_it->is_valid) { + seq_id = right_it->seq_id; + } + if (filter_node->filter_operator == AND) { is_valid = left_it->is_valid && right_it->is_valid; diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 50e7d6d6..d3aa3f31 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -284,10 +284,11 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); - std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}; + std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 3, 4, 5, 5}; expected = {1, 0, 1, 0, 1, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test.seq_id); } delete filter_tree_root; @@ -301,9 +302,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); validate_ids = {0, 1, 2, 3, 4, 5, 6}; + seq_ids = {1, 1, 3, 3, 5, 5, 5}; expected = {0, 1, 0, 1, 0, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_not_equals_filter_test.seq_id); } delete filter_tree_root; From e5fbf0cff2ed67a5e42caf985945399d0d068564 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 1 May 2023 14:53:19 +0530 Subject: [PATCH 43/64] Refactor `art_fuzzy_search_i`. --- src/art.cpp | 263 ++++++++++++++++++++++++++++------------------------ 1 file changed, 144 insertions(+), 119 deletions(-) diff --git a/src/art.cpp b/src/art.cpp index 7a7d5d5b..8c3b99ad 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -973,36 +973,6 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, return prev_token_doc_ids; } -const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, - size_t& prev_token_doc_ids_len) { - - art_leaf* prev_leaf = static_cast( - art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) - ); - - uint32_t* prev_token_doc_ids = nullptr; - - if(prev_token.empty() || !prev_leaf) { - prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids); - return prev_token_doc_ids; - } - - std::vector prev_leaf_ids; - posting_t::merge({prev_leaf->values}, prev_leaf_ids); - - if(filter_result_iterator.is_valid) { - prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), - prev_token_doc_ids); - } else { - prev_token_doc_ids_len = prev_leaf_ids.size(); - prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; - std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); - } - - return prev_token_doc_ids; -} - bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, std::set& exclude_leaves, const art_leaf* exact_leaf, @@ -1030,6 +1000,52 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str return true; } +bool validate_and_add_leaf(art_leaf* leaf, + const std::string& prev_token, const art_leaf* prev_leaf, + const art_leaf* exact_leaf, + filter_result_iterator_t& filter_result_iterator, + std::set& exclude_leaves, + std::vector& results) { + if(leaf == exact_leaf) { + return false; + } + + std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); + if(exclude_leaves.count(tok) != 0) { + return false; + } + + if(prev_token.empty() || !prev_leaf) { + if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) { + return false; + } + } else if (!filter_result_iterator.is_valid) { + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + if (!posting_t::contains_atleast_one(leaf->values, prev_leaf_ids.data(), prev_leaf_ids.size())) { + return false; + } + } else { + std::vector leaf_ids; + posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); + + bool found = false; + for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) { + found = (filter_result_iterator.valid(leaf_ids[i]) == 1); + } + + if (!found) { + return false; + } + } + + exclude_leaves.emplace(tok); + results.push_back(leaf); + + return true; +} + int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, const art_leaf* exact_leaf, const bool last_token, const std::string& prev_token, @@ -1126,6 +1142,101 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r return 0; } +int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, + const art_leaf* exact_leaf, + const bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + const art_tree* t, std::set& exclude_leaves, std::vector& results) { + + printf("INSIDE art_topk_iter: root->type: %d\n", root->type); + + auto prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + std::priority_queue, + decltype(&compare_art_node_score_pq)> q(compare_art_node_score_pq); + + if(token_order == FREQUENCY) { + q = std::priority_queue, + decltype(&compare_art_node_frequency_pq)>(compare_art_node_frequency_pq); + } + + q.push(root); + + size_t num_processed = 0; + + while(!q.empty() && results.size() < max_results*4) { + art_node *n = (art_node *) q.top(); + q.pop(); + + if (!n) continue; + if (IS_LEAF(n)) { + art_leaf *l = (art_leaf *) LEAF_RAW(n); + //LOG(INFO) << "END LEAF SCORE: " << l->max_score; + + validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator, + exclude_leaves, results); + filter_result_iterator.reset(); + + if (++num_processed % 1024 == 0 && (microseconds( + std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { + search_cutoff = true; + break; + } + + continue; + } + + int idx; + switch (n->type) { + case NODE4: + //LOG(INFO) << "NODE4, SCORE: " << n->max_score; + for (int i=0; i < n->num_children; i++) { + art_node* child = ((art_node4*)n)->children[i]; + q.push(child); + } + break; + + case NODE16: + //LOG(INFO) << "NODE16, SCORE: " << n->max_score; + for (int i=0; i < n->num_children; i++) { + q.push(((art_node16*)n)->children[i]); + } + break; + + case NODE48: + //LOG(INFO) << "NODE48, SCORE: " << n->max_score; + for (int i=0; i < 256; i++) { + idx = ((art_node48*)n)->keys[i]; + if (!idx) continue; + art_node *child = ((art_node48*)n)->children[idx - 1]; + q.push(child); + } + break; + + case NODE256: + //LOG(INFO) << "NODE256, SCORE: " << n->max_score; + for (int i=0; i < 256; i++) { + if (!((art_node256*)n)->children[i]) continue; + q.push(((art_node256*)n)->children[i]); + } + break; + + default: + printf("ABORTING BECAUSE OF UNKNOWN NODE TYPE: %d\n", n->type); + abort(); + } + } + + /*LOG(INFO) << "leaf results.size: " << results.size() + << ", filter_ids_length: " << filter_ids_length + << ", num_large_lists: " << num_large_lists;*/ + + printf("OUTSIDE art_topk_iter: results size: %d\n", results.size()); + return 0; +} + // Recursively iterates over the tree static int recursive_iter(art_node *n, art_callback cb, void *data) { // Handle base cases @@ -1689,14 +1800,10 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; - // documents that contain the previous token and/or filter ids - size_t allowed_doc_ids_len = 0; - const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); - filter_result_iterator.reset(); - for(auto node: nodes) { art_topk_iter(node, token_order, max_words, exact_leaf, - last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + last_token, prev_token, + filter_result_iterator, t, exclude_leaves, results); } @@ -1722,91 +1829,9 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le if(time_micro > 1000) { LOG(INFO) << "Time taken for art_topk_iter: " << time_micro << "us, size of nodes: " << nodes.size() - << ", filter_ids_length: " << filter_ids_length; + << ", filter_ids_length: " << filter_result_iterator.approx_filter_ids_length; }*/ - delete [] allowed_doc_ids; - - return 0; -} - -int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const size_t max_words, const token_ordering token_order, const bool prefix, - bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, - std::vector &results, std::set& exclude_leaves) { - - std::vector nodes; - int irow[term_len + 1]; - int jrow[term_len + 1]; - for (int i = 0; i <= term_len; i++){ - irow[i] = jrow[i] = i; - } - - //auto begin = std::chrono::high_resolution_clock::now(); - - if(IS_LEAF(t->root)) { - art_leaf *l = (art_leaf *) LEAF_RAW(t->root); - art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); - } else { - if(t->root == nullptr) { - return 0; - } - - // send depth as -1 to indicate that this is a root node - art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); - } - - //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); - //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); - - //auto begin = std::chrono::high_resolution_clock::now(); - - size_t key_len = prefix ? term_len + 1 : term_len; - art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); - //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; - - // documents that contain the previous token and/or filter ids - size_t allowed_doc_ids_len = 0; - const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); - - for(auto node: nodes) { - art_topk_iter(node, token_order, max_words, exact_leaf, - last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, - t, exclude_leaves, results); - } - - if(token_order == FREQUENCY) { - std::sort(results.begin(), results.end(), compare_art_leaf_frequency); - } else { - std::sort(results.begin(), results.end(), compare_art_leaf_score); - } - - if(exact_leaf && min_cost == 0) { - std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); - if(exclude_leaves.count(tok) == 0) { - results.insert(results.begin(), exact_leaf); - exclude_leaves.emplace(tok); - } - } - - if(results.size() > max_words) { - results.resize(max_words); - } - - /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); - - if(time_micro > 1000) { - LOG(INFO) << "Time taken for art_topk_iter: " << time_micro - << "us, size of nodes: " << nodes.size() - << ", filter_ids_length: " << filter_ids_length; - }*/ - -// TODO: Figure out this edge case. -// if(allowed_doc_ids != filter_ids) { -// delete [] allowed_doc_ids; -// } - return 0; } From fc032efed3f40d44d05e6975c2bdd3e9838e2af0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 2 May 2023 10:49:03 +0530 Subject: [PATCH 44/64] Add test for prefix search with filter. --- src/art.cpp | 2 +- test/collection_filtering_test.cpp | 42 ++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/art.cpp b/src/art.cpp index 8c3b99ad..ee0fb53a 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1031,7 +1031,7 @@ bool validate_and_add_leaf(art_leaf* leaf, posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); bool found = false; - for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) { + for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator.is_valid && !found; i++) { found = (filter_result_iterator.valid(leaf_ids[i]) == 1); } diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 1628c2dc..c644d547 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -2628,4 +2628,46 @@ TEST_F(CollectionFilteringTest, ComplexFilterQuery) { } collectionManager.drop_collection("ComplexFilterQueryCollection"); +} + +TEST_F(CollectionFilteringTest, PrefixSearchWithFilter) { + std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl"); + std::vector search_fields = { + field("title", field_types::STRING, false), + field("points", field_types::INT32, false) + }; + + query_fields = {"title"}; + sort_fields = { sort_by(sort_field_const::text_match, "DESC"), sort_by("points", "DESC") }; + + auto collection = collectionManager.create_collection("collection", 4, search_fields, "points").get(); + + std::string json_line; + + // dummy record for record id 0: to make the test record IDs to match with line numbers + json_line = "{\"points\":10,\"title\":\"z\"}"; + collection->add(json_line); + + while (std::getline(infile, json_line)) { + collection->add(json_line); + } + + infile.close(); + + std::vector facets; + auto results = collection->search("what ex", query_fields, "points: >10", facets, sort_fields, {0}, 10, 1, MAX_SCORE, {true}, 10, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10).get(); + ASSERT_EQ(7, results["hits"].size()); + std::vector ids = {"6", "12", "19", "22", "13", "8", "15"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + collectionManager.drop_collection("collection"); } \ No newline at end of file From 61c2b73d1e8226b798c8912ab0b40a290aa19ed6 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 3 May 2023 18:42:31 +0530 Subject: [PATCH 45/64] Fix phrase search. --- include/art.h | 2 +- include/filter_result_iterator.h | 9 +++ include/index.h | 20 ++--- src/art.cpp | 16 ++-- src/filter_result_iterator.cpp | 45 +++++++++++ src/index.cpp | 129 ++++++++++++++++--------------- test/filter_test.cpp | 14 ++++ 7 files changed, 155 insertions(+), 80 deletions(-) diff --git a/include/art.h b/include/art.h index 11f57a68..92e043e3 100644 --- a/include/art.h +++ b/include/art.h @@ -279,7 +279,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, const size_t max_words, const token_ordering token_order, const bool prefix, bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::vector &results, std::set& exclude_leaves); void encode_int32(int32_t n, unsigned char *chars); diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index bc8c4c23..b3b12555 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -109,6 +109,8 @@ private: std::vector> posting_list_iterators; std::vector expanded_plists; + bool delete_filter_node = false; + /// Initializes the state of iterator node after it's creation. void init(); @@ -127,6 +129,8 @@ private: /// Finds the next match for a filter on string field. void get_string_filter_next_match(const bool& field_is_array); + explicit filter_result_iterator_t(uint32_t approx_filter_ids_length); + public: uint32_t seq_id = 0; /// Collection name -> references @@ -143,6 +147,8 @@ public: /// iterator reaching it's end. (is_valid would be false in both these cases) uint32_t approx_filter_ids_length; + explicit filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count); + explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node, uint32_t approx_filter_ids_length = UINT32_MAX); @@ -193,4 +199,7 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); + + static void add_phrase_ids(filter_result_iterator_t*& filter_result_iterator, + uint32_t* phrase_result_ids, const uint32_t& phrase_result_count); }; diff --git a/include/index.h b/include/index.h index e80d5696..5d825555 100644 --- a/include/index.h +++ b/include/index.h @@ -408,7 +408,7 @@ private: void search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -725,7 +725,7 @@ public: const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, @@ -766,7 +766,7 @@ public: std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -797,7 +797,7 @@ public: spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -814,6 +814,7 @@ public: std::array*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + filter_result_iterator_t*& filter_result_iterator, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed, const std::set& curated_ids, @@ -821,8 +822,7 @@ public: const std::unordered_set& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, - bool is_wildcard_query, - uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const; + bool is_wildcard_query) const; void fuzzy_search_fields(const std::vector& the_fields, const std::vector& query_tokens, @@ -830,7 +830,7 @@ public: const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -859,7 +859,7 @@ public: const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, @@ -881,7 +881,7 @@ public: const std::vector& group_by_fields, bool prioritize_exact_match, const bool search_all_candidates, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, @@ -933,7 +933,7 @@ public: void process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, const bool filter_curated_hits, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, diff --git a/src/art.cpp b/src/art.cpp index ee0fb53a..65d3ca63 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1003,7 +1003,7 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str bool validate_and_add_leaf(art_leaf* leaf, const std::string& prev_token, const art_leaf* prev_leaf, const art_leaf* exact_leaf, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& exclude_leaves, std::vector& results) { if(leaf == exact_leaf) { @@ -1016,10 +1016,10 @@ bool validate_and_add_leaf(art_leaf* leaf, } if(prev_token.empty() || !prev_leaf) { - if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) { + if (filter_result_iterator->is_valid && !filter_result_iterator->contains_atleast_one(leaf->values)) { return false; } - } else if (!filter_result_iterator.is_valid) { + } else if (!filter_result_iterator->is_valid) { std::vector prev_leaf_ids; posting_t::merge({prev_leaf->values}, prev_leaf_ids); @@ -1031,8 +1031,8 @@ bool validate_and_add_leaf(art_leaf* leaf, posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); bool found = false; - for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator.is_valid && !found; i++) { - found = (filter_result_iterator.valid(leaf_ids[i]) == 1); + for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator->is_valid && !found; i++) { + found = (filter_result_iterator->valid(leaf_ids[i]) == 1); } if (!found) { @@ -1145,7 +1145,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, const art_leaf* exact_leaf, const bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const art_tree* t, std::set& exclude_leaves, std::vector& results) { printf("INSIDE art_topk_iter: root->type: %d\n", root->type); @@ -1177,7 +1177,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator, exclude_leaves, results); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if (++num_processed % 1024 == 0 && (microseconds( std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { @@ -1767,7 +1767,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, const size_t max_words, const token_ordering token_order, const bool prefix, bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::vector &results, std::set& exclude_leaves) { std::vector nodes; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f155d955..c4dcd0b5 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1252,6 +1252,10 @@ filter_result_iterator_t::~filter_result_iterator_t() { delete expanded_plist; } + if (delete_filter_node) { + delete filter_node; + } + delete left_it; delete right_it; } @@ -1343,3 +1347,44 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, next(); } } + +filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_length) : + approx_filter_ids_length(approx_filter_ids_length) { + filter_node = new filter_node_t(AND, nullptr, nullptr); + delete_filter_node = true; +} + +filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count) { + filter_result.count = approx_filter_ids_length = ids_count; + filter_result.docs = ids; + is_valid = ids_count > 0; + + if (is_valid) { + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + filter_node = new filter_node_t({"dummy", {}, {}}); + delete_filter_node = true; + } +} + +void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_result_iterator, + uint32_t* phrase_result_ids, const uint32_t& phrase_result_count) { + auto root_iterator = new filter_result_iterator_t(std::min(phrase_result_count, filter_result_iterator->approx_filter_ids_length)); + root_iterator->left_it = new filter_result_iterator_t(phrase_result_ids, phrase_result_count); + root_iterator->right_it = filter_result_iterator; + + auto& left_it = root_iterator->left_it; + auto& right_it = root_iterator->right_it; + + while (left_it->is_valid && right_it->is_valid && left_it->seq_id != right_it->seq_id) { + if (left_it->seq_id < right_it->seq_id) { + left_it->skip_to(right_it->seq_id); + } else { + right_it->skip_to(left_it->seq_id); + } + } + + root_iterator->is_valid = left_it->is_valid && right_it->is_valid; + root_iterator->seq_id = left_it->seq_id; + filter_result_iterator = root_iterator; +} diff --git a/src/index.cpp b/src/index.cpp index 552b3149..9090230d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1263,7 +1263,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { void Index::search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -2247,14 +2247,16 @@ Option Index::search(std::vector& field_query_tokens, cons return rearrange_op; } - auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, + auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root, approx_filter_ids_length); - auto filter_init_op = filter_result_iterator.init_status(); + std::unique_ptr filter_iterator_guard(filter_result_iterator); + + auto filter_init_op = filter_result_iterator->init_status(); if (!filter_init_op.ok()) { return filter_init_op; } - if (filter_tree_root != nullptr && !filter_result_iterator.is_valid) { + if (filter_tree_root != nullptr && !filter_result_iterator->is_valid) { return Option(true); } @@ -2268,7 +2270,7 @@ Option Index::search(std::vector& field_query_tokens, cons process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits, filter_result_iterator, curated_ids, included_ids_map, included_ids_vec, excluded_group_ids); - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2299,24 +2301,19 @@ Option Index::search(std::vector& field_query_tokens, cons field_query_tokens[0].q_include_tokens[0].value == "*"; - // TODO: Do AND with phrase ids at last // handle phrase searches - uint32_t* phrase_result_ids = nullptr; - uint32_t phrase_result_count = 0; - std::unique_ptr phrase_result_ids_guard; - if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, topster, sort_order, field_values, geopoint_indices, curated_ids_sorted, - all_result_ids, all_result_ids_len, groups_processed, curated_ids, + filter_result_iterator, all_result_ids, all_result_ids_len, groups_processed, curated_ids, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, - included_ids_map, is_wildcard_query, - phrase_result_ids, phrase_result_count); + included_ids_map, is_wildcard_query); - phrase_result_ids_guard.reset(phrase_result_ids); + filter_iterator_guard.release(); + filter_iterator_guard.reset(filter_result_iterator); - if (phrase_result_count == 0) { + if (filter_result_iterator->approx_filter_ids_length == 0) { goto process_search_results; } } @@ -2324,7 +2321,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.is_valid); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator->is_valid); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2372,8 +2369,10 @@ Option Index::search(std::vector& field_query_tokens, cons Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, store, doc_id_prefix, filter_tree_root); - filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); - approx_filter_ids_length = filter_result_iterator.approx_filter_ids_length; + filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); + filter_iterator_guard.reset(filter_result_iterator); + + approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -2391,9 +2390,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint32_t filter_id_count = 0; while (!no_filters_provided && - filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { - auto seq_id = filter_result_iterator.seq_id; - filter_result_iterator.next(); + filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator->is_valid) { + auto seq_id = filter_result_iterator->seq_id; + filter_result_iterator->next(); std::vector values; try { @@ -2417,12 +2416,13 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels.emplace_back(dist, seq_id); filter_id_count++; } + filter_result_iterator->reset(); if(no_filters_provided || - (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.is_valid)) { + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->is_valid)) { dist_labels.clear(); - VectorFilterFunctor filterFunctor(&filter_result_iterator); + VectorFilterFunctor filterFunctor(filter_result_iterator); if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); @@ -2432,8 +2432,7 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } } - - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector nearest_ids; @@ -2488,7 +2487,7 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } // filter tree was initialized to have all sequence ids in this flow. @@ -2549,7 +2548,7 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -2586,7 +2585,7 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } } @@ -2602,7 +2601,7 @@ Option Index::search(std::vector& field_query_tokens, cons filter_result_iterator, query_hashes, sort_order, field_values, geopoint_indices, qtoken_set); - filter_result_iterator.reset(); + filter_result_iterator->reset(); // gather up both original query and synonym queries and do drop tokens @@ -2659,7 +2658,7 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } else { break; @@ -2676,7 +2675,7 @@ Option Index::search(std::vector& field_query_tokens, cons sort_order, field_values, geopoint_indices, curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if(!vector_query.field_name.empty()) { // check at least one of sort fields is text match @@ -2693,7 +2692,7 @@ Option Index::search(std::vector& field_query_tokens, cons constexpr float TEXT_MATCH_WEIGHT = 0.7; constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; - VectorFilterFunctor filterFunctor(&filter_result_iterator); + VectorFilterFunctor filterFunctor(filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; auto k = std::max(vector_query.k, fetch_size); @@ -2705,7 +2704,7 @@ Option Index::search(std::vector& field_query_tokens, cons } else { dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector> vec_results; for (const auto& dist_label : dist_labels) { @@ -2915,7 +2914,7 @@ Option Index::search(std::vector& field_query_tokens, cons void Index::process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, - const bool filter_curated_hits, filter_result_iterator_t& filter_result_iterator, + const bool filter_curated_hits, filter_result_iterator_t* const filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, @@ -2938,9 +2937,9 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_result_iterator.is_valid && filter_curated_hits) { + if(filter_result_iterator->is_valid && filter_curated_hits) { for (const auto &included_id: included_ids_vec) { - auto result = filter_result_iterator.valid(included_id); + auto result = filter_result_iterator->valid(included_id); if (result == -1) { break; @@ -3007,7 +3006,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector & sort_fields, @@ -3153,7 +3152,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, last_token, prev_token, filter_result_iterator, field_leaves, unique_tokens); - filter_result_iterator.reset(); + filter_result_iterator->reset(); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3184,7 +3183,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, token_candidates_vec.back().candidates[0], the_fields, num_search_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, prev_token_doc_ids, popular_field_ids); - filter_result_iterator.reset(); + filter_result_iterator->reset(); for(size_t field_id: query_field_ids) { auto& the_field = the_fields[field_id]; @@ -3207,7 +3206,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, false, "", filter_result_iterator, field_leaves, unique_tokens); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if(field_leaves.empty()) { // look at the next field @@ -3271,7 +3270,6 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, exhaustive_search, max_candidates, syn_orig_num_tokens, sort_order, field_values, geopoint_indices, query_hashes, id_buff); - filter_result_iterator.reset(); if(id_buff.size() > 1) { gfx::timsort(id_buff.begin(), id_buff.end()); @@ -3332,7 +3330,7 @@ void Index::find_across_fields(const token_t& previous_token, const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, std::vector& top_prefix_field_ids) const { @@ -3343,7 +3341,7 @@ void Index::find_across_fields(const token_t& previous_token, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator); const bool prefix_search = previous_token.is_prefix_searched; const uint32_t token_num_typos = previous_token.num_typos; @@ -3424,7 +3422,7 @@ void Index::search_across_fields(const std::vector& query_tokens, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool prioritize_token_position, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, @@ -3485,7 +3483,7 @@ void Index::search_across_fields(const std::vector& query_tokens, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator); // for each token, find the posting lists across all query_by fields for(size_t ti = 0; ti < query_tokens.size(); ti++) { @@ -3947,6 +3945,7 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + filter_result_iterator_t*& filter_result_iterator, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed, const std::set& curated_ids, @@ -3954,9 +3953,10 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, - bool is_wildcard_query, - uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const { + bool is_wildcard_query) const { + uint32_t* phrase_result_ids = nullptr; + uint32_t phrase_result_count = 0; std::map phrase_match_id_scores; for(size_t i = 0; i < num_search_fields; i++) { @@ -4045,12 +4045,19 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vectoris_valid) { + filter_result_iterator_t::add_phrase_ids(filter_result_iterator, phrase_result_ids, phrase_result_count); + } else { + delete filter_result_iterator; + filter_result_iterator = new filter_result_iterator_t(phrase_result_ids, phrase_result_count); + } + size_t filter_index = 0; if(is_wildcard_query) { - all_result_ids = new uint32_t[phrase_result_count]; - std::copy(phrase_result_ids, phrase_result_ids + phrase_result_count, all_result_ids); - all_result_ids_len = phrase_result_count; + all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids); + filter_result_iterator->reset(); } else { // this means that the there are non-phrase tokens in the query // so we cannot directly copy to the all_result_ids array @@ -4058,8 +4065,8 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector(10000, phrase_result_count); i++) { - auto seq_id = phrase_result_ids[i]; + for(size_t i = 0; i < std::min(10000, all_result_ids_len); i++) { + auto seq_id = all_result_ids[i]; int64_t match_score = phrase_match_id_scores[seq_id]; int64_t scores[3] = {0}; @@ -4112,7 +4119,7 @@ void Index::do_synonym_search(const std::vector& the_fields, spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -4140,7 +4147,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -4174,10 +4181,10 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vectoris_valid) { uint32_t *filtered_raw_infix_ids = nullptr; - raw_infix_ids_length = filter_result_iterator.and_scalar(raw_infix_ids, raw_infix_ids_length, + raw_infix_ids_length = filter_result_iterator->and_scalar(raw_infix_ids, raw_infix_ids_length, filtered_raw_infix_ids); if(raw_infix_ids != &infix_ids[0]) { delete [] raw_infix_ids; @@ -4472,7 +4479,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, @@ -4502,11 +4509,11 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, auto parent_search_cutoff = search_cutoff; uint32_t excluded_result_index = 0; - for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { + for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator->is_valid; thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, + filter_result_iterator->get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, batch_result_ids); num_queued++; @@ -4588,8 +4595,8 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - filter_result_iterator.reset(); - all_result_ids_len = filter_result_iterator.to_filter_id_array(all_result_ids); + filter_result_iterator->reset(); + all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids); } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, diff --git a/test/filter_test.cpp b/test/filter_test.cpp index d3aa3f31..ac6efdb4 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -482,5 +482,19 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_EQ(6, iter_skip_test4.seq_id); ASSERT_TRUE(iter_skip_test4.is_valid); + auto iter_add_phrase_ids_test = new filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + std::unique_ptr filter_iter_guard(iter_add_phrase_ids_test); + ASSERT_TRUE(iter_add_phrase_ids_test->init_status().ok()); + + auto phrase_ids = new uint32_t[4]; + for (uint32_t i = 0; i < 4; i++) { + phrase_ids[i] = i * 2; + } + filter_result_iterator_t::add_phrase_ids(iter_add_phrase_ids_test, phrase_ids, 4); + filter_iter_guard.reset(iter_add_phrase_ids_test); + + ASSERT_TRUE(iter_add_phrase_ids_test->is_valid); + ASSERT_EQ(6, iter_add_phrase_ids_test->seq_id); + delete filter_tree_root; } \ No newline at end of file From 61357fff938253e4228b22c64c6f9e97da4d0357 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 4 May 2023 16:32:52 +0530 Subject: [PATCH 46/64] Fix alloc-dealloc-mismatch. --- src/index.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index.cpp b/src/index.cpp index 9090230d..46ec633e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1716,7 +1716,7 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter uint32_t* reference_docs = nullptr; uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs); - std::unique_ptr docs_guard(reference_docs); + std::unique_ptr docs_guard(reference_docs); // doc id -> reference doc ids std::map> reference_map; From eb298ff9a086edb4030c5dbae6ed735aba40423b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 5 May 2023 12:34:17 +0530 Subject: [PATCH 47/64] Fix ASAN issues. --- src/collection.cpp | 2 +- test/filter_test.cpp | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index 5eab45f0..976ee70e 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -417,7 +417,7 @@ Option Collection::update_matching_filter(const std::string& fil } const auto& dirty_values = parse_dirty_values_option(req_dirty_values); - size_t docs_updated_count; + size_t docs_updated_count = 0; nlohmann::json update_document, dummy; try { diff --git a/test/filter_test.cpp b/test/filter_test.cpp index ac6efdb4..afce0209 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -423,7 +423,7 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_to_array_test.is_valid); - delete filter_ids; + delete[] filter_ids; auto iter_and_scalar_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_and_scalar_test.init_status().ok()); @@ -440,7 +440,7 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_and_scalar_test.is_valid); - delete and_result; + delete[] and_result; delete filter_tree_root; doc = R"({ @@ -491,6 +491,7 @@ TEST_F(FilterTest, FilterTreeIterator) { phrase_ids[i] = i * 2; } filter_result_iterator_t::add_phrase_ids(iter_add_phrase_ids_test, phrase_ids, 4); + filter_iter_guard.release(); filter_iter_guard.reset(iter_add_phrase_ids_test); ASSERT_TRUE(iter_add_phrase_ids_test->is_valid); From 77161ec3e2f73ba876e523c38c8e3bdb82fc7460 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 10:28:36 +0530 Subject: [PATCH 48/64] Fix failing join tests. --- src/index.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 46ec633e..3f87b86f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1339,6 +1339,7 @@ void Index::search_all_candidates(const size_t num_search_fields, id_buff, all_result_ids, all_result_ids_len); query_hashes.insert(qhash); + filter_result_iterator->reset(); } } @@ -2548,7 +2549,6 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator->reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -2585,7 +2585,6 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator->reset(); } } @@ -2658,7 +2657,6 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); - filter_result_iterator->reset(); } else { break; From 8a9bc0dfb0d375a8670d6eec09875e6186786eb1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 12:13:16 +0530 Subject: [PATCH 49/64] Remove `SEQ_IDS_FILTER` logic. --- include/index.h | 4 ---- src/filter.cpp | 3 --- src/filter_result_iterator.cpp | 12 ------------ src/index.cpp | 13 +------------ 4 files changed, 1 insertion(+), 31 deletions(-) diff --git a/include/index.h b/include/index.h index 5d825555..9ba6ea5b 100644 --- a/include/index.h +++ b/include/index.h @@ -542,10 +542,6 @@ public: // in the query that have the least individual hits one by one until enough results are found. static const int DROP_TOKENS_THRESHOLD = 1; - // "_all_" is a special field that maps to all the ids in the index. - static constexpr const char* SEQ_IDS_FIELD = "_all_"; - static constexpr const char* SEQ_IDS_FILTER = "_all_: 1"; - Index() = delete; Index(const std::string& name, diff --git a/src/filter.cpp b/src/filter.cpp index 18348ed6..95fbfefc 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -283,9 +283,6 @@ Option toFilter(const std::string expression, } } return Option(true); - } else if (field_name == Index::SEQ_IDS_FIELD) { - filter_exp = {field_name, {}, {}}; - return Option(true); } auto field_it = search_schema.find(field_name); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index c4dcd0b5..88485aa4 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -588,18 +588,6 @@ void filter_result_iterator_t::init() { filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; - is_filter_result_initialized = true; - return; - } else if (a_filter.field_name == Index::SEQ_IDS_FIELD) { - if (index->seq_ids->num_ids() == 0) { - is_valid = false; - return; - } - - approx_filter_ids_length = filter_result.count = index->seq_ids->num_ids(); - filter_result.docs = index->seq_ids->uncompress(); - seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; return; diff --git a/src/index.cpp b/src/index.cpp index 3f87b86f..b2a4ff39 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2366,11 +2366,7 @@ Option Index::search(std::vector& field_query_tokens, cons // if filters were not provided, use the seq_ids index to generate the list of all document ids if (no_filters_provided) { - const std::string doc_id_prefix = std::to_string(collection_id) + "_" + Collection::DOC_ID_PREFIX + "_"; - Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, - store, doc_id_prefix, filter_tree_root); - - filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); + filter_result_iterator = new filter_result_iterator_t(seq_ids->uncompress(), seq_ids->num_ids()); filter_iterator_guard.reset(filter_result_iterator); approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; @@ -2491,13 +2487,6 @@ Option Index::search(std::vector& field_query_tokens, cons filter_result_iterator->reset(); } - // filter tree was initialized to have all sequence ids in this flow. - if (no_filters_provided) { - delete filter_tree_root; - filter_tree_root = nullptr; - approx_filter_ids_length = 0; - } - uint32_t _all_result_ids_len = all_result_ids_len; curate_filtered_ids(curated_ids, excluded_result_ids, excluded_result_ids_size, all_result_ids, _all_result_ids_len, curated_ids_sorted); From 8eea7198356c99bcf9d880908fc0165ec3d973c0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 13:38:18 +0530 Subject: [PATCH 50/64] Fix `HybridSearchRankFusionTest`. --- include/index.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/index.h b/include/index.h index 9ba6ea5b..c01ed672 100644 --- a/include/index.h +++ b/include/index.h @@ -240,6 +240,10 @@ public: filter_result_iterator(filter_result_iterator) {} bool operator()(hnswlib::labeltype id) override { + if (filter_result_iterator->approx_filter_ids_length == 0) { + return true; + } + filter_result_iterator->reset(); return filter_result_iterator->valid(id) == 1; } From b97f51f53ad75ec4ae02fbd91acd26283396c6d3 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 19:12:58 +0530 Subject: [PATCH 51/64] Try merging id lists using priority queue --- include/num_tree.h | 5 +++++ src/filter_result_iterator.cpp | 34 +++++++++++++++++++++++-------- src/num_tree.cpp | 37 ++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/include/num_tree.h b/include/num_tree.h index de95307b..5359d7fd 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -63,4 +63,9 @@ public: uint32_t* const& context_ids, size_t& result_ids_len, uint32_t*& result_ids) const; + + void merge_id_list_iterators(std::vector& id_list_iterators, + const NUM_COMPARATOR &comparator, + uint32_t*& result_ids, + uint32_t& result_ids_len) const; }; \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 88485aa4..482c7475 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -606,22 +606,40 @@ void filter_result_iterator_t::init() { for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); + std::vector id_list_iterators; + std::vector expanded_id_lists; - size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); + num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + value, id_list_iterators, expanded_id_lists); } - filter_result.count = result_size; + uint32_t* filter_match_ids = nullptr; + uint32_t filter_ids_length; + num_tree->merge_id_list_iterators(id_list_iterators, a_filter.comparators[fi], + filter_match_ids, filter_ids_length); + + if (a_filter.comparators[fi] == NOT_EQUALS) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_match_ids, filter_ids_length); + } + + uint32_t *out = nullptr; + filter_result.count = ArrayUtils::or_scalar(filter_match_ids, filter_ids_length, + filter_result.docs, filter_result.count, &out); + + delete [] filter_match_ids; + delete [] filter_result.docs; + filter_result.docs = out; + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } } if (a_filter.apply_not_equals) { diff --git a/src/num_tree.cpp b/src/num_tree.cpp index 4214ced6..fc54f0fa 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -405,3 +405,40 @@ num_tree_t::~num_tree_t() { ids_t::destroy_list(kv.second); } } + +void num_tree_t::merge_id_list_iterators(std::vector& id_list_iterators, + const NUM_COMPARATOR &comparator, + uint32_t*& result_ids, + uint32_t& result_ids_len) const { + struct comp { + bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { + return lhs->id() > rhs->id(); + } + }; + + std::priority_queue, comp> iter_queue; + for (auto& id_list_iterator: id_list_iterators) { + if (id_list_iterator.valid()) { + iter_queue.push(&id_list_iterator); + } + } + + std::vector consolidated_ids; + while (!iter_queue.empty()) { + id_list_t::iterator_t* iter = iter_queue.top(); + iter_queue.pop(); + + consolidated_ids.push_back(iter->id()); + iter->next(); + + if (iter->valid()) { + iter_queue.push(iter); + } + } + + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + result_ids_len = consolidated_ids.size(); + result_ids = new uint32_t[consolidated_ids.size()]; + std::copy(consolidated_ids.begin(), consolidated_ids.end(), result_ids); +} From ab8ac4a5849ea45ffd906fae54aa82ddbb37beca Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 9 May 2023 15:59:36 +0530 Subject: [PATCH 52/64] Optimize string filtering. --- src/filter_result_iterator.cpp | 105 +++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 30 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 482c7475..e58c9c17 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -299,48 +299,93 @@ void filter_result_iterator_t::get_string_filter_next_match(const bool& field_is uint32_t lowest_id = UINT32_MAX; if (filter_node->filter_exp.comparators[0] == EQUALS || filter_node->filter_exp.comparators[0] == NOT_EQUALS) { - for (auto& filter_value_tokens : posting_list_iterators) { - bool tokens_iter_is_valid, exact_match = false; - while(true) { - // Perform AND between tokens of a filter value. - posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + bool exact_match_found = false; + switch (posting_list_iterators.size()) { + case 1: + while(true) { + // Perform AND between tokens of a filter value. + posting_list_t::intersect(posting_list_iterators[0], one_is_valid); - if (!tokens_iter_is_valid) { - break; + if (!one_is_valid) { + break; + } + + if (posting_list_t::has_exact_match(posting_list_iterators[0], field_is_array)) { + exact_match_found = true; + break; + } else { + // Keep advancing token iterators till exact match is not found. + for (auto& iter: posting_list_iterators[0]) { + if (!iter.valid()) { + break; + } + + iter.next(); + } + } } - if (posting_list_t::has_exact_match(filter_value_tokens, field_is_array)) { - exact_match = true; - break; - } else { - // Keep advancing token iterators till exact match is not found. - for (auto &iter: filter_value_tokens) { - if (!iter.valid()) { + if (one_is_valid && exact_match_found) { + lowest_id = posting_list_iterators[0][0].id(); + } + break; + + default : + for (auto& filter_value_tokens : posting_list_iterators) { + bool tokens_iter_is_valid; + while(true) { + // Perform AND between tokens of a filter value. + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + if (!tokens_iter_is_valid) { break; } - iter.next(); + if (posting_list_t::has_exact_match(filter_value_tokens, field_is_array)) { + exact_match_found = true; + break; + } else { + // Keep advancing token iterators till exact match is not found. + for (auto &iter: filter_value_tokens) { + if (!iter.valid()) { + break; + } + + iter.next(); + } + } + } + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && exact_match_found && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); } } - } - - one_is_valid = tokens_iter_is_valid || one_is_valid; - - if (tokens_iter_is_valid && exact_match && filter_value_tokens[0].id() < lowest_id) { - lowest_id = filter_value_tokens[0].id(); - } } } else { - for (auto& filter_value_tokens : posting_list_iterators) { - // Perform AND between tokens of a filter value. - bool tokens_iter_is_valid; - posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + switch (posting_list_iterators.size()) { + case 1: + // Perform AND between tokens of a filter value. + posting_list_t::intersect(posting_list_iterators[0], one_is_valid); - one_is_valid = tokens_iter_is_valid || one_is_valid; + if (one_is_valid) { + lowest_id = posting_list_iterators[0][0].id(); + } + break; - if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { - lowest_id = filter_value_tokens[0].id(); - } + default: + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } } } From c075cb3307922e5b33ef8fd391279c6aeabfbb63 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 16 May 2023 18:16:43 +0530 Subject: [PATCH 53/64] Undo id list merge using priority queue. --- include/num_tree.h | 15 ----- src/filter_result_iterator.cpp | 36 +++-------- src/num_tree.cpp | 115 --------------------------------- 3 files changed, 9 insertions(+), 157 deletions(-) diff --git a/include/num_tree.h b/include/num_tree.h index 5359d7fd..5406a109 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -30,11 +30,6 @@ public: void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len); - void range_inclusive_search_iterators(int64_t start, - int64_t end, - std::vector& id_list_iterators, - std::vector& expanded_id_lists); - void approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len); void range_inclusive_contains(const int64_t& start, const int64_t& end, @@ -47,11 +42,6 @@ public: void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len); - void search_iterators(NUM_COMPARATOR comparator, - int64_t value, - std::vector& id_list_iterators, - std::vector& expanded_id_lists); - void approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len); void remove(uint64_t value, uint32_t id); @@ -63,9 +53,4 @@ public: uint32_t* const& context_ids, size_t& result_ids_len, uint32_t*& result_ids) const; - - void merge_id_list_iterators(std::vector& id_list_iterators, - const NUM_COMPARATOR &comparator, - uint32_t*& result_ids, - uint32_t& result_ids_len) const; }; \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index e58c9c17..99231679 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -651,40 +651,22 @@ void filter_result_iterator_t::init() { for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); - std::vector id_list_iterators; - std::vector expanded_id_lists; + size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); } else { - num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], - value, id_list_iterators, expanded_id_lists); + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); } - uint32_t* filter_match_ids = nullptr; - uint32_t filter_ids_length; - num_tree->merge_id_list_iterators(id_list_iterators, a_filter.comparators[fi], - filter_match_ids, filter_ids_length); - - if (a_filter.comparators[fi] == NOT_EQUALS) { - apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_match_ids, filter_ids_length); - } - - uint32_t *out = nullptr; - filter_result.count = ArrayUtils::or_scalar(filter_match_ids, filter_ids_length, - filter_result.docs, filter_result.count, &out); - - delete [] filter_match_ids; - delete [] filter_result.docs; - filter_result.docs = out; - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; - } + filter_result.count = result_size; } if (a_filter.apply_not_equals) { @@ -1400,7 +1382,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_length) : - approx_filter_ids_length(approx_filter_ids_length) { + approx_filter_ids_length(approx_filter_ids_length) { filter_node = new filter_node_t(AND, nullptr, nullptr); delete_filter_node = true; } diff --git a/src/num_tree.cpp b/src/num_tree.cpp index fc54f0fa..1bcdbc9f 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -43,30 +43,6 @@ void num_tree_t::range_inclusive_search(int64_t start, int64_t end, uint32_t** i *ids = out; } -void num_tree_t::range_inclusive_search_iterators(int64_t start, - int64_t end, - std::vector& id_list_iterators, - std::vector& expanded_id_lists) { - if (int64map.empty()) { - return; - } - - auto it_start = int64map.lower_bound(start); // iter values will be >= start - - std::vector raw_id_lists; - while (it_start != int64map.end() && it_start->first <= end) { - raw_id_lists.push_back(it_start->second); - it_start++; - } - - std::vector id_lists; - ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); - - for (const auto &id_list: id_lists) { - id_list_iterators.emplace_back(id_list->new_iterator()); - } -} - void num_tree_t::approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len) { if (int64map.empty()) { return; @@ -211,60 +187,6 @@ void num_tree_t::search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids } } -void num_tree_t::search_iterators(NUM_COMPARATOR comparator, - int64_t value, - std::vector& id_list_iterators, - std::vector& expanded_id_lists) { - if (int64map.empty()) { - return ; - } - - std::vector raw_id_lists; - if (comparator == EQUALS) { - const auto& it = int64map.find(value); - if (it != int64map.end()) { - raw_id_lists.emplace_back(it->second); - } - } else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) { - // iter entries will be >= value, or end() if all entries are before value - auto iter_ge_value = int64map.lower_bound(value); - - if(iter_ge_value == int64map.end()) { - return ; - } - - if(comparator == GREATER_THAN && iter_ge_value->first == value) { - iter_ge_value++; - } - - while(iter_ge_value != int64map.end()) { - raw_id_lists.emplace_back(iter_ge_value->second); - iter_ge_value++; - } - } else if(comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) { - // iter entries will be >= value, or end() if all entries are before value - auto iter_ge_value = int64map.lower_bound(value); - - auto it = int64map.begin(); - while(it != iter_ge_value) { - raw_id_lists.emplace_back(it->second); - it++; - } - - // for LESS_THAN_EQUALS, check if last iter entry is equal to value - if(it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) { - raw_id_lists.emplace_back(it->second); - } - } - - std::vector id_lists; - ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); - - for (const auto &id_list: id_lists) { - id_list_iterators.emplace_back(id_list->new_iterator()); - } -} - void num_tree_t::approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len) { if (int64map.empty()) { return; @@ -405,40 +327,3 @@ num_tree_t::~num_tree_t() { ids_t::destroy_list(kv.second); } } - -void num_tree_t::merge_id_list_iterators(std::vector& id_list_iterators, - const NUM_COMPARATOR &comparator, - uint32_t*& result_ids, - uint32_t& result_ids_len) const { - struct comp { - bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { - return lhs->id() > rhs->id(); - } - }; - - std::priority_queue, comp> iter_queue; - for (auto& id_list_iterator: id_list_iterators) { - if (id_list_iterator.valid()) { - iter_queue.push(&id_list_iterator); - } - } - - std::vector consolidated_ids; - while (!iter_queue.empty()) { - id_list_t::iterator_t* iter = iter_queue.top(); - iter_queue.pop(); - - consolidated_ids.push_back(iter->id()); - iter->next(); - - if (iter->valid()) { - iter_queue.push(iter); - } - } - - consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); - - result_ids_len = consolidated_ids.size(); - result_ids = new uint32_t[consolidated_ids.size()]; - std::copy(consolidated_ids.begin(), consolidated_ids.end(), result_ids); -} From b5209e6d0c436c451df410bc858f311fd48507d7 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 16 May 2023 18:18:29 +0530 Subject: [PATCH 54/64] Optimize `filter_result_iterator_t::and_filter_iterators`. --- src/filter_result_iterator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 99231679..a55a01de 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -180,7 +180,7 @@ void filter_result_t::or_filter_results(const filter_result_t& a, const filter_r void filter_result_iterator_t::and_filter_iterators() { while (left_it->is_valid && right_it->is_valid) { while (left_it->seq_id < right_it->seq_id) { - left_it->next(); + left_it->skip_to(right_it->seq_id); if (!left_it->is_valid) { is_valid = false; return; @@ -188,7 +188,7 @@ void filter_result_iterator_t::and_filter_iterators() { } while (left_it->seq_id > right_it->seq_id) { - right_it->next(); + right_it->skip_to(left_it->seq_id); if (!right_it->is_valid) { is_valid = false; return; From 460ed6730adb0755277041adea35161d5ffc2ccd Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 17 May 2023 14:44:27 +0530 Subject: [PATCH 55/64] Fix `filter_result_iterator_t::valid` not updating `seq_id` properly for complex filter expressions. --- src/filter_result_iterator.cpp | 23 +++++++++++++++-------- src/or_iterator.cpp | 6 +----- test/filter_test.cpp | 2 +- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index a55a01de..f0145b72 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1047,14 +1047,6 @@ int filter_result_iterator_t::valid(uint32_t id) { if (filter_node->isOperator) { auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); - if (left_it->is_valid && right_it->is_valid) { - seq_id = std::min(left_it->seq_id, right_it->seq_id); - } else if (left_it->is_valid) { - seq_id = left_it->seq_id; - } else if (right_it->is_valid) { - seq_id = right_it->seq_id; - } - if (filter_node->filter_operator == AND) { is_valid = left_it->is_valid && right_it->is_valid; @@ -1063,9 +1055,20 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } + // id did not match the filter but both of the sub-iterators are still valid. + // Updating seq_id to the next potential match. + if (left_valid == 0 && right_valid == 0) { + seq_id = std::max(left_it->seq_id, right_it->seq_id); + } else if (left_valid == 0) { + seq_id = left_it->seq_id; + } else if (right_valid == 0) { + seq_id = right_it->seq_id; + } + return 0; } + seq_id = id; return 1; } else { is_valid = left_it->is_valid || right_it->is_valid; @@ -1075,9 +1078,13 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } + // id did not match the filter but both of the sub-iterators are still valid. + // Next seq_id match would be the minimum of the two. + seq_id = std::min(left_it->seq_id, right_it->seq_id); return 0; } + seq_id = id; return 1; } } diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index 6384cba2..5a88d68e 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -209,11 +209,7 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e } if (istate.fit != nullptr && istate.fit->approx_filter_ids_length > 0) { - if (istate.fit->valid(id) == -1) { - return false; - } - - if (istate.fit->seq_id == id) { + if (istate.fit->valid(id) == 1) { istate.fit->next(); return true; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index afce0209..aa10b97d 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -284,7 +284,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); - std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 3, 4, 5, 5}; + std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 4, 4, 5, 5}; expected = {1, 0, 1, 0, 1, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); From 3fe81777373f81dc5ff6ba4eb55b9116b0a20787 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 18 May 2023 07:50:37 +0530 Subject: [PATCH 56/64] Add tests for `filter_result_iterator_t::valid`. --- src/filter_result_iterator.cpp | 15 +++++++++---- test/filter_test.cpp | 40 ++++++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f0145b72..707f2f69 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1061,7 +1061,7 @@ int filter_result_iterator_t::valid(uint32_t id) { seq_id = std::max(left_it->seq_id, right_it->seq_id); } else if (left_valid == 0) { seq_id = left_it->seq_id; - } else if (right_valid == 0) { + } else { seq_id = right_it->seq_id; } @@ -1078,9 +1078,16 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } - // id did not match the filter but both of the sub-iterators are still valid. - // Next seq_id match would be the minimum of the two. - seq_id = std::min(left_it->seq_id, right_it->seq_id); + // id did not match the filter; both of the sub-iterators or one of them might be valid. + // Updating seq_id to the next match. + if (left_valid == 0 && right_valid == 0) { + seq_id = std::min(left_it->seq_id, right_it->seq_id); + } else if (left_valid == 0) { + seq_id = left_it->seq_id; + } else { + seq_id = right_it->seq_id; + } + return 0; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index aa10b97d..e1a6aaf7 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -281,14 +281,46 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); - ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); + auto iter_validate_ids_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test1.init_status().ok()); std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 4, 4, 5, 5}; expected = {1, 0, 1, 0, 1, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { - ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); - ASSERT_EQ(seq_ids[i], iter_validate_ids_test.seq_id); + ASSERT_EQ(expected[i], iter_validate_ids_test1.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test1.seq_id); + } + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: platinum || name: James", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test2.init_status().ok()); + + validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {1, 1, 5, 5, 5, 5, 5}; + expected = {0, 1, 0, 0, 0, 1, -1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_test2.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test2.seq_id); + } + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold && rating: < 6", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_test3 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test3.init_status().ok()); + + validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 3, 3, 4, 4, 4, 4}; + expected = {1, 0, 0, 0, 1, -1, -1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_test3.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test3.seq_id); } delete filter_tree_root; From 48cf58f162744297c084e8154bebba587c1b345d Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 19 May 2023 18:52:20 +0530 Subject: [PATCH 57/64] Add `filter::parse_geopoint_filter_value`. --- include/filter.h | 14 ++++ src/filter.cpp | 151 ++++++++++++++++++++++++++++++++++++++++++- src/string_utils.cpp | 38 +++++++++++ 3 files changed, 202 insertions(+), 1 deletion(-) diff --git a/include/filter.h b/include/filter.h index 1961b0ac..0efd5900 100644 --- a/include/filter.h +++ b/include/filter.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "store.h" enum FILTER_OPERATOR { @@ -26,6 +27,15 @@ struct filter { // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` std::string referenced_collection_name = ""; + std::vector params; + + /// For searching places within a given radius of a given latlong (mi for miles and km for kilometers) + static constexpr const char* GEO_FILTER_RADIUS = "radius"; + + /// Radius threshold beyond which exact filtering on geo_result_ids will not be done. + static constexpr const char* EXACT_GEO_FILTER_RADIUS = "exact_filter_radius"; + static constexpr const char* DEFAULT_EXACT_GEO_FILTER_RADIUS = "10km"; + static const std::string RANGE_OPERATOR() { return ".."; } @@ -39,6 +49,10 @@ struct filter { std::string& processed_filter_val, NUM_COMPARATOR& num_comparator); + static Option parse_geopoint_filter_value(std::string& raw_value, + const std::string& format_err_msg, + filter& filter_exp); + static Option parse_filter_query(const std::string& filter_query, const tsl::htrie_map& search_schema, const Store* store, diff --git a/src/filter.cpp b/src/filter.cpp index 95fbfefc..c7cc73ca 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -143,6 +143,142 @@ Option filter::parse_geopoint_filter_value(std::string& raw_value, return Option(true); } +Option filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) { + // FORMAT: + // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ] + + // Every open parenthesis represent a geo filter value. + auto open_parenthesis_count = std::count(raw_value.begin(), raw_value.end(), '('); + if (open_parenthesis_count < 1) { + return Option(400, format_err_msg); + } + + filter_exp.comparators.push_back(LESS_THAN_EQUALS); + bool is_multivalued = raw_value[0] == '['; + size_t i = is_multivalued; + + // Adding polygonal values at last since they don't have any parameters associated with them. + std::vector polygons; + for (auto j = 0; j < open_parenthesis_count; j++) { + if (is_multivalued) { + auto pos = raw_value.find('(', i); + if (pos == std::string::npos) { + return Option(400, format_err_msg); + } + i = pos; + } + + i++; + if (i >= raw_value.size()) { + return Option(400, format_err_msg); + } + + auto value_end_index = raw_value.find(')', i); + if (value_end_index == std::string::npos) { + return Option(400, format_err_msg); + } + + // [48.853, 2.344], radius: 1km, exact_filter_radius: 100km + // [48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469] + std::string value_str = raw_value.substr(i, value_end_index - i); + StringUtils::trim(value_str); + + if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) { + return Option(400, format_err_msg); + } + + auto points_str = value_str.substr(1, value_str.find(']', 1) - 1); + std::vector geo_points; + StringUtils::split(points_str, geo_points, ","); + + for (const auto& geo_point: geo_points) { + if(!StringUtils::is_float(geo_point)) { + return Option(400, format_err_msg); + } + } + + bool is_polygon = value_str.back() == ']'; + if (is_polygon) { + polygons.push_back(points_str); + continue; + } + + // Handle options. + // , radius: 1km, exact_filter_radius: 100km + i = raw_value.find(']', i); + i++; + + std::vector options; + StringUtils::split(raw_value.substr(i, value_end_index - i), options, ","); + + if (options.empty()) { + // Missing radius option + return Option(400, format_err_msg); + } + + bool is_radius_present = false; + for (auto const& option: options) { + if (option.empty()) { + continue; + } + + std::vector key_value; + StringUtils::split(option, key_value, ":"); + + if (key_value.size() < 2) { + continue; + } + + if (key_value[0] == GEO_FILTER_RADIUS) { + is_radius_present = true; + auto& value = key_value[1]; + + if(value.size() < 2) { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::string unit = value.substr(value.size() - 2, 2); + + if(unit != "km" && unit != "mi") { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::vector dist_values; + StringUtils::split(value, dist_values, unit); + + if(dist_values.size() != 1) { + return Option(400, format_err_msg); + } + + if(!StringUtils::is_float(dist_values[0])) { + return Option(400, format_err_msg); + } + + filter_exp.values.push_back(points_str + ", " + dist_values[0] + ", " + unit); + } else if (key_value[0] == EXACT_GEO_FILTER_RADIUS) { + nlohmann::json param; + param[EXACT_GEO_FILTER_RADIUS] = key_value[1]; + filter_exp.params.push_back(param); + } + } + + if (!is_radius_present) { + return Option(400, format_err_msg); + } + if (filter_exp.params.empty()) { + nlohmann::json param; + param[EXACT_GEO_FILTER_RADIUS] = DEFAULT_EXACT_GEO_FILTER_RADIUS; + filter_exp.params.push_back(param); + } + } + + for (auto const& polygon: polygons) { + filter_exp.values.push_back(polygon); + } + + return Option(true); +} + bool isOperator(const std::string& expression) { return expression == "&&" || expression == "||"; } @@ -383,10 +519,23 @@ Option toFilter(const std::string expression, } } else if (_field.is_geopoint()) { filter_exp = {field_name, {}, {}}; + NUM_COMPARATOR num_comparator; + + if ((raw_value[0] == '(' && std::count(raw_value.begin(), raw_value.end(), '[') > 0) || + std::count(raw_value.begin(), raw_value.end(), '[') > 1 || + std::count(raw_value.begin(), raw_value.end(), ':') > 0) { + + const std::string& format_err_msg = "Value of filter field `" + _field.name + "`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format."; + + auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, filter_exp); + return parse_op; + } + const std::string& format_err_msg = "Value of filter field `" + _field.name + "`: must be in the `(-44.50, 170.29, 0.75 km)` or " "(56.33, -65.97, 23.82, -127.82) format."; - NUM_COMPARATOR num_comparator; // could be a single value or a list if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { std::vector filter_values; diff --git a/src/string_utils.cpp b/src/string_utils.cpp index b2caf445..f205d774 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -349,6 +349,35 @@ size_t StringUtils::get_num_chars(const std::string& s) { return j; } +Option parse_multi_valued_geopoint_filter(const std::string& filter_query, std::string& tokens, size_t& index) { + // Multi-valued geopoint filter. + // field_name:[ ([points], options), ([points]) ] + auto error = Option(400, "Could not parse the geopoint filter."); + if (filter_query[index] != '[') { + return error; + } + + size_t start_index = index; + auto size = filter_query.size(); + + // Individual geopoint filters have square brackets inside them. + int square_bracket_count = 1; + while (++index < size && square_bracket_count > 0) { + if (filter_query[index] == '[') { + square_bracket_count++; + } else if (filter_query[index] == ']') { + square_bracket_count--; + } + } + + if (square_bracket_count != 0) { + return error; + } + + tokens = filter_query.substr(start_index, index - start_index); + return Option(true); +} + Option parse_reference_filter(const std::string& filter_query, std::queue& tokens, size_t& index) { auto error = Option(400, "Could not parse the reference filter."); if (filter_query[index] != '$') { @@ -440,6 +469,15 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, if (preceding_colon && c == '(') { is_geo_value = true; preceding_colon = false; + } else if (preceding_colon && c == '[') { + std::string value; + auto op = parse_multi_valued_geopoint_filter(filter_query, value, i); + if (!op.ok()) { + return op; + } + + ss << value; + break; } else if (preceding_colon && c != ' ') { preceding_colon = false; } From 6e5abe1b9ac908e2bd9c858d966843b113c9fee0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 22 May 2023 12:16:00 +0530 Subject: [PATCH 58/64] Add logic to skip exact geo filtering beyond threshold. --- include/filter.h | 8 +-- src/filter.cpp | 82 +++++++++++++++------- src/filter_result_iterator.cpp | 19 ++++- test/collection_filtering_test.cpp | 107 ++++++++++++++++++++++++++++- 4 files changed, 186 insertions(+), 30 deletions(-) diff --git a/include/filter.h b/include/filter.h index 0efd5900..f9f15c1d 100644 --- a/include/filter.h +++ b/include/filter.h @@ -25,16 +25,16 @@ struct filter { bool apply_not_equals = false; // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` - std::string referenced_collection_name = ""; + std::string referenced_collection_name; std::vector params; /// For searching places within a given radius of a given latlong (mi for miles and km for kilometers) - static constexpr const char* GEO_FILTER_RADIUS = "radius"; + static constexpr const char* GEO_FILTER_RADIUS_KEY = "radius"; /// Radius threshold beyond which exact filtering on geo_result_ids will not be done. - static constexpr const char* EXACT_GEO_FILTER_RADIUS = "exact_filter_radius"; - static constexpr const char* DEFAULT_EXACT_GEO_FILTER_RADIUS = "10km"; + static constexpr const char* EXACT_GEO_FILTER_RADIUS_KEY = "exact_filter_radius"; + static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000; static const std::string RANGE_OPERATOR() { return ".."; diff --git a/src/filter.cpp b/src/filter.cpp index c7cc73ca..3b13b400 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -143,6 +143,33 @@ Option filter::parse_geopoint_filter_value(std::string& raw_value, return Option(true); } +Option validate_geofilter_distance(std::string& raw_value, const string& format_err_msg, + std::string& distance, std::string& unit) { + if (raw_value.size() < 2) { + return Option(400, "Unit must be either `km` or `mi`."); + } + + unit = raw_value.substr(raw_value.size() - 2, 2); + + if (unit != "km" && unit != "mi") { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::vector dist_values; + StringUtils::split(raw_value, dist_values, unit); + + if (dist_values.size() != 1) { + return Option(400, format_err_msg); + } + + if (!StringUtils::is_float(dist_values[0])) { + return Option(400, format_err_msg); + } + + distance = std::string(dist_values[0]); + return Option(true); +} + Option filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) { // FORMAT: // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ] @@ -185,19 +212,27 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) { return Option(400, format_err_msg); + } else { + std::vector filter_values; + StringUtils::split(value_str, filter_values, ","); + + if(filter_values.size() < 3) { + return Option(400, format_err_msg); + } } auto points_str = value_str.substr(1, value_str.find(']', 1) - 1); std::vector geo_points; StringUtils::split(points_str, geo_points, ","); + bool is_polygon = value_str.back() == ']'; for (const auto& geo_point: geo_points) { - if(!StringUtils::is_float(geo_point)) { + if (!StringUtils::is_float(geo_point) || + (!is_polygon && (geo_point == "nan" || geo_point == "NaN"))) { return Option(400, format_err_msg); } } - bool is_polygon = value_str.back() == ']'; if (is_polygon) { polygons.push_back(points_str); continue; @@ -229,35 +264,34 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string continue; } - if (key_value[0] == GEO_FILTER_RADIUS) { + if (key_value[0] == GEO_FILTER_RADIUS_KEY) { is_radius_present = true; - auto& value = key_value[1]; - if(value.size() < 2) { - return Option(400, "Unit must be either `km` or `mi`."); + std::string distance, unit; + auto validate_op = validate_geofilter_distance(key_value[1], format_err_msg, distance, unit); + if (!validate_op.ok()) { + return validate_op; } - std::string unit = value.substr(value.size() - 2, 2); - - if(unit != "km" && unit != "mi") { - return Option(400, "Unit must be either `km` or `mi`."); + filter_exp.values.push_back(points_str + ", " + distance + ", " + unit); + } else if (key_value[0] == EXACT_GEO_FILTER_RADIUS_KEY) { + std::string distance, unit; + auto validate_op = validate_geofilter_distance(key_value[1], format_err_msg, distance, unit); + if (!validate_op.ok()) { + return validate_op; } - std::vector dist_values; - StringUtils::split(value, dist_values, unit); + double exact_under_radius = std::stof(distance); - if(dist_values.size() != 1) { - return Option(400, format_err_msg); + if (unit == "km") { + exact_under_radius *= 1000; + } else { + // assume "mi" (validated upstream) + exact_under_radius *= 1609.34; } - if(!StringUtils::is_float(dist_values[0])) { - return Option(400, format_err_msg); - } - - filter_exp.values.push_back(points_str + ", " + dist_values[0] + ", " + unit); - } else if (key_value[0] == EXACT_GEO_FILTER_RADIUS) { nlohmann::json param; - param[EXACT_GEO_FILTER_RADIUS] = key_value[1]; + param[EXACT_GEO_FILTER_RADIUS_KEY] = exact_under_radius; filter_exp.params.push_back(param); } } @@ -265,9 +299,11 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string if (!is_radius_present) { return Option(400, format_err_msg); } - if (filter_exp.params.empty()) { + + // EXACT_GEO_FILTER_RADIUS_KEY was not present. + if (filter_exp.params.size() < filter_exp.values.size()) { nlohmann::json param; - param[EXACT_GEO_FILTER_RADIUS] = DEFAULT_EXACT_GEO_FILTER_RADIUS; + param[EXACT_GEO_FILTER_RADIUS_KEY] = DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE; filter_exp.params.push_back(param); } } diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 707f2f69..115c22c2 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -754,7 +754,9 @@ void filter_result_iterator_t::init() { is_filter_result_initialized = true; return; } else if (f.is_geopoint()) { - for (const std::string& filter_value : a_filter.values) { + for (uint32_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + std::vector geo_result_ids; std::vector filter_value_parts; @@ -763,6 +765,7 @@ void filter_result_iterator_t::init() { bool is_polygon = StringUtils::is_float(filter_value_parts.back()); S2Region* query_region; + double radius; if (is_polygon) { const int num_verts = int(filter_value_parts.size()) / 2; std::vector vertices; @@ -788,7 +791,7 @@ void filter_result_iterator_t::init() { query_region = loop; } } else { - double radius = std::stof(filter_value_parts[2]); + radius = std::stof(filter_value_parts[2]); const auto& unit = filter_value_parts[3]; if (unit == "km") { @@ -820,6 +823,18 @@ void filter_result_iterator_t::init() { gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); + // Skip exact filtering step if query radius is greater than the threshold. + if (!is_polygon && fi < a_filter.params.size() && + radius > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get()) { + uint32_t* out = nullptr; + filter_result.count = ArrayUtils::or_scalar(geo_result_ids.data(), geo_result_ids.size(), + filter_result.docs, filter_result.count, &out); + + delete[] filter_result.docs; + filter_result.docs = out; + continue; + } + // `geo_result_ids` will contain all IDs that are within approximately within query radius // we still need to do another round of exact filtering on them diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index c644d547..b8439b0f 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1049,7 +1049,7 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) { collectionManager.drop_collection("coll_array_fields"); } -TEST_F(CollectionFilteringTest, GeoPointFiltering) { +TEST_F(CollectionFilteringTest, GeoPointFilteringV1) { Collection *coll1; std::vector fields = {field("title", field_types::STRING, false), @@ -1192,6 +1192,111 @@ TEST_F(CollectionFilteringTest, GeoPointFiltering) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionFilteringTest, GeoPointFilteringV2) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(10, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + // when geo query does not send radius key + gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) { Collection *coll1; From b25c072c47b7230a94ca6bdfc3d8ba4331a16d4f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 22 May 2023 15:32:34 +0530 Subject: [PATCH 59/64] Add `geo_filtering_test.cpp`. --- src/filter_result_iterator.cpp | 3 +- test/collection_filtering_test.cpp | 107 +------ test/geo_filtering_test.cpp | 474 +++++++++++++++++++++++++++++ 3 files changed, 476 insertions(+), 108 deletions(-) create mode 100644 test/geo_filtering_test.cpp diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 115c22c2..d9b56683 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -807,6 +807,7 @@ void filter_result_iterator_t::init() { S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); query_region = new S2Cap(center, query_radius); } + std::unique_ptr query_region_guard(query_region); S2RegionTermIndexer::Options options; options.set_index_contains_points_only(true); @@ -883,8 +884,6 @@ void filter_result_iterator_t::init() { delete[] filter_result.docs; filter_result.docs = out; - - delete query_region; } if (filter_result.count == 0) { diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index b8439b0f..c644d547 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1049,7 +1049,7 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) { collectionManager.drop_collection("coll_array_fields"); } -TEST_F(CollectionFilteringTest, GeoPointFilteringV1) { +TEST_F(CollectionFilteringTest, GeoPointFiltering) { Collection *coll1; std::vector fields = {field("title", field_types::STRING, false), @@ -1192,111 +1192,6 @@ TEST_F(CollectionFilteringTest, GeoPointFilteringV1) { collectionManager.drop_collection("coll1"); } -TEST_F(CollectionFilteringTest, GeoPointFilteringV2) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - - results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - - // pick location close to none of the spots - results = coll1->search("*", - {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(0, results["found"].get()); - - // pick a large radius covering all points - - results = coll1->search("*", - {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(10, results["found"].get()); - - // 1 mile radius - - results = coll1->search("*", - {}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - - ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); - - // when geo query had NaN - auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY); - - ASSERT_FALSE(gop.ok()); - ASSERT_EQ("Value of filter field `loc`: must be in the " - "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); - - // when geo query does not send radius key - gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", - {}, {}, {0}, 10, 1, FREQUENCY); - - ASSERT_FALSE(gop.ok()); - ASSERT_EQ("Value of filter field `loc`: must be in the " - "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); - - collectionManager.drop_collection("coll1"); -} - TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) { Collection *coll1; diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp new file mode 100644 index 00000000..9db67c0b --- /dev/null +++ b/test/geo_filtering_test.cpp @@ -0,0 +1,474 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class GeoFilteringTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(GeoFilteringTest, GeoPointFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(10, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + // when geo query does not send radius key + gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector>> records = { + { {"Alpha Inc", "Ennore", "13.22112, 80.30511"}, + {"Alpha Inc", "Velachery", "12.98973, 80.23095"} + }, + + { + {"Veera Inc", "Thiruvallur", "13.12752, 79.90136"}, + }, + + { + {"B1 Inc", "Bengaluru", "12.98246, 77.5847"}, + {"B1 Inc", "Hosur", "12.74147, 77.82915"}, + {"B1 Inc", "Vellore", "12.91866, 79.13075"}, + }, + + { + {"M Inc", "Nashik", "20.11282, 73.79458"}, + {"M Inc", "Pune", "18.56309, 73.855"}, + } + }; + + for(size_t i=0; i> lat_lngs; + for(size_t k = 0; k < records[i].size(); k++) { + std::vector lat_lng_str; + StringUtils::split(records[i][k][2], lat_lng_str, ", "); + + std::vector lat_lng = { + std::stod(lat_lng_str[0]), + std::stod(lat_lng_str[1]) + }; + + lat_lngs.push_back(lat_lng); + } + + doc["loc"] = lat_lngs; + auto add_op = coll1->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + } + + // pick a location close to Chennai + auto results = coll1->search("*", + {}, "loc: ([13.12631, 80.20252], radius: 100km, exact_filter_radius: 100km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + // Default value of exact_filter_radius is 10km, exact filtering is not performed. + results = coll1->search("*", + {}, "loc: ([13.12631, 80.20252], radius: 100km,)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + ASSERT_STREQ("2", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: ([13.62601, 79.39559], radius: 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: ([21.20714729927276, 78.99153966917213], radius: 1000 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(4, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: ([12.98941, 80.23073], radius: 1mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPointRemoval) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc1", field_types::GEOPOINT, false), + field("loc2", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "Palais Garnier"; + doc["loc1"] = {48.872576479306765, 2.332291112241466}; + doc["loc2"] = nlohmann::json::array(); + doc["loc2"][0] = {48.84620987789056, 2.345152755563131}; + doc["points"] = 100; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("*", + {}, "loc1: ([48.87491151802846, 2.343945883701618], radius: 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: ([48.87491151802846, 2.343945883701618], radius: 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + // remove the document, index another document and try querying again + coll1->remove("0"); + doc["id"] = "1"; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + results = coll1->search("*", + {}, "loc1: ([48.87491151802846, 2.343945883701618], radius: 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: ([48.87491151802846, 2.343945883701618], radius: 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} + +TEST_F(GeoFilteringTest, GeoPolygonFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.875223042424125,2.323509661928681, " + "48.85745408145392, 2.3267084486160856, " + "48.859636574404355,2.351469427048221, " + "48.87756059389807, 2.3443610121873206])", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + + // should work even if points of polygon are clockwise + + results = coll1->search("*", + {}, "loc: ([48.87756059389807, 2.3443610121873206, " + "48.859636574404355,2.351469427048221, " + "48.85745408145392, 2.3267084486160856, " + "48.875223042424125,2.323509661928681])", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPolygonFilteringSouthAmerica) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"North of Equator", "4.48615, -71.38049"}, + {"South of Equator", "-8.48587, -71.02892"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a polygon that covers both points + + auto results = coll1->search("*", + {}, "loc: ([13.3163, -82.3585, " + "-29.134, -82.3585, " + "-29.134, -59.8528, " + "13.3163, -59.8528])", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPointFilteringWithNonSortableLocationField) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string", "sort": false}, + {"name": "loc", "type": "geopoint", "sort": false}, + {"name": "points", "type": "int32", "sort": false} + ] + })"_json; + + auto coll_op = collectionManager.create_collection(schema); + ASSERT_TRUE(coll_op.ok()); + Collection* coll1 = coll_op.get(); + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.90615915923891, 2.3435897727061175], radius:3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} From 9f5a97459b31b411606bfbd912dd3985baf95ee1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 10:09:28 +0530 Subject: [PATCH 60/64] Skip exact filtering beyond threshold in case of geo polygon. --- src/filter.cpp | 61 +++++++++++++++++----------------- src/filter_result_iterator.cpp | 18 +++++----- test/geo_filtering_test.cpp | 44 +++++++++++++++--------- 3 files changed, 69 insertions(+), 54 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index 3b13b400..5946e5d1 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -172,7 +172,8 @@ Option validate_geofilter_distance(std::string& raw_value, const string& f Option filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) { // FORMAT: - // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ] + // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), + // ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469], exact_filter_radius: 100km) ] // Every open parenthesis represent a geo filter value. auto open_parenthesis_count = std::count(raw_value.begin(), raw_value.end(), '('); @@ -181,11 +182,9 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } filter_exp.comparators.push_back(LESS_THAN_EQUALS); - bool is_multivalued = raw_value[0] == '['; + bool is_multivalued = open_parenthesis_count > 1; size_t i = is_multivalued; - // Adding polygonal values at last since they don't have any parameters associated with them. - std::vector polygons; for (auto j = 0; j < open_parenthesis_count; j++) { if (is_multivalued) { auto pos = raw_value.find('(', i); @@ -206,57 +205,55 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } // [48.853, 2.344], radius: 1km, exact_filter_radius: 100km - // [48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469] + // [48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469], exact_filter_radius: 100km std::string value_str = raw_value.substr(i, value_end_index - i); StringUtils::trim(value_str); if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) { return Option(400, format_err_msg); - } else { - std::vector filter_values; - StringUtils::split(value_str, filter_values, ","); - - if(filter_values.size() < 3) { - return Option(400, format_err_msg); - } } auto points_str = value_str.substr(1, value_str.find(']', 1) - 1); std::vector geo_points; StringUtils::split(points_str, geo_points, ","); - bool is_polygon = value_str.back() == ']'; + if (geo_points.size() < 2 || geo_points.size() % 2) { + return Option(400, format_err_msg); + } + + bool is_polygon = geo_points.size() > 2; for (const auto& geo_point: geo_points) { - if (!StringUtils::is_float(geo_point) || - (!is_polygon && (geo_point == "nan" || geo_point == "NaN"))) { + if (geo_point == "nan" || geo_point == "NaN" || !StringUtils::is_float(geo_point)) { return Option(400, format_err_msg); } } if (is_polygon) { - polygons.push_back(points_str); - continue; + filter_exp.values.push_back(points_str); } // Handle options. // , radius: 1km, exact_filter_radius: 100km - i = raw_value.find(']', i); - i++; + i = raw_value.find(']', i) + 1; std::vector options; StringUtils::split(raw_value.substr(i, value_end_index - i), options, ","); if (options.empty()) { - // Missing radius option - return Option(400, format_err_msg); + if (!is_polygon) { + // Missing radius option + return Option(400, format_err_msg); + } + + nlohmann::json param; + param[EXACT_GEO_FILTER_RADIUS_KEY] = DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE; + filter_exp.params.push_back(param); + + continue; } bool is_radius_present = false; for (auto const& option: options) { - if (option.empty()) { - continue; - } - std::vector key_value; StringUtils::split(option, key_value, ":"); @@ -264,7 +261,7 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string continue; } - if (key_value[0] == GEO_FILTER_RADIUS_KEY) { + if (key_value[0] == GEO_FILTER_RADIUS_KEY && !is_polygon) { is_radius_present = true; std::string distance, unit; @@ -293,10 +290,16 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string nlohmann::json param; param[EXACT_GEO_FILTER_RADIUS_KEY] = exact_under_radius; filter_exp.params.push_back(param); + + // Only EXACT_GEO_FILTER_RADIUS_KEY option would be present for a polygon. We can also stop if we've + // parsed the radius in case of a single geopoint since there are only two options. + if (is_polygon || is_radius_present) { + break; + } } } - if (!is_radius_present) { + if (!is_radius_present && !is_polygon) { return Option(400, format_err_msg); } @@ -308,10 +311,6 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } } - for (auto const& polygon: polygons) { - filter_exp.values.push_back(polygon); - } - return Option(true); } diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d9b56683..a6b61134 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -765,7 +765,7 @@ void filter_result_iterator_t::init() { bool is_polygon = StringUtils::is_float(filter_value_parts.back()); S2Region* query_region; - double radius; + double query_radius_meters; if (is_polygon) { const int num_verts = int(filter_value_parts.size()) / 2; std::vector vertices; @@ -790,22 +790,24 @@ void filter_result_iterator_t::init() { } else { query_region = loop; } + + query_radius_meters = S2Earth::RadiansToMeters(query_region->GetCapBound().GetRadius().radians()); } else { - radius = std::stof(filter_value_parts[2]); + query_radius_meters = std::stof(filter_value_parts[2]); const auto& unit = filter_value_parts[3]; if (unit == "km") { - radius *= 1000; + query_radius_meters *= 1000; } else { // assume "mi" (validated upstream) - radius *= 1609.34; + query_radius_meters *= 1609.34; } - S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); + S1Angle query_radius_radians = S1Angle::Radians(S2Earth::MetersToRadians(query_radius_meters)); double query_lat = std::stod(filter_value_parts[0]); double query_lng = std::stod(filter_value_parts[1]); S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); - query_region = new S2Cap(center, query_radius); + query_region = new S2Cap(center, query_radius_radians); } std::unique_ptr query_region_guard(query_region); @@ -825,8 +827,8 @@ void filter_result_iterator_t::init() { geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); // Skip exact filtering step if query radius is greater than the threshold. - if (!is_polygon && fi < a_filter.params.size() && - radius > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get()) { + if (fi < a_filter.params.size() && + query_radius_meters > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get()) { uint32_t* out = nullptr; filter_result.count = ArrayUtils::or_scalar(geo_result_ids.data(), geo_result_ids.size(), filter_result.docs, filter_result.count, &out); diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index 9db67c0b..ab200262 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -85,7 +85,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_EQ(1, results["found"].get()); ASSERT_EQ(1, results["hits"].size()); - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", {}, {}, {0}, 10, 1, FREQUENCY).get(); @@ -115,9 +115,9 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_EQ(3, results["found"].get()); - ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_EQ("6", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("5", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("3", results["hits"][2]["document"]["id"].get()); // when geo query had NaN auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", @@ -206,8 +206,8 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ(2, results["found"].get()); ASSERT_EQ(2, results["hits"].size()); - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); // Default value of exact_filter_radius is 10km, exact filtering is not performed. results = coll1->search("*", @@ -217,9 +217,9 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); - ASSERT_STREQ("2", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); // pick location close to none of the spots results = coll1->search("*", @@ -244,7 +244,7 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ(1, results["found"].get()); - ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); collectionManager.drop_collection("coll1"); } @@ -355,9 +355,9 @@ TEST_F(GeoFilteringTest, GeoPolygonFiltering) { ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); - ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_EQ("8", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("4", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); // should work even if points of polygon are clockwise @@ -389,6 +389,8 @@ TEST_F(GeoFilteringTest, GeoPolygonFilteringSouthAmerica) { std::vector> records = { {"North of Equator", "4.48615, -71.38049"}, {"South of Equator", "-8.48587, -71.02892"}, + {"North of Equator, outside polygon", "4.13377, -56.00459"}, + {"South of Equator, outside polygon", "-4.5041, -57.34523"}, }; for(size_t i=0; iadd(doc.dump()).ok()); } - // pick a polygon that covers both points - + // polygon only covers 2 points but all points are returned since exact filtering is not performed. auto results = coll1->search("*", {}, "loc: ([13.3163, -82.3585, " "-29.134, -82.3585, " @@ -417,9 +418,22 @@ TEST_F(GeoFilteringTest, GeoPolygonFilteringSouthAmerica) { "13.3163, -59.8528])", {}, {}, {0}, 10, 1, FREQUENCY).get(); + ASSERT_EQ(4, results["found"].get()); + ASSERT_EQ(4, results["hits"].size()); + + results = coll1->search("*", + {}, "loc: ([13.3163, -82.3585, " + "-29.134, -82.3585, " + "-29.134, -59.8528, " + "13.3163, -59.8528], exact_filter_radius: 2703km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + ASSERT_EQ(2, results["found"].get()); ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + collectionManager.drop_collection("coll1"); } From e6f8a52e5f87cc8e0e0edce5949fb90766cc5ce8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:02:47 +0530 Subject: [PATCH 61/64] Add test cases. --- src/filter.cpp | 2 +- test/geo_filtering_test.cpp | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index 5946e5d1..f14edb60 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -562,7 +562,7 @@ Option toFilter(const std::string expression, const std::string& format_err_msg = "Value of filter field `" + _field.name + "`: must be in the " "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format."; + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format."; auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, filter_exp); return parse_op; diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index ab200262..f8c187f9 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -126,7 +126,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_FALSE(gop.ok()); ASSERT_EQ("Value of filter field `loc`: must be in the " "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); // when geo query does not send radius key gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", @@ -135,7 +135,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_FALSE(gop.ok()); ASSERT_EQ("Value of filter field `loc`: must be in the " "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); collectionManager.drop_collection("coll1"); } @@ -371,6 +371,21 @@ TEST_F(GeoFilteringTest, GeoPolygonFiltering) { ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: ([48.87756059389807, 2.3443610121873206, NaN, nan])", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); + + gop = coll1->search("*", {}, "loc: ([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7k)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Unit must be either `km` or `mi`.", gop.error()); + collectionManager.drop_collection("coll1"); } From 022d010efc21d13d0705fc8653bc03b068c9bd11 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:07:24 +0530 Subject: [PATCH 62/64] Add comment. --- include/filter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/filter.h b/include/filter.h index f9f15c1d..f52b086f 100644 --- a/include/filter.h +++ b/include/filter.h @@ -34,7 +34,7 @@ struct filter { /// Radius threshold beyond which exact filtering on geo_result_ids will not be done. static constexpr const char* EXACT_GEO_FILTER_RADIUS_KEY = "exact_filter_radius"; - static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000; + static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000; // meters static const std::string RANGE_OPERATOR() { return ".."; From 43047f58f60d43ccedfbbbf43f6f0cf4fd5fbab2 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:14:59 +0530 Subject: [PATCH 63/64] Allow even a single value with square bracket notation. --- src/filter.cpp | 2 +- test/geo_filtering_test.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index f14edb60..c152d77f 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -182,7 +182,7 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } filter_exp.comparators.push_back(LESS_THAN_EQUALS); - bool is_multivalued = open_parenthesis_count > 1; + bool is_multivalued = raw_value[0] == '['; size_t i = is_multivalued; for (auto j = 0; j < open_parenthesis_count; j++) { diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index f8c187f9..0cd45d69 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -87,6 +87,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + // Multiple queries can be clubbed using square brackets [ filterA, filterB, ... ] results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", {}, {}, {0}, 10, 1, FREQUENCY).get(); @@ -94,7 +95,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { // pick location close to none of the spots results = coll1->search("*", - {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", + {}, "loc: [([48.910544830985785, 2.337218333651177], radius: 2 km)]", {}, {}, {0}, 10, 1, FREQUENCY).get(); ASSERT_EQ(0, results["found"].get()); From aa05071f5c2781c49c532cb997b1cf0b38644bc8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:40:26 +0530 Subject: [PATCH 64/64] Add `geo_filtering_old_test.cpp`. --- test/collection_filtering_test.cpp | 508 --------------------------- test/geo_filtering_old_test.cpp | 544 +++++++++++++++++++++++++++++ test/geo_filtering_test.cpp | 90 +++++ 3 files changed, 634 insertions(+), 508 deletions(-) create mode 100644 test/geo_filtering_old_test.cpp diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index c644d547..988b035a 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1049,514 +1049,6 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) { collectionManager.drop_collection("coll_array_fields"); } -TEST_F(CollectionFilteringTest, GeoPointFiltering) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - - - results = coll1->search("*", {}, "loc: (48.90615, 2.34358, 1 km) || " - "loc: (48.8462, 2.34515, 1 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - - // pick location close to none of the spots - results = coll1->search("*", - {}, "loc: (48.910544830985785, 2.337218333651177, 2 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(0, results["found"].get()); - - // pick a large radius covering all points - - results = coll1->search("*", - {}, "loc: (48.910544830985785, 2.337218333651177, 20 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(10, results["found"].get()); - - // 1 mile radius - - results = coll1->search("*", - {}, "loc: (48.85825332869331, 2.303816427653377, 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - - ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); - - // when geo query had NaN - auto gop = coll1->search("*", {}, "loc: (NaN, nan, 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY); - - ASSERT_FALSE(gop.ok()); - ASSERT_EQ("Value of filter field `loc`: must be in the `(-44.50, 170.29, 0.75 km)` or " - "(56.33, -65.97, 23.82, -127.82) format.", gop.error()); - - // when geo field is formatted as string, show meaningful error - nlohmann::json bad_doc; - bad_doc["id"] = "1000"; - bad_doc["title"] = "Test record"; - bad_doc["loc"] = {"48.91", "2.33"}; - bad_doc["points"] = 1000; - - auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - bad_doc["loc"] = "foobar"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); - - bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2k)"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); - - bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2)"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); - - bad_doc["loc"] = {"foo", "bar"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - bad_doc["loc"] = {"2.33", "bar"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - bad_doc["loc"] = {"foo", "2.33"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - // under coercion mode, it should work - bad_doc["loc"] = {"48.91", "2.33"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_TRUE(add_op.ok()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT_ARRAY, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector>> records = { - { {"Alpha Inc", "Ennore", "13.22112, 80.30511"}, - {"Alpha Inc", "Velachery", "12.98973, 80.23095"} - }, - - { - {"Veera Inc", "Thiruvallur", "13.12752, 79.90136"}, - }, - - { - {"B1 Inc", "Bengaluru", "12.98246, 77.5847"}, - {"B1 Inc", "Hosur", "12.74147, 77.82915"}, - {"B1 Inc", "Vellore", "12.91866, 79.13075"}, - }, - - { - {"M Inc", "Nashik", "20.11282, 73.79458"}, - {"M Inc", "Pune", "18.56309, 73.855"}, - } - }; - - for(size_t i=0; i> lat_lngs; - for(size_t k = 0; k < records[i].size(); k++) { - std::vector lat_lng_str; - StringUtils::split(records[i][k][2], lat_lng_str, ", "); - - std::vector lat_lng = { - std::stod(lat_lng_str[0]), - std::stod(lat_lng_str[1]) - }; - - lat_lngs.push_back(lat_lng); - } - - doc["loc"] = lat_lngs; - auto add_op = coll1->add(doc.dump()); - ASSERT_TRUE(add_op.ok()); - } - - // pick a location close to Chennai - auto results = coll1->search("*", - {}, "loc: (13.12631, 80.20252, 100km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - ASSERT_EQ(2, results["hits"].size()); - - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); - - // pick location close to none of the spots - results = coll1->search("*", - {}, "loc: (13.62601, 79.39559, 10 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(0, results["found"].get()); - - // pick a large radius covering all points - - results = coll1->search("*", - {}, "loc: (21.20714729927276, 78.99153966917213, 1000 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(4, results["found"].get()); - - // 1 mile radius - - results = coll1->search("*", - {}, "loc: (12.98941, 80.23073, 1mi)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - - ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); - - // when geo field is formatted badly, show meaningful error - nlohmann::json bad_doc; - bad_doc["id"] = "1000"; - bad_doc["title"] = "Test record"; - bad_doc["loc"] = {"48.91", "2.33"}; - bad_doc["points"] = 1000; - - auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must contain 2 element arrays: [ [lat, lng],... ].", add_op.error()); - - bad_doc["loc"] = "foobar"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array.", add_op.error()); - - bad_doc["loc"] = nlohmann::json::array(); - nlohmann::json points = nlohmann::json::array(); - points.push_back("foo"); - points.push_back("bar"); - bad_doc["loc"].push_back(points); - - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); - - bad_doc["loc"][0][0] = "2.33"; - bad_doc["loc"][0][1] = "bar"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); - - bad_doc["loc"][0][0] = "foo"; - bad_doc["loc"][0][1] = "2.33"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); - - // under coercion mode, it should work - bad_doc["loc"][0][0] = "48.91"; - bad_doc["loc"][0][1] = "2.33"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_TRUE(add_op.ok()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPointRemoval) { - std::vector fields = {field("title", field_types::STRING, false), - field("loc1", field_types::GEOPOINT, false), - field("loc2", field_types::GEOPOINT_ARRAY, false), - field("points", field_types::INT32, false),}; - - Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - - nlohmann::json doc; - doc["id"] = "0"; - doc["title"] = "Palais Garnier"; - doc["loc1"] = {48.872576479306765, 2.332291112241466}; - doc["loc2"] = nlohmann::json::array(); - doc["loc2"][0] = {48.84620987789056, 2.345152755563131}; - doc["points"] = 100; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - - auto results = coll1->search("*", - {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - results = coll1->search("*", - {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - // remove the document, index another document and try querying again - coll1->remove("0"); - doc["id"] = "1"; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - - results = coll1->search("*", - {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - results = coll1->search("*", - {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); -} - -TEST_F(CollectionFilteringTest, GeoPolygonFiltering) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: (48.875223042424125,2.323509661928681, " - "48.85745408145392, 2.3267084486160856, " - "48.859636574404355,2.351469427048221, " - "48.87756059389807, 2.3443610121873206)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - ASSERT_EQ(3, results["hits"].size()); - - ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); - - // should work even if points of polygon are clockwise - - results = coll1->search("*", - {}, "loc: (48.87756059389807, 2.3443610121873206, " - "48.859636574404355,2.351469427048221, " - "48.85745408145392, 2.3267084486160856, " - "48.875223042424125,2.323509661928681)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - ASSERT_EQ(3, results["hits"].size()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPolygonFilteringSouthAmerica) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"North of Equator", "4.48615, -71.38049"}, - {"South of Equator", "-8.48587, -71.02892"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a polygon that covers both points - - auto results = coll1->search("*", - {}, "loc: (13.3163, -82.3585, " - "-29.134, -82.3585, " - "-29.134, -59.8528, " - "13.3163, -59.8528)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - ASSERT_EQ(2, results["hits"].size()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPointFilteringWithNonSortableLocationField) { - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - nlohmann::json schema = R"({ - "name": "coll1", - "fields": [ - {"name": "title", "type": "string", "sort": false}, - {"name": "loc", "type": "geopoint", "sort": false}, - {"name": "points", "type": "int32", "sort": false} - ] - })"_json; - - auto coll_op = collectionManager.create_collection(schema); - ASSERT_TRUE(coll_op.ok()); - Collection* coll1 = coll_op.get(); - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); -} - TEST_F(CollectionFilteringTest, FilteringWithPrefixSearch) { Collection *coll1; diff --git a/test/geo_filtering_old_test.cpp b/test/geo_filtering_old_test.cpp new file mode 100644 index 00000000..f5f37684 --- /dev/null +++ b/test/geo_filtering_old_test.cpp @@ -0,0 +1,544 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class GeoFilteringOldTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(GeoFilteringOldTest, GeoPointFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + + results = coll1->search("*", {}, "loc: (48.90615, 2.34358, 1 km) || " + "loc: (48.8462, 2.34515, 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: (48.910544830985785, 2.337218333651177, 2 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: (48.910544830985785, 2.337218333651177, 20 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(10, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: (48.85825332869331, 2.303816427653377, 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: (NaN, nan, 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the `(-44.50, 170.29, 0.75 km)` or " + "(56.33, -65.97, 23.82, -127.82) format.", gop.error()); + + // when geo field is formatted as string, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2k)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = {"foo", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"2.33", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"foo", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"] = {"48.91", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPointArrayFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector>> records = { + { {"Alpha Inc", "Ennore", "13.22112, 80.30511"}, + {"Alpha Inc", "Velachery", "12.98973, 80.23095"} + }, + + { + {"Veera Inc", "Thiruvallur", "13.12752, 79.90136"}, + }, + + { + {"B1 Inc", "Bengaluru", "12.98246, 77.5847"}, + {"B1 Inc", "Hosur", "12.74147, 77.82915"}, + {"B1 Inc", "Vellore", "12.91866, 79.13075"}, + }, + + { + {"M Inc", "Nashik", "20.11282, 73.79458"}, + {"M Inc", "Pune", "18.56309, 73.855"}, + } + }; + + for(size_t i=0; i> lat_lngs; + for(size_t k = 0; k < records[i].size(); k++) { + std::vector lat_lng_str; + StringUtils::split(records[i][k][2], lat_lng_str, ", "); + + std::vector lat_lng = { + std::stod(lat_lng_str[0]), + std::stod(lat_lng_str[1]) + }; + + lat_lngs.push_back(lat_lng); + } + + doc["loc"] = lat_lngs; + auto add_op = coll1->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + } + + // pick a location close to Chennai + auto results = coll1->search("*", + {}, "loc: (13.12631, 80.20252, 100km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: (13.62601, 79.39559, 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: (21.20714729927276, 78.99153966917213, 1000 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(4, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: (12.98941, 80.23073, 1mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + // when geo field is formatted badly, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must contain 2 element arrays: [ [lat, lng],... ].", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array.", add_op.error()); + + bad_doc["loc"] = nlohmann::json::array(); + nlohmann::json points = nlohmann::json::array(); + points.push_back("foo"); + points.push_back("bar"); + bad_doc["loc"].push_back(points); + + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "2.33"; + bad_doc["loc"][0][1] = "bar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "foo"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"][0][0] = "48.91"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPointRemoval) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc1", field_types::GEOPOINT, false), + field("loc2", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "Palais Garnier"; + doc["loc1"] = {48.872576479306765, 2.332291112241466}; + doc["loc2"] = nlohmann::json::array(); + doc["loc2"][0] = {48.84620987789056, 2.345152755563131}; + doc["points"] = 100; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("*", + {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + // remove the document, index another document and try querying again + coll1->remove("0"); + doc["id"] = "1"; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + results = coll1->search("*", + {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} + +TEST_F(GeoFilteringOldTest, GeoPolygonFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: (48.875223042424125,2.323509661928681, " + "48.85745408145392, 2.3267084486160856, " + "48.859636574404355,2.351469427048221, " + "48.87756059389807, 2.3443610121873206)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + + // should work even if points of polygon are clockwise + + results = coll1->search("*", + {}, "loc: (48.87756059389807, 2.3443610121873206, " + "48.859636574404355,2.351469427048221, " + "48.85745408145392, 2.3267084486160856, " + "48.875223042424125,2.323509661928681)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPolygonFilteringSouthAmerica) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"North of Equator", "4.48615, -71.38049"}, + {"South of Equator", "-8.48587, -71.02892"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a polygon that covers both points + + auto results = coll1->search("*", + {}, "loc: (13.3163, -82.3585, " + "-29.134, -82.3585, " + "-29.134, -59.8528, " + "13.3163, -59.8528)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPointFilteringWithNonSortableLocationField) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string", "sort": false}, + {"name": "loc", "type": "geopoint", "sort": false}, + {"name": "points", "type": "int32", "sort": false} + ] + })"_json; + + auto coll_op = collectionManager.create_collection(schema); + ASSERT_TRUE(coll_op.ok()); + Collection* coll1 = coll_op.get(); + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} \ No newline at end of file diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index 0cd45d69..526d44ae 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -138,6 +138,52 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); + // when geo field is formatted as string, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2k)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = {"foo", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"2.33", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"foo", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"] = {"48.91", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + collectionManager.drop_collection("coll1"); } @@ -247,6 +293,50 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + // when geo field is formatted badly, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must contain 2 element arrays: [ [lat, lng],... ].", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array.", add_op.error()); + + bad_doc["loc"] = nlohmann::json::array(); + nlohmann::json points = nlohmann::json::array(); + points.push_back("foo"); + points.push_back("bar"); + bad_doc["loc"].push_back(points); + + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "2.33"; + bad_doc["loc"][0][1] = "bar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "foo"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"][0][0] = "48.91"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + collectionManager.drop_collection("coll1"); }