From f32accdd87b85614de366686e2b723d3ef153427 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Sat, 18 Mar 2023 17:25:25 +0530 Subject: [PATCH 01/93] Add `filter_result_iterator_t`. --- include/filter.h | 81 ++++++++ include/index.h | 2 + include/option.h | 12 ++ include/posting_list.h | 2 + src/filter.cpp | 430 +++++++++++++++++++++++++++++++++++++++++ src/posting_list.cpp | 36 ++++ test/filter_test.cpp | 182 +++++++++++++++++ 7 files changed, 745 insertions(+) create mode 100644 include/filter.h create mode 100644 src/filter.cpp create mode 100644 test/filter_test.cpp diff --git a/include/filter.h b/include/filter.h new file mode 100644 index 00000000..239a7e77 --- /dev/null +++ b/include/filter.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include "index.h" + +class filter_result_iterator_t { +private: + std::string collection_name; + const Index* index; + filter_node_t* filter_node; + filter_result_iterator_t* left_it = nullptr; + filter_result_iterator_t* right_it = nullptr; + + // Used in case of id and reference filter. + uint32_t result_index = 0; + + // Stores the result of the filters that cannot be iterated. + filter_result_t filter_result; + + // Initialized in case of filter on string field. + // Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + // for each token. + // + // Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_list_iterators; + std::vector expanded_plists; + + // Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initializes the state of iterator node after it's creation. + void init(); + + /// Performs AND on the subtrees of operator. + void and_filter_iterators(); + + /// Performs OR on the subtrees of operator. + void or_filter_iterators(); + + /// Finds the next match for a filter on string field. + void doc_matching_string_filter(); + +public: + uint32_t doc; + // Collection name -> references + std::map reference; + Option status; + + explicit filter_result_iterator_t(const std::string& collection_name, + const Index* index, filter_node_t* filter_node, + Option& status) : + collection_name(collection_name), + index(index), + filter_node(filter_node), + status(status) { + // Generate the iterator tree and then initialize each node. + if (filter_node->isOperator) { + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, status); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, status); + } + + init(); + } + + ~filter_result_iterator_t() { + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + } + + [[nodiscard]] bool valid(); + + void next(); + + void skip_to(uint32_t id); +}; diff --git a/include/index.h b/include/index.h index ae14f06c..63681b9e 100644 --- a/include/index.h +++ b/include/index.h @@ -957,6 +957,8 @@ public: Option seq_ids_outside_top_k(const std::string& field_name, size_t k, std::vector& outside_seq_ids); + + friend class filter_result_iterator_t; }; template diff --git a/include/option.h b/include/option.h index 0f8c49e0..ced54ae9 100644 --- a/include/option.h +++ b/include/option.h @@ -31,6 +31,18 @@ public: error_code = obj.error_code; } + Option& operator=(Option&& obj) noexcept { + if (&obj == this) + return *this; + + value = obj.value; + is_ok = obj.is_ok; + error_msg = obj.error_msg; + error_code = obj.error_code; + + return *this; + } + bool ok() const { return is_ok; } diff --git a/include/posting_list.h b/include/posting_list.h index dea742f7..16f42ea7 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -164,6 +164,8 @@ public: static void intersect(const std::vector& posting_lists, std::vector& result_ids); + static void intersect(std::vector& posting_list_iterators, bool& is_valid); + template static bool block_intersect( std::vector& its, diff --git a/src/filter.cpp b/src/filter.cpp new file mode 100644 index 00000000..804e19a4 --- /dev/null +++ b/src/filter.cpp @@ -0,0 +1,430 @@ +#include +#include +#include +#include "filter.h" + +void filter_result_iterator_t::and_filter_iterators() { + while (left_it->valid() && right_it->valid()) { + while (left_it->doc < right_it->doc) { + left_it->next(); + if (!left_it->valid()) { + is_valid = false; + return; + } + } + + while (left_it->doc > right_it->doc) { + right_it->next(); + if (!right_it->valid()) { + is_valid = false; + return; + } + } + + if (left_it->doc == right_it->doc) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + } + + is_valid = false; +} + +void filter_result_iterator_t::or_filter_iterators() { + if (left_it->valid() && right_it->valid()) { + if (left_it->doc < right_it->doc) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->doc > right_it->doc) { + doc = right_it->doc; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->valid()) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (right_it->valid()) { + doc = right_it->doc; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + is_valid = false; +} + +void filter_result_iterator_t::doc_matching_string_filter() { + // If none of the filter value iterators are valid, mark this node as invalid. + bool one_is_valid = false; + + // Since we do OR between filter values, the lowest doc id from all is selected. + uint32_t lowest_id = UINT32_MAX; + + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + + if (one_is_valid) { + doc = lowest_id; + } + + is_valid = one_is_valid; +} + +void filter_result_iterator_t::next() { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Advance all the filter values that are at doc. Then find the next one. + std::vector doc_matching_indexes; + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + const auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { + doc_matching_indexes.push_back(i); + } + } + + for (const auto &lowest_id_index: doc_matching_indexes) { + for (auto &iter: posting_list_iterators[lowest_id_index]) { + iter.next(); + } + } + + doc_matching_string_filter(); + return; + } +} + +void filter_result_iterator_t::init() { + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + if (collection == nullptr) { + status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + is_valid = false; + return; + } + + auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, + filter_result, + collection_name); + if (!reference_filter_op.ok()) { + status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name + + "` collection: " + reference_filter_op.error()); + is_valid = false; + return; + } + + is_valid = filter_result.count > 0; + return; + } + + if (a_filter.field_name == "id") { + if (a_filter.values.empty()) { + is_valid = false; + return; + } + + // we handle `ids` separately + std::vector result_ids; + for (const auto& id_str : a_filter.values) { + result_ids.push_back(std::stoul(id_str)); + } + + std::sort(result_ids.begin(), result_ids.end()); + + filter_result.count = result_ids.size(); + filter_result.docs = new uint32_t[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + art_tree* t = index->search_index.at(a_filter.field_name); + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + std::vector plists; + posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + + posting_list_iterators.emplace_back(std::vector()); + + for (auto const& plist: plists) { + posting_list_iterators.back().push_back(plist->new_iterator()); + } + } + + doc_matching_string_filter(); + return; + } +} + +bool filter_result_iterator_t::valid() { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + is_valid = left_it->valid() && right_it->valid(); + return is_valid; + } else { + is_valid = left_it->valid() || right_it->valid(); + return is_valid; + } + } + + const filter a_filter = filter_node->filter_exp; + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + is_valid = result_index < filter_result.count; + return is_valid; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return is_valid; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + bool one_is_valid = false; + for (auto& filter_value_tokens: posting_list_iterators) { + posting_list_t::intersect(filter_value_tokens, one_is_valid); + + if (one_is_valid) { + break; + } + } + + is_valid = one_is_valid; + return is_valid; + } + + return true; +} + +void filter_result_iterator_t::skip_to(uint32_t id) { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + if (filter_node->filter_operator == AND) { + left_it->skip_to(id); + and_filter_iterators(); + } else { + right_it->skip_to(id); + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Skip all the token iterators and find a new match. + for (auto& filter_value_tokens : posting_list_iterators) { + for (auto& token: filter_value_tokens) { + // We perform AND on tokens. Short-circuiting here. + if (!token.valid()) { + break; + } + + token.skip_to(id); + } + } + + doc_matching_string_filter(); + return; + } +} diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 23f17e8b..39f3ac00 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -754,6 +754,42 @@ void posting_list_t::intersect(const std::vector& posting_lists } } +void posting_list_t::intersect(std::vector& posting_list_iterators, bool& is_valid) { + if (posting_list_iterators.empty()) { + is_valid = false; + return; + } + + if (posting_list_iterators.size() == 1) { + is_valid = posting_list_iterators.front().valid(); + return; + } + + switch (posting_list_iterators.size()) { + case 2: + while(!at_end2(posting_list_iterators)) { + if(equals2(posting_list_iterators)) { + is_valid = true; + return; + } else { + advance_non_largest2(posting_list_iterators); + } + } + is_valid = false; + break; + default: + while(!at_end(posting_list_iterators)) { + if(equals(posting_list_iterators)) { + is_valid = true; + return; + } else { + advance_non_largest(posting_list_iterators); + } + } + is_valid = false; + } +} + bool posting_list_t::take_id(result_iter_state_t& istate, uint32_t id) { // decide if this result id should be excluded if(istate.excluded_result_ids_size != 0) { diff --git a/test/filter_test.cpp b/test/filter_test.cpp new file mode 100644 index 00000000..60cff669 --- /dev/null +++ b/test/filter_test.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class FilterTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_join"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(FilterTest, FilterTreeIterator) { + nlohmann::json schema = + R"({ + "name": "Collection", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int32"}, + {"name": "years", "type": "int32[]"}, + {"name": "rating", "type": "float"}, + {"name": "tags", "type": "string[]"} + ] + })"_json; + + Collection* coll = collectionManager.create_collection(schema).get(); + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::string json_line; + while (std::getline(infile, json_line)) { + auto add_op = coll->add(json_line); + ASSERT_TRUE(add_op.ok()); + } + infile.close(); + + const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_"; + filter_node_t* filter_tree_root = nullptr; + Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + Option iter_op(true); + auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_FALSE(iter_no_match_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: [foo bar, baz]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_FALSE(iter_no_match_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: Jeremy", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_contains_test.valid()); + ASSERT_EQ(i, iter_contains_test.doc); + iter_contains_test.next(); + } + ASSERT_FALSE(iter_contains_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: [Jeremy, Howard, Richard]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_contains_multi_test.valid()); + ASSERT_EQ(i, iter_contains_multi_test.doc); + iter_contains_multi_test.next(); + } + ASSERT_FALSE(iter_contains_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name:= Jeremy Howard", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_exact_match_test.valid()); + ASSERT_EQ(i, iter_exact_match_test.doc); + iter_exact_match_test.next(); + } + ASSERT_FALSE(iter_exact_match_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= [gold, silver]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + std::vector expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_exact_match_multi_test.valid()); + ASSERT_EQ(i, iter_exact_match_multi_test.doc); + iter_exact_match_multi_test.next(); + } + ASSERT_FALSE(iter_exact_match_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + +// delete filter_tree_root; +// filter_tree_root = nullptr; +// filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, +// filter_tree_root); +// ASSERT_TRUE(filter_op.ok()); +// +// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); +// +// std::vector expected = {1, 3}; +// for (auto const& i : expected) { +// ASSERT_TRUE(iter_not_equals_test.valid()); +// ASSERT_EQ(i, iter_not_equals_test.doc); +// iter_not_equals_test.next(); +// } +// +// ASSERT_FALSE(iter_not_equals_test.valid()); +// ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_skip_test.valid()); + iter_skip_test.skip_to(3); + ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_EQ(4, iter_skip_test.doc); + iter_skip_test.next(); + + ASSERT_FALSE(iter_skip_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; +} \ No newline at end of file From 647942b44d33a0092c9f3ddb59d8673ad181ac74 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 10:33:01 +0530 Subject: [PATCH 02/93] Add `filter_result_iterator_t::valid(uint32_t id)`. --- include/filter.h | 23 +++++++++++++++-------- src/filter.cpp | 26 ++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/include/filter.h b/include/filter.h index 239a7e77..d11a8f9d 100644 --- a/include/filter.h +++ b/include/filter.h @@ -12,21 +12,21 @@ private: filter_result_iterator_t* left_it = nullptr; filter_result_iterator_t* right_it = nullptr; - // Used in case of id and reference filter. + /// Used in case of id and reference filter. uint32_t result_index = 0; - // Stores the result of the filters that cannot be iterated. + /// Stores the result of the filters that cannot be iterated. filter_result_t filter_result; - // Initialized in case of filter on string field. - // Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator - // for each token. - // - // Multiple filter values: Multiple tokens: posting list iterator + /// Initialized in case of filter on string field. + /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + /// for each token. + /// + /// Multiple filter values: Multiple tokens: posting list iterator std::vector> posting_list_iterators; std::vector expanded_plists; - // Set to false when this iterator or it's subtree becomes invalid. + /// Set to false when this iterator or it's subtree becomes invalid. bool is_valid = true; /// Initializes the state of iterator node after it's creation. @@ -73,9 +73,16 @@ public: delete right_it; } + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). [[nodiscard]] bool valid(); + /// Returns true when id is a match to the filter. Handles moving the individual iterators internally. + [[nodiscard]] bool valid(uint32_t id); + + /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this + /// operation. void next(); + /// Advances the iterator until the doc value is less than id. The iterator may become invalid during this operation. void skip_to(uint32_t id); }; diff --git a/src/filter.cpp b/src/filter.cpp index 804e19a4..a8fa1eeb 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -361,11 +361,12 @@ void filter_result_iterator_t::skip_to(uint32_t id) { if (filter_node->isOperator) { // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + left_it->skip_to(id); + right_it->skip_to(id); + if (filter_node->filter_operator == AND) { - left_it->skip_to(id); and_filter_iterators(); } else { - right_it->skip_to(id); or_filter_iterators(); } @@ -428,3 +429,24 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } } + +bool filter_result_iterator_t::valid(uint32_t id) { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + auto and_is_valid = left_it->valid(id) && right_it->valid(id); + is_valid = left_it->is_valid && right_it->is_valid; + return and_is_valid; + } else { + auto or_is_valid = left_it->valid(id) || right_it->valid(id); + is_valid = left_it->is_valid || right_it->is_valid; + return or_is_valid; + } + } + + skip_to(id); + return is_valid && doc == id; +} From 1b88b9afe318985cc68502cc66ccf29f648014fd Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 12:56:53 +0530 Subject: [PATCH 03/93] Add test cases for `AND` and `OR`. --- include/filter.h | 3 ++- src/filter.cpp | 24 ++++++++++++++++++------ test/filter_test.cpp | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/include/filter.h b/include/filter.h index d11a8f9d..643c5006 100644 --- a/include/filter.h +++ b/include/filter.h @@ -83,6 +83,7 @@ public: /// operation. void next(); - /// Advances the iterator until the doc value is less than id. The iterator may become invalid during this operation. + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during + /// this operation. void skip_to(uint32_t id); }; diff --git a/src/filter.cpp b/src/filter.cpp index a8fa1eeb..b42c3e5d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -4,10 +4,10 @@ #include "filter.h" void filter_result_iterator_t::and_filter_iterators() { - while (left_it->valid() && right_it->valid()) { + while (left_it->is_valid && right_it->is_valid) { while (left_it->doc < right_it->doc) { left_it->next(); - if (!left_it->valid()) { + if (!left_it->is_valid) { is_valid = false; return; } @@ -15,7 +15,7 @@ void filter_result_iterator_t::and_filter_iterators() { while (left_it->doc > right_it->doc) { right_it->next(); - if (!right_it->valid()) { + if (!right_it->is_valid) { is_valid = false; return; } @@ -40,7 +40,7 @@ void filter_result_iterator_t::and_filter_iterators() { } void filter_result_iterator_t::or_filter_iterators() { - if (left_it->valid() && right_it->valid()) { + if (left_it->is_valid && right_it->is_valid) { if (left_it->doc < right_it->doc) { doc = left_it->doc; reference.clear(); @@ -76,7 +76,7 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - if (left_it->valid()) { + if (left_it->is_valid) { doc = left_it->doc; reference.clear(); @@ -87,7 +87,7 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - if (right_it->valid()) { + if (right_it->is_valid) { doc = right_it->doc; reference.clear(); @@ -133,9 +133,21 @@ void filter_result_iterator_t::next() { } if (filter_node->isOperator) { + // Advance the subtrees and then apply operators to arrive at the next valid doc. if (filter_node->filter_operator == AND) { + left_it->next(); + right_it->next(); and_filter_iterators(); } else { + if (left_it->doc == doc && right_it->doc == doc) { + left_it->next(); + right_it->next(); + } else if (left_it->doc == doc) { + left_it->next(); + } else { + right_it->next(); + } + or_filter_iterators(); } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 60cff669..99edf27f 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -178,5 +178,49 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_skip_test.valid()); ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: jeremy && tags: fine platinum", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_and_test.valid()); + ASSERT_EQ(1, iter_and_test.doc); + iter_and_test.next(); + + ASSERT_FALSE(iter_and_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || tags: bronze", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto doc = + R"({ + "name": "James Rowdy", + "age": 36, + "years": [2005, 2022], + "rating": 6.03, + "tags": ["copper"] + })"_json; + auto add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + + auto iter_or_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + expected = {2, 4, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_or_test.valid()); + ASSERT_EQ(i, iter_or_test.doc); + iter_or_test.next(); + } + + ASSERT_FALSE(iter_or_test.valid()); + ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; } \ No newline at end of file From c29268c52a4867b645f61269ccc8c72f5b7f0cb7 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 13:22:04 +0530 Subject: [PATCH 04/93] Add test case for complex filter. --- test/filter_test.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 99edf27f..ee8b7b00 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -222,5 +222,26 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_or_test.valid()); ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || (tags: gold && tags: silver)", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + iter_skip_complex_filter_test.skip_to(4); + + expected = {4, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + ASSERT_EQ(i, iter_skip_complex_filter_test.doc); + iter_skip_complex_filter_test.next(); + } + + ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; } \ No newline at end of file From 198488d42820207561dff973e9b3504c232817a5 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 15:50:14 +0530 Subject: [PATCH 05/93] Add `filter_result_iterator_t::valid(uint32_t id)` test case. --- test/filter_test.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/filter_test.cpp b/test/filter_test.cpp index ee8b7b00..96121e92 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -243,5 +243,21 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_skip_complex_filter_test.valid()); ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || (tags: gold && tags: [silver, bronze])", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + expected = {0, 2, 4, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_validate_ids_test.valid(i)); + } + + ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_op.ok()); + delete filter_tree_root; } \ No newline at end of file From 709e65ccbd750f9706c03e3b8f3aca556fd56178 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 21 Mar 2023 18:01:26 +0530 Subject: [PATCH 06/93] Handle `apply_not_equals`. --- src/filter.cpp | 17 +++++++++++++++++ test/filter_test.cpp | 17 ++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/filter.cpp b/src/filter.cpp index b42c3e5d..e7a46093 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -459,6 +459,23 @@ bool filter_result_iterator_t::valid(uint32_t id) { } } + if (filter_node->filter_exp.apply_not_equals) { + // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. + if (!valid()) { + is_valid = true; + return is_valid; + } + + skip_to(id); + + if (!is_valid) { + is_valid = true; + return is_valid; + } + + return doc != id; + } + skip_to(id); return is_valid && doc == id; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 96121e92..3c4b94a5 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -256,7 +256,22 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_validate_ids_test.valid(i)); } - ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: James || tags: != gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_not_equals_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root, iter_op); + + expected = {1, 3, 5}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.valid(i)); + } + ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; From fdbf6f01b06d8d7703a0ab78cee48e2a07aed81f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 22 Mar 2023 10:46:01 +0530 Subject: [PATCH 07/93] Refactor `filter_result_iterator_t::next()`. --- src/filter.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index e7a46093..2c289761 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -190,19 +190,14 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { - // Advance all the filter values that are at doc. Then find the next one. - std::vector doc_matching_indexes; + // Advance all the filter values that are at doc. Then find the next doc. for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { - const auto& filter_value_tokens = posting_list_iterators[i]; + auto& filter_value_tokens = posting_list_iterators[i]; if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { - doc_matching_indexes.push_back(i); - } - } - - for (const auto &lowest_id_index: doc_matching_indexes) { - for (auto &iter: posting_list_iterators[lowest_id_index]) { - iter.next(); + for (auto& iter: filter_value_tokens) { + iter.next(); + } } } @@ -363,7 +358,7 @@ bool filter_result_iterator_t::valid() { return is_valid; } - return true; + return false; } void filter_result_iterator_t::skip_to(uint32_t id) { From f3ddbd44aa33c85d3bf688bb373700df8a616191 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 22 Mar 2023 13:23:40 +0530 Subject: [PATCH 08/93] Add `filter_result_iterator_t::init_status()`. --- include/filter.h | 17 +++++----- src/filter.cpp | 54 ++++++++++++++++++------------- test/filter_test.cpp | 76 ++++++++++++++++++++++---------------------- 3 files changed, 79 insertions(+), 68 deletions(-) diff --git a/include/filter.h b/include/filter.h index 643c5006..c5a20f11 100644 --- a/include/filter.h +++ b/include/filter.h @@ -42,22 +42,20 @@ private: void doc_matching_string_filter(); public: - uint32_t doc; + uint32_t seq_id = 0; // Collection name -> references std::map reference; - Option status; + Option status = Option(true); explicit filter_result_iterator_t(const std::string& collection_name, - const Index* index, filter_node_t* filter_node, - Option& status) : + const Index* index, filter_node_t* filter_node) : collection_name(collection_name), index(index), - filter_node(filter_node), - status(status) { + filter_node(filter_node) { // Generate the iterator tree and then initialize each node. if (filter_node->isOperator) { - left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, status); - right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, status); + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right); } init(); @@ -73,6 +71,9 @@ public: delete right_it; } + /// Returns the status of the initialization of iterator tree. + Option init_status(); + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). [[nodiscard]] bool valid(); diff --git a/src/filter.cpp b/src/filter.cpp index 2c289761..16e52c49 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -5,7 +5,7 @@ void filter_result_iterator_t::and_filter_iterators() { while (left_it->is_valid && right_it->is_valid) { - while (left_it->doc < right_it->doc) { + while (left_it->seq_id < right_it->seq_id) { left_it->next(); if (!left_it->is_valid) { is_valid = false; @@ -13,7 +13,7 @@ void filter_result_iterator_t::and_filter_iterators() { } } - while (left_it->doc > right_it->doc) { + while (left_it->seq_id > right_it->seq_id) { right_it->next(); if (!right_it->is_valid) { is_valid = false; @@ -21,8 +21,8 @@ void filter_result_iterator_t::and_filter_iterators() { } } - if (left_it->doc == right_it->doc) { - doc = left_it->doc; + if (left_it->seq_id == right_it->seq_id) { + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -41,8 +41,8 @@ void filter_result_iterator_t::and_filter_iterators() { void filter_result_iterator_t::or_filter_iterators() { if (left_it->is_valid && right_it->is_valid) { - if (left_it->doc < right_it->doc) { - doc = left_it->doc; + if (left_it->seq_id < right_it->seq_id) { + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -52,8 +52,8 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - if (left_it->doc > right_it->doc) { - doc = right_it->doc; + if (left_it->seq_id > right_it->seq_id) { + seq_id = right_it->seq_id; reference.clear(); for (const auto& item: right_it->reference) { @@ -63,7 +63,7 @@ void filter_result_iterator_t::or_filter_iterators() { return; } - doc = left_it->doc; + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -77,7 +77,7 @@ void filter_result_iterator_t::or_filter_iterators() { } if (left_it->is_valid) { - doc = left_it->doc; + seq_id = left_it->seq_id; reference.clear(); for (const auto& item: left_it->reference) { @@ -88,7 +88,7 @@ void filter_result_iterator_t::or_filter_iterators() { } if (right_it->is_valid) { - doc = right_it->doc; + seq_id = right_it->seq_id; reference.clear(); for (const auto& item: right_it->reference) { @@ -105,7 +105,7 @@ void filter_result_iterator_t::doc_matching_string_filter() { // If none of the filter value iterators are valid, mark this node as invalid. bool one_is_valid = false; - // Since we do OR between filter values, the lowest doc id from all is selected. + // Since we do OR between filter values, the lowest seq_id id from all is selected. uint32_t lowest_id = UINT32_MAX; for (auto& filter_value_tokens : posting_list_iterators) { @@ -121,7 +121,7 @@ void filter_result_iterator_t::doc_matching_string_filter() { } if (one_is_valid) { - doc = lowest_id; + seq_id = lowest_id; } is_valid = one_is_valid; @@ -139,10 +139,10 @@ void filter_result_iterator_t::next() { right_it->next(); and_filter_iterators(); } else { - if (left_it->doc == doc && right_it->doc == doc) { + if (left_it->seq_id == seq_id && right_it->seq_id == seq_id) { left_it->next(); right_it->next(); - } else if (left_it->doc == doc) { + } else if (left_it->seq_id == seq_id) { left_it->next(); } else { right_it->next(); @@ -163,7 +163,7 @@ void filter_result_iterator_t::next() { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; reference.clear(); for (auto const& item: filter_result.reference_filter_results) { reference[item.first] = item.second[result_index]; @@ -178,7 +178,7 @@ void filter_result_iterator_t::next() { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; return; } @@ -194,7 +194,7 @@ void filter_result_iterator_t::next() { for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { auto& filter_value_tokens = posting_list_iterators[i]; - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { for (auto& iter: filter_value_tokens) { iter.next(); } @@ -391,7 +391,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; reference.clear(); for (auto const& item: filter_result.reference_filter_results) { reference[item.first] = item.second[result_index]; @@ -408,7 +408,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - doc = filter_result.docs[result_index]; + seq_id = filter_result.docs[result_index]; return; } @@ -468,9 +468,19 @@ bool filter_result_iterator_t::valid(uint32_t id) { return is_valid; } - return doc != id; + return seq_id != id; } skip_to(id); - return is_valid && doc == id; + return is_valid && seq_id == id; +} + +Option filter_result_iterator_t::init_status() { + if (filter_node->isOperator) { + auto left_status = left_it->init_status(); + + return !left_status.ok() ? left_status : right_it->init_status(); + } + + return status; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 3c4b94a5..935efff8 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -64,11 +64,10 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - Option iter_op(true); - auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_no_match_test.init_status().ok()); ASSERT_FALSE(iter_no_match_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -76,10 +75,10 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_no_match_multi_test.init_status().ok()); ASSERT_FALSE(iter_no_match_multi_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -87,14 +86,15 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_contains_test.init_status().ok()); + for (uint32_t i = 0; i < 5; i++) { ASSERT_TRUE(iter_contains_test.valid()); - ASSERT_EQ(i, iter_contains_test.doc); + ASSERT_EQ(i, iter_contains_test.seq_id); iter_contains_test.next(); } ASSERT_FALSE(iter_contains_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -102,14 +102,15 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_contains_multi_test.init_status().ok()); + for (uint32_t i = 0; i < 5; i++) { ASSERT_TRUE(iter_contains_multi_test.valid()); - ASSERT_EQ(i, iter_contains_multi_test.doc); + ASSERT_EQ(i, iter_contains_multi_test.seq_id); iter_contains_multi_test.next(); } ASSERT_FALSE(iter_contains_multi_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -117,14 +118,15 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_test.init_status().ok()); + for (uint32_t i = 0; i < 5; i++) { ASSERT_TRUE(iter_exact_match_test.valid()); - ASSERT_EQ(i, iter_exact_match_test.doc); + ASSERT_EQ(i, iter_exact_match_test.seq_id); iter_exact_match_test.next(); } ASSERT_FALSE(iter_exact_match_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -132,16 +134,16 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_multi_test.init_status().ok()); std::vector expected = {0, 2, 3, 4}; for (auto const& i : expected) { ASSERT_TRUE(iter_exact_match_multi_test.valid()); - ASSERT_EQ(i, iter_exact_match_multi_test.doc); + ASSERT_EQ(i, iter_exact_match_multi_test.seq_id); iter_exact_match_multi_test.next(); } ASSERT_FALSE(iter_exact_match_multi_test.valid()); - ASSERT_TRUE(iter_op.ok()); // delete filter_tree_root; // filter_tree_root = nullptr; @@ -149,17 +151,17 @@ TEST_F(FilterTest, FilterTreeIterator) { // filter_tree_root); // ASSERT_TRUE(filter_op.ok()); // -// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); +// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); +// ASSERT_TRUE(iter_not_equals_test.init_status().ok()); // // std::vector expected = {1, 3}; // for (auto const& i : expected) { // ASSERT_TRUE(iter_not_equals_test.valid()); -// ASSERT_EQ(i, iter_not_equals_test.doc); +// ASSERT_EQ(i, iter_not_equals_test.seq_id); // iter_not_equals_test.next(); // } // // ASSERT_FALSE(iter_not_equals_test.valid()); -// ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -167,16 +169,16 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test.init_status().ok()); ASSERT_TRUE(iter_skip_test.valid()); iter_skip_test.skip_to(3); ASSERT_TRUE(iter_skip_test.valid()); - ASSERT_EQ(4, iter_skip_test.doc); + ASSERT_EQ(4, iter_skip_test.seq_id); iter_skip_test.next(); ASSERT_FALSE(iter_skip_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -184,14 +186,14 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_and_test.init_status().ok()); ASSERT_TRUE(iter_and_test.valid()); - ASSERT_EQ(1, iter_and_test.doc); + ASSERT_EQ(1, iter_and_test.seq_id); iter_and_test.next(); ASSERT_FALSE(iter_and_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -210,17 +212,17 @@ TEST_F(FilterTest, FilterTreeIterator) { auto add_op = coll->add(doc.dump()); ASSERT_TRUE(add_op.ok()); - auto iter_or_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_or_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_or_test.init_status().ok()); expected = {2, 4, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_or_test.valid()); - ASSERT_EQ(i, iter_or_test.doc); + ASSERT_EQ(i, iter_or_test.seq_id); iter_or_test.next(); } ASSERT_FALSE(iter_or_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -228,7 +230,8 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_complex_filter_test.init_status().ok()); ASSERT_TRUE(iter_skip_complex_filter_test.valid()); iter_skip_complex_filter_test.skip_to(4); @@ -236,12 +239,11 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {4, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_skip_complex_filter_test.valid()); - ASSERT_EQ(i, iter_skip_complex_filter_test.doc); + ASSERT_EQ(i, iter_skip_complex_filter_test.seq_id); iter_skip_complex_filter_test.next(); } ASSERT_FALSE(iter_skip_complex_filter_test.valid()); - ASSERT_TRUE(iter_op.ok()); delete filter_tree_root; filter_tree_root = nullptr; @@ -249,15 +251,14 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); expected = {0, 2, 4, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_validate_ids_test.valid(i)); } - ASSERT_TRUE(iter_op.ok()); - delete filter_tree_root; filter_tree_root = nullptr; filter_op = filter::parse_filter_query("name: James || tags: != gold", coll->get_schema(), store, doc_id_prefix, @@ -265,14 +266,13 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(filter_op.ok()); auto iter_validate_ids_not_equals_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), - filter_tree_root, iter_op); + filter_tree_root); + ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); expected = {1, 3, 5}; for (auto const& i : expected) { ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.valid(i)); } - ASSERT_TRUE(iter_op.ok()); - delete filter_tree_root; } \ No newline at end of file From b3533b5967259cf8c67080bd4ddc3ed8fa43bc65 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 27 Mar 2023 08:25:15 +0530 Subject: [PATCH 09/93] Refactor `valid(id)`. --- include/filter.h | 9 +++++++-- src/filter.cpp | 38 ++++++++++++++++++++++++++++---------- test/filter_test.cpp | 16 +++++++++------- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/include/filter.h b/include/filter.h index c5a20f11..9bb78deb 100644 --- a/include/filter.h +++ b/include/filter.h @@ -77,8 +77,13 @@ public: /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). [[nodiscard]] bool valid(); - /// Returns true when id is a match to the filter. Handles moving the individual iterators internally. - [[nodiscard]] bool valid(uint32_t id); + /// Returns a tri-state: + /// 0: id is not valid + /// 1: id is valid + /// -1: end of iterator + /// + /// Handles moving the individual iterators internally. + [[nodiscard]] int valid(uint32_t id); /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this /// operation. diff --git a/src/filter.cpp b/src/filter.cpp index 16e52c49..44393d8d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -437,20 +437,38 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } } -bool filter_result_iterator_t::valid(uint32_t id) { +int filter_result_iterator_t::valid(uint32_t id) { if (!is_valid) { - return false; + return -1; } if (filter_node->isOperator) { + auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); + if (filter_node->filter_operator == AND) { - auto and_is_valid = left_it->valid(id) && right_it->valid(id); is_valid = left_it->is_valid && right_it->is_valid; - return and_is_valid; + + if (left_valid < 1 || right_valid < 1) { + if (left_valid == -1 || right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; } else { - auto or_is_valid = left_it->valid(id) || right_it->valid(id); is_valid = left_it->is_valid || right_it->is_valid; - return or_is_valid; + + if (left_valid < 1 && right_valid < 1) { + if (left_valid == -1 && right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; } } @@ -458,21 +476,21 @@ bool filter_result_iterator_t::valid(uint32_t id) { // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. if (!valid()) { is_valid = true; - return is_valid; + return 1; } skip_to(id); if (!is_valid) { is_valid = true; - return is_valid; + return 1; } - return seq_id != id; + return seq_id != id ? 1 : 0; } skip_to(id); - return is_valid && seq_id == id; + return is_valid ? (seq_id == id ? 1 : 0) : -1; } Option filter_result_iterator_t::init_status() { diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 935efff8..e607e03b 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -137,7 +137,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_exact_match_multi_test.init_status().ok()); - std::vector expected = {0, 2, 3, 4}; + std::vector expected = {0, 2, 3, 4}; for (auto const& i : expected) { ASSERT_TRUE(iter_exact_match_multi_test.valid()); ASSERT_EQ(i, iter_exact_match_multi_test.seq_id); @@ -254,9 +254,10 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); - expected = {0, 2, 4, 5}; - for (auto const& i : expected) { - ASSERT_TRUE(iter_validate_ids_test.valid(i)); + std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}; + expected = {1, 0, 1, 0, 1, 1, -1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); } delete filter_tree_root; @@ -269,9 +270,10 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); - expected = {1, 3, 5}; - for (auto const& i : expected) { - ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.valid(i)); + validate_ids = {0, 1, 2, 3, 4, 5, 6, 7, 100}; + expected = {0, 1, 0, 1, 0, 1, 1, 1, 1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); } delete filter_tree_root; From b6ee380086a4af188ea589171beb38a6c5c13d18 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 28 Mar 2023 10:00:47 +0530 Subject: [PATCH 10/93] Add `filter_result_iterator_t::contains_atleast_one`. --- include/filter.h | 3 +++ src/filter.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/include/filter.h b/include/filter.h index 9bb78deb..c8d4e3b0 100644 --- a/include/filter.h +++ b/include/filter.h @@ -2,6 +2,7 @@ #include #include +#include "posting_list.h" #include "index.h" class filter_result_iterator_t { @@ -92,4 +93,6 @@ public: /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); + + bool contains_atleast_one(const void* obj); }; diff --git a/src/filter.cpp b/src/filter.cpp index 44393d8d..304799b3 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -502,3 +502,45 @@ Option filter_result_iterator_t::init_status() { return status; } + +bool filter_result_iterator_t::contains_atleast_one(const void *obj) { + if(IS_COMPACT_POSTING(obj)) { + compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); + + size_t i = 0; + while(i < list->length && valid()) { + size_t num_existing_offsets = list->id_offsets[i]; + size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; + + if (existing_id == seq_id) { + return true; + } + + // advance smallest value + if (existing_id < seq_id) { + i += num_existing_offsets + 2; + } else { + skip_to(existing_id); + } + } + } else { + auto list = (posting_list_t*)(obj); + posting_list_t::iterator_t it = list->new_iterator(); + + while(it.valid() && valid()) { + uint32_t id = it.id(); + + if(id == seq_id) { + return true; + } + + if(id < seq_id) { + it.skip_to(seq_id); + } else { + skip_to(id); + } + } + } + + return false; +} From ce6c314771f7add062e30b27201392ae41463cac Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 28 Mar 2023 18:21:58 +0530 Subject: [PATCH 11/93] Add tests for `filter_result_iterator_t::contains_atleast_one`. --- test/filter_test.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/test/filter_test.cpp b/test/filter_test.cpp index e607e03b..0c3687db 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "collection.h" class FilterTest : public ::testing::Test { @@ -276,5 +277,59 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); } + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_compact_plist_contains_atleast_one_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_compact_plist_contains_atleast_one_test1.init_status().ok()); + + std::vector ids = {1, 3, 5}; + std::vector offset_index = {0, 3, 6}; + std::vector offsets = {0, 3, 4, 0, 3, 4, 0, 3, 4}; + + compact_posting_list_t* c_list1 = compact_posting_list_t::create(3, &ids[0], &offset_index[0], 9, &offsets[0]); + ASSERT_FALSE(iter_compact_plist_contains_atleast_one_test1.contains_atleast_one(SET_COMPACT_POSTING(c_list1))); + free(c_list1); + + auto iter_compact_plist_contains_atleast_one_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_compact_plist_contains_atleast_one_test2.init_status().ok()); + + ids = {1, 3, 4}; + offset_index = {0, 3, 6}; + offsets = {0, 3, 4, 0, 3, 4, 0, 3, 4}; + + compact_posting_list_t* c_list2 = compact_posting_list_t::create(3, &ids[0], &offset_index[0], 9, &offsets[0]); + ASSERT_TRUE(iter_compact_plist_contains_atleast_one_test2.contains_atleast_one(SET_COMPACT_POSTING(c_list2))); + free(c_list2); + + auto iter_plist_contains_atleast_one_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_plist_contains_atleast_one_test1.init_status().ok()); + + posting_list_t p_list1(2); + ids = {1, 3, 5}; + for (const auto &i: ids) { + p_list1.upsert(i, {1, 2, 3}); + } + + ASSERT_FALSE(iter_plist_contains_atleast_one_test1.contains_atleast_one(&p_list1)); + + auto iter_plist_contains_atleast_one_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), + filter_tree_root); + ASSERT_TRUE(iter_plist_contains_atleast_one_test2.init_status().ok()); + + posting_list_t p_list2(2); + ids = {1, 3, 4}; + for (const auto &i: ids) { + p_list1.upsert(i, {1, 2, 3}); + } + + ASSERT_TRUE(iter_plist_contains_atleast_one_test2.contains_atleast_one(&p_list1)); + delete filter_tree_root; } \ No newline at end of file From 025f4bbd3a4d5cd6a1726e72400c84789d9f5d96 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 29 Mar 2023 13:35:39 +0530 Subject: [PATCH 12/93] Add `filter_result_iterator_t::reset`. --- include/filter.h | 4 ++++ src/filter.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ test/filter_test.cpp | 26 ++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/include/filter.h b/include/filter.h index c8d4e3b0..41ccfac8 100644 --- a/include/filter.h +++ b/include/filter.h @@ -94,5 +94,9 @@ public: /// this operation. void skip_to(uint32_t id); + /// Returns true if at least one id from the posting list object matches the filter. bool contains_atleast_one(const void* obj); + + /// Returns to the initial state of the iterator. + void reset(); }; diff --git a/src/filter.cpp b/src/filter.cpp index 304799b3..d40a5564 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -544,3 +544,45 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { return false; } + +void filter_result_iterator_t::reset() { + if (filter_node->isOperator) { + // Reset the subtrees then apply operators to arrive at the first valid doc. + left_it->reset(); + right_it->reset(); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter || a_filter.field_name == "id") { + result_index = 0; + is_valid = filter_result.count > 0; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + posting_list_iterators.clear(); + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + expanded_plists.clear(); + + init(); + return; + } +} diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 0c3687db..af7145f2 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -331,5 +331,31 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_plist_contains_atleast_one_test2.contains_atleast_one(&p_list1)); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= [gold, silver]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_reset_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_reset_test.init_status().ok()); + + expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_EQ(i, iter_reset_test.seq_id); + iter_reset_test.next(); + } + ASSERT_FALSE(iter_reset_test.valid()); + + iter_reset_test.reset(); + + for (auto const& i : expected) { + ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_EQ(i, iter_reset_test.seq_id); + iter_reset_test.next(); + } + ASSERT_FALSE(iter_reset_test.valid()); + delete filter_tree_root; } \ No newline at end of file From dc74be283f5daab377b9146ef3a5724cf81df5e9 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 30 Mar 2023 09:52:10 +0530 Subject: [PATCH 13/93] Handle null filter tree. --- include/filter.h | 5 +++++ src/filter.cpp | 10 +++++++++- test/filter_test.cpp | 6 ++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/filter.h b/include/filter.h index 41ccfac8..babee59d 100644 --- a/include/filter.h +++ b/include/filter.h @@ -53,6 +53,11 @@ public: collection_name(collection_name), index(index), filter_node(filter_node) { + if (filter_node == nullptr) { + is_valid = false; + return; + } + // Generate the iterator tree and then initialize each node. if (filter_node->isOperator) { left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); diff --git a/src/filter.cpp b/src/filter.cpp index d40a5564..27c21198 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -207,6 +207,10 @@ void filter_result_iterator_t::next() { } void filter_result_iterator_t::init() { + if (filter_node == nullptr) { + return; + } + if (filter_node->isOperator) { if (filter_node->filter_operator == AND) { and_filter_iterators(); @@ -494,7 +498,7 @@ int filter_result_iterator_t::valid(uint32_t id) { } Option filter_result_iterator_t::init_status() { - if (filter_node->isOperator) { + if (filter_node != nullptr && filter_node->isOperator) { auto left_status = left_it->init_status(); return !left_status.ok() ? left_status : right_it->init_status(); @@ -546,6 +550,10 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { } void filter_result_iterator_t::reset() { + if (filter_node == nullptr) { + return; + } + if (filter_node->isOperator) { // Reset the subtrees then apply operators to arrive at the first valid doc. left_it->reset(); diff --git a/test/filter_test.cpp b/test/filter_test.cpp index af7145f2..6a952c57 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -61,6 +61,12 @@ TEST_F(FilterTest, FilterTreeIterator) { const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_"; filter_node_t* filter_tree_root = nullptr; + + auto iter_null_filter_tree_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + + ASSERT_TRUE(iter_null_filter_tree_test.init_status().ok()); + ASSERT_FALSE(iter_null_filter_tree_test.valid()); + Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, filter_tree_root); ASSERT_TRUE(filter_op.ok()); From 36c5f0eeed9d93f5da8500fb7e6ad779a7672811 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 3 Apr 2023 11:35:27 +0530 Subject: [PATCH 14/93] Refactor `filter_result_iterator_t`. --- include/filter.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/filter.h b/include/filter.h index babee59d..901e6e4f 100644 --- a/include/filter.h +++ b/include/filter.h @@ -5,11 +5,13 @@ #include "posting_list.h" #include "index.h" +class Index; + class filter_result_iterator_t { private: - std::string collection_name; - const Index* index; - filter_node_t* filter_node; + const std::string collection_name; + Index const* const index = nullptr; + filter_node_t const* const filter_node = nullptr; filter_result_iterator_t* left_it = nullptr; filter_result_iterator_t* right_it = nullptr; @@ -49,7 +51,7 @@ public: Option status = Option(true); explicit filter_result_iterator_t(const std::string& collection_name, - const Index* index, filter_node_t* filter_node) : + Index const* const index, filter_node_t const* const filter_node) : collection_name(collection_name), index(index), filter_node(filter_node) { From bb15dba25431808e4f76db637c5cf825d594c721 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 4 Apr 2023 11:53:47 +0530 Subject: [PATCH 15/93] Add move assignment operator in `filter_result_iterator_t`. --- include/filter.h | 45 ++++++++++++++++++++++++++++++++++++++++---- test/filter_test.cpp | 13 +++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/include/filter.h b/include/filter.h index 901e6e4f..5deca7b5 100644 --- a/include/filter.h +++ b/include/filter.h @@ -9,9 +9,9 @@ class Index; class filter_result_iterator_t { private: - const std::string collection_name; - Index const* const index = nullptr; - filter_node_t const* const filter_node = nullptr; + std::string collection_name; + const Index* index = nullptr; + const filter_node_t* filter_node = nullptr; filter_result_iterator_t* left_it = nullptr; filter_result_iterator_t* right_it = nullptr; @@ -50,7 +50,7 @@ public: std::map reference; Option status = Option(true); - explicit filter_result_iterator_t(const std::string& collection_name, + explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node) : collection_name(collection_name), index(index), @@ -79,6 +79,43 @@ public: delete right_it; } + filter_result_iterator_t& operator=(filter_result_iterator_t&& obj) noexcept { + if (&obj == this) + return *this; + + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + + collection_name = obj.collection_name; + index = obj.index; + filter_node = obj.filter_node; + left_it = obj.left_it; + right_it = obj.right_it; + + obj.left_it = nullptr; + obj.right_it = nullptr; + + result_index = obj.result_index; + + filter_result = std::move(obj.filter_result); + + posting_list_iterators = std::move(obj.posting_list_iterators); + expanded_plists = std::move(obj.expanded_plists); + + is_valid = obj.is_valid; + + seq_id = obj.seq_id; + reference = std::move(obj.reference); + status = std::move(obj.status); + + return *this; + } + /// Returns the status of the initialization of iterator tree. Option init_status(); diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 6a952c57..91dd81c8 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -363,5 +363,18 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_reset_test.valid()); + auto iter_move_assignment_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + + iter_reset_test.reset(); + iter_move_assignment_test = std::move(iter_reset_test); + + expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_move_assignment_test.valid()); + ASSERT_EQ(i, iter_move_assignment_test.seq_id); + iter_move_assignment_test.next(); + } + ASSERT_FALSE(iter_move_assignment_test.valid()); + delete filter_tree_root; } \ No newline at end of file From 6493a0d2cf30c3477184e4636f35e5abb3754766 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 5 Apr 2023 10:10:13 +0530 Subject: [PATCH 16/93] Add `to_filter_id_array` and `and_scalar` methods. --- include/filter.h | 16 +++++++++++++--- src/filter.cpp | 45 ++++++++++++++++++++++++++++++++++++++++++++ test/filter_test.cpp | 39 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/include/filter.h b/include/filter.h index 5deca7b5..4103f49f 100644 --- a/include/filter.h +++ b/include/filter.h @@ -2,10 +2,14 @@ #include #include +#include #include "posting_list.h" #include "index.h" class Index; +struct filter_node_t; +struct reference_filter_result_t; +struct filter_result_t; class filter_result_iterator_t { private: @@ -52,9 +56,9 @@ public: explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node) : - collection_name(collection_name), - index(index), - filter_node(filter_node) { + collection_name(collection_name), + index(index), + filter_node(filter_node) { if (filter_node == nullptr) { is_valid = false; return; @@ -143,4 +147,10 @@ public: /// Returns to the initial state of the iterator. void reset(); + + /// Iterates and collects all the filter ids into filter_array. + /// \return size of the filter array + uint32_t to_filter_id_array(uint32_t*& filter_array); + + uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); }; diff --git a/src/filter.cpp b/src/filter.cpp index 27c21198..77ef5d5d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -594,3 +594,48 @@ void filter_result_iterator_t::reset() { return; } } + +uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + do { + filter_ids.push_back(seq_id); + next(); + } while (valid()); + + filter_array = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), filter_array); + + return filter_ids.size(); +} + +uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + for (uint32_t i = 0; i < lenA; i++) { + auto result = valid(A[i]); + + if (result == -1) { + break; + } + + if (result == 1) { + filter_ids.push_back(A[i]); + } + } + + if (filter_ids.empty()) { + return 0; + } + + results = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), results); + + return filter_ids.size(); +} diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 91dd81c8..b4f836f1 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -377,4 +377,43 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_FALSE(iter_move_assignment_test.valid()); delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_to_array_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_to_array_test.init_status().ok()); + + uint32_t* filter_ids = nullptr; + uint32_t filter_ids_length; + + filter_ids_length = iter_to_array_test.to_filter_id_array(filter_ids); + ASSERT_EQ(3, filter_ids_length); + + expected = {0, 2, 4}; + for (uint32_t i = 0; i < filter_ids_length; i++) { + ASSERT_EQ(expected[i], filter_ids[i]); + } + ASSERT_FALSE(iter_to_array_test.valid()); + + delete filter_ids; + + auto iter_and_scalar_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_and_scalar_test.init_status().ok()); + + uint32_t a_ids[6] = {0, 1, 3, 4, 5, 6}; + uint32_t* and_result = nullptr; + uint32_t and_result_length; + and_result_length = iter_and_scalar_test.and_scalar(a_ids, 6, and_result); + ASSERT_EQ(2, and_result_length); + + expected = {0, 4}; + for (uint32_t i = 0; i < and_result_length; i++) { + ASSERT_EQ(expected[i], and_result[i]); + } + ASSERT_FALSE(iter_and_test.valid()); + + delete and_result; + delete filter_tree_root; } \ No newline at end of file From 7b3b321aafa32db1ec6728a18d0a67de9064bfd4 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 10 Apr 2023 17:43:39 +0530 Subject: [PATCH 17/93] Refactor filtering logic to overcome circular referencing. Handle exact string filtering in `filter_result_iterator`. --- include/art.h | 19 +- include/field.h | 205 +----- include/filter.h | 218 ++----- include/filter_result_iterator.h | 174 +++++ include/index.h | 56 +- include/posting_list.h | 11 + include/topster.h | 1 + include/validator.h | 15 - src/art.cpp | 190 ++++++ src/collection.cpp | 1 + src/field.cpp | 669 ------------------- src/filter.cpp | 1038 ++++++++++++++---------------- src/filter_result_iterator.cpp | 901 ++++++++++++++++++++++++++ src/index.cpp | 284 ++++---- src/or_iterator.cpp | 6 +- src/posting_list.cpp | 140 ++++ src/validator.cpp | 1 + test/collection_test.cpp | 1 + test/filter_test.cpp | 22 +- 19 files changed, 2191 insertions(+), 1761 deletions(-) create mode 100644 include/filter_result_iterator.h create mode 100644 src/filter_result_iterator.cpp diff --git a/include/art.h b/include/art.h index a9715fac..11f57a68 100644 --- a/include/art.h +++ b/include/art.h @@ -1,5 +1,4 @@ -#ifndef ART_H -#define ART_H +#pragma once #include #include @@ -7,6 +6,7 @@ #include #include "array.h" #include "sorted_array.h" +#include "filter_result_iterator.h" #define IGNORE_PRINTF 1 @@ -111,7 +111,7 @@ struct token_leaf { uint32_t num_typos; token_leaf(art_leaf* leaf, uint32_t root_len, uint32_t num_typos, bool is_prefix) : - leaf(leaf), root_len(root_len), num_typos(num_typos), is_prefix(is_prefix) { + leaf(leaf), root_len(root_len), num_typos(num_typos), is_prefix(is_prefix) { } }; @@ -157,11 +157,6 @@ enum NUM_COMPARATOR { RANGE_INCLUSIVE }; -enum FILTER_OPERATOR { - AND, - OR -}; - /** * Initializes an ART tree * @return 0 on success. @@ -281,6 +276,12 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const uint32_t *filter_ids, const size_t filter_ids_length, std::vector &results, std::set& exclude_leaves); +int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, + const size_t max_words, const token_ordering token_order, + const bool prefix, bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + std::vector &results, std::set& exclude_leaves); + void encode_int32(int32_t n, unsigned char *chars); void encode_int64(int64_t n, unsigned char *chars); @@ -295,6 +296,4 @@ int art_float_search(art_tree *t, float value, NUM_COMPARATOR comparator, std::v #ifdef __cplusplus } -#endif - #endif \ No newline at end of file diff --git a/include/field.h b/include/field.h index 8d1606e3..e2f9729f 100644 --- a/include/field.h +++ b/include/field.h @@ -2,13 +2,13 @@ #include #include -#include "art.h" #include "option.h" #include "string_utils.h" #include "logger.h" #include "store.h" #include #include +#include #include "json.hpp" #include "text_embedder_manager.h" @@ -515,200 +515,19 @@ struct field { static void compact_nested_fields(tsl::htrie_map& nested_fields); }; -struct filter_node_t; - -struct filter { - std::string field_name; - std::vector values; - std::vector comparators; - // Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the - // case of int and float fields. During filtering, all the results of matching the field against the values are - // aggregated and then this flag is checked if negation on the aggregated result is required. - bool apply_not_equals = false; - - // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` - std::string referenced_collection_name = ""; - - static const std::string RANGE_OPERATOR() { - return ".."; - } - - static Option validate_numerical_filter_value(field _field, const std::string& raw_value) { - if(_field.is_int32() && !StringUtils::is_int32_t(raw_value)) { - return Option(400, "Error with filter field `" + _field.name + "`: Not an int32."); - } - - else if(_field.is_int64() && !StringUtils::is_int64_t(raw_value)) { - return Option(400, "Error with filter field `" + _field.name + "`: Not an int64."); - } - - else if(_field.is_float() && !StringUtils::is_float(raw_value)) { - return Option(400, "Error with filter field `" + _field.name + "`: Not a float."); - } - - return Option(true); - } - - static Option extract_num_comparator(std::string & comp_and_value) { - auto num_comparator = EQUALS; - - if(StringUtils::is_integer(comp_and_value) || StringUtils::is_float(comp_and_value)) { - num_comparator = EQUALS; - } - - // the ordering is important - we have to compare 2-letter operators first - else if(comp_and_value.compare(0, 2, "<=") == 0) { - num_comparator = LESS_THAN_EQUALS; - } - - else if(comp_and_value.compare(0, 2, ">=") == 0) { - num_comparator = GREATER_THAN_EQUALS; - } - - else if(comp_and_value.compare(0, 2, "!=") == 0) { - num_comparator = NOT_EQUALS; - } - - else if(comp_and_value.compare(0, 1, "<") == 0) { - num_comparator = LESS_THAN; - } - - else if(comp_and_value.compare(0, 1, ">") == 0) { - num_comparator = GREATER_THAN; - } - - else if(comp_and_value.find("..") != std::string::npos) { - num_comparator = RANGE_INCLUSIVE; - } - - else { - return Option(400, "Numerical field has an invalid comparator."); - } - - if(num_comparator == LESS_THAN || num_comparator == GREATER_THAN) { - comp_and_value = comp_and_value.substr(1); - } else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS || num_comparator == NOT_EQUALS) { - comp_and_value = comp_and_value.substr(2); - } - - comp_and_value = StringUtils::trim(comp_and_value); - - return Option(num_comparator); - } - - static Option parse_geopoint_filter_value(std::string& raw_value, - const std::string& format_err_msg, - std::string& processed_filter_val, - NUM_COMPARATOR& num_comparator); - - static Option parse_filter_query(const std::string& filter_query, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix, - filter_node_t*& root); +enum index_operation_t { + CREATE, + UPSERT, + UPDATE, + EMPLACE, + DELETE }; -struct filter_node_t { - filter filter_exp; - FILTER_OPERATOR filter_operator; - bool isOperator; - filter_node_t* left = nullptr; - filter_node_t* right = nullptr; - - filter_node_t(filter filter_exp) - : filter_exp(std::move(filter_exp)), - isOperator(false), - left(nullptr), - right(nullptr) {} - - filter_node_t(FILTER_OPERATOR filter_operator, - filter_node_t* left, - filter_node_t* right) - : filter_operator(filter_operator), - isOperator(true), - left(left), - right(right) {} - - ~filter_node_t() { - delete left; - delete right; - } -}; - -struct reference_filter_result_t { - uint32_t count = 0; - uint32_t* docs = nullptr; - - reference_filter_result_t& operator=(const reference_filter_result_t& obj) noexcept { - if (&obj == this) - return *this; - - count = obj.count; - docs = new uint32_t[count]; - memcpy(docs, obj.docs, count * sizeof(uint32_t)); - - return *this; - } - - ~reference_filter_result_t() { - delete[] docs; - } -}; - -struct filter_result_t { - uint32_t count = 0; - uint32_t* docs = nullptr; - // Collection name -> Reference filter result - std::map reference_filter_results; - - filter_result_t() = default; - - filter_result_t(uint32_t count, uint32_t* docs) : count(count), docs(docs) {} - - filter_result_t& operator=(const filter_result_t& obj) noexcept { - if (&obj == this) - return *this; - - count = obj.count; - docs = new uint32_t[count]; - memcpy(docs, obj.docs, count * sizeof(uint32_t)); - - // Copy every collection's references. - for (const auto &item: obj.reference_filter_results) { - reference_filter_results[item.first] = new reference_filter_result_t[count]; - - for (uint32_t i = 0; i < count; i++) { - reference_filter_results[item.first][i] = item.second[i]; - } - } - - return *this; - } - - filter_result_t& operator=(filter_result_t&& obj) noexcept { - if (&obj == this) - return *this; - - count = obj.count; - docs = obj.docs; - reference_filter_results = std::map(obj.reference_filter_results); - - obj.docs = nullptr; - obj.reference_filter_results.clear(); - - return *this; - } - - ~filter_result_t() { - delete[] docs; - for (const auto &item: reference_filter_results) { - delete[] item.second; - } - } - - static void and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); - - static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); +enum class DIRTY_VALUES { + REJECT = 1, + DROP = 2, + COERCE_OR_REJECT = 3, + COERCE_OR_DROP = 4, }; namespace sort_field_const { diff --git a/include/filter.h b/include/filter.h index 4103f49f..1961b0ac 100644 --- a/include/filter.h +++ b/include/filter.h @@ -2,155 +2,73 @@ #include #include -#include -#include "posting_list.h" -#include "index.h" +#include +#include +#include "store.h" -class Index; -struct filter_node_t; -struct reference_filter_result_t; -struct filter_result_t; - -class filter_result_iterator_t { -private: - std::string collection_name; - const Index* index = nullptr; - const filter_node_t* filter_node = nullptr; - filter_result_iterator_t* left_it = nullptr; - filter_result_iterator_t* right_it = nullptr; - - /// Used in case of id and reference filter. - uint32_t result_index = 0; - - /// Stores the result of the filters that cannot be iterated. - filter_result_t filter_result; - - /// Initialized in case of filter on string field. - /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator - /// for each token. - /// - /// Multiple filter values: Multiple tokens: posting list iterator - std::vector> posting_list_iterators; - std::vector expanded_plists; - - /// Set to false when this iterator or it's subtree becomes invalid. - bool is_valid = true; - - /// Initializes the state of iterator node after it's creation. - void init(); - - /// Performs AND on the subtrees of operator. - void and_filter_iterators(); - - /// Performs OR on the subtrees of operator. - void or_filter_iterators(); - - /// Finds the next match for a filter on string field. - void doc_matching_string_filter(); - -public: - uint32_t seq_id = 0; - // Collection name -> references - std::map reference; - Option status = Option(true); - - explicit filter_result_iterator_t(const std::string collection_name, - Index const* const index, filter_node_t const* const filter_node) : - collection_name(collection_name), - index(index), - filter_node(filter_node) { - if (filter_node == nullptr) { - is_valid = false; - return; - } - - // Generate the iterator tree and then initialize each node. - if (filter_node->isOperator) { - left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); - right_it = new filter_result_iterator_t(collection_name, index, filter_node->right); - } - - init(); - } - - ~filter_result_iterator_t() { - // In case the filter was on string field. - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - - delete left_it; - delete right_it; - } - - filter_result_iterator_t& operator=(filter_result_iterator_t&& obj) noexcept { - if (&obj == this) - return *this; - - // In case the filter was on string field. - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - - delete left_it; - delete right_it; - - collection_name = obj.collection_name; - index = obj.index; - filter_node = obj.filter_node; - left_it = obj.left_it; - right_it = obj.right_it; - - obj.left_it = nullptr; - obj.right_it = nullptr; - - result_index = obj.result_index; - - filter_result = std::move(obj.filter_result); - - posting_list_iterators = std::move(obj.posting_list_iterators); - expanded_plists = std::move(obj.expanded_plists); - - is_valid = obj.is_valid; - - seq_id = obj.seq_id; - reference = std::move(obj.reference); - status = std::move(obj.status); - - return *this; - } - - /// Returns the status of the initialization of iterator tree. - Option init_status(); - - /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). - [[nodiscard]] bool valid(); - - /// Returns a tri-state: - /// 0: id is not valid - /// 1: id is valid - /// -1: end of iterator - /// - /// Handles moving the individual iterators internally. - [[nodiscard]] int valid(uint32_t id); - - /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this - /// operation. - void next(); - - /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during - /// this operation. - void skip_to(uint32_t id); - - /// Returns true if at least one id from the posting list object matches the filter. - bool contains_atleast_one(const void* obj); - - /// Returns to the initial state of the iterator. - void reset(); - - /// Iterates and collects all the filter ids into filter_array. - /// \return size of the filter array - uint32_t to_filter_id_array(uint32_t*& filter_array); - - uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); +enum FILTER_OPERATOR { + AND, + OR +}; + +struct filter_node_t; +struct field; + +struct filter { + std::string field_name; + std::vector values; + std::vector comparators; + // Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the + // case of int and float fields. During filtering, all the results of matching the field against the values are + // aggregated and then this flag is checked if negation on the aggregated result is required. + bool apply_not_equals = false; + + // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` + std::string referenced_collection_name = ""; + + static const std::string RANGE_OPERATOR() { + return ".."; + } + + static Option validate_numerical_filter_value(field _field, const std::string& raw_value); + + static Option extract_num_comparator(std::string & comp_and_value); + + static Option parse_geopoint_filter_value(std::string& raw_value, + const std::string& format_err_msg, + std::string& processed_filter_val, + NUM_COMPARATOR& num_comparator); + + static Option parse_filter_query(const std::string& filter_query, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix, + filter_node_t*& root); +}; + +struct filter_node_t { + filter filter_exp; + FILTER_OPERATOR filter_operator; + bool isOperator; + filter_node_t* left = nullptr; + filter_node_t* right = nullptr; + + filter_node_t(filter filter_exp) + : filter_exp(std::move(filter_exp)), + isOperator(false), + left(nullptr), + right(nullptr) {} + + filter_node_t(FILTER_OPERATOR filter_operator, + filter_node_t* left, + filter_node_t* right) + : filter_operator(filter_operator), + isOperator(true), + left(left), + right(right) {} + + ~filter_node_t() { + delete left; + delete right; + } }; diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h new file mode 100644 index 00000000..b67d67ca --- /dev/null +++ b/include/filter_result_iterator.h @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include +#include "option.h" +#include "posting_list.h" + +class Index; +struct filter_node_t; + +struct reference_filter_result_t { + uint32_t count = 0; + uint32_t* docs = nullptr; + + reference_filter_result_t& operator=(const reference_filter_result_t& obj) noexcept { + if (&obj == this) + return *this; + + count = obj.count; + docs = new uint32_t[count]; + memcpy(docs, obj.docs, count * sizeof(uint32_t)); + + return *this; + } + + ~reference_filter_result_t() { + delete[] docs; + } +}; + +struct filter_result_t { + uint32_t count = 0; + uint32_t* docs = nullptr; + // Collection name -> Reference filter result + std::map reference_filter_results; + + filter_result_t() = default; + + filter_result_t(uint32_t count, uint32_t* docs) : count(count), docs(docs) {} + + filter_result_t& operator=(const filter_result_t& obj) noexcept { + if (&obj == this) + return *this; + + count = obj.count; + docs = new uint32_t[count]; + memcpy(docs, obj.docs, count * sizeof(uint32_t)); + + // Copy every collection's references. + for (const auto &item: obj.reference_filter_results) { + reference_filter_results[item.first] = new reference_filter_result_t[count]; + + for (uint32_t i = 0; i < count; i++) { + reference_filter_results[item.first][i] = item.second[i]; + } + } + + return *this; + } + + filter_result_t& operator=(filter_result_t&& obj) noexcept { + if (&obj == this) + return *this; + + count = obj.count; + docs = obj.docs; + reference_filter_results = std::map(obj.reference_filter_results); + + obj.docs = nullptr; + obj.reference_filter_results.clear(); + + return *this; + } + + ~filter_result_t() { + delete[] docs; + for (const auto &item: reference_filter_results) { + delete[] item.second; + } + } + + static void and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); + + static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); +}; + + +class filter_result_iterator_t { +private: + std::string collection_name; + const Index* index = nullptr; + const filter_node_t* filter_node = nullptr; + filter_result_iterator_t* left_it = nullptr; + filter_result_iterator_t* right_it = nullptr; + + /// Used in case of id and reference filter. + uint32_t result_index = 0; + + /// Stores the result of the filters that cannot be iterated. + filter_result_t filter_result; + + /// Initialized in case of filter on string field. + /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + /// for each token. + /// + /// Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_list_iterators; + std::vector expanded_plists; + + /// Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initializes the state of iterator node after it's creation. + void init(); + + /// Performs AND on the subtrees of operator. + void and_filter_iterators(); + + /// Performs OR on the subtrees of operator. + void or_filter_iterators(); + + /// Finds the next match for a filter on string field. + void doc_matching_string_filter(bool field_is_array); + +public: + uint32_t seq_id = 0; + // Collection name -> references + std::map reference; + Option status = Option(true); + + explicit filter_result_iterator_t(const std::string collection_name, + Index const* const index, filter_node_t const* const filter_node); + + ~filter_result_iterator_t(); + + filter_result_iterator_t& operator=(filter_result_iterator_t&& obj) noexcept; + + /// Returns the status of the initialization of iterator tree. + Option init_status(); + + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). + [[nodiscard]] bool valid(); + + /// Returns a tri-state: + /// 0: id is not valid + /// 1: id is valid + /// -1: end of iterator + /// + /// Handles moving the individual iterators internally. + [[nodiscard]] int valid(uint32_t id); + + /// Advances the iterator to get the next value of doc and reference. The iterator may become invalid during this + /// operation. + void next(); + + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during + /// this operation. + void skip_to(uint32_t id); + + /// Returns true if at least one id from the posting list object matches the filter. + bool contains_atleast_one(const void* obj); + + /// Returns to the initial state of the iterator. + void reset(); + + /// Iterates and collects all the filter ids into filter_array. + /// \return size of the filter array + uint32_t to_filter_id_array(uint32_t*& filter_array); + + /// Performs AND with the contents of A and allocates a new array of results. + /// \return size of the results array + uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); +}; diff --git a/include/index.h b/include/index.h index 63681b9e..c65b84c5 100644 --- a/include/index.h +++ b/include/index.h @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include "string_utils.h" @@ -30,6 +29,7 @@ #include "override.h" #include "vector_query_ops.h" #include "hnswlib/hnswlib.h" +#include "filter.h" static constexpr size_t ARRAY_FACET_DIM = 4; using facet_map_t = spp::sparse_hash_map; @@ -233,19 +233,15 @@ struct index_record { }; class VectorFilterFunctor: public hnswlib::BaseFilterFunctor { - const uint32_t* filter_ids = nullptr; - const uint32_t filter_ids_length = 0; + filter_result_iterator_t* const filter_result_iterator; public: - explicit VectorFilterFunctor(const uint32_t* filter_ids, const uint32_t filter_ids_length) : - filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} + explicit VectorFilterFunctor(filter_result_iterator_t* const filter_result_iterator) : + filter_result_iterator(filter_result_iterator) {} - bool operator()(hnswlib::labeltype id) override { - if(filter_ids_length == 0) { - return true; - } - - return std::binary_search(filter_ids, filter_ids + filter_ids_length, id); + bool operator()(unsigned int id) { + filter_result_iterator->reset(); + return filter_result_iterator->valid(id) == 1; } }; @@ -412,7 +408,7 @@ private: void search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -563,6 +559,9 @@ public: // in the query that have the least individual hits one by one until enough results are found. static const int DROP_TOKENS_THRESHOLD = 1; + // "_all_" is a special field that maps to all the ids in the index. + static constexpr const char* SEQ_IDS_FILTER = "_all_: 1"; + Index() = delete; Index(const std::string& name, @@ -741,8 +740,9 @@ public: const std::vector& group_by_fields, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, - uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, - uint32_t filter_ids_length, const size_t concurrency, + uint32_t*& all_result_ids, size_t& all_result_ids_len, + filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices) const; @@ -782,7 +782,7 @@ public: std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - const uint32_t *filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -813,7 +813,7 @@ public: spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -838,7 +838,7 @@ public: Topster* curated_topster, const std::map>& included_ids_map, bool is_wildcard_query, - uint32_t*& filter_ids, uint32_t& filter_ids_length) const; + uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const; void fuzzy_search_fields(const std::vector& the_fields, const std::vector& query_tokens, @@ -846,7 +846,7 @@ public: const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -875,7 +875,7 @@ public: const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, @@ -897,7 +897,7 @@ public: const std::vector& group_by_fields, bool prioritize_exact_match, const bool search_all_candidates, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, @@ -946,14 +946,14 @@ public: int64_t* scores, int64_t& match_score_index) const; - void - process_curated_ids(const std::vector>& included_ids, - const std::vector& excluded_ids, const std::vector& group_by_fields, - const size_t group_limit, const bool filter_curated_hits, const uint32_t* filter_ids, - uint32_t filter_ids_length, std::set& curated_ids, - std::map>& included_ids_map, - std::vector& included_ids_vec, - std::unordered_set& excluded_group_ids) const; + void process_curated_ids(const std::vector>& included_ids, + const std::vector& excluded_ids, const std::vector& group_by_fields, + const size_t group_limit, const bool filter_curated_hits, + filter_result_iterator_t& filter_result_iterator, + std::set& curated_ids, + std::map>& included_ids_map, + std::vector& included_ids_vec, + std::unordered_set& excluded_group_ids) const; Option seq_ids_outside_top_k(const std::string& field_name, size_t k, std::vector& outside_seq_ids); diff --git a/include/posting_list.h b/include/posting_list.h index 16f42ea7..11ed9c91 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -8,6 +8,7 @@ #include "thread_local_vars.h" typedef uint32_t last_id_t; +class filter_result_iterator_t; struct result_iter_state_t { const uint32_t* excluded_result_ids = nullptr; @@ -19,12 +20,19 @@ struct result_iter_state_t { size_t excluded_result_ids_index = 0; size_t filter_ids_index = 0; + filter_result_iterator_t* fit = nullptr; + result_iter_state_t() = default; result_iter_state_t(const uint32_t* excluded_result_ids, size_t excluded_result_ids_size, const uint32_t* filter_ids, const size_t filter_ids_length) : excluded_result_ids(excluded_result_ids), excluded_result_ids_size(excluded_result_ids_size), filter_ids(filter_ids), filter_ids_length(filter_ids_length) {} + + result_iter_state_t(const uint32_t* excluded_result_ids, size_t excluded_result_ids_size, + filter_result_iterator_t* fit) : excluded_result_ids(excluded_result_ids), + excluded_result_ids_size(excluded_result_ids_size), + fit(fit){} }; /* @@ -186,6 +194,9 @@ public: const uint32_t* ids, const uint32_t num_ids, uint32_t*& exact_ids, size_t& num_exact_ids); + static bool has_exact_match(std::vector& posting_list_iterators, + const bool field_is_array); + static void get_phrase_matches(std::vector& its, bool field_is_array, const uint32_t* ids, const uint32_t num_ids, uint32_t*& phrase_ids, size_t& num_phrase_ids); diff --git a/include/topster.h b/include/topster.h index e59ae74c..8e10abcb 100644 --- a/include/topster.h +++ b/include/topster.h @@ -6,6 +6,7 @@ #include #include #include +#include "filter_result_iterator.h" struct KV { int8_t match_score_index{}; diff --git a/include/validator.h b/include/validator.h index 3998f5dc..a8a4a8f9 100644 --- a/include/validator.h +++ b/include/validator.h @@ -6,21 +6,6 @@ #include "tsl/htrie_map.h" #include "field.h" -enum index_operation_t { - CREATE, - UPSERT, - UPDATE, - EMPLACE, - DELETE -}; - -enum class DIRTY_VALUES { - REJECT = 1, - DROP = 2, - COERCE_OR_REJECT = 3, - COERCE_OR_DROP = 4, -}; - class validator_t { public: diff --git a/src/art.cpp b/src/art.cpp index 40b028a3..48470551 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -22,6 +22,7 @@ #include "art.h" #include "logger.h" #include "array_utils.h" +#include "filter_result_iterator.h" /** * Macros to manipulate pointer tags @@ -972,6 +973,36 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, return prev_token_doc_ids; } +const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + size_t& prev_token_doc_ids_len) { + + art_leaf* prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + uint32_t* prev_token_doc_ids = nullptr; + + if(prev_token.empty() || !prev_leaf) { + prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids); + return prev_token_doc_ids; + } + + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + if(filter_result_iterator.valid()) { + prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), + prev_token_doc_ids); + } else { + prev_token_doc_ids_len = prev_leaf_ids.size(); + prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; + std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); + } + + return prev_token_doc_ids; +} + bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, std::set& exclude_leaves, const art_leaf* exact_leaf, @@ -1622,6 +1653,165 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, return 0; } +int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, + const size_t max_words, const token_ordering token_order, + const bool prefix, bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + std::vector &results, std::set& exclude_leaves) { + + std::vector nodes; + int irow[term_len + 1]; + int jrow[term_len + 1]; + for (int i = 0; i <= term_len; i++){ + irow[i] = jrow[i] = i; + } + + //auto begin = std::chrono::high_resolution_clock::now(); + + if(IS_LEAF(t->root)) { + art_leaf *l = (art_leaf *) LEAF_RAW(t->root); + art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } else { + if(t->root == nullptr) { + return 0; + } + + // send depth as -1 to indicate that this is a root node + art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } + + //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); + + //auto begin = std::chrono::high_resolution_clock::now(); + + size_t key_len = prefix ? term_len + 1 : term_len; + art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); + //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; + + // documents that contain the previous token and/or filter ids + size_t allowed_doc_ids_len = 0; + const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); + + for(auto node: nodes) { + art_topk_iter(node, token_order, max_words, exact_leaf, + last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + t, exclude_leaves, results); + } + + if(token_order == FREQUENCY) { + std::sort(results.begin(), results.end(), compare_art_leaf_frequency); + } else { + std::sort(results.begin(), results.end(), compare_art_leaf_score); + } + + if(exact_leaf && min_cost == 0) { + std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); + if(exclude_leaves.count(tok) == 0) { + results.insert(results.begin(), exact_leaf); + exclude_leaves.emplace(tok); + } + } + + if(results.size() > max_words) { + results.resize(max_words); + } + + /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + if(time_micro > 1000) { + LOG(INFO) << "Time taken for art_topk_iter: " << time_micro + << "us, size of nodes: " << nodes.size() + << ", filter_ids_length: " << filter_ids_length; + }*/ + +// TODO: Figure out this edge case. +// if(allowed_doc_ids != filter_ids) { +// delete [] allowed_doc_ids; +// } + + return 0; +} + +int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, + const size_t max_words, const token_ordering token_order, const bool prefix, + bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + std::vector &results, std::set& exclude_leaves) { + + std::vector nodes; + int irow[term_len + 1]; + int jrow[term_len + 1]; + for (int i = 0; i <= term_len; i++){ + irow[i] = jrow[i] = i; + } + + //auto begin = std::chrono::high_resolution_clock::now(); + + if(IS_LEAF(t->root)) { + art_leaf *l = (art_leaf *) LEAF_RAW(t->root); + art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } else { + if(t->root == nullptr) { + return 0; + } + + // send depth as -1 to indicate that this is a root node + art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); + } + + //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); + + //auto begin = std::chrono::high_resolution_clock::now(); + + size_t key_len = prefix ? term_len + 1 : term_len; + art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); + //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; + + // documents that contain the previous token and/or filter ids + size_t allowed_doc_ids_len = 0; + const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); + + for(auto node: nodes) { + art_topk_iter(node, token_order, max_words, exact_leaf, + last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + t, exclude_leaves, results); + } + + if(token_order == FREQUENCY) { + std::sort(results.begin(), results.end(), compare_art_leaf_frequency); + } else { + std::sort(results.begin(), results.end(), compare_art_leaf_score); + } + + if(exact_leaf && min_cost == 0) { + std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); + if(exclude_leaves.count(tok) == 0) { + results.insert(results.begin(), exact_leaf); + exclude_leaves.emplace(tok); + } + } + + if(results.size() > max_words) { + results.resize(max_words); + } + + /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); + + if(time_micro > 1000) { + LOG(INFO) << "Time taken for art_topk_iter: " << time_micro + << "us, size of nodes: " << nodes.size() + << ", filter_ids_length: " << filter_ids_length; + }*/ + +// TODO: Figure out this edge case. +// if(allowed_doc_ids != filter_ids) { +// delete [] allowed_doc_ids; +// } + + return 0; +} + void encode_int32(int32_t n, unsigned char *chars) { unsigned char symbols[16] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 diff --git a/src/collection.cpp b/src/collection.cpp index 63db8c19..21f2d187 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "validator.h" #include "topster.h" #include "logger.h" #include "thread_local_vars.h" diff --git a/src/field.cpp b/src/field.cpp index ee62b7fa..f48ecf50 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -6,511 +6,6 @@ #include #include -Option filter::parse_geopoint_filter_value(std::string& raw_value, - const std::string& format_err_msg, - std::string& processed_filter_val, - NUM_COMPARATOR& num_comparator) { - - num_comparator = LESS_THAN_EQUALS; - - if(!(raw_value[0] == '(' && raw_value[raw_value.size() - 1] == ')')) { - return Option(400, format_err_msg); - } - - std::vector filter_values; - auto raw_val_without_paran = raw_value.substr(1, raw_value.size() - 2); - StringUtils::split(raw_val_without_paran, filter_values, ","); - - // we will end up with: "10.45 34.56 2 km" or "10.45 34.56 2mi" or a geo polygon - - if(filter_values.size() < 3) { - return Option(400, format_err_msg); - } - - // do validation: format should match either a point + radius or polygon - - size_t num_floats = 0; - for(const auto& fvalue: filter_values) { - if(StringUtils::is_float(fvalue)) { - num_floats++; - } - } - - bool is_polygon = (num_floats == filter_values.size()); - if(!is_polygon) { - // we have to ensure that this is a point + radius match - if(!StringUtils::is_float(filter_values[0]) || !StringUtils::is_float(filter_values[1])) { - return Option(400, format_err_msg); - } - - if(filter_values[0] == "nan" || filter_values[0] == "NaN" || - filter_values[1] == "nan" || filter_values[1] == "NaN") { - return Option(400, format_err_msg); - } - } - - if(is_polygon) { - processed_filter_val = raw_val_without_paran; - } else { - // point + radius - // filter_values[2] is distance, get the unit, validate it and split on that - if(filter_values[2].size() < 2) { - return Option(400, "Unit must be either `km` or `mi`."); - } - - std::string unit = filter_values[2].substr(filter_values[2].size()-2, 2); - - if(unit != "km" && unit != "mi") { - return Option(400, "Unit must be either `km` or `mi`."); - } - - std::vector dist_values; - StringUtils::split(filter_values[2], dist_values, unit); - - if(dist_values.size() != 1) { - return Option(400, format_err_msg); - } - - if(!StringUtils::is_float(dist_values[0])) { - return Option(400, format_err_msg); - } - - processed_filter_val = filter_values[0] + ", " + filter_values[1] + ", " + // co-ords - dist_values[0] + ", " + unit; // X km - } - - return Option(true); -} - -bool isOperator(const std::string& expression) { - return expression == "&&" || expression == "||"; -} - -// https://en.wikipedia.org/wiki/Shunting_yard_algorithm -Option toPostfix(std::queue& tokens, std::queue& postfix) { - std::stack operatorStack; - - while (!tokens.empty()) { - auto expression = tokens.front(); - tokens.pop(); - - if (isOperator(expression)) { - // We only have two operators &&, || having the same precedence and both being left associative. - while (!operatorStack.empty() && operatorStack.top() != "(") { - postfix.push(operatorStack.top()); - operatorStack.pop(); - } - - operatorStack.push(expression); - } else if (expression == "(") { - operatorStack.push(expression); - } else if (expression == ")") { - while (!operatorStack.empty() && operatorStack.top() != "(") { - postfix.push(operatorStack.top()); - operatorStack.pop(); - } - - if (operatorStack.empty() || operatorStack.top() != "(") { - return Option(400, "Could not parse the filter query: unbalanced parentheses."); - } - operatorStack.pop(); - } else { - postfix.push(expression); - } - } - - while (!operatorStack.empty()) { - if (operatorStack.top() == "(") { - return Option(400, "Could not parse the filter query: unbalanced parentheses."); - } - postfix.push(operatorStack.top()); - operatorStack.pop(); - } - - return Option(true); -} - -Option toMultiValueNumericFilter(std::string& raw_value, filter& filter_exp, const field& _field) { - std::vector filter_values; - StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); - filter_exp = {_field.name, {}, {}}; - for (std::string& filter_value: filter_values) { - Option op_comparator = filter::extract_num_comparator(filter_value); - if (!op_comparator.ok()) { - return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); - } - if (op_comparator.get() == RANGE_INCLUSIVE) { - // split the value around range operator to extract bounds - std::vector range_values; - StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR()); - for (const std::string& range_value: range_values) { - auto validate_op = filter::validate_numerical_filter_value(_field, range_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp.values.push_back(range_value); - filter_exp.comparators.push_back(op_comparator.get()); - } - } else { - auto validate_op = filter::validate_numerical_filter_value(_field, filter_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp.values.push_back(filter_value); - filter_exp.comparators.push_back(op_comparator.get()); - } - } - - return Option(true); -} - -Option toFilter(const std::string expression, - filter& filter_exp, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix) { - // split into [field_name, value] - size_t found_index = expression.find(':'); - if (found_index == std::string::npos) { - return Option(400, "Could not parse the filter query."); - } - std::string&& field_name = expression.substr(0, found_index); - StringUtils::trim(field_name); - if (field_name == "id") { - std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); - StringUtils::trim(raw_value); - std::string empty_filter_err = "Error with filter field `id`: Filter value cannot be empty."; - if (raw_value.empty()) { - return Option(400, empty_filter_err); - } - filter_exp = {field_name, {}, {}}; - NUM_COMPARATOR id_comparator = EQUALS; - size_t filter_value_index = 0; - if (raw_value[0] == '=') { - id_comparator = EQUALS; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - return Option(400, "Not equals filtering is not supported on the `id` field."); - } - if (filter_value_index != 0) { - raw_value = raw_value.substr(filter_value_index); - } - if (raw_value.empty()) { - return Option(400, empty_filter_err); - } - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector doc_ids; - StringUtils::split_to_values(raw_value.substr(1, raw_value.size() - 2), doc_ids); - for (std::string& doc_id: doc_ids) { - // we have to convert the doc_id to seq id - std::string seq_id_str; - StoreStatus seq_id_status = store->get(doc_id_prefix + doc_id, seq_id_str); - if (seq_id_status != StoreStatus::FOUND) { - continue; - } - filter_exp.values.push_back(seq_id_str); - filter_exp.comparators.push_back(id_comparator); - } - } else { - std::vector doc_ids; - StringUtils::split_to_values(raw_value, doc_ids); // to handle backticks - std::string seq_id_str; - StoreStatus seq_id_status = store->get(doc_id_prefix + doc_ids[0], seq_id_str); - if (seq_id_status == StoreStatus::FOUND) { - filter_exp.values.push_back(seq_id_str); - filter_exp.comparators.push_back(id_comparator); - } - } - return Option(true); - } - - auto field_it = search_schema.find(field_name); - - if (field_it == search_schema.end()) { - return Option(404, "Could not find a filter field named `" + field_name + "` in the schema."); - } - - if (field_it->num_dim > 0) { - return Option(404, "Cannot filter on vector field `" + field_name + "`."); - } - - const field& _field = field_it.value(); - std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); - StringUtils::trim(raw_value); - // skip past optional `:=` operator, which has no meaning for non-string fields - if (!_field.is_string() && raw_value[0] == '=') { - size_t filter_value_index = 0; - while (raw_value[++filter_value_index] == ' '); - raw_value = raw_value.substr(filter_value_index); - } - if (_field.is_integer() || _field.is_float()) { - // could be a single value or a list - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); - if (!op.ok()) { - return op; - } - } else { - Option op_comparator = filter::extract_num_comparator(raw_value); - if (!op_comparator.ok()) { - return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); - } - if (op_comparator.get() == RANGE_INCLUSIVE) { - // split the value around range operator to extract bounds - std::vector range_values; - StringUtils::split(raw_value, range_values, filter::RANGE_OPERATOR()); - filter_exp.field_name = field_name; - for (const std::string& range_value: range_values) { - auto validate_op = filter::validate_numerical_filter_value(_field, range_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp.values.push_back(range_value); - filter_exp.comparators.push_back(op_comparator.get()); - } - } else if (op_comparator.get() == NOT_EQUALS && raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); - if (!op.ok()) { - return op; - } - filter_exp.apply_not_equals = true; - } else { - auto validate_op = filter::validate_numerical_filter_value(_field, raw_value); - if (!validate_op.ok()) { - return validate_op; - } - filter_exp = {field_name, {raw_value}, {op_comparator.get()}}; - } - } - } else if (_field.is_bool()) { - NUM_COMPARATOR bool_comparator = EQUALS; - size_t filter_value_index = 0; - if (raw_value[0] == '=') { - bool_comparator = EQUALS; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - bool_comparator = NOT_EQUALS; - filter_value_index++; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } - if (filter_value_index != 0) { - raw_value = raw_value.substr(filter_value_index); - } - if (filter_value_index == raw_value.size()) { - return Option(400, "Error with filter field `" + _field.name + - "`: Filter value cannot be empty."); - } - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector filter_values; - StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); - filter_exp = {field_name, {}, {}}; - for (std::string& filter_value: filter_values) { - if (filter_value != "true" && filter_value != "false") { - return Option(400, "Values of filter field `" + _field.name + - "`: must be `true` or `false`."); - } - filter_value = (filter_value == "true") ? "1" : "0"; - filter_exp.values.push_back(filter_value); - filter_exp.comparators.push_back(bool_comparator); - } - } else { - if (raw_value != "true" && raw_value != "false") { - return Option(400, "Value of filter field `" + _field.name + "` must be `true` or `false`."); - } - std::string bool_value = (raw_value == "true") ? "1" : "0"; - filter_exp = {field_name, {bool_value}, {bool_comparator}}; - } - } else if (_field.is_geopoint()) { - filter_exp = {field_name, {}, {}}; - const std::string& format_err_msg = "Value of filter field `" + _field.name + - "`: must be in the `(-44.50, 170.29, 0.75 km)` or " - "(56.33, -65.97, 23.82, -127.82) format."; - NUM_COMPARATOR num_comparator; - // could be a single value or a list - if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector filter_values; - StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, "),"); - for (std::string& filter_value: filter_values) { - filter_value += ")"; - std::string processed_filter_val; - auto parse_op = filter::parse_geopoint_filter_value(filter_value, format_err_msg, processed_filter_val, - num_comparator); - if (!parse_op.ok()) { - return parse_op; - } - filter_exp.values.push_back(processed_filter_val); - filter_exp.comparators.push_back(num_comparator); - } - } else { - // single value, e.g. (10.45, 34.56, 2 km) - std::string processed_filter_val; - auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, processed_filter_val, - num_comparator); - if (!parse_op.ok()) { - return parse_op; - } - filter_exp.values.push_back(processed_filter_val); - filter_exp.comparators.push_back(num_comparator); - } - } else if (_field.is_string()) { - size_t filter_value_index = 0; - NUM_COMPARATOR str_comparator = CONTAINS; - if (raw_value[0] == '=') { - // string filter should be evaluated in strict "equals" mode - str_comparator = EQUALS; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - str_comparator = NOT_EQUALS; - filter_value_index++; - while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); - } - if (filter_value_index == raw_value.size()) { - return Option(400, "Error with filter field `" + _field.name + - "`: Filter value cannot be empty."); - } - if (raw_value[filter_value_index] == '[' && raw_value[raw_value.size() - 1] == ']') { - std::vector filter_values; - StringUtils::split_to_values( - raw_value.substr(filter_value_index + 1, raw_value.size() - filter_value_index - 2), filter_values); - filter_exp = {field_name, filter_values, {str_comparator}}; - } else { - filter_exp = {field_name, {raw_value.substr(filter_value_index)}, {str_comparator}}; - } - - filter_exp.apply_not_equals = (str_comparator == NOT_EQUALS); - } else { - return Option(400, "Error with filter field `" + _field.name + - "`: Unidentified field data type, see docs for supported data types."); - } - - return Option(true); -} - -// https://stackoverflow.com/a/423914/11218270 -Option toParseTree(std::queue& postfix, filter_node_t*& root, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix) { - std::stack nodeStack; - bool is_successful = true; - std::string error_message; - - filter_node_t *filter_node = nullptr; - - while (!postfix.empty()) { - const std::string expression = postfix.front(); - postfix.pop(); - - if (isOperator(expression)) { - if (nodeStack.empty()) { - is_successful = false; - error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; - break; - } - auto operandB = nodeStack.top(); - nodeStack.pop(); - - if (nodeStack.empty()) { - delete operandB; - is_successful = false; - error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; - break; - } - auto operandA = nodeStack.top(); - nodeStack.pop(); - - filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); - } else { - filter filter_exp; - - // Expected value: $Collection(...) - bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')'); - if (is_referenced_filter) { - size_t parenthesis_index = expression.find('('); - - std::string collection_name = expression.substr(1, parenthesis_index - 1); - auto &cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(collection_name); - if (collection == nullptr) { - is_successful = false; - error_message = "Referenced collection `" + collection_name + "` not found."; - break; - } - - filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)}; - filter_exp.referenced_collection_name = collection_name; - } else { - Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); - if (!toFilter_op.ok()) { - is_successful = false; - error_message = toFilter_op.error(); - break; - } - } - - filter_node = new filter_node_t(filter_exp); - } - - nodeStack.push(filter_node); - } - - if (!is_successful) { - while (!nodeStack.empty()) { - auto filterNode = nodeStack.top(); - delete filterNode; - nodeStack.pop(); - } - - return Option(400, error_message); - } - - if (nodeStack.empty()) { - return Option(400, "Filter query cannot be empty."); - } - root = nodeStack.top(); - - return Option(true); -} - -Option filter::parse_filter_query(const std::string& filter_query, - const tsl::htrie_map& search_schema, - const Store* store, - const std::string& doc_id_prefix, - filter_node_t*& root) { - auto _filter_query = filter_query; - StringUtils::trim(_filter_query); - if (_filter_query.empty()) { - return Option(true); - } - - std::queue tokens; - Option tokenize_op = StringUtils::tokenize_filter_query(filter_query, tokens); - if (!tokenize_op.ok()) { - return tokenize_op; - } - - if (tokens.size() > 100) { - return Option(400, "Filter expression is not valid."); - } - - std::queue postfix; - Option toPostfix_op = toPostfix(tokens, postfix); - if (!toPostfix_op.ok()) { - return toPostfix_op; - } - - Option toParseTree_op = toParseTree(postfix, - root, - search_schema, - store, - doc_id_prefix); - if (!toParseTree_op.ok()) { - return toParseTree_op; - } - - return Option(true); -} - Option field::json_field_to_field(bool enable_nested_fields, nlohmann::json& field_json, std::vector& the_fields, string& fallback_field_type, size_t& num_auto_detect_fields) { @@ -1046,167 +541,3 @@ void field::compact_nested_fields(tsl::htrie_map& nested_fields) { nested_fields.erase_prefix(field_name + "."); } } - -void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { - auto lenA = a.count, lenB = b.count; - if (lenA == 0 || lenB == 0) { - return; - } - - result.docs = new uint32_t[std::min(lenA, lenB)]; - - auto A = a.docs, B = b.docs, out = result.docs; - const uint32_t *endA = A + lenA; - const uint32_t *endB = B + lenB; - - // Add an entry of references in the result for each unique collection in a and b. - for (auto const& item: a.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; - } - } - for (auto const& item: b.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; - } - } - - while (true) { - while (*A < *B) { - SKIP_FIRST_COMPARE: - if (++A == endA) { - result.count = out - result.docs; - return; - } - } - while (*A > *B) { - if (++B == endB) { - result.count = out - result.docs; - return; - } - } - if (*A == *B) { - *out = *A; - - // Copy the references of the document from every collection into result. - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][out - result.docs] = item.second[A - a.docs]; - } - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][out - result.docs] = item.second[B - b.docs]; - } - - out++; - - if (++A == endA || ++B == endB) { - result.count = out - result.docs; - return; - } - } else { - goto SKIP_FIRST_COMPARE; - } - } -} - -void filter_result_t::or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { - if (a.count == 0 && b.count == 0) { - return; - } - - // If either one of a or b does not have any matches, copy other into result. - if (a.count == 0) { - result = b; - return; - } - if (b.count == 0) { - result = a; - return; - } - - size_t indexA = 0, indexB = 0, res_index = 0, lenA = a.count, lenB = b.count; - result.docs = new uint32_t[lenA + lenB]; - - // Add an entry of references in the result for each unique collection in a and b. - for (auto const& item: a.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; - } - } - for (auto const& item: b.reference_filter_results) { - if (result.reference_filter_results.count(item.first) == 0) { - result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; - } - } - - while (indexA < lenA && indexB < lenB) { - if (a.docs[indexA] < b.docs[indexB]) { - // check for duplicate - if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { - result.docs[res_index] = a.docs[indexA]; - res_index++; - } - - // Copy references of the last result document from every collection in a. - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; - } - - indexA++; - } else { - if (res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { - result.docs[res_index] = b.docs[indexB]; - res_index++; - } - - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; - } - - indexB++; - } - } - - while (indexA < lenA) { - if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { - result.docs[res_index] = a.docs[indexA]; - res_index++; - } - - for (auto const& item: a.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; - } - - indexA++; - } - - while (indexB < lenB) { - if(res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { - result.docs[res_index] = b.docs[indexB]; - res_index++; - } - - for (auto const& item: b.reference_filter_results) { - result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; - } - - indexB++; - } - - result.count = res_index; - - // shrink fit - auto out = new uint32_t[res_index]; - memcpy(out, result.docs, res_index * sizeof(uint32_t)); - delete[] result.docs; - result.docs = out; - - for (auto &item: result.reference_filter_results) { - auto out_references = new reference_filter_result_t[res_index]; - - for (uint32_t i = 0; i < result.count; i++) { - out_references[i] = item.second[i]; - } - delete[] item.second; - item.second = out_references; - } -} diff --git a/src/filter.cpp b/src/filter.cpp index 77ef5d5d..95fbfefc 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -1,641 +1,573 @@ #include #include #include +#include #include "filter.h" -void filter_result_iterator_t::and_filter_iterators() { - while (left_it->is_valid && right_it->is_valid) { - while (left_it->seq_id < right_it->seq_id) { - left_it->next(); - if (!left_it->is_valid) { - is_valid = false; - return; - } - } - - while (left_it->seq_id > right_it->seq_id) { - right_it->next(); - if (!right_it->is_valid) { - is_valid = false; - return; - } - } - - if (left_it->seq_id == right_it->seq_id) { - seq_id = left_it->seq_id; - reference.clear(); - - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; - } +Option filter::validate_numerical_filter_value(field _field, const string &raw_value) { + if(_field.is_int32() && !StringUtils::is_int32_t(raw_value)) { + return Option(400, "Error with filter field `" + _field.name + "`: Not an int32."); } - is_valid = false; + else if(_field.is_int64() && !StringUtils::is_int64_t(raw_value)) { + return Option(400, "Error with filter field `" + _field.name + "`: Not an int64."); + } + + else if(_field.is_float() && !StringUtils::is_float(raw_value)) { + return Option(400, "Error with filter field `" + _field.name + "`: Not a float."); + } + + return Option(true); } -void filter_result_iterator_t::or_filter_iterators() { - if (left_it->is_valid && right_it->is_valid) { - if (left_it->seq_id < right_it->seq_id) { - seq_id = left_it->seq_id; - reference.clear(); +Option filter::extract_num_comparator(string &comp_and_value) { + auto num_comparator = EQUALS; - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - - return; - } - - if (left_it->seq_id > right_it->seq_id) { - seq_id = right_it->seq_id; - reference.clear(); - - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; - } - - seq_id = left_it->seq_id; - reference.clear(); - - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; + if(StringUtils::is_integer(comp_and_value) || StringUtils::is_float(comp_and_value)) { + num_comparator = EQUALS; } - if (left_it->is_valid) { - seq_id = left_it->seq_id; - reference.clear(); - - for (const auto& item: left_it->reference) { - reference[item.first] = item.second; - } - - return; + // the ordering is important - we have to compare 2-letter operators first + else if(comp_and_value.compare(0, 2, "<=") == 0) { + num_comparator = LESS_THAN_EQUALS; } - if (right_it->is_valid) { - seq_id = right_it->seq_id; - reference.clear(); - - for (const auto& item: right_it->reference) { - reference[item.first] = item.second; - } - - return; + else if(comp_and_value.compare(0, 2, ">=") == 0) { + num_comparator = GREATER_THAN_EQUALS; } - is_valid = false; + else if(comp_and_value.compare(0, 2, "!=") == 0) { + num_comparator = NOT_EQUALS; + } + + else if(comp_and_value.compare(0, 1, "<") == 0) { + num_comparator = LESS_THAN; + } + + else if(comp_and_value.compare(0, 1, ">") == 0) { + num_comparator = GREATER_THAN; + } + + else if(comp_and_value.find("..") != std::string::npos) { + num_comparator = RANGE_INCLUSIVE; + } + + else { + return Option(400, "Numerical field has an invalid comparator."); + } + + if(num_comparator == LESS_THAN || num_comparator == GREATER_THAN) { + comp_and_value = comp_and_value.substr(1); + } else if(num_comparator == LESS_THAN_EQUALS || num_comparator == GREATER_THAN_EQUALS || num_comparator == NOT_EQUALS) { + comp_and_value = comp_and_value.substr(2); + } + + comp_and_value = StringUtils::trim(comp_and_value); + + return Option(num_comparator); } -void filter_result_iterator_t::doc_matching_string_filter() { - // If none of the filter value iterators are valid, mark this node as invalid. - bool one_is_valid = false; +Option filter::parse_geopoint_filter_value(std::string& raw_value, + const std::string& format_err_msg, + std::string& processed_filter_val, + NUM_COMPARATOR& num_comparator) { - // Since we do OR between filter values, the lowest seq_id id from all is selected. - uint32_t lowest_id = UINT32_MAX; + num_comparator = LESS_THAN_EQUALS; - for (auto& filter_value_tokens : posting_list_iterators) { - // Perform AND between tokens of a filter value. - bool tokens_iter_is_valid; - posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + if(!(raw_value[0] == '(' && raw_value[raw_value.size() - 1] == ')')) { + return Option(400, format_err_msg); + } - one_is_valid = tokens_iter_is_valid || one_is_valid; + std::vector filter_values; + auto raw_val_without_paran = raw_value.substr(1, raw_value.size() - 2); + StringUtils::split(raw_val_without_paran, filter_values, ","); - if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { - lowest_id = filter_value_tokens[0].id(); + // we will end up with: "10.45 34.56 2 km" or "10.45 34.56 2mi" or a geo polygon + + if(filter_values.size() < 3) { + return Option(400, format_err_msg); + } + + // do validation: format should match either a point + radius or polygon + + size_t num_floats = 0; + for(const auto& fvalue: filter_values) { + if(StringUtils::is_float(fvalue)) { + num_floats++; } } - if (one_is_valid) { - seq_id = lowest_id; + bool is_polygon = (num_floats == filter_values.size()); + if(!is_polygon) { + // we have to ensure that this is a point + radius match + if(!StringUtils::is_float(filter_values[0]) || !StringUtils::is_float(filter_values[1])) { + return Option(400, format_err_msg); + } + + if(filter_values[0] == "nan" || filter_values[0] == "NaN" || + filter_values[1] == "nan" || filter_values[1] == "NaN") { + return Option(400, format_err_msg); + } } - is_valid = one_is_valid; + if(is_polygon) { + processed_filter_val = raw_val_without_paran; + } else { + // point + radius + // filter_values[2] is distance, get the unit, validate it and split on that + if(filter_values[2].size() < 2) { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::string unit = filter_values[2].substr(filter_values[2].size()-2, 2); + + if(unit != "km" && unit != "mi") { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::vector dist_values; + StringUtils::split(filter_values[2], dist_values, unit); + + if(dist_values.size() != 1) { + return Option(400, format_err_msg); + } + + if(!StringUtils::is_float(dist_values[0])) { + return Option(400, format_err_msg); + } + + processed_filter_val = filter_values[0] + ", " + filter_values[1] + ", " + // co-ords + dist_values[0] + ", " + unit; // X km + } + + return Option(true); } -void filter_result_iterator_t::next() { - if (!is_valid) { - return; - } +bool isOperator(const std::string& expression) { + return expression == "&&" || expression == "||"; +} - if (filter_node->isOperator) { - // Advance the subtrees and then apply operators to arrive at the next valid doc. - if (filter_node->filter_operator == AND) { - left_it->next(); - right_it->next(); - and_filter_iterators(); +// https://en.wikipedia.org/wiki/Shunting_yard_algorithm +Option toPostfix(std::queue& tokens, std::queue& postfix) { + std::stack operatorStack; + + while (!tokens.empty()) { + auto expression = tokens.front(); + tokens.pop(); + + if (isOperator(expression)) { + // We only have two operators &&, || having the same precedence and both being left associative. + while (!operatorStack.empty() && operatorStack.top() != "(") { + postfix.push(operatorStack.top()); + operatorStack.pop(); + } + + operatorStack.push(expression); + } else if (expression == "(") { + operatorStack.push(expression); + } else if (expression == ")") { + while (!operatorStack.empty() && operatorStack.top() != "(") { + postfix.push(operatorStack.top()); + operatorStack.pop(); + } + + if (operatorStack.empty() || operatorStack.top() != "(") { + return Option(400, "Could not parse the filter query: unbalanced parentheses."); + } + operatorStack.pop(); } else { - if (left_it->seq_id == seq_id && right_it->seq_id == seq_id) { - left_it->next(); - right_it->next(); - } else if (left_it->seq_id == seq_id) { - left_it->next(); - } else { - right_it->next(); - } - - or_filter_iterators(); + postfix.push(expression); } - - return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - if (++result_index >= filter_result.count) { - is_valid = false; - return; + while (!operatorStack.empty()) { + if (operatorStack.top() == "(") { + return Option(400, "Could not parse the filter query: unbalanced parentheses."); } - - seq_id = filter_result.docs[result_index]; - reference.clear(); - for (auto const& item: filter_result.reference_filter_results) { - reference[item.first] = item.second[result_index]; - } - - return; + postfix.push(operatorStack.top()); + operatorStack.pop(); } - if (a_filter.field_name == "id") { - if (++result_index >= filter_result.count) { - is_valid = false; - return; + return Option(true); +} + +Option toMultiValueNumericFilter(std::string& raw_value, filter& filter_exp, const field& _field) { + std::vector filter_values; + StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); + filter_exp = {_field.name, {}, {}}; + for (std::string& filter_value: filter_values) { + Option op_comparator = filter::extract_num_comparator(filter_value); + if (!op_comparator.ok()) { + return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); } - - seq_id = filter_result.docs[result_index]; - return; - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - // Advance all the filter values that are at doc. Then find the next doc. - for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { - auto& filter_value_tokens = posting_list_iterators[i]; - - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { - for (auto& iter: filter_value_tokens) { - iter.next(); + if (op_comparator.get() == RANGE_INCLUSIVE) { + // split the value around range operator to extract bounds + std::vector range_values; + StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR()); + for (const std::string& range_value: range_values) { + auto validate_op = filter::validate_numerical_filter_value(_field, range_value); + if (!validate_op.ok()) { + return validate_op; } + filter_exp.values.push_back(range_value); + filter_exp.comparators.push_back(op_comparator.get()); } + } else { + auto validate_op = filter::validate_numerical_filter_value(_field, filter_value); + if (!validate_op.ok()) { + return validate_op; + } + filter_exp.values.push_back(filter_value); + filter_exp.comparators.push_back(op_comparator.get()); } - - doc_matching_string_filter(); - return; } + + return Option(true); } -void filter_result_iterator_t::init() { - if (filter_node == nullptr) { - return; +Option toFilter(const std::string expression, + filter& filter_exp, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix) { + // split into [field_name, value] + size_t found_index = expression.find(':'); + if (found_index == std::string::npos) { + return Option(400, "Could not parse the filter query."); } - - if (filter_node->isOperator) { - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); + std::string&& field_name = expression.substr(0, found_index); + StringUtils::trim(field_name); + if (field_name == "id") { + std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); + StringUtils::trim(raw_value); + std::string empty_filter_err = "Error with filter field `id`: Filter value cannot be empty."; + if (raw_value.empty()) { + return Option(400, empty_filter_err); } - - return; - } - - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. - auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(a_filter.referenced_collection_name); - if (collection == nullptr) { - status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); - is_valid = false; - return; + filter_exp = {field_name, {}, {}}; + NUM_COMPARATOR id_comparator = EQUALS; + size_t filter_value_index = 0; + if (raw_value[0] == '=') { + id_comparator = EQUALS; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + return Option(400, "Not equals filtering is not supported on the `id` field."); } - - auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, - filter_result, - collection_name); - if (!reference_filter_op.ok()) { - status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name - + "` collection: " + reference_filter_op.error()); - is_valid = false; - return; + if (filter_value_index != 0) { + raw_value = raw_value.substr(filter_value_index); } - - is_valid = filter_result.count > 0; - return; - } - - if (a_filter.field_name == "id") { - if (a_filter.values.empty()) { - is_valid = false; - return; + if (raw_value.empty()) { + return Option(400, empty_filter_err); } - - // we handle `ids` separately - std::vector result_ids; - for (const auto& id_str : a_filter.values) { - result_ids.push_back(std::stoul(id_str)); - } - - std::sort(result_ids.begin(), result_ids.end()); - - filter_result.count = result_ids.size(); - filter_result.docs = new uint32_t[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - art_tree* t = index->search_index.at(a_filter.field_name); - - for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; - - // there could be multiple tokens in a filter value, which we have to treat as ANDs - // e.g. country: South Africa - Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); - - std::string str_token; - size_t token_index = 0; - std::vector str_tokens; - - while (tokenizer.next(str_token, token_index)) { - str_tokens.push_back(str_token); - - art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), - str_token.length()+1); - if (leaf == nullptr) { + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector doc_ids; + StringUtils::split_to_values(raw_value.substr(1, raw_value.size() - 2), doc_ids); + for (std::string& doc_id: doc_ids) { + // we have to convert the doc_id to seq id + std::string seq_id_str; + StoreStatus seq_id_status = store->get(doc_id_prefix + doc_id, seq_id_str); + if (seq_id_status != StoreStatus::FOUND) { continue; } - - posting_lists.push_back(leaf->values); + filter_exp.values.push_back(seq_id_str); + filter_exp.comparators.push_back(id_comparator); } - - if (posting_lists.size() != str_tokens.size()) { - continue; - } - - std::vector plists; - posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); - - posting_list_iterators.emplace_back(std::vector()); - - for (auto const& plist: plists) { - posting_list_iterators.back().push_back(plist->new_iterator()); + } else { + std::vector doc_ids; + StringUtils::split_to_values(raw_value, doc_ids); // to handle backticks + std::string seq_id_str; + StoreStatus seq_id_status = store->get(doc_id_prefix + doc_ids[0], seq_id_str); + if (seq_id_status == StoreStatus::FOUND) { + filter_exp.values.push_back(seq_id_str); + filter_exp.comparators.push_back(id_comparator); } } - - doc_matching_string_filter(); - return; + return Option(true); } + + auto field_it = search_schema.find(field_name); + + if (field_it == search_schema.end()) { + return Option(404, "Could not find a filter field named `" + field_name + "` in the schema."); + } + + if (field_it->num_dim > 0) { + return Option(404, "Cannot filter on vector field `" + field_name + "`."); + } + + const field& _field = field_it.value(); + std::string&& raw_value = expression.substr(found_index + 1, std::string::npos); + StringUtils::trim(raw_value); + // skip past optional `:=` operator, which has no meaning for non-string fields + if (!_field.is_string() && raw_value[0] == '=') { + size_t filter_value_index = 0; + while (raw_value[++filter_value_index] == ' '); + raw_value = raw_value.substr(filter_value_index); + } + if (_field.is_integer() || _field.is_float()) { + // could be a single value or a list + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); + if (!op.ok()) { + return op; + } + } else { + Option op_comparator = filter::extract_num_comparator(raw_value); + if (!op_comparator.ok()) { + return Option(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error()); + } + if (op_comparator.get() == RANGE_INCLUSIVE) { + // split the value around range operator to extract bounds + std::vector range_values; + StringUtils::split(raw_value, range_values, filter::RANGE_OPERATOR()); + filter_exp.field_name = field_name; + for (const std::string& range_value: range_values) { + auto validate_op = filter::validate_numerical_filter_value(_field, range_value); + if (!validate_op.ok()) { + return validate_op; + } + filter_exp.values.push_back(range_value); + filter_exp.comparators.push_back(op_comparator.get()); + } + } else if (op_comparator.get() == NOT_EQUALS && raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + Option op = toMultiValueNumericFilter(raw_value, filter_exp, _field); + if (!op.ok()) { + return op; + } + filter_exp.apply_not_equals = true; + } else { + auto validate_op = filter::validate_numerical_filter_value(_field, raw_value); + if (!validate_op.ok()) { + return validate_op; + } + filter_exp = {field_name, {raw_value}, {op_comparator.get()}}; + } + } + } else if (_field.is_bool()) { + NUM_COMPARATOR bool_comparator = EQUALS; + size_t filter_value_index = 0; + if (raw_value[0] == '=') { + bool_comparator = EQUALS; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + bool_comparator = NOT_EQUALS; + filter_value_index++; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } + if (filter_value_index != 0) { + raw_value = raw_value.substr(filter_value_index); + } + if (filter_value_index == raw_value.size()) { + return Option(400, "Error with filter field `" + _field.name + + "`: Filter value cannot be empty."); + } + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector filter_values; + StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, ","); + filter_exp = {field_name, {}, {}}; + for (std::string& filter_value: filter_values) { + if (filter_value != "true" && filter_value != "false") { + return Option(400, "Values of filter field `" + _field.name + + "`: must be `true` or `false`."); + } + filter_value = (filter_value == "true") ? "1" : "0"; + filter_exp.values.push_back(filter_value); + filter_exp.comparators.push_back(bool_comparator); + } + } else { + if (raw_value != "true" && raw_value != "false") { + return Option(400, "Value of filter field `" + _field.name + "` must be `true` or `false`."); + } + std::string bool_value = (raw_value == "true") ? "1" : "0"; + filter_exp = {field_name, {bool_value}, {bool_comparator}}; + } + } else if (_field.is_geopoint()) { + filter_exp = {field_name, {}, {}}; + const std::string& format_err_msg = "Value of filter field `" + _field.name + + "`: must be in the `(-44.50, 170.29, 0.75 km)` or " + "(56.33, -65.97, 23.82, -127.82) format."; + NUM_COMPARATOR num_comparator; + // could be a single value or a list + if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector filter_values; + StringUtils::split(raw_value.substr(1, raw_value.size() - 2), filter_values, "),"); + for (std::string& filter_value: filter_values) { + filter_value += ")"; + std::string processed_filter_val; + auto parse_op = filter::parse_geopoint_filter_value(filter_value, format_err_msg, processed_filter_val, + num_comparator); + if (!parse_op.ok()) { + return parse_op; + } + filter_exp.values.push_back(processed_filter_val); + filter_exp.comparators.push_back(num_comparator); + } + } else { + // single value, e.g. (10.45, 34.56, 2 km) + std::string processed_filter_val; + auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, processed_filter_val, + num_comparator); + if (!parse_op.ok()) { + return parse_op; + } + filter_exp.values.push_back(processed_filter_val); + filter_exp.comparators.push_back(num_comparator); + } + } else if (_field.is_string()) { + size_t filter_value_index = 0; + NUM_COMPARATOR str_comparator = CONTAINS; + if (raw_value[0] == '=') { + // string filter should be evaluated in strict "equals" mode + str_comparator = EQUALS; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { + str_comparator = NOT_EQUALS; + filter_value_index++; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); + } + if (filter_value_index == raw_value.size()) { + return Option(400, "Error with filter field `" + _field.name + + "`: Filter value cannot be empty."); + } + if (raw_value[filter_value_index] == '[' && raw_value[raw_value.size() - 1] == ']') { + std::vector filter_values; + StringUtils::split_to_values( + raw_value.substr(filter_value_index + 1, raw_value.size() - filter_value_index - 2), filter_values); + filter_exp = {field_name, filter_values, {str_comparator}}; + } else { + filter_exp = {field_name, {raw_value.substr(filter_value_index)}, {str_comparator}}; + } + + filter_exp.apply_not_equals = (str_comparator == NOT_EQUALS); + } else { + return Option(400, "Error with filter field `" + _field.name + + "`: Unidentified field data type, see docs for supported data types."); + } + + return Option(true); } -bool filter_result_iterator_t::valid() { - if (!is_valid) { - return false; - } +// https://stackoverflow.com/a/423914/11218270 +Option toParseTree(std::queue& postfix, filter_node_t*& root, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix) { + std::stack nodeStack; + bool is_successful = true; + std::string error_message; - if (filter_node->isOperator) { - if (filter_node->filter_operator == AND) { - is_valid = left_it->valid() && right_it->valid(); - return is_valid; - } else { - is_valid = left_it->valid() || right_it->valid(); - return is_valid; - } - } + filter_node_t *filter_node = nullptr; - const filter a_filter = filter_node->filter_exp; + while (!postfix.empty()) { + const std::string expression = postfix.front(); + postfix.pop(); - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { - is_valid = result_index < filter_result.count; - return is_valid; - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return is_valid; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - bool one_is_valid = false; - for (auto& filter_value_tokens: posting_list_iterators) { - posting_list_t::intersect(filter_value_tokens, one_is_valid); - - if (one_is_valid) { + if (isOperator(expression)) { + if (nodeStack.empty()) { + is_successful = false; + error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; break; } - } + auto operandB = nodeStack.top(); + nodeStack.pop(); - is_valid = one_is_valid; - return is_valid; - } + if (nodeStack.empty()) { + delete operandB; + is_successful = false; + error_message = "Could not parse the filter query: unbalanced `" + expression + "` operands."; + break; + } + auto operandA = nodeStack.top(); + nodeStack.pop(); - return false; -} - -void filter_result_iterator_t::skip_to(uint32_t id) { - if (!is_valid) { - return; - } - - if (filter_node->isOperator) { - // Skip the subtrees to id and then apply operators to arrive at the next valid doc. - left_it->skip_to(id); - right_it->skip_to(id); - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); + filter_node = new filter_node_t(expression == "&&" ? AND : OR, operandA, operandB); } else { - or_filter_iterators(); - } + filter filter_exp; - return; - } + // Expected value: $Collection(...) + bool is_referenced_filter = (expression[0] == '$' && expression[expression.size() - 1] == ')'); + if (is_referenced_filter) { + size_t parenthesis_index = expression.find('('); - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); - - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - reference.clear(); - for (auto const& item: filter_result.reference_filter_results) { - reference[item.first] = item.second[result_index]; - } - - return; - } - - if (a_filter.field_name == "id") { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); - - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - // Skip all the token iterators and find a new match. - for (auto& filter_value_tokens : posting_list_iterators) { - for (auto& token: filter_value_tokens) { - // We perform AND on tokens. Short-circuiting here. - if (!token.valid()) { + std::string collection_name = expression.substr(1, parenthesis_index - 1); + auto &cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(collection_name); + if (collection == nullptr) { + is_successful = false; + error_message = "Referenced collection `" + collection_name + "` not found."; break; } - token.skip_to(id); - } - } - - doc_matching_string_filter(); - return; - } -} - -int filter_result_iterator_t::valid(uint32_t id) { - if (!is_valid) { - return -1; - } - - if (filter_node->isOperator) { - auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); - - if (filter_node->filter_operator == AND) { - is_valid = left_it->is_valid && right_it->is_valid; - - if (left_valid < 1 || right_valid < 1) { - if (left_valid == -1 || right_valid == -1) { - return -1; - } - - return 0; - } - - return 1; - } else { - is_valid = left_it->is_valid || right_it->is_valid; - - if (left_valid < 1 && right_valid < 1) { - if (left_valid == -1 && right_valid == -1) { - return -1; - } - - return 0; - } - - return 1; - } - } - - if (filter_node->filter_exp.apply_not_equals) { - // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. - if (!valid()) { - is_valid = true; - return 1; - } - - skip_to(id); - - if (!is_valid) { - is_valid = true; - return 1; - } - - return seq_id != id ? 1 : 0; - } - - skip_to(id); - return is_valid ? (seq_id == id ? 1 : 0) : -1; -} - -Option filter_result_iterator_t::init_status() { - if (filter_node != nullptr && filter_node->isOperator) { - auto left_status = left_it->init_status(); - - return !left_status.ok() ? left_status : right_it->init_status(); - } - - return status; -} - -bool filter_result_iterator_t::contains_atleast_one(const void *obj) { - if(IS_COMPACT_POSTING(obj)) { - compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); - - size_t i = 0; - while(i < list->length && valid()) { - size_t num_existing_offsets = list->id_offsets[i]; - size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; - - if (existing_id == seq_id) { - return true; - } - - // advance smallest value - if (existing_id < seq_id) { - i += num_existing_offsets + 2; + filter_exp = {expression.substr(parenthesis_index + 1, expression.size() - parenthesis_index - 2)}; + filter_exp.referenced_collection_name = collection_name; } else { - skip_to(existing_id); - } - } - } else { - auto list = (posting_list_t*)(obj); - posting_list_t::iterator_t it = list->new_iterator(); - - while(it.valid() && valid()) { - uint32_t id = it.id(); - - if(id == seq_id) { - return true; + Option toFilter_op = toFilter(expression, filter_exp, search_schema, store, doc_id_prefix); + if (!toFilter_op.ok()) { + is_successful = false; + error_message = toFilter_op.error(); + break; + } } - if(id < seq_id) { - it.skip_to(seq_id); - } else { - skip_to(id); - } + filter_node = new filter_node_t(filter_exp); } + + nodeStack.push(filter_node); } - return false; + if (!is_successful) { + while (!nodeStack.empty()) { + auto filterNode = nodeStack.top(); + delete filterNode; + nodeStack.pop(); + } + + return Option(400, error_message); + } + + if (nodeStack.empty()) { + return Option(400, "Filter query cannot be empty."); + } + root = nodeStack.top(); + + return Option(true); } -void filter_result_iterator_t::reset() { - if (filter_node == nullptr) { - return; +Option filter::parse_filter_query(const std::string& filter_query, + const tsl::htrie_map& search_schema, + const Store* store, + const std::string& doc_id_prefix, + filter_node_t*& root) { + auto _filter_query = filter_query; + StringUtils::trim(_filter_query); + if (_filter_query.empty()) { + return Option(true); } - if (filter_node->isOperator) { - // Reset the subtrees then apply operators to arrive at the first valid doc. - left_it->reset(); - right_it->reset(); - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); - } - - return; + std::queue tokens; + Option tokenize_op = StringUtils::tokenize_filter_query(filter_query, tokens); + if (!tokenize_op.ok()) { + return tokenize_op; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter || a_filter.field_name == "id") { - result_index = 0; - is_valid = filter_result.count > 0; - return; + if (tokens.size() > 100) { + return Option(400, "Filter expression is not valid."); } - if (!index->field_is_indexed(a_filter.field_name)) { - return; + std::queue postfix; + Option toPostfix_op = toPostfix(tokens, postfix); + if (!toPostfix_op.ok()) { + return toPostfix_op; } - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - posting_list_iterators.clear(); - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - expanded_plists.clear(); - - init(); - return; + Option toParseTree_op = toParseTree(postfix, + root, + search_schema, + store, + doc_id_prefix); + if (!toParseTree_op.ok()) { + return toParseTree_op; } -} - -uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { - if (!valid()) { - return 0; - } - - std::vector filter_ids; - do { - filter_ids.push_back(seq_id); - next(); - } while (valid()); - - filter_array = new uint32_t[filter_ids.size()]; - std::copy(filter_ids.begin(), filter_ids.end(), filter_array); - - return filter_ids.size(); -} - -uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { - if (!valid()) { - return 0; - } - - std::vector filter_ids; - for (uint32_t i = 0; i < lenA; i++) { - auto result = valid(A[i]); - - if (result == -1) { - break; - } - - if (result == 1) { - filter_ids.push_back(A[i]); - } - } - - if (filter_ids.empty()) { - return 0; - } - - results = new uint32_t[filter_ids.size()]; - std::copy(filter_ids.begin(), filter_ids.end(), results); - - return filter_ids.size(); + + return Option(true); } diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp new file mode 100644 index 00000000..21a4128f --- /dev/null +++ b/src/filter_result_iterator.cpp @@ -0,0 +1,901 @@ +#include "filter_result_iterator.h" +#include "index.h" +#include "posting.h" +#include "collection_manager.h" + +void filter_result_t::and_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { + auto lenA = a.count, lenB = b.count; + if (lenA == 0 || lenB == 0) { + return; + } + + result.docs = new uint32_t[std::min(lenA, lenB)]; + + auto A = a.docs, B = b.docs, out = result.docs; + const uint32_t *endA = A + lenA; + const uint32_t *endB = B + lenB; + + // Add an entry of references in the result for each unique collection in a and b. + for (auto const& item: a.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; + } + } + for (auto const& item: b.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[std::min(lenA, lenB)]; + } + } + + while (true) { + while (*A < *B) { + SKIP_FIRST_COMPARE: + if (++A == endA) { + result.count = out - result.docs; + return; + } + } + while (*A > *B) { + if (++B == endB) { + result.count = out - result.docs; + return; + } + } + if (*A == *B) { + *out = *A; + + // Copy the references of the document from every collection into result. + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][out - result.docs] = item.second[A - a.docs]; + } + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][out - result.docs] = item.second[B - b.docs]; + } + + out++; + + if (++A == endA || ++B == endB) { + result.count = out - result.docs; + return; + } + } else { + goto SKIP_FIRST_COMPARE; + } + } +} + +void filter_result_t::or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result) { + if (a.count == 0 && b.count == 0) { + return; + } + + // If either one of a or b does not have any matches, copy other into result. + if (a.count == 0) { + result = b; + return; + } + if (b.count == 0) { + result = a; + return; + } + + size_t indexA = 0, indexB = 0, res_index = 0, lenA = a.count, lenB = b.count; + result.docs = new uint32_t[lenA + lenB]; + + // Add an entry of references in the result for each unique collection in a and b. + for (auto const& item: a.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; + } + } + for (auto const& item: b.reference_filter_results) { + if (result.reference_filter_results.count(item.first) == 0) { + result.reference_filter_results[item.first] = new reference_filter_result_t[lenA + lenB]; + } + } + + while (indexA < lenA && indexB < lenB) { + if (a.docs[indexA] < b.docs[indexB]) { + // check for duplicate + if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { + result.docs[res_index] = a.docs[indexA]; + res_index++; + } + + // Copy references of the last result document from every collection in a. + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; + } + + indexA++; + } else { + if (res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { + result.docs[res_index] = b.docs[indexB]; + res_index++; + } + + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; + } + + indexB++; + } + } + + while (indexA < lenA) { + if (res_index == 0 || result.docs[res_index - 1] != a.docs[indexA]) { + result.docs[res_index] = a.docs[indexA]; + res_index++; + } + + for (auto const& item: a.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexA]; + } + + indexA++; + } + + while (indexB < lenB) { + if(res_index == 0 || result.docs[res_index - 1] != b.docs[indexB]) { + result.docs[res_index] = b.docs[indexB]; + res_index++; + } + + for (auto const& item: b.reference_filter_results) { + result.reference_filter_results[item.first][res_index - 1] = item.second[indexB]; + } + + indexB++; + } + + result.count = res_index; + + // shrink fit + auto out = new uint32_t[res_index]; + memcpy(out, result.docs, res_index * sizeof(uint32_t)); + delete[] result.docs; + result.docs = out; + + for (auto &item: result.reference_filter_results) { + auto out_references = new reference_filter_result_t[res_index]; + + for (uint32_t i = 0; i < result.count; i++) { + out_references[i] = item.second[i]; + } + delete[] item.second; + item.second = out_references; + } +} + +void filter_result_iterator_t::and_filter_iterators() { + while (left_it->is_valid && right_it->is_valid) { + while (left_it->seq_id < right_it->seq_id) { + left_it->next(); + if (!left_it->is_valid) { + is_valid = false; + return; + } + } + + while (left_it->seq_id > right_it->seq_id) { + right_it->next(); + if (!right_it->is_valid) { + is_valid = false; + return; + } + } + + if (left_it->seq_id == right_it->seq_id) { + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + } + + is_valid = false; +} + +void filter_result_iterator_t::or_filter_iterators() { + if (left_it->is_valid && right_it->is_valid) { + if (left_it->seq_id < right_it->seq_id) { + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->seq_id > right_it->seq_id) { + seq_id = right_it->seq_id; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->is_valid) { + seq_id = left_it->seq_id; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (right_it->is_valid) { + seq_id = right_it->seq_id; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + is_valid = false; +} + +void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { + // If none of the filter value iterators are valid, mark this node as invalid. + bool one_is_valid = false; + + // Since we do OR between filter values, the lowest seq_id id from all is selected. + uint32_t lowest_id = UINT32_MAX; + + if (filter_node->filter_exp.comparators[0] == EQUALS || filter_node->filter_exp.comparators[0] == NOT_EQUALS) { + for (auto& filter_value_tokens : posting_list_iterators) { + bool tokens_iter_is_valid, exact_match = false; + while(true) { + // Perform AND between tokens of a filter value. + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + if (!tokens_iter_is_valid) { + break; + } + + if (posting_list_t::has_exact_match(filter_value_tokens, field_is_array)) { + exact_match = true; + break; + } else { + // Keep advancing token iterators till exact match is not found. + for (auto &item: filter_value_tokens) { + item.next(); + } + } + } + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && exact_match && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + } else { + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + } + + if (one_is_valid) { + seq_id = lowest_id; + } + + is_valid = one_is_valid; +} + +void filter_result_iterator_t::next() { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Advance the subtrees and then apply operators to arrive at the next valid doc. + if (filter_node->filter_operator == AND) { + left_it->next(); + right_it->next(); + and_filter_iterators(); + } else { + if (left_it->seq_id == seq_id && right_it->seq_id == seq_id) { + left_it->next(); + right_it->next(); + } else if (left_it->seq_id == seq_id) { + left_it->next(); + } else { + right_it->next(); + } + + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Advance all the filter values that are at doc. Then find the next doc. + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { + for (auto& iter: filter_value_tokens) { + iter.next(); + } + } + } + + doc_matching_string_filter(f.is_array()); + return; + } +} + +void filter_result_iterator_t::init() { + if (filter_node == nullptr) { + return; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + if (collection == nullptr) { + status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + is_valid = false; + return; + } + + auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, + filter_result, + collection_name); + if (!reference_filter_op.ok()) { + status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name + + "` collection: " + reference_filter_op.error()); + is_valid = false; + return; + } + + is_valid = filter_result.count > 0; + return; + } + + if (a_filter.field_name == "id") { + if (a_filter.values.empty()) { + is_valid = false; + return; + } + + // we handle `ids` separately + std::vector result_ids; + for (const auto& id_str : a_filter.values) { + result_ids.push_back(std::stoul(id_str)); + } + + std::sort(result_ids.begin(), result_ids.end()); + + filter_result.count = result_ids.size(); + filter_result.docs = new uint32_t[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + art_tree* t = index->search_index.at(a_filter.field_name); + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + std::vector plists; + posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + + posting_list_iterators.emplace_back(std::vector()); + + for (auto const& plist: plists) { + posting_list_iterators.back().push_back(plist->new_iterator()); + } + } + + doc_matching_string_filter(f.is_array()); + return; + } +} + +bool filter_result_iterator_t::valid() { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + is_valid = left_it->valid() && right_it->valid(); + return is_valid; + } else { + is_valid = left_it->valid() || right_it->valid(); + return is_valid; + } + } + + const filter a_filter = filter_node->filter_exp; + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + is_valid = result_index < filter_result.count; + return is_valid; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return is_valid; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + bool one_is_valid = false; + for (auto& filter_value_tokens: posting_list_iterators) { + posting_list_t::intersect(filter_value_tokens, one_is_valid); + + if (one_is_valid) { + break; + } + } + + is_valid = one_is_valid; + return is_valid; + } + + return false; +} + +void filter_result_iterator_t::skip_to(uint32_t id) { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + left_it->skip_to(id); + right_it->skip_to(id); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Skip all the token iterators and find a new match. + for (auto& filter_value_tokens : posting_list_iterators) { + for (auto& token: filter_value_tokens) { + // We perform AND on tokens. Short-circuiting here. + if (!token.valid()) { + break; + } + + token.skip_to(id); + } + } + + doc_matching_string_filter(f.is_array()); + return; + } +} + +int filter_result_iterator_t::valid(uint32_t id) { + if (!is_valid) { + return -1; + } + + if (filter_node->isOperator) { + auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); + + if (filter_node->filter_operator == AND) { + is_valid = left_it->is_valid && right_it->is_valid; + + if (left_valid < 1 || right_valid < 1) { + if (left_valid == -1 || right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; + } else { + is_valid = left_it->is_valid || right_it->is_valid; + + if (left_valid < 1 && right_valid < 1) { + if (left_valid == -1 && right_valid == -1) { + return -1; + } + + return 0; + } + + return 1; + } + } + + if (filter_node->filter_exp.apply_not_equals) { + // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. + if (!valid()) { + is_valid = true; + return 1; + } + + skip_to(id); + + if (!is_valid) { + is_valid = true; + return 1; + } + + return seq_id != id ? 1 : 0; + } + + skip_to(id); + return is_valid ? (seq_id == id ? 1 : 0) : -1; +} + +Option filter_result_iterator_t::init_status() { + if (filter_node != nullptr && filter_node->isOperator) { + auto left_status = left_it->init_status(); + + return !left_status.ok() ? left_status : right_it->init_status(); + } + + return status; +} + +bool filter_result_iterator_t::contains_atleast_one(const void *obj) { + if(IS_COMPACT_POSTING(obj)) { + compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); + + size_t i = 0; + while(i < list->length && valid()) { + size_t num_existing_offsets = list->id_offsets[i]; + size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; + + if (existing_id == seq_id) { + return true; + } + + // advance smallest value + if (existing_id < seq_id) { + i += num_existing_offsets + 2; + } else { + skip_to(existing_id); + } + } + } else { + auto list = (posting_list_t*)(obj); + posting_list_t::iterator_t it = list->new_iterator(); + + while(it.valid() && valid()) { + uint32_t id = it.id(); + + if(id == seq_id) { + return true; + } + + if(id < seq_id) { + it.skip_to(seq_id); + } else { + skip_to(id); + } + } + } + + return false; +} + +void filter_result_iterator_t::reset() { + if (filter_node == nullptr) { + return; + } + + if (filter_node->isOperator) { + // Reset the subtrees then apply operators to arrive at the first valid doc. + left_it->reset(); + right_it->reset(); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter || a_filter.field_name == "id") { + result_index = 0; + is_valid = filter_result.count > 0; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + posting_list_iterators.clear(); + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + expanded_plists.clear(); + + init(); + return; + } +} + +uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + do { + filter_ids.push_back(seq_id); + next(); + } while (valid()); + + filter_array = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), filter_array); + + return filter_ids.size(); +} + +uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { + if (!valid()) { + return 0; + } + + std::vector filter_ids; + for (uint32_t i = 0; i < lenA; i++) { + auto result = valid(A[i]); + + if (result == -1) { + break; + } + + if (result == 1) { + filter_ids.push_back(A[i]); + } + } + + if (filter_ids.empty()) { + return 0; + } + + results = new uint32_t[filter_ids.size()]; + std::copy(filter_ids.begin(), filter_ids.end(), results); + + return filter_ids.size(); +} + +filter_result_iterator_t::filter_result_iterator_t(const std::string collection_name, const Index *const index, + const filter_node_t *const filter_node) : + collection_name(collection_name), + index(index), + filter_node(filter_node) { + if (filter_node == nullptr) { + is_valid = false; + return; + } + + // Generate the iterator tree and then initialize each node. + if (filter_node->isOperator) { + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right); + } + + init(); +} + +filter_result_iterator_t::~filter_result_iterator_t() { + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; +} + +filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iterator_t &&obj) noexcept { + if (&obj == this) + return *this; + + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + + collection_name = obj.collection_name; + index = obj.index; + filter_node = obj.filter_node; + left_it = obj.left_it; + right_it = obj.right_it; + + obj.left_it = nullptr; + obj.right_it = nullptr; + + result_index = obj.result_index; + + filter_result = std::move(obj.filter_result); + + posting_list_iterators = std::move(obj.posting_list_iterators); + expanded_plists = std::move(obj.expanded_plists); + + is_valid = obj.is_valid; + + seq_id = obj.seq_id; + reference = std::move(obj.reference); + status = std::move(obj.status); + + return *this; +} diff --git a/src/index.cpp b/src/index.cpp index 80c0ab9e..c076f8ed 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -5,10 +5,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include @@ -24,6 +24,7 @@ #include #include "logger.h" #include +#include "validator.h" #define RETURN_CIRCUIT_BREAKER if((std::chrono::duration_cast( \ std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { \ @@ -1285,7 +1286,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { void Index::search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -1354,7 +1355,8 @@ void Index::search_all_candidates(const size_t num_search_fields, sort_fields, topster,groups_processed, searched_queries, qtoken_set, dropped_tokens, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, - filter_ids, filter_ids_length, total_cost, syn_orig_num_tokens, + filter_result_iterator, + total_cost, syn_orig_num_tokens, exclude_token_ids, exclude_token_ids_size, excluded_group_ids, sort_order, field_values, geopoint_indices, id_buff, all_result_ids, all_result_ids_len); @@ -2728,19 +2730,19 @@ Option Index::search(std::vector& field_query_tokens, cons const std::string& collection_name) const { std::shared_lock lock(mutex); - uint32_t filter_ids_length = 0; - auto rearrange_op = rearrange_filter_tree(filter_tree_root, filter_ids_length, collection_name); + uint32_t approx_filter_ids_length = 0; + auto rearrange_op = rearrange_filter_tree(filter_tree_root, approx_filter_ids_length, collection_name); if (!rearrange_op.ok()) { return rearrange_op; } - filter_result_t filter_result; - auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } - if (filter_tree_root != nullptr && filter_result.count == 0) { + if (filter_tree_root != nullptr && !filter_result_iterator.valid()) { return Option(true); } @@ -2752,8 +2754,9 @@ Option Index::search(std::vector& field_query_tokens, cons std::unordered_set excluded_group_ids; process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits, - filter_result.docs, filter_result.count, curated_ids, included_ids_map, + filter_result_iterator, curated_ids, included_ids_map, included_ids_vec, excluded_group_ids); + filter_result_iterator.reset(); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2784,7 +2787,10 @@ Option Index::search(std::vector& field_query_tokens, cons field_query_tokens[0].q_include_tokens[0].value == "*"; + // TODO: Do AND with phrase ids at last // handle phrase searches + uint32_t* phrase_result_ids = nullptr; + uint32_t phrase_result_count = 0; if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, @@ -2792,8 +2798,8 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, groups_processed, curated_ids, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, included_ids_map, is_wildcard_query, - filter_result.docs, filter_result.count); - if (filter_result.count == 0) { + phrase_result_ids, phrase_result_count); + if (phrase_result_count == 0) { goto process_search_results; } } @@ -2801,7 +2807,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && filter_result.count == 0); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.valid()); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2843,15 +2849,18 @@ Option Index::search(std::vector& field_query_tokens, cons goto process_search_results; } - // if filters were not provided, use the seq_ids index to generate the - // list of all document ids + // if filters were not provided, use the seq_ids index to generate the list of all document ids if (no_filters_provided) { - filter_result.count = seq_ids->num_ids(); - filter_result.docs = seq_ids->uncompress(); + const std::string doc_id_prefix = std::to_string(collection_id) + "_" + Collection::DOC_ID_PREFIX + "_"; + Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, + store, doc_id_prefix, filter_tree_root); + + filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); } - curate_filtered_ids(curated_ids, excluded_result_ids, - excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); +// TODO: Curate ids at last +// curate_filtered_ids(curated_ids, excluded_result_ids, +// excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { @@ -2861,37 +2870,46 @@ Option Index::search(std::vector& field_query_tokens, cons k++; } - VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; - if(!no_filters_provided && filter_result.count < vector_query.flat_search_cutoff) { - for(size_t i = 0; i < filter_result.count; i++) { - auto seq_id = filter_result.docs[i]; - std::vector values; + uint32_t filter_id_count = 0; + while (!no_filters_provided && + filter_id_count < vector_query.flat_search_cutoff && + filter_result_iterator.valid()) { + auto seq_id = filter_result_iterator.seq_id; + std::vector values; - try { - values = field_vector_index->vecdex->getDataByLabel(seq_id); - } catch(...) { - // likely not found - continue; - } - - float dist; - if(field_vector_index->distance_type == cosine) { - std::vector normalized_q(vector_query.values.size()); - hnsw_index_t::normalize_vector(vector_query.values, normalized_q); - dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(), - &field_vector_index->num_dim); - } else { - dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(), - &field_vector_index->num_dim); - } - - dist_labels.emplace_back(dist, seq_id); + try { + values = field_vector_index->vecdex->getDataByLabel(seq_id); + } catch(...) { + // likely not found + continue; } - } else { + + float dist; + if(field_vector_index->distance_type == cosine) { + std::vector normalized_q(vector_query.values.size()); + hnsw_index_t::normalize_vector(vector_query.values, normalized_q); + dist = field_vector_index->space->get_dist_func()(normalized_q.data(), values.data(), + &field_vector_index->num_dim); + } else { + dist = field_vector_index->space->get_dist_func()(vector_query.values.data(), values.data(), + &field_vector_index->num_dim); + } + + dist_labels.emplace_back(dist, seq_id); + filter_result_iterator.next(); + filter_id_count++; + } + + if(no_filters_provided || + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.valid())) { + dist_labels.clear(); + + VectorFilterFunctor filterFunctor(&filter_result_iterator); + if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); hnsw_index_t::normalize_vector(vector_query.values, normalized_q); @@ -2901,6 +2919,8 @@ Option Index::search(std::vector& field_query_tokens, cons } } + filter_result_iterator.reset(); + std::vector nearest_ids; for (const auto& dist_label : dist_labels) { @@ -2952,7 +2972,7 @@ Option Index::search(std::vector& field_query_tokens, cons curated_ids, curated_ids_sorted, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, all_result_ids, all_result_ids_len, - filter_result.docs, filter_result.count, concurrency, + filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); } } else { @@ -2994,7 +3014,7 @@ Option Index::search(std::vector& field_query_tokens, cons } fuzzy_search_fields(the_fields, field_query_tokens[0].q_include_tokens, {}, match_type, excluded_result_ids, - excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, + excluded_result_ids_size, filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, @@ -3002,6 +3022,7 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -3032,12 +3053,13 @@ Option Index::search(std::vector& field_query_tokens, cons } fuzzy_search_fields(the_fields, resolved_tokens, {}, match_type, excluded_result_ids, - excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted, + excluded_result_ids_size, filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); } } @@ -3050,9 +3072,10 @@ Option Index::search(std::vector& field_query_tokens, cons excluded_result_ids, excluded_result_ids_size, excluded_group_ids, topster, q_pos_synonyms, syn_orig_num_tokens, groups_processed, searched_queries, all_result_ids, all_result_ids_len, - filter_result.docs, filter_result.count, query_hashes, + filter_result_iterator, query_hashes, sort_order, field_values, geopoint_indices, qtoken_set); + filter_result_iterator.reset(); // gather up both original query and synonym queries and do drop tokens @@ -3101,7 +3124,7 @@ Option Index::search(std::vector& field_query_tokens, cons fuzzy_search_fields(the_fields, truncated_tokens, dropped_tokens, match_type, excluded_result_ids, excluded_result_ids_size, - filter_result.docs, filter_result.count, + filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, num_typos, searched_queries, qtoken_set, topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, @@ -3109,6 +3132,7 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); } else { break; @@ -3121,10 +3145,12 @@ Option Index::search(std::vector& field_query_tokens, cons group_limit, group_by_fields, max_extra_prefix, max_extra_suffix, field_query_tokens[0].q_include_tokens, - topster, filter_result.docs, filter_result.count, + topster, filter_result_iterator, sort_order, field_values, geopoint_indices, curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed); + filter_result_iterator.reset(); + if(!vector_query.field_name.empty()) { // check at least one of sort fields is text match bool has_text_match = false; @@ -3140,7 +3166,7 @@ Option Index::search(std::vector& field_query_tokens, cons constexpr float TEXT_MATCH_WEIGHT = 0.7; constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; - VectorFilterFunctor filterFunctor(filter_result.docs, filter_result.count); + VectorFilterFunctor filterFunctor(&filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; auto k = std::max(vector_query.k, fetch_size); @@ -3152,6 +3178,7 @@ Option Index::search(std::vector& field_query_tokens, cons } else { dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } + filter_result_iterator.reset(); std::vector> vec_results; for (const auto& dist_label : dist_labels) { @@ -3361,7 +3388,7 @@ Option Index::search(std::vector& field_query_tokens, cons void Index::process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, - const bool filter_curated_hits, const uint32_t* filter_ids, uint32_t filter_ids_length, + const bool filter_curated_hits, filter_result_iterator_t& filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, @@ -3384,19 +3411,18 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_ids_length != 0 && filter_curated_hits) { - uint32_t* included_ids_arr = nullptr; - size_t included_ids_len = ArrayUtils::and_scalar(&included_ids_vec[0], included_ids_vec.size(), filter_ids, - filter_ids_length, &included_ids_arr); + if(filter_result_iterator.valid() && filter_curated_hits) { + for (const auto &included_id: included_ids_vec) { + auto result = filter_result_iterator.valid(included_id); - included_ids_vec.clear(); + if (result == -1) { + break; + } - for(size_t i = 0; i < included_ids_len; i++) { - included_ids_set.insert(included_ids_arr[i]); - included_ids_vec.push_back(included_ids_arr[i]); + if (result == 1) { + included_ids_set.insert(included_id); + } } - - delete [] included_ids_arr; } else { included_ids_set.insert(included_ids_vec.begin(), included_ids_vec.end()); } @@ -3454,7 +3480,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - const uint32_t* filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector & sort_fields, @@ -3597,9 +3623,10 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const auto& prev_token = last_token ? token_candidates_vec.back().candidates[0] : ""; std::vector field_leaves; - art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, + art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, - last_token, prev_token, filter_ids, filter_ids_length, field_leaves, unique_tokens); + last_token, prev_token, filter_result_iterator, field_leaves, unique_tokens); + filter_result_iterator.reset(); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3628,8 +3655,9 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, std::vector prev_token_doc_ids; find_across_fields(token_candidates_vec.back().token, token_candidates_vec.back().candidates[0], - the_fields, num_search_fields, filter_ids, filter_ids_length, exclude_token_ids, + the_fields, num_search_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, prev_token_doc_ids, popular_field_ids); + filter_result_iterator.reset(); for(size_t field_id: query_field_ids) { auto& the_field = the_fields[field_id]; @@ -3649,9 +3677,9 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, } std::vector field_leaves; - art_fuzzy_search(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, + art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, - false, "", filter_ids, filter_ids_length, field_leaves, unique_tokens); + false, "", filter_result_iterator, field_leaves, unique_tokens); if(field_leaves.empty()) { // look at the next field @@ -3704,7 +3732,8 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, if(token_candidates_vec.size() == query_tokens.size()) { std::vector id_buff; - search_all_candidates(num_search_fields, match_type, the_fields, filter_ids, filter_ids_length, + + search_all_candidates(num_search_fields, match_type, the_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, excluded_group_ids, sort_fields, token_candidates_vec, searched_queries, qtoken_set, dropped_tokens, topster, @@ -3714,6 +3743,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, exhaustive_search, max_candidates, syn_orig_num_tokens, sort_order, field_values, geopoint_indices, query_hashes, id_buff); + filter_result_iterator.reset(); if(id_buff.size() > 1) { gfx::timsort(id_buff.begin(), id_buff.end()); @@ -3774,7 +3804,7 @@ void Index::find_across_fields(const token_t& previous_token, const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, std::vector& top_prefix_field_ids) const { @@ -3785,7 +3815,7 @@ void Index::find_across_fields(const token_t& previous_token, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); const bool prefix_search = previous_token.is_prefix_searched; const uint32_t token_num_typos = previous_token.num_typos; @@ -3866,7 +3896,7 @@ void Index::search_across_fields(const std::vector& query_tokens, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool prioritize_token_position, - const uint32_t* filter_ids, uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, @@ -3927,7 +3957,7 @@ void Index::search_across_fields(const std::vector& query_tokens, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_ids, filter_ids_length); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); // for each token, find the posting lists across all query_by fields for(size_t ti = 0; ti < query_tokens.size(); ti++) { @@ -4397,13 +4427,10 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector>& included_ids_map, bool is_wildcard_query, - uint32_t*& filter_ids, uint32_t& filter_ids_length) const { + uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const { std::map phrase_match_id_scores; - uint32_t* phrase_match_ids = nullptr; - size_t phrase_match_ids_size = 0; - for(size_t i = 0; i < num_search_fields; i++) { const std::string& field_name = search_fields[i].name; const size_t field_weight = search_fields[i].weight; @@ -4472,50 +4499,30 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector(10000, filter_ids_length); i++) { - auto seq_id = filter_ids[i]; + for(size_t i = 0; i < std::min(10000, phrase_result_count); i++) { + auto seq_id = phrase_result_ids[i]; int64_t match_score = phrase_match_id_scores[seq_id]; int64_t scores[3] = {0}; @@ -4577,7 +4584,7 @@ void Index::do_synonym_search(const std::vector& the_fields, spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - const uint32_t* filter_ids, const uint32_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -4587,7 +4594,7 @@ void Index::do_synonym_search(const std::vector& the_fields, for (const auto& syn_tokens : q_pos_synonyms) { query_hashes.clear(); fuzzy_search_fields(the_fields, syn_tokens, {}, match_type, exclude_token_ids, - exclude_token_ids_size, filter_ids, filter_ids_length, curated_ids_sorted, excluded_group_ids, + exclude_token_ids_size, filter_result_iterator, curated_ids_sorted, excluded_group_ids, sort_fields_std, {0}, searched_queries, qtoken_set, actual_topster, groups_processed, all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, @@ -4605,7 +4612,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - const uint32_t *filter_ids, size_t filter_ids_length, + filter_result_iterator_t& filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -4639,10 +4646,11 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const std::set& curated_ids, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, - uint32_t*& all_result_ids, size_t& all_result_ids_len, const uint32_t* filter_ids, - uint32_t filter_ids_length, const size_t concurrency, + uint32_t*& all_result_ids, size_t& all_result_ids_len, + filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices) const { uint32_t token_bits = 0; - const bool check_for_circuit_break = (filter_ids_length > 1000000); + const bool check_for_circuit_break = (approx_filter_ids_length > 1000000); //auto beginF = std::chrono::high_resolution_clock::now(); - const size_t num_threads = std::min(concurrency, filter_ids_length); + const size_t num_threads = std::min(concurrency, approx_filter_ids_length); const size_t window_size = (num_threads == 0) ? 0 : - (filter_ids_length + num_threads - 1) / num_threads; // rounds up + (approx_filter_ids_length + num_threads - 1) / num_threads; // rounds up spp::sparse_hash_map tgroups_processed[num_threads]; Topster* topsters[num_threads]; @@ -4965,14 +4974,15 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - for(size_t thread_id = 0; thread_id < num_threads && filter_index < filter_ids_length; thread_id++) { - size_t batch_res_len = window_size; + for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.valid(); thread_id++) { + std::vector batch_result_ids; + batch_result_ids.reserve(window_size); - if(filter_index + window_size > filter_ids_length) { - batch_res_len = filter_ids_length - filter_index; - } + do { + batch_result_ids.push_back(filter_result_iterator.seq_id); + filter_result_iterator.next(); + } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); - const uint32_t* batch_result_ids = filter_ids + filter_index; num_queued++; searched_queries.push_back({}); @@ -4984,7 +4994,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, &group_limit, &group_by_fields, &topsters, &tgroups_processed, &excluded_group_ids, &sort_order, field_values, &geopoint_indices, &plists, check_for_circuit_break, - batch_result_ids, batch_res_len, + batch_result_ids, &num_processed, &m_process, &cv_process]() { search_begin_us = parent_search_begin; @@ -4993,7 +5003,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, size_t filter_index = 0; - for(size_t i = 0; i < batch_res_len; i++) { + for(size_t i = 0; i < batch_result_ids.size(); i++) { const uint32_t seq_id = batch_result_ids[i]; int64_t match_score = 0; @@ -5033,7 +5043,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, cv_process.notify_one(); }); - filter_index += batch_res_len; + filter_index += batch_result_ids.size(); } std::unique_lock lock_process(m_process); @@ -5045,7 +5055,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, //groups_processed.insert(tgroups_processed[thread_id].begin(), tgroups_processed[thread_id].end()); for(const auto& it : tgroups_processed[thread_id]) { groups_processed[it.first]+= it.second; - } + } aggregate_topster(topster, topsters[thread_id]); delete topsters[thread_id]; } @@ -5054,11 +5064,13 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - uint32_t* new_all_result_ids = nullptr; - all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids, - filter_ids_length, &new_all_result_ids); - delete [] all_result_ids; - all_result_ids = new_all_result_ids; +// TODO: OR filter ids with all_results_ids +// +// uint32_t* new_all_result_ids = nullptr; +// all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids, +// filter_ids_length, &new_all_result_ids); +// delete [] all_result_ids; +// all_result_ids = new_all_result_ids; } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index 13165370..df4404cb 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -1,5 +1,5 @@ #include "or_iterator.h" - +#include "filter.h" bool or_iterator_t::at_end(const std::vector& its) { // if any iterator is invalid, we stop @@ -208,6 +208,10 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e return false; } + if (istate.fit != nullptr) { + return (istate.fit->valid(id) == 1); + } + return true; } diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 39f3ac00..e983e98b 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -1260,6 +1260,146 @@ void posting_list_t::get_exact_matches(std::vector& its, const bool num_exact_ids = exact_id_index; } +bool posting_list_t::has_exact_match(std::vector& posting_list_iterators, + const bool field_is_array) { + if(posting_list_iterators.size() == 1) { + return is_single_token_verbatim_match(posting_list_iterators[0], field_is_array); + } else { + + if (!field_is_array) { + for (uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + posting_list_t::iterator_t& it = posting_list_iterators[i]; + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + return false; + } + + uint32_t* offsets = it.offsets; + + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + if(i == posting_list_iterators.size() - 1) { + // check if the last query token is the last offset + if( offsets[end_offset_index-1] != 0 || + (end_offset_index-2 >= 0 && offsets[end_offset_index-2] != posting_list_iterators.size())) { + // not the last token for the document, so skip + return false; + } + } + + // looping handles duplicate query tokens, e.g. "hip hip hurray hurray" + while(start_offset_index < end_offset_index) { + uint32_t offset = offsets[start_offset_index]; + start_offset_index++; + + if(offset == (i + 1)) { + // we have found a matching index, no need to look further + return true; + } + + if(offset > (i + 1)) { + return false; + } + } + } + } + + else { + // field is an array + + struct token_index_meta_t { + std::bitset<32> token_index; + bool has_last_token; + }; + + std::map array_index_to_token_index; + + for(uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + posting_list_t::iterator_t& it = posting_list_iterators[i]; + + block_t* curr_block = it.block(); + uint32_t curr_index = it.index(); + + if(curr_block == nullptr || curr_index == UINT32_MAX) { + return false; + } + + uint32_t* offsets = it.offsets; + uint32_t start_offset_index = it.offset_index[curr_index]; + uint32_t end_offset_index = (curr_index == curr_block->size() - 1) ? + curr_block->offsets.getLength() : + it.offset_index[curr_index + 1]; + + int prev_pos = -1; + bool has_atleast_one_last_token = false; + bool found_matching_index = false; + size_t num_matching_index = 0; + + while(start_offset_index < end_offset_index) { + int pos = offsets[start_offset_index]; + start_offset_index++; + + if(pos == prev_pos) { // indicates end of array index + size_t array_index = (size_t) offsets[start_offset_index]; + + if(start_offset_index+1 < end_offset_index) { + size_t next_offset = (size_t) offsets[start_offset_index + 1]; + if(next_offset == 0 && pos == posting_list_iterators.size()) { + // indicates that token is the last token on the doc + array_index_to_token_index[array_index].has_last_token = true; + has_atleast_one_last_token = true; + start_offset_index++; + } + } + + if(found_matching_index) { + array_index_to_token_index[array_index].token_index.set(i + 1); + } + + start_offset_index++; // skip current value which is the array index or flag for last index + prev_pos = -1; + found_matching_index = false; + continue; + } + + if(pos == (i + 1)) { + // we have found a matching index + found_matching_index = true; + num_matching_index++; + } + + prev_pos = pos; + } + + // check if the last query token is the last offset of ANY array element + if(i == posting_list_iterators.size() - 1 && !has_atleast_one_last_token) { + return false; + } + + if(num_matching_index == 0) { + // not even a single matching index found: can never be an exact match + return false; + } + } + + // iterate array index to token index to check if atleast 1 array position contains all tokens + for(auto& kv: array_index_to_token_index) { + if(kv.second.token_index.count() == posting_list_iterators.size() && kv.second.has_last_token) { + return true; + } + } + } + } + + return false; +} + bool posting_list_t::found_token_sequence(const std::vector& token_positions, const size_t token_index, const uint16_t target_pos) { diff --git a/src/validator.cpp b/src/validator.cpp index 04d99ab2..0caa1302 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -1,4 +1,5 @@ #include "validator.h" +#include "field.h" Option validator_t::coerce_element(const field& a_field, nlohmann::json& document, nlohmann::json& doc_ele, diff --git a/test/collection_test.cpp b/test/collection_test.cpp index e2ce9044..930c8176 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "collection.h" #include "text_embedder_manager.h" #include "http_client.h" diff --git a/test/filter_test.cpp b/test/filter_test.cpp index b4f836f1..d125e814 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -125,15 +125,25 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); - ASSERT_TRUE(iter_exact_match_test.init_status().ok()); + auto iter_exact_match_1_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_1_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_exact_match_test.valid()); - ASSERT_EQ(i, iter_exact_match_test.seq_id); - iter_exact_match_test.next(); + ASSERT_TRUE(iter_exact_match_1_test.valid()); + ASSERT_EQ(i, iter_exact_match_1_test.seq_id); + iter_exact_match_1_test.next(); } - ASSERT_FALSE(iter_exact_match_test.valid()); + ASSERT_FALSE(iter_exact_match_1_test.valid()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= PLATINUM", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_2_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_exact_match_2_test.init_status().ok()); + ASSERT_FALSE(iter_exact_match_2_test.valid()); delete filter_tree_root; filter_tree_root = nullptr; From a749d834013624282283eb3ebcad440b482872c1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 10 Apr 2023 19:13:45 +0530 Subject: [PATCH 18/93] Fix `FacetFieldStringFiltering` test. --- src/index.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index c076f8ed..2e79e987 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4983,6 +4983,12 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, filter_result_iterator.next(); } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); + uint32_t* new_all_result_ids = nullptr; + all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, batch_result_ids.data(), + batch_result_ids.size(), &new_all_result_ids); + delete [] all_result_ids; + all_result_ids = new_all_result_ids; + num_queued++; searched_queries.push_back({}); @@ -5063,14 +5069,6 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, /*long long int timeMillisF = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - -// TODO: OR filter ids with all_results_ids -// -// uint32_t* new_all_result_ids = nullptr; -// all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, filter_ids, -// filter_ids_length, &new_all_result_ids); -// delete [] all_result_ids; -// all_result_ids = new_all_result_ids; } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, From 25c9eba1a564a69abc5436498fbaa73e948bcfa7 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 11 Apr 2023 10:50:39 +0530 Subject: [PATCH 19/93] Fix memory leak. --- src/art.cpp | 5 +---- src/posting_list.cpp | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/art.cpp b/src/art.cpp index 48470551..d189d7eb 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1724,10 +1724,7 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le << ", filter_ids_length: " << filter_ids_length; }*/ -// TODO: Figure out this edge case. -// if(allowed_doc_ids != filter_ids) { -// delete [] allowed_doc_ids; -// } + delete [] allowed_doc_ids; return 0; } diff --git a/src/posting_list.cpp b/src/posting_list.cpp index e983e98b..1b5a3323 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -1267,7 +1267,7 @@ bool posting_list_t::has_exact_match(std::vector& po } else { if (!field_is_array) { - for (uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + for (int i = posting_list_iterators.size() - 1; i >= 0; i--) { posting_list_t::iterator_t& it = posting_list_iterators[i]; block_t* curr_block = it.block(); @@ -1320,7 +1320,7 @@ bool posting_list_t::has_exact_match(std::vector& po std::map array_index_to_token_index; - for(uint32_t i = posting_list_iterators.size() - 1; i >= 0; i--) { + for(int i = posting_list_iterators.size() - 1; i >= 0; i--) { posting_list_t::iterator_t& it = posting_list_iterators[i]; block_t* curr_block = it.block(); From 95c452c3dbd6d7ab51a0f9f407d4a0a6f5567df0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 11 Apr 2023 11:27:31 +0530 Subject: [PATCH 20/93] Refactor `Index::search_wildcard`. --- src/index.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 2e79e987..d27e9336 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4983,12 +4983,6 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, filter_result_iterator.next(); } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); - uint32_t* new_all_result_ids = nullptr; - all_result_ids_len = ArrayUtils::or_scalar(all_result_ids, all_result_ids_len, batch_result_ids.data(), - batch_result_ids.size(), &new_all_result_ids); - delete [] all_result_ids; - all_result_ids = new_all_result_ids; - num_queued++; searched_queries.push_back({}); @@ -5069,6 +5063,9 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, /*long long int timeMillisF = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ + + filter_result_iterator.reset(); + all_result_ids_len = filter_result_iterator.to_filter_id_array(all_result_ids); } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, From 0f8bb94b1eb424d9d2c38eb5dd07744d07664fce Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 11 Apr 2023 17:44:32 +0530 Subject: [PATCH 21/93] Add `approx_filter_ids_length` field. --- include/filter_result_iterator.h | 10 ++++++++-- src/filter_result_iterator.cpp | 6 ++++-- src/index.cpp | 3 ++- src/or_iterator.cpp | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index b67d67ca..f5ae22a6 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -125,12 +125,18 @@ private: public: uint32_t seq_id = 0; - // Collection name -> references + /// Collection name -> references std::map reference; Option status = Option(true); + /// Holds the upper-bound of the number of seq ids this iterator would match. + /// Useful in a scenario where we need to differentiate between filter iterator not matching any document v/s filter + /// iterator reaching it's end. (is_valid would be false in both these cases) + uint32_t approx_filter_ids_length; + explicit filter_result_iterator_t(const std::string collection_name, - Index const* const index, filter_node_t const* const filter_node); + Index const* const index, filter_node_t const* const filter_node, + uint32_t approx_filter_ids_length = UINT32_MAX); ~filter_result_iterator_t(); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 21a4128f..f5bd8b5f 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -835,10 +835,12 @@ uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& } filter_result_iterator_t::filter_result_iterator_t(const std::string collection_name, const Index *const index, - const filter_node_t *const filter_node) : + const filter_node_t *const filter_node, + uint32_t approx_filter_ids_length) : collection_name(collection_name), index(index), - filter_node(filter_node) { + filter_node(filter_node), + approx_filter_ids_length(approx_filter_ids_length) { if (filter_node == nullptr) { is_valid = false; return; diff --git a/src/index.cpp b/src/index.cpp index d27e9336..44fb93dc 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2736,7 +2736,8 @@ Option Index::search(std::vector& field_query_tokens, cons return rearrange_op; } - auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, + approx_filter_ids_length); auto filter_init_op = filter_result_iterator.init_status(); if (!filter_init_op.ok()) { return filter_init_op; diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index df4404cb..8dd9d487 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -208,7 +208,7 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e return false; } - if (istate.fit != nullptr) { + if (istate.fit != nullptr && istate.fit->approx_filter_ids_length > 0) { return (istate.fit->valid(id) == 1); } From f2d5ae961b1da1ae672fd8a7e2f17ebfd9b0b6ba Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 12 Apr 2023 18:47:55 +0530 Subject: [PATCH 22/93] Handle `!=` in `filter_result_iterator_t`. --- include/filter_result_iterator.h | 3 + include/id_list.h | 2 + src/filter_result_iterator.cpp | 145 +++++++++++++++++++++++++------ src/id_list.cpp | 8 ++ test/filter_test.cpp | 40 ++++----- 5 files changed, 153 insertions(+), 45 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index f5ae22a6..46ec11bc 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -120,6 +120,9 @@ private: /// Performs OR on the subtrees of operator. void or_filter_iterators(); + /// Advance all the token iterators that are at seq_id. + void advance_string_filter_token_iterators(); + /// Finds the next match for a filter on string field. void doc_matching_string_filter(bool field_is_array); diff --git a/include/id_list.h b/include/id_list.h index 3b0ef7ae..ad890119 100644 --- a/include/id_list.h +++ b/include/id_list.h @@ -126,6 +126,8 @@ public: uint32_t first_id(); + uint32_t last_id(); + block_t* block_of(uint32_t id); bool contains(uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f5bd8b5f..f7217c0a 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -265,6 +265,18 @@ void filter_result_iterator_t::or_filter_iterators() { is_valid = false; } +void filter_result_iterator_t::advance_string_filter_token_iterators() { + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { + for (auto& iter: filter_value_tokens) { + iter.next(); + } + } + } +} + void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { // If none of the filter value iterators are valid, mark this node as invalid. bool one_is_valid = false; @@ -384,18 +396,38 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { - // Advance all the filter values that are at doc. Then find the next doc. - for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { - auto& filter_value_tokens = posting_list_iterators[i]; - - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { - for (auto& iter: filter_value_tokens) { - iter.next(); - } + if (filter_node->filter_exp.apply_not_equals) { + if (++seq_id < result_index) { + return; } + + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id); + + if (!is_valid) { + // We've reached the end of the index, no possible matches pending. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + result_index = index->seq_ids->last_id() + 1; + seq_id = previous_match + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + return; } + advance_string_filter_token_iterators(); doc_matching_string_filter(f.is_array()); + return; } } @@ -509,6 +541,47 @@ void filter_result_iterator_t::init() { } doc_matching_string_filter(f.is_array()); + + if (filter_node->filter_exp.apply_not_equals) { + // filter didn't match any id. So by applying not equals, every id in the index is a match. + if (!is_valid) { + is_valid = true; + seq_id = 0; + result_index = index->seq_ids->last_id() + 1; + return; + } + + // [0, seq_id) are a match for not equals. + if (seq_id > 0) { + result_index = seq_id; + seq_id = 0; + return; + } + + // Keep ignoring the consecutive matches. + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id); + + if (!is_valid) { + // filter matched all the ids in the index. So for not equals, there's no match. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + result_index = index->seq_ids->last_id() + 1; + seq_id = previous_match + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + } + return; } } @@ -543,6 +616,10 @@ bool filter_result_iterator_t::valid() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { + if (filter_node->filter_exp.apply_not_equals) { + return seq_id < result_index; + } + bool one_is_valid = false; for (auto& filter_value_tokens: posting_list_iterators) { posting_list_t::intersect(filter_value_tokens, one_is_valid); @@ -618,6 +695,41 @@ void filter_result_iterator_t::skip_to(uint32_t id) { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { + if (filter_node->filter_exp.apply_not_equals) { + if (id < seq_id) { + return; + } + + if (id < result_index) { + seq_id = id; + return; + } + + seq_id = result_index; + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id && seq_id >= id); + + if (!is_valid) { + // filter matched all the ids in the index. So for not equals, there's no match. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + seq_id = previous_match + 1; + result_index = index->seq_ids->last_id() + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + return; + } + // Skip all the token iterators and find a new match. for (auto& filter_value_tokens : posting_list_iterators) { for (auto& token: filter_value_tokens) { @@ -670,23 +782,6 @@ int filter_result_iterator_t::valid(uint32_t id) { } } - if (filter_node->filter_exp.apply_not_equals) { - // Even when iterator becomes invalid, we keep it marked as valid since we are evaluating not equals. - if (!valid()) { - is_valid = true; - return 1; - } - - skip_to(id); - - if (!is_valid) { - is_valid = true; - return 1; - } - - return seq_id != id ? 1 : 0; - } - skip_to(id); return is_valid ? (seq_id == id ? 1 : 0) : -1; } diff --git a/src/id_list.cpp b/src/id_list.cpp index 4b308603..82712368 100644 --- a/src/id_list.cpp +++ b/src/id_list.cpp @@ -338,6 +338,14 @@ uint32_t id_list_t::first_id() { return root_block.ids.at(0); } +uint32_t id_list_t::last_id() { + if(id_block_map.empty()) { + return 0; + } + + return id_block_map.rbegin()->first; +} + id_list_t::block_t* id_list_t::block_of(uint32_t id) { const auto it = id_block_map.lower_bound(id); if(it == id_block_map.end()) { diff --git a/test/filter_test.cpp b/test/filter_test.cpp index d125e814..6cab88b5 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -162,23 +162,23 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_exact_match_multi_test.valid()); -// delete filter_tree_root; -// filter_tree_root = nullptr; -// filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, -// filter_tree_root); -// ASSERT_TRUE(filter_op.ok()); -// -// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); -// ASSERT_TRUE(iter_not_equals_test.init_status().ok()); -// -// std::vector expected = {1, 3}; -// for (auto const& i : expected) { -// ASSERT_TRUE(iter_not_equals_test.valid()); -// ASSERT_EQ(i, iter_not_equals_test.seq_id); -// iter_not_equals_test.next(); -// } -// -// ASSERT_FALSE(iter_not_equals_test.valid()); + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_not_equals_test.init_status().ok()); + + expected = {1, 3}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_not_equals_test.valid()); + ASSERT_EQ(i, iter_not_equals_test.seq_id); + iter_not_equals_test.next(); + } + + ASSERT_FALSE(iter_not_equals_test.valid()); delete filter_tree_root; filter_tree_root = nullptr; @@ -287,8 +287,8 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); - validate_ids = {0, 1, 2, 3, 4, 5, 6, 7, 100}; - expected = {0, 1, 0, 1, 0, 1, 1, 1, 1}; + validate_ids = {0, 1, 2, 3, 4, 5, 6}; + expected = {0, 1, 0, 1, 0, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); } @@ -422,7 +422,7 @@ TEST_F(FilterTest, FilterTreeIterator) { for (uint32_t i = 0; i < and_result_length; i++) { ASSERT_EQ(expected[i], and_result[i]); } - ASSERT_FALSE(iter_and_test.valid()); + ASSERT_FALSE(iter_and_scalar_test.valid()); delete and_result; delete filter_tree_root; From 9a9154b6310c27a58fe5a79dacd28e477262eda7 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 13 Apr 2023 14:49:46 +0530 Subject: [PATCH 23/93] Fix approximation logic of filter matches in case of `!=`. --- src/index.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 44fb93dc..5fd2a0ac 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2095,9 +2095,8 @@ Option Index::_approximate_filter_ids(const filter& a_filter, } } - if (a_filter.apply_not_equals) { - auto all_ids_size = seq_ids->num_ids(); - filter_ids_length = (all_ids_size - filter_ids_length); + if (a_filter.apply_not_equals && filter_ids_length == 0) { + filter_ids_length = seq_ids->num_ids(); } return Option(true); From f4c229793a6e9828b3a730538fb1733bc04cf9a1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 17 Apr 2023 20:54:50 +0530 Subject: [PATCH 24/93] Add numeric field support in `filter_result_t`. --- include/ids_t.h | 8 +- include/num_tree.h | 10 ++ src/filter_result_iterator.cpp | 199 +++++++++++++++++++++++++++++++-- src/num_tree.cpp | 78 +++++++++++++ 4 files changed, 283 insertions(+), 12 deletions(-) diff --git a/include/ids_t.h b/include/ids_t.h index 15cf8c6e..949c71b8 100644 --- a/include/ids_t.h +++ b/include/ids_t.h @@ -39,11 +39,6 @@ struct compact_id_list_t { }; class ids_t { -private: - - static void to_expanded_id_lists(const std::vector& raw_id_lists, std::vector& id_lists, - std::vector& expanded_id_lists); - public: static constexpr size_t COMPACT_LIST_THRESHOLD_LENGTH = 64; static constexpr size_t MAX_BLOCK_ELEMENTS = 256; @@ -104,6 +99,9 @@ public: static uint32_t* uncompress(void*& obj); static void uncompress(void*& obj, std::vector& ids); + + static void to_expanded_id_lists(const std::vector& raw_id_lists, std::vector& id_lists, + std::vector& expanded_id_lists); }; template diff --git a/include/num_tree.h b/include/num_tree.h index 2170a30e..444f6266 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -30,6 +30,11 @@ public: void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len); + void range_inclusive_search_iterators(int64_t start, + int64_t end, + std::vector& id_list_iterators, + std::vector& expanded_id_lists); + void approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len); void range_inclusive_contains(const int64_t& start, const int64_t& end, @@ -42,6 +47,11 @@ public: void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len); + void search_iterators(NUM_COMPARATOR comparator, + int64_t value, + std::vector& id_list_iterators, + std::vector& expanded_id_lists); + void approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len); void remove(uint64_t value, uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f7217c0a..6076807f 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1,3 +1,5 @@ +#include +#include #include "filter_result_iterator.h" #include "index.h" #include "posting.h" @@ -395,7 +397,16 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + result_index++; + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + return; + } else if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (++seq_id < result_index) { return; @@ -432,6 +443,40 @@ void filter_result_iterator_t::next() { } } +void merge_id_list_iterators(std::vector& id_list_iterators, + std::vector& result_ids) { + struct comp { + bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { + return lhs->id() > rhs->id(); + } + }; + + std::priority_queue, comp> iter_queue; + for (auto& id_list_iterator: id_list_iterators) { + if (id_list_iterator.valid()) { + iter_queue.push(&id_list_iterator); + } + } + + if (iter_queue.empty()) { + return; + } + + // TODO: Handle != + + do { + id_list_t::iterator_t* iter = iter_queue.top(); + iter_queue.pop(); + + result_ids.push_back(iter->id()); + iter->next(); + + if (iter->valid()) { + iter_queue.push(iter); + } + } while (!iter_queue.empty()); +} + void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -470,7 +515,12 @@ void filter_result_iterator_t::init() { return; } - is_valid = filter_result.count > 0; + if (filter_result.count == 0) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; return; } @@ -491,6 +541,7 @@ void filter_result_iterator_t::init() { filter_result.count = result_ids.size(); filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; } if (!index->field_is_indexed(a_filter.field_name)) { @@ -500,7 +551,112 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer()) { + auto num_tree = index->numerical_index.at(a_filter.field_name); + + std::vector ids; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + int64_t value = (int64_t)std::stol(filter_value); + std::vector id_list_iterators; + std::vector expanded_id_lists; + + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const range_end_value = (int64_t)std::stol(next_filter_value); + num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); + fi++; + } else { + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + value, id_list_iterators, expanded_id_lists); + } + + merge_id_list_iterators(id_list_iterators, ids); + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } + } + + if (ids.empty()) { + is_valid = false; + return; + } + + filter_result.count = ids.size(); + filter_result.docs = new uint32_t[ids.size()]; + std::copy(ids.begin(), ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + } else if (f.is_float()) { + auto num_tree = index->numerical_index.at(a_filter.field_name); + + std::vector ids; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); + std::vector id_list_iterators; + std::vector expanded_id_lists; + + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi+1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + num_tree->range_inclusive_search_iterators(float_int64, range_end_value, + id_list_iterators, expanded_id_lists); + fi++; + } else { + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + float_int64, id_list_iterators, expanded_id_lists); + } + + merge_id_list_iterators(id_list_iterators, ids); + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } + } + + if (ids.empty()) { + is_valid = false; + return; + } + + filter_result.count = ids.size(); + filter_result.docs = new uint32_t[ids.size()]; + std::copy(ids.begin(), ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + } else if (f.is_bool()) { + auto num_tree = index->numerical_index.at(a_filter.field_name); + + std::vector ids; + size_t value_index = 0; + for (const std::string& filter_value : a_filter.values) { + int64_t bool_int64 = (filter_value == "1") ? 1 : 0; + std::vector id_list_iterators; + std::vector expanded_id_lists; + + num_tree->search_iterators(a_filter.comparators[value_index] == NOT_EQUALS ? EQUALS : a_filter.comparators[value_index], + bool_int64, id_list_iterators, expanded_id_lists); + + merge_id_list_iterators(id_list_iterators, ids); + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } + + value_index++; + } + + if (ids.empty()) { + is_valid = false; + return; + } + + filter_result.count = ids.size(); + filter_result.docs = new uint32_t[ids.size()]; + std::copy(ids.begin(), ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); for (const std::string& filter_value : a_filter.values) { @@ -615,7 +771,10 @@ bool filter_result_iterator_t::valid() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + is_valid = result_index < filter_result.count; + return is_valid; + } else if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { return seq_id < result_index; } @@ -694,7 +853,17 @@ void filter_result_iterator_t::skip_to(uint32_t id) { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + while(result_index < filter_result.count && filter_result.docs[result_index] < id) { + result_index++; + } + + if (result_index >= filter_result.count) { + is_valid = false; + } + + return; + } else if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (id < seq_id) { return; @@ -861,8 +1030,14 @@ void filter_result_iterator_t::reset() { bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); if (is_referenced_filter || a_filter.field_name == "id") { + if (filter_result.count == 0) { + is_valid = false; + return; + } + result_index = 0; - is_valid = filter_result.count > 0; + seq_id = filter_result.docs[result_index]; + is_valid = true; return; } @@ -872,7 +1047,17 @@ void filter_result_iterator_t::reset() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_string()) { + if (f.is_integer() || f.is_float() || f.is_bool()) { + if (filter_result.count == 0) { + is_valid = false; + return; + } + + result_index = 0; + seq_id = filter_result.docs[result_index]; + is_valid = true; + return; + } else if (f.is_string()) { posting_list_iterators.clear(); for(auto expanded_plist: expanded_plists) { delete expanded_plist; diff --git a/src/num_tree.cpp b/src/num_tree.cpp index c59cb008..89c5e3a0 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -43,6 +43,30 @@ void num_tree_t::range_inclusive_search(int64_t start, int64_t end, uint32_t** i *ids = out; } +void num_tree_t::range_inclusive_search_iterators(int64_t start, + int64_t end, + std::vector& id_list_iterators, + std::vector& expanded_id_lists) { + if (int64map.empty()) { + return; + } + + auto it_start = int64map.lower_bound(start); // iter values will be >= start + + std::vector raw_id_lists; + while (it_start != int64map.end() && it_start->first <= end) { + raw_id_lists.push_back(it_start->second); + it_start++; + } + + std::vector id_lists; + ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); + + for (const auto &id_list: id_lists) { + id_list_iterators.emplace_back(id_list->new_iterator()); + } +} + void num_tree_t::approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len) { if (int64map.empty()) { return; @@ -187,6 +211,60 @@ void num_tree_t::search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids } } +void num_tree_t::search_iterators(NUM_COMPARATOR comparator, + int64_t value, + std::vector& id_list_iterators, + std::vector& expanded_id_lists) { + if (int64map.empty()) { + return ; + } + + std::vector raw_id_lists; + if (comparator == EQUALS) { + const auto& it = int64map.find(value); + if (it != int64map.end()) { + raw_id_lists.emplace_back(it->second); + } + } else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) { + // iter entries will be >= value, or end() if all entries are before value + auto iter_ge_value = int64map.lower_bound(value); + + if(iter_ge_value == int64map.end()) { + return ; + } + + if(comparator == GREATER_THAN && iter_ge_value->first == value) { + iter_ge_value++; + } + + while(iter_ge_value != int64map.end()) { + raw_id_lists.emplace_back(iter_ge_value->second); + iter_ge_value++; + } + } else if(comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) { + // iter entries will be >= value, or end() if all entries are before value + auto iter_ge_value = int64map.lower_bound(value); + + auto it = int64map.begin(); + while(it != iter_ge_value) { + raw_id_lists.emplace_back(it->second); + it++; + } + + // for LESS_THAN_EQUALS, check if last iter entry is equal to value + if(it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) { + raw_id_lists.emplace_back(it->second); + } + } + + std::vector id_lists; + ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); + + for (const auto &id_list: id_lists) { + id_list_iterators.emplace_back(id_list->new_iterator()); + } +} + void num_tree_t::approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len) { if (int64map.empty()) { return; From 2740262cb6864976c440f344dc4385dbc994edc0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 18 Apr 2023 09:57:39 +0530 Subject: [PATCH 25/93] Optimize `filter_result_iterator_t::to_filter_id_array`. --- src/filter_result_iterator.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 6076807f..f2aeb7cd 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1074,6 +1074,19 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { return 0; } + if (!filter_node->isOperator) { + const filter a_filter = filter_node->filter_exp; + field f = index->search_schema.at(a_filter.field_name); + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || + (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { + filter_array = new uint32_t[filter_result.count]; + std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); + + return filter_result.count; + } + } + std::vector filter_ids; do { filter_ids.push_back(seq_id); From b71ad7fd507c1e30a201f81dff69032ad4cb5855 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 18 Apr 2023 15:16:33 +0530 Subject: [PATCH 26/93] Refactor numeric filter initialization. --- src/filter_result_iterator.cpp | 105 +++++---------------------------- 1 file changed, 16 insertions(+), 89 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f2aeb7cd..d57c8895 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -443,40 +443,6 @@ void filter_result_iterator_t::next() { } } -void merge_id_list_iterators(std::vector& id_list_iterators, - std::vector& result_ids) { - struct comp { - bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { - return lhs->id() > rhs->id(); - } - }; - - std::priority_queue, comp> iter_queue; - for (auto& id_list_iterator: id_list_iterators) { - if (id_list_iterator.valid()) { - iter_queue.push(&id_list_iterator); - } - } - - if (iter_queue.empty()) { - return; - } - - // TODO: Handle != - - do { - id_list_t::iterator_t* iter = iter_queue.top(); - iter_queue.pop(); - - result_ids.push_back(iter->id()); - iter->next(); - - if (iter->valid()) { - iter_queue.push(iter); - } - } while (!iter_queue.empty()); -} - void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -554,108 +520,69 @@ void filter_result_iterator_t::init() { if (f.is_integer()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - std::vector ids; + // TODO: Handle not equals + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); - std::vector id_list_iterators; - std::vector expanded_id_lists; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, + reinterpret_cast(filter_result.count)); fi++; } else { - num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], - value, id_list_iterators, expanded_id_lists); - } - - merge_id_list_iterators(id_list_iterators, ids); - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; + num_tree->search(a_filter.comparators[fi], value, + &filter_result.docs, reinterpret_cast(filter_result.count)); } } - if (ids.empty()) { + if (filter_result.count == 0) { is_valid = false; return; } - - filter_result.count = ids.size(); - filter_result.docs = new uint32_t[ids.size()]; - std::copy(ids.begin(), ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; } else if (f.is_float()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - std::vector ids; for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; float value = (float)std::atof(filter_value.c_str()); int64_t float_int64 = Index::float_to_int64_t(value); - std::vector id_list_iterators; - std::vector expanded_id_lists; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi+1]; int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search_iterators(float_int64, range_end_value, - id_list_iterators, expanded_id_lists); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, + reinterpret_cast(filter_result.count)); fi++; } else { - num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], - float_int64, id_list_iterators, expanded_id_lists); - } - - merge_id_list_iterators(id_list_iterators, ids); - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; + num_tree->search(a_filter.comparators[fi], float_int64, + &filter_result.docs, reinterpret_cast(filter_result.count)); } } - if (ids.empty()) { + if (filter_result.count == 0) { is_valid = false; return; } - - filter_result.count = ids.size(); - filter_result.docs = new uint32_t[ids.size()]; - std::copy(ids.begin(), ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; } else if (f.is_bool()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - std::vector ids; size_t value_index = 0; for (const std::string& filter_value : a_filter.values) { int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - std::vector id_list_iterators; - std::vector expanded_id_lists; - num_tree->search_iterators(a_filter.comparators[value_index] == NOT_EQUALS ? EQUALS : a_filter.comparators[value_index], - bool_int64, id_list_iterators, expanded_id_lists); - - merge_id_list_iterators(id_list_iterators, ids); - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; - } + num_tree->search(a_filter.comparators[value_index], bool_int64, + &filter_result.docs, reinterpret_cast(filter_result.count)); value_index++; } - if (ids.empty()) { + if (filter_result.count == 0) { is_valid = false; return; } - - filter_result.count = ids.size(); - filter_result.docs = new uint32_t[ids.size()]; - std::copy(ids.begin(), ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); @@ -1080,9 +1007,9 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { + filter_array = new uint32_t[filter_result.count]; std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); - return filter_result.count; } } From 45975327ff48466e0d751f4ea76bcbfc1cb9832f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 18 Apr 2023 16:53:24 +0530 Subject: [PATCH 27/93] Expose filter ids from iterator where possible. --- include/filter_result_iterator.h | 10 ++++++++++ src/filter_result_iterator.cpp | 33 +++++++++++++++++++++----------- src/index.cpp | 21 +++++++++++++------- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 46ec11bc..bd9e66f0 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -127,6 +127,14 @@ private: void doc_matching_string_filter(bool field_is_array); public: + uint32_t* get_ids() { + return filter_result.docs; + } + + uint32_t get_length() { + return filter_result.count; + } + uint32_t seq_id = 0; /// Collection name -> references std::map reference; @@ -180,4 +188,6 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); + + bool can_get_ids(); }; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d57c8895..787ad578 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1001,17 +1001,10 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { return 0; } - if (!filter_node->isOperator) { - const filter a_filter = filter_node->filter_exp; - field f = index->search_schema.at(a_filter.field_name); - - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || - (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { - - filter_array = new uint32_t[filter_result.count]; - std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); - return filter_result.count; - } + if (can_get_ids()) { + filter_array = new uint32_t[filter_result.count]; + std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); + return filter_result.count; } std::vector filter_ids; @@ -1031,6 +1024,10 @@ uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& return 0; } + if (can_get_ids()) { + return ArrayUtils::and_scalar(A, lenA, filter_result.docs, filter_result.count, &results); + } + std::vector filter_ids; for (uint32_t i = 0; i < lenA; i++) { auto result = valid(A[i]); @@ -1121,3 +1118,17 @@ filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iter return *this; } + +bool filter_result_iterator_t::can_get_ids() { + if (!filter_node->isOperator) { + const filter a_filter = filter_node->filter_exp; + field f = index->search_schema.at(a_filter.field_name); + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || + (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { + return true; + } + } + + return false; +} diff --git a/src/index.cpp b/src/index.cpp index 5fd2a0ac..0763e822 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4974,14 +4974,23 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.valid(); thread_id++) { + for(size_t thread_id = 0; thread_id < num_threads && + (filter_result_iterator.can_get_ids() ? + filter_index < filter_result_iterator.get_length() : + filter_result_iterator.valid()); thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - do { - batch_result_ids.push_back(filter_result_iterator.seq_id); - filter_result_iterator.next(); - } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); + if (filter_result_iterator.can_get_ids()) { + while (batch_result_ids.size() < window_size && filter_index < filter_result_iterator.get_length()) { + batch_result_ids.push_back(filter_result_iterator.get_ids()[filter_index++]); + } + } else { + do { + batch_result_ids.push_back(filter_result_iterator.seq_id); + filter_result_iterator.next(); + } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); + } num_queued++; @@ -5042,8 +5051,6 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, parent_search_cutoff = parent_search_cutoff || search_cutoff; cv_process.notify_one(); }); - - filter_index += batch_result_ids.size(); } std::unique_lock lock_process(m_process); From 9c34236f2e8cac19b09dc5cffc428b5610364b9d Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 20 Apr 2023 13:19:42 +0530 Subject: [PATCH 28/93] Add `filter_result_iterator_t::get_n_ids`. Use `is_valid` instead of `valid()`. Handle special `_all_` field name in filtering logic. --- include/filter_result_iterator.h | 28 ++-- include/index.h | 1 + src/art.cpp | 3 +- src/filter.cpp | 3 + src/filter_result_iterator.cpp | 256 +++++++++++++++++++------------ src/index.cpp | 47 +++--- test/filter_test.cpp | 64 ++++---- 7 files changed, 227 insertions(+), 175 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index bd9e66f0..1184b74a 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -99,6 +99,7 @@ private: /// Stores the result of the filters that cannot be iterated. filter_result_t filter_result; + bool is_filter_result_initialized = false; /// Initialized in case of filter on string field. /// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator @@ -108,9 +109,6 @@ private: std::vector> posting_list_iterators; std::vector expanded_plists; - /// Set to false when this iterator or it's subtree becomes invalid. - bool is_valid = true; - /// Initializes the state of iterator node after it's creation. void init(); @@ -126,18 +124,18 @@ private: /// Finds the next match for a filter on string field. void doc_matching_string_filter(bool field_is_array); + /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). + [[nodiscard]] bool valid(); + public: - uint32_t* get_ids() { - return filter_result.docs; - } - - uint32_t get_length() { - return filter_result.count; - } - uint32_t seq_id = 0; /// Collection name -> references std::map reference; + + /// Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initialization status of the iterator. Option status = Option(true); /// Holds the upper-bound of the number of seq ids this iterator would match. @@ -156,9 +154,6 @@ public: /// Returns the status of the initialization of iterator tree. Option init_status(); - /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). - [[nodiscard]] bool valid(); - /// Returns a tri-state: /// 0: id is not valid /// 1: id is valid @@ -171,6 +166,9 @@ public: /// operation. void next(); + /// Collects n doc ids while advancing the iterator. The iterator may become invalid during this operation. + void get_n_ids(const uint32_t& n, std::vector& results); + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); @@ -188,6 +186,4 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); - - bool can_get_ids(); }; diff --git a/include/index.h b/include/index.h index c65b84c5..4f689c33 100644 --- a/include/index.h +++ b/include/index.h @@ -560,6 +560,7 @@ public: static const int DROP_TOKENS_THRESHOLD = 1; // "_all_" is a special field that maps to all the ids in the index. + static constexpr const char* SEQ_IDS_FIELD = "_all_"; static constexpr const char* SEQ_IDS_FILTER = "_all_: 1"; Index() = delete; diff --git a/src/art.cpp b/src/art.cpp index d189d7eb..7a7d5d5b 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -991,7 +991,7 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, std::vector prev_leaf_ids; posting_t::merge({prev_leaf->values}, prev_leaf_ids); - if(filter_result_iterator.valid()) { + if(filter_result_iterator.is_valid) { prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), prev_token_doc_ids); } else { @@ -1692,6 +1692,7 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le // documents that contain the previous token and/or filter ids size_t allowed_doc_ids_len = 0; const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); + filter_result_iterator.reset(); for(auto node: nodes) { art_topk_iter(node, token_order, max_words, exact_leaf, diff --git a/src/filter.cpp b/src/filter.cpp index 95fbfefc..18348ed6 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -283,6 +283,9 @@ Option toFilter(const std::string expression, } } return Option(true); + } else if (field_name == Index::SEQ_IDS_FIELD) { + filter_exp = {field_name, {}, {}}; + return Option(true); } auto field_it = search_schema.find(field_name); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 787ad578..c5098581 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -271,8 +271,12 @@ void filter_result_iterator_t::advance_string_filter_token_iterators() { for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { auto& filter_value_tokens = posting_list_iterators[i]; - if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == seq_id) { - for (auto& iter: filter_value_tokens) { + if (!filter_value_tokens[0].valid() || filter_value_tokens[0].id() != seq_id) { + continue; + } + + for (auto& iter: filter_value_tokens) { + if (iter.valid()) { iter.next(); } } @@ -362,10 +366,7 @@ void filter_result_iterator_t::next() { return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { + if (is_filter_result_initialized) { if (++result_index >= filter_result.count) { is_valid = false; return; @@ -380,15 +381,7 @@ void filter_result_iterator_t::next() { return; } - if (a_filter.field_name == "id") { - if (++result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { is_valid = false; @@ -397,16 +390,7 @@ void filter_result_iterator_t::next() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - result_index++; - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } else if (f.is_string()) { + if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (++seq_id < result_index) { return; @@ -443,6 +427,41 @@ void filter_result_iterator_t::next() { } } +void numeric_not_equals_filter(num_tree_t* const num_tree, + const int64_t value, + uint32_t*&& all_ids, + uint32_t&& all_ids_length, + uint32_t*& result_ids, + size_t& result_ids_len) { + uint32_t* to_exclude_ids = nullptr; + size_t to_exclude_ids_len = 0; + + num_tree->search(EQUALS, value, &to_exclude_ids, to_exclude_ids_len); + + result_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_length, to_exclude_ids, to_exclude_ids_len, &result_ids); + + delete[] all_ids; + delete[] to_exclude_ids; +} + +void apply_not_equals(uint32_t*&& all_ids, + uint32_t&& all_ids_length, + uint32_t*& result_ids, + uint32_t& result_ids_len) { + + uint32_t* to_include_ids = nullptr; + size_t to_include_ids_len = 0; + + to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_length, result_ids, + result_ids_len, &to_include_ids); + + delete[] all_ids; + delete[] result_ids; + + result_ids = to_include_ids; + result_ids_len = to_include_ids_len; +} + void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -487,6 +506,11 @@ void filter_result_iterator_t::init() { } seq_id = filter_result.docs[result_index]; + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + is_filter_result_initialized = true; return; } @@ -507,7 +531,22 @@ void filter_result_iterator_t::init() { filter_result.count = result_ids.size(); filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; + } else if (a_filter.field_name == Index::SEQ_IDS_FIELD) { + if (index->seq_ids->num_ids() == 0) { + is_valid = false; + return; + } + + filter_result.count = index->seq_ids->num_ids(); + filter_result.docs = index->seq_ids->uncompress(); + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } if (!index->field_is_indexed(a_filter.field_name)) { @@ -520,28 +559,40 @@ void filter_result_iterator_t::init() { if (f.is_integer()) { auto num_tree = index->numerical_index.at(a_filter.field_name); - // TODO: Handle not equals - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); + size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, - reinterpret_cast(filter_result.count)); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], value, - &filter_result.docs, reinterpret_cast(filter_result.count)); + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); } + + filter_result.count = result_size; + } + + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); } if (filter_result.count == 0) { is_valid = false; return; } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } else if (f.is_float()) { auto num_tree = index->numerical_index.at(a_filter.field_name); @@ -550,22 +601,36 @@ void filter_result_iterator_t::init() { float value = (float)std::atof(filter_value.c_str()); int64_t float_int64 = Index::float_to_int64_t(value); + size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi+1]; int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, - reinterpret_cast(filter_result.count)); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, float_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], float_int64, - &filter_result.docs, reinterpret_cast(filter_result.count)); + num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); } + + filter_result.count = result_size; + } + + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); } if (filter_result.count == 0) { is_valid = false; return; } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } else if (f.is_bool()) { auto num_tree = index->numerical_index.at(a_filter.field_name); @@ -573,16 +638,32 @@ void filter_result_iterator_t::init() { for (const std::string& filter_value : a_filter.values) { int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - num_tree->search(a_filter.comparators[value_index], bool_int64, - &filter_result.docs, reinterpret_cast(filter_result.count)); + size_t result_size = filter_result.count; + if (a_filter.comparators[value_index] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, bool_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[value_index], bool_int64, &filter_result.docs, result_size); + } + filter_result.count = result_size; value_index++; } + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); + } + if (filter_result.count == 0) { is_valid = false; return; } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); @@ -684,13 +765,13 @@ bool filter_result_iterator_t::valid() { } } - const filter a_filter = filter_node->filter_exp; - - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + if (is_filter_result_initialized) { is_valid = result_index < filter_result.count; return is_valid; } + const filter a_filter = filter_node->filter_exp; + if (!index->field_is_indexed(a_filter.field_name)) { is_valid = false; return is_valid; @@ -698,10 +779,7 @@ bool filter_result_iterator_t::valid() { field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - is_valid = result_index < filter_result.count; - return is_valid; - } else if (f.is_string()) { + if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { return seq_id < result_index; } @@ -741,10 +819,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { + if (is_filter_result_initialized) { while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); if (result_index >= filter_result.count) { @@ -761,17 +836,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - if (a_filter.field_name == "id") { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); - - if (result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - return; - } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { is_valid = false; @@ -780,17 +845,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - while(result_index < filter_result.count && filter_result.docs[result_index] < id) { - result_index++; - } - - if (result_index >= filter_result.count) { - is_valid = false; - } - - return; - } else if (f.is_string()) { + if (f.is_string()) { if (filter_node->filter_exp.apply_not_equals) { if (id < seq_id) { return; @@ -897,7 +952,7 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { compact_posting_list_t* list = COMPACT_POSTING_PTR(obj); size_t i = 0; - while(i < list->length && valid()) { + while(i < list->length && is_valid) { size_t num_existing_offsets = list->id_offsets[i]; size_t existing_id = list->id_offsets[i + num_existing_offsets + 1]; @@ -916,7 +971,7 @@ bool filter_result_iterator_t::contains_atleast_one(const void *obj) { auto list = (posting_list_t*)(obj); posting_list_t::iterator_t it = list->new_iterator(); - while(it.valid() && valid()) { + while(it.valid() && is_valid) { uint32_t id = it.id(); if(id == seq_id) { @@ -943,6 +998,7 @@ void filter_result_iterator_t::reset() { // Reset the subtrees then apply operators to arrive at the first valid doc. left_it->reset(); right_it->reset(); + is_valid = true; if (filter_node->filter_operator == AND) { and_filter_iterators(); @@ -953,10 +1009,7 @@ void filter_result_iterator_t::reset() { return; } - const filter a_filter = filter_node->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter || a_filter.field_name == "id") { + if (is_filter_result_initialized) { if (filter_result.count == 0) { is_valid = false; return; @@ -964,27 +1017,25 @@ void filter_result_iterator_t::reset() { result_index = 0; seq_id = filter_result.docs[result_index]; + + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + is_valid = true; return; } + const filter a_filter = filter_node->filter_exp; + if (!index->field_is_indexed(a_filter.field_name)) { return; } field f = index->search_schema.at(a_filter.field_name); - if (f.is_integer() || f.is_float() || f.is_bool()) { - if (filter_result.count == 0) { - is_valid = false; - return; - } - - result_index = 0; - seq_id = filter_result.docs[result_index]; - is_valid = true; - return; - } else if (f.is_string()) { + if (f.is_string()) { posting_list_iterators.clear(); for(auto expanded_plist: expanded_plists) { delete expanded_plist; @@ -997,11 +1048,11 @@ void filter_result_iterator_t::reset() { } uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { - if (!valid()) { + if (!is_valid) { return 0; } - if (can_get_ids()) { + if (is_filter_result_initialized) { filter_array = new uint32_t[filter_result.count]; std::copy(filter_result.docs, filter_result.docs + filter_result.count, filter_array); return filter_result.count; @@ -1011,7 +1062,7 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { do { filter_ids.push_back(seq_id); next(); - } while (valid()); + } while (is_valid); filter_array = new uint32_t[filter_ids.size()]; std::copy(filter_ids.begin(), filter_ids.end(), filter_array); @@ -1020,11 +1071,11 @@ uint32_t filter_result_iterator_t::to_filter_id_array(uint32_t*& filter_array) { } uint32_t filter_result_iterator_t::and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results) { - if (!valid()) { + if (!is_valid) { return 0; } - if (can_get_ids()) { + if (is_filter_result_initialized) { return ArrayUtils::and_scalar(A, lenA, filter_result.docs, filter_result.count, &results); } @@ -1115,20 +1166,23 @@ filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iter seq_id = obj.seq_id; reference = std::move(obj.reference); status = std::move(obj.status); + is_filter_result_initialized = obj.is_filter_result_initialized; return *this; } -bool filter_result_iterator_t::can_get_ids() { - if (!filter_node->isOperator) { - const filter a_filter = filter_node->filter_exp; - field f = index->search_schema.at(a_filter.field_name); - - if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id" || - (index->field_is_indexed(a_filter.field_name) && (f.is_integer() || f.is_float() || f.is_bool()))) { - return true; +void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { + if (is_filter_result_initialized) { + for (uint32_t count = 0; count < n && result_index < filter_result.count; count++) { + results.push_back(filter_result.docs[result_index++]); } + + is_valid = result_index < filter_result.count; + return; } - return false; + for (uint32_t count = 0; count < n && is_valid; count++) { + results.push_back(seq_id); + next(); + } } diff --git a/src/index.cpp b/src/index.cpp index 0763e822..2c03d03b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2742,7 +2742,7 @@ Option Index::search(std::vector& field_query_tokens, cons return filter_init_op; } - if (filter_tree_root != nullptr && !filter_result_iterator.valid()) { + if (filter_tree_root != nullptr && !filter_result_iterator.is_valid) { return Option(true); } @@ -2807,7 +2807,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.valid()); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.is_valid); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2856,11 +2856,9 @@ Option Index::search(std::vector& field_query_tokens, cons store, doc_id_prefix, filter_tree_root); filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + approx_filter_ids_length = filter_result_iterator.is_valid; } -// TODO: Curate ids at last -// curate_filtered_ids(curated_ids, excluded_result_ids, -// excluded_result_ids_size, filter_result.docs, filter_result.count, curated_ids_sorted); collate_included_ids({}, included_ids_map, curated_topster, searched_queries); if (!vector_query.field_name.empty()) { @@ -2876,8 +2874,7 @@ Option Index::search(std::vector& field_query_tokens, cons uint32_t filter_id_count = 0; while (!no_filters_provided && - filter_id_count < vector_query.flat_search_cutoff && - filter_result_iterator.valid()) { + filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { auto seq_id = filter_result_iterator.seq_id; std::vector values; @@ -2905,7 +2902,7 @@ Option Index::search(std::vector& field_query_tokens, cons } if(no_filters_provided || - (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.valid())) { + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.is_valid)) { dist_labels.clear(); VectorFilterFunctor filterFunctor(&filter_result_iterator); @@ -2974,7 +2971,19 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); + filter_result_iterator.reset(); } + + // filter tree was initialized to have all sequence ids in this flow. + if (no_filters_provided) { + delete filter_tree_root; + filter_tree_root = nullptr; + } + + uint32_t _all_result_ids_len = all_result_ids_len; + curate_filtered_ids(curated_ids, excluded_result_ids, + excluded_result_ids_size, all_result_ids, _all_result_ids_len, curated_ids_sorted); + all_result_ids_len = _all_result_ids_len; } else { // Non-wildcard // In multi-field searches, a record can be matched across different fields, so we use this for aggregation @@ -3411,7 +3420,7 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_result_iterator.valid() && filter_curated_hits) { + if(filter_result_iterator.is_valid && filter_curated_hits) { for (const auto &included_id: included_ids_vec) { auto result = filter_result_iterator.valid(included_id); @@ -3680,6 +3689,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, false, "", filter_result_iterator, field_leaves, unique_tokens); + filter_result_iterator.reset(); if(field_leaves.empty()) { // look at the next field @@ -4646,7 +4656,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector batch_result_ids; batch_result_ids.reserve(window_size); - if (filter_result_iterator.can_get_ids()) { - while (batch_result_ids.size() < window_size && filter_index < filter_result_iterator.get_length()) { - batch_result_ids.push_back(filter_result_iterator.get_ids()[filter_index++]); - } - } else { - do { - batch_result_ids.push_back(filter_result_iterator.seq_id); - filter_result_iterator.next(); - } while (batch_result_ids.size() < window_size && filter_result_iterator.valid()); - } + filter_result_iterator.get_n_ids(window_size, batch_result_ids); num_queued++; diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 6cab88b5..86cdd0f5 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -65,7 +65,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_null_filter_tree_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_null_filter_tree_test.init_status().ok()); - ASSERT_FALSE(iter_null_filter_tree_test.valid()); + ASSERT_FALSE(iter_null_filter_tree_test.is_valid); Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, filter_tree_root); @@ -74,7 +74,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_no_match_test.init_status().ok()); - ASSERT_FALSE(iter_no_match_test.valid()); + ASSERT_FALSE(iter_no_match_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -85,7 +85,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_no_match_multi_test.init_status().ok()); - ASSERT_FALSE(iter_no_match_multi_test.valid()); + ASSERT_FALSE(iter_no_match_multi_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -97,11 +97,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_contains_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_contains_test.valid()); + ASSERT_TRUE(iter_contains_test.is_valid); ASSERT_EQ(i, iter_contains_test.seq_id); iter_contains_test.next(); } - ASSERT_FALSE(iter_contains_test.valid()); + ASSERT_FALSE(iter_contains_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -113,11 +113,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_contains_multi_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_contains_multi_test.valid()); + ASSERT_TRUE(iter_contains_multi_test.is_valid); ASSERT_EQ(i, iter_contains_multi_test.seq_id); iter_contains_multi_test.next(); } - ASSERT_FALSE(iter_contains_multi_test.valid()); + ASSERT_FALSE(iter_contains_multi_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -129,11 +129,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_exact_match_1_test.init_status().ok()); for (uint32_t i = 0; i < 5; i++) { - ASSERT_TRUE(iter_exact_match_1_test.valid()); + ASSERT_TRUE(iter_exact_match_1_test.is_valid); ASSERT_EQ(i, iter_exact_match_1_test.seq_id); iter_exact_match_1_test.next(); } - ASSERT_FALSE(iter_exact_match_1_test.valid()); + ASSERT_FALSE(iter_exact_match_1_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -143,7 +143,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_exact_match_2_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_exact_match_2_test.init_status().ok()); - ASSERT_FALSE(iter_exact_match_2_test.valid()); + ASSERT_FALSE(iter_exact_match_2_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -156,11 +156,11 @@ TEST_F(FilterTest, FilterTreeIterator) { std::vector expected = {0, 2, 3, 4}; for (auto const& i : expected) { - ASSERT_TRUE(iter_exact_match_multi_test.valid()); + ASSERT_TRUE(iter_exact_match_multi_test.is_valid); ASSERT_EQ(i, iter_exact_match_multi_test.seq_id); iter_exact_match_multi_test.next(); } - ASSERT_FALSE(iter_exact_match_multi_test.valid()); + ASSERT_FALSE(iter_exact_match_multi_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -173,12 +173,12 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {1, 3}; for (auto const& i : expected) { - ASSERT_TRUE(iter_not_equals_test.valid()); + ASSERT_TRUE(iter_not_equals_test.is_valid); ASSERT_EQ(i, iter_not_equals_test.seq_id); iter_not_equals_test.next(); } - ASSERT_FALSE(iter_not_equals_test.valid()); + ASSERT_FALSE(iter_not_equals_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -189,13 +189,13 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_skip_test.init_status().ok()); - ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_TRUE(iter_skip_test.is_valid); iter_skip_test.skip_to(3); - ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_TRUE(iter_skip_test.is_valid); ASSERT_EQ(4, iter_skip_test.seq_id); iter_skip_test.next(); - ASSERT_FALSE(iter_skip_test.valid()); + ASSERT_FALSE(iter_skip_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -206,11 +206,11 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_and_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_and_test.init_status().ok()); - ASSERT_TRUE(iter_and_test.valid()); + ASSERT_TRUE(iter_and_test.is_valid); ASSERT_EQ(1, iter_and_test.seq_id); iter_and_test.next(); - ASSERT_FALSE(iter_and_test.valid()); + ASSERT_FALSE(iter_and_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -234,12 +234,12 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {2, 4, 5}; for (auto const& i : expected) { - ASSERT_TRUE(iter_or_test.valid()); + ASSERT_TRUE(iter_or_test.is_valid); ASSERT_EQ(i, iter_or_test.seq_id); iter_or_test.next(); } - ASSERT_FALSE(iter_or_test.valid()); + ASSERT_FALSE(iter_or_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -250,17 +250,17 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_skip_complex_filter_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_skip_complex_filter_test.init_status().ok()); - ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_skip_complex_filter_test.is_valid); iter_skip_complex_filter_test.skip_to(4); expected = {4, 5}; for (auto const& i : expected) { - ASSERT_TRUE(iter_skip_complex_filter_test.valid()); + ASSERT_TRUE(iter_skip_complex_filter_test.is_valid); ASSERT_EQ(i, iter_skip_complex_filter_test.seq_id); iter_skip_complex_filter_test.next(); } - ASSERT_FALSE(iter_skip_complex_filter_test.valid()); + ASSERT_FALSE(iter_skip_complex_filter_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -358,20 +358,20 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {0, 2, 3, 4}; for (auto const& i : expected) { - ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_TRUE(iter_reset_test.is_valid); ASSERT_EQ(i, iter_reset_test.seq_id); iter_reset_test.next(); } - ASSERT_FALSE(iter_reset_test.valid()); + ASSERT_FALSE(iter_reset_test.is_valid); iter_reset_test.reset(); for (auto const& i : expected) { - ASSERT_TRUE(iter_reset_test.valid()); + ASSERT_TRUE(iter_reset_test.is_valid); ASSERT_EQ(i, iter_reset_test.seq_id); iter_reset_test.next(); } - ASSERT_FALSE(iter_reset_test.valid()); + ASSERT_FALSE(iter_reset_test.is_valid); auto iter_move_assignment_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); @@ -380,11 +380,11 @@ TEST_F(FilterTest, FilterTreeIterator) { expected = {0, 2, 3, 4}; for (auto const& i : expected) { - ASSERT_TRUE(iter_move_assignment_test.valid()); + ASSERT_TRUE(iter_move_assignment_test.is_valid); ASSERT_EQ(i, iter_move_assignment_test.seq_id); iter_move_assignment_test.next(); } - ASSERT_FALSE(iter_move_assignment_test.valid()); + ASSERT_FALSE(iter_move_assignment_test.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -405,7 +405,7 @@ TEST_F(FilterTest, FilterTreeIterator) { for (uint32_t i = 0; i < filter_ids_length; i++) { ASSERT_EQ(expected[i], filter_ids[i]); } - ASSERT_FALSE(iter_to_array_test.valid()); + ASSERT_FALSE(iter_to_array_test.is_valid); delete filter_ids; @@ -422,7 +422,7 @@ TEST_F(FilterTest, FilterTreeIterator) { for (uint32_t i = 0; i < and_result_length; i++) { ASSERT_EQ(expected[i], and_result[i]); } - ASSERT_FALSE(iter_and_scalar_test.valid()); + ASSERT_FALSE(iter_and_scalar_test.is_valid); delete and_result; delete filter_tree_root; From 66e830c591c10ac21cdd1b01f521823124046f65 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 21 Apr 2023 09:48:19 +0530 Subject: [PATCH 29/93] Add tests. --- src/filter_result_iterator.cpp | 23 +++++++++-- test/filter_test.cpp | 70 ++++++++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index c5098581..2e144af8 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -858,11 +858,15 @@ void filter_result_iterator_t::skip_to(uint32_t id) { seq_id = result_index; uint32_t previous_match; + + // Keep ignoring the found gaps till they cannot contain id. do { - previous_match = seq_id; - advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); - } while (is_valid && previous_match + 1 == seq_id && seq_id >= id); + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + doc_matching_string_filter(f.is_array()); + } while (is_valid && previous_match + 1 == seq_id); + } while (is_valid && seq_id <= id); if (!is_valid) { // filter matched all the ids in the index. So for not equals, there's no match. @@ -873,11 +877,22 @@ void filter_result_iterator_t::skip_to(uint32_t id) { is_valid = true; seq_id = previous_match + 1; result_index = index->seq_ids->last_id() + 1; + + // Skip to id, if possible. + if (seq_id < id && id < result_index) { + seq_id = id; + } + return; } result_index = seq_id; seq_id = previous_match + 1; + + if (seq_id < id && id < result_index) { + seq_id = id; + } + return; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 86cdd0f5..50e7d6d6 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -186,16 +186,29 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); - ASSERT_TRUE(iter_skip_test.init_status().ok()); + auto iter_skip_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test1.init_status().ok()); - ASSERT_TRUE(iter_skip_test.is_valid); - iter_skip_test.skip_to(3); - ASSERT_TRUE(iter_skip_test.is_valid); - ASSERT_EQ(4, iter_skip_test.seq_id); - iter_skip_test.next(); + ASSERT_TRUE(iter_skip_test1.is_valid); + iter_skip_test1.skip_to(3); + ASSERT_TRUE(iter_skip_test1.is_valid); + ASSERT_EQ(4, iter_skip_test1.seq_id); + iter_skip_test1.next(); - ASSERT_FALSE(iter_skip_test.is_valid); + ASSERT_FALSE(iter_skip_test1.is_valid); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: != silver", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test2.init_status().ok()); + + ASSERT_TRUE(iter_skip_test2.is_valid); + iter_skip_test2.skip_to(3); + ASSERT_FALSE(iter_skip_test2.is_valid); delete filter_tree_root; filter_tree_root = nullptr; @@ -426,4 +439,45 @@ TEST_F(FilterTest, FilterTreeIterator) { delete and_result; delete filter_tree_root; + + doc = R"({ + "name": "James Rowdy", + "age": 36, + "years": [2005, 2022], + "rating": 6.03, + "tags": ["FINE PLATINUM"] + })"_json; + add_op = coll->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: != FINE PLATINUM", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test3 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test3.init_status().ok()); + + ASSERT_TRUE(iter_skip_test3.is_valid); + iter_skip_test3.skip_to(4); + ASSERT_EQ(4, iter_skip_test3.seq_id); + + ASSERT_TRUE(iter_skip_test3.is_valid); + + delete filter_tree_root; + + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: != gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test4 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_skip_test4.init_status().ok()); + + ASSERT_TRUE(iter_skip_test4.is_valid); + iter_skip_test4.skip_to(6); + ASSERT_EQ(6, iter_skip_test4.seq_id); + ASSERT_TRUE(iter_skip_test4.is_valid); + + delete filter_tree_root; } \ No newline at end of file From 4d21853f627364558d8b5cbd7ec64b8651f0d4c9 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 21 Apr 2023 12:48:46 +0530 Subject: [PATCH 30/93] Support geo filtering. --- src/filter_result_iterator.cpp | 143 ++++++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 2 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 2e144af8..1c7cbc7a 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1,5 +1,13 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include #include "filter_result_iterator.h" #include "index.h" #include "posting.h" @@ -306,8 +314,12 @@ void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { break; } else { // Keep advancing token iterators till exact match is not found. - for (auto &item: filter_value_tokens) { - item.next(); + for (auto &iter: filter_value_tokens) { + if (!iter.valid()) { + break; + } + + iter.next(); } } } @@ -661,6 +673,133 @@ void filter_result_iterator_t::init() { return; } + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + return; + } else if (f.is_geopoint()) { + for (const std::string& filter_value : a_filter.values) { + std::vector geo_result_ids; + + std::vector filter_value_parts; + StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points + + bool is_polygon = StringUtils::is_float(filter_value_parts.back()); + S2Region* query_region; + + if (is_polygon) { + const int num_verts = int(filter_value_parts.size()) / 2; + std::vector vertices; + double sum = 0.0; + + for (size_t point_index = 0; point_index < size_t(num_verts); + point_index++) { + double lat = std::stod(filter_value_parts[point_index * 2]); + double lon = std::stod(filter_value_parts[point_index * 2 + 1]); + S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint(); + vertices.emplace_back(vertex); + } + + auto loop = new S2Loop(vertices, S2Debug::DISABLE); + loop->Normalize(); // if loop is not CCW but CW, change to CCW. + + S2Error error; + if (loop->FindValidationError(&error)) { + LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error; + delete loop; + continue; + } else { + query_region = loop; + } + } else { + double radius = std::stof(filter_value_parts[2]); + const auto& unit = filter_value_parts[3]; + + if (unit == "km") { + radius *= 1000; + } else { + // assume "mi" (validated upstream) + radius *= 1609.34; + } + + S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); + double query_lat = std::stod(filter_value_parts[0]); + double query_lng = std::stod(filter_value_parts[1]); + S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); + query_region = new S2Cap(center, query_radius); + } + + S2RegionTermIndexer::Options options; + options.set_index_contains_points_only(true); + S2RegionTermIndexer indexer(options); + + for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { + auto geo_index = index->geopoint_index.at(a_filter.field_name); + const auto& ids_it = geo_index->find(term); + if(ids_it != geo_index->end()) { + geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); + } + } + + gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); + geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); + + // `geo_result_ids` will contain all IDs that are within approximately within query radius + // we still need to do another round of exact filtering on them + + std::vector exact_geo_result_ids; + + if (f.is_single_geopoint()) { + spp::sparse_hash_map* sort_field_index = index->sort_index.at(f.name); + + for (auto result_id : geo_result_ids) { + // no need to check for existence of `result_id` because of indexer based pre-filtering above + int64_t lat_lng = sort_field_index->at(result_id); + S2LatLng s2_lat_lng; + GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); + if (query_region->Contains(s2_lat_lng.ToPoint())) { + exact_geo_result_ids.push_back(result_id); + } + } + } else { + spp::sparse_hash_map* geo_field_index = index->geo_array_index.at(f.name); + + for (auto result_id : geo_result_ids) { + int64_t* lat_lngs = geo_field_index->at(result_id); + + bool point_found = false; + + // any one point should exist + for (size_t li = 0; li < lat_lngs[0]; li++) { + int64_t lat_lng = lat_lngs[li + 1]; + S2LatLng s2_lat_lng; + GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); + if (query_region->Contains(s2_lat_lng.ToPoint())) { + point_found = true; + break; + } + } + + if (point_found) { + exact_geo_result_ids.push_back(result_id); + } + } + } + + uint32_t* out = nullptr; + filter_result.count = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(), + filter_result.docs, filter_result.count, &out); + + delete[] filter_result.docs; + filter_result.docs = out; + + delete query_region; + } + + if (filter_result.count == 0) { + is_valid = false; + return; + } + seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; return; From 26f50d517859d82cc296899fb088e43987b494c8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 21 Apr 2023 14:31:23 +0530 Subject: [PATCH 31/93] Fix failing tests. --- src/filter_result_iterator.cpp | 4 +++- src/index.cpp | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 1c7cbc7a..d0037512 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -553,7 +553,7 @@ void filter_result_iterator_t::init() { return; } - filter_result.count = index->seq_ids->num_ids(); + approx_filter_ids_length = filter_result.count = index->seq_ids->num_ids(); filter_result.docs = index->seq_ids->uncompress(); seq_id = filter_result.docs[result_index]; @@ -1322,6 +1322,8 @@ filter_result_iterator_t &filter_result_iterator_t::operator=(filter_result_iter status = std::move(obj.status); is_filter_result_initialized = obj.is_filter_result_initialized; + approx_filter_ids_length = obj.approx_filter_ids_length; + return *this; } diff --git a/src/index.cpp b/src/index.cpp index 2c03d03b..c9df3288 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2856,7 +2856,7 @@ Option Index::search(std::vector& field_query_tokens, cons store, doc_id_prefix, filter_tree_root); filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); - approx_filter_ids_length = filter_result_iterator.is_valid; + approx_filter_ids_length = filter_result_iterator.approx_filter_ids_length; } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -2978,6 +2978,7 @@ Option Index::search(std::vector& field_query_tokens, cons if (no_filters_provided) { delete filter_tree_root; filter_tree_root = nullptr; + approx_filter_ids_length = 0; } uint32_t _all_result_ids_len = all_result_ids_len; From 92c0a837b14df1b9ac896be083c4e0d1fce79ba1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 24 Apr 2023 13:54:16 +0530 Subject: [PATCH 32/93] Handle excluded ids in `filter_result_iterator_t::get_n_ids`. --- include/filter_result_iterator.h | 6 ++++++ src/filter_result_iterator.cpp | 33 ++++++++++++++++++++++++++++++++ src/index.cpp | 12 +++++++++--- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 1184b74a..784aa385 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -169,6 +169,12 @@ public: /// Collects n doc ids while advancing the iterator. The iterator may become invalid during this operation. void get_n_ids(const uint32_t& n, std::vector& results); + /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The + /// iterator may become invalid during this operation. + void get_n_ids(const uint32_t &n, + uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, + std::vector &results); + /// Advances the iterator until the doc value reaches or just overshoots id. The iterator may become invalid during /// this operation. void skip_to(uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d0037512..4ea630cd 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1275,6 +1275,10 @@ filter_result_iterator_t::filter_result_iterator_t(const std::string collection_ } init(); + + if (!is_valid) { + this->approx_filter_ids_length = 0; + } } filter_result_iterator_t::~filter_result_iterator_t() { @@ -1342,3 +1346,32 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { + if (excluded_result_ids == nullptr || excluded_result_ids_size == 0) { + return get_n_ids(n, results); + } + + if (is_filter_result_initialized) { + for (uint32_t count = 0; count < n && result_index < filter_result.count;) { + auto id = filter_result.docs[result_index++]; + if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, id)) { + results.push_back(id); + count++; + } + } + + is_valid = result_index < filter_result.count; + return; + } + + for (uint32_t count = 0; count < n && is_valid;) { + if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, seq_id)) { + results.push_back(seq_id); + count++; + } + next(); + } +} diff --git a/src/index.cpp b/src/index.cpp index c9df3288..eeb4fe32 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1993,8 +1993,14 @@ void Index::aproximate_numerical_match(num_tree_t* const num_tree, uint32_t to_exclude_ids_len = 0; num_tree->approx_search_count(EQUALS, value, to_exclude_ids_len); - auto all_ids_size = seq_ids->num_ids(); - filter_ids_length += (all_ids_size - to_exclude_ids_len); + if (to_exclude_ids_len == 0) { + filter_ids_length += seq_ids->num_ids(); + } else if (to_exclude_ids_len >= seq_ids->num_ids()) { + filter_ids_length += 0; + } else { + filter_ids_length += (seq_ids->num_ids() - to_exclude_ids_len); + } + return; } @@ -4988,7 +4994,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, batch_result_ids); + filter_result_iterator.get_n_ids(window_size, exclude_token_ids, exclude_token_ids_size, batch_result_ids); num_queued++; From 63119b0eb11007817ba59e5f302b8179b7285d26 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 24 Apr 2023 14:21:15 +0530 Subject: [PATCH 33/93] Optimize exclusion in `filter_result_iterator_t::get_n_ids`. --- include/filter_result_iterator.h | 1 + src/filter_result_iterator.cpp | 17 ++++++++++++++--- src/index.cpp | 4 +++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 784aa385..2e30a2cb 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -172,6 +172,7 @@ public: /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The /// iterator may become invalid during this operation. void get_n_ids(const uint32_t &n, + size_t& excluded_result_index, uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, std::vector &results); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 4ea630cd..9977c6be 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1348,16 +1348,23 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { - if (excluded_result_ids == nullptr || excluded_result_ids_size == 0) { + if (excluded_result_ids == nullptr || excluded_result_ids_size == 0 || + excluded_result_index >= excluded_result_ids_size) { return get_n_ids(n, results); } if (is_filter_result_initialized) { for (uint32_t count = 0; count < n && result_index < filter_result.count;) { auto id = filter_result.docs[result_index++]; - if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, id)) { + + while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < id) { + excluded_result_index++; + } + + if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != id) { results.push_back(id); count++; } @@ -1368,7 +1375,11 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } for (uint32_t count = 0; count < n && is_valid;) { - if (!std::binary_search(excluded_result_ids, excluded_result_ids + excluded_result_ids_size, seq_id)) { + while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < seq_id) { + excluded_result_index++; + } + + if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != seq_id) { results.push_back(seq_id); count++; } diff --git a/src/index.cpp b/src/index.cpp index eeb4fe32..986b31cf 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4989,12 +4989,14 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_begin = search_begin_us; const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; + size_t excluded_result_index = 0; for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, exclude_token_ids, exclude_token_ids_size, batch_result_ids); + filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, + batch_result_ids); num_queued++; From 158791e376cc73057ec37db4f02c249b5207ee95 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 09:39:51 +0530 Subject: [PATCH 34/93] Add `ArrayUtils::skip_index_to_id`. --- include/array_utils.h | 6 ++++++ include/filter_result_iterator.h | 2 +- src/array_utils.cpp | 25 +++++++++++++++++++++++++ src/filter_result_iterator.cpp | 16 ++++------------ src/index.cpp | 2 +- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/include/array_utils.h b/include/array_utils.h index 81a0576c..008a8a36 100644 --- a/include/array_utils.h +++ b/include/array_utils.h @@ -16,4 +16,10 @@ public: static size_t exclude_scalar(const uint32_t *src, const size_t lenSrc, const uint32_t *filter, const size_t lenFilter, uint32_t **out); + + /// Performs binary search to find the index of id. If id is not found, curr_index is set to the index of next bigger + /// number than id in the array. + /// \return Whether or not id was found in array. + static bool skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, + const uint32_t& id); }; \ No newline at end of file diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 2e30a2cb..259b93f0 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -172,7 +172,7 @@ public: /// Collects n doc ids while advancing the iterator. The ids present in excluded_result_ids are ignored. The /// iterator may become invalid during this operation. void get_n_ids(const uint32_t &n, - size_t& excluded_result_index, + uint32_t& excluded_result_index, uint32_t const* const excluded_result_ids, const size_t& excluded_result_ids_size, std::vector &results); diff --git a/src/array_utils.cpp b/src/array_utils.cpp index 9f0c7f4a..ad22a85f 100644 --- a/src/array_utils.cpp +++ b/src/array_utils.cpp @@ -149,4 +149,29 @@ size_t ArrayUtils::exclude_scalar(const uint32_t *A, const size_t lenA, delete[] results; return res_index; +} + +bool ArrayUtils::skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, + const uint32_t& id) { + if (id <= array[curr_index]) { + return id == array[curr_index]; + } + + long start = curr_index, mid, end = array_len; + + while (start <= end) { + mid = start + (end - start) / 2; + + if (array[mid] == id) { + curr_index = mid; + return true; + } else if (array[mid] < id) { + start = mid + 1; + } else { + end = mid - 1; + } + } + + curr_index = start; + return false; } \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 9977c6be..63a76f73 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -959,7 +959,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } if (is_filter_result_initialized) { - while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id); if (result_index >= filter_result.count) { is_valid = false; @@ -1348,7 +1348,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, std::vector& results) { if (excluded_result_ids == nullptr || excluded_result_ids_size == 0 || @@ -1360,11 +1360,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, for (uint32_t count = 0; count < n && result_index < filter_result.count;) { auto id = filter_result.docs[result_index++]; - while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < id) { - excluded_result_index++; - } - - if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != id) { + if (!ArrayUtils::skip_index_to_id(excluded_result_index, excluded_result_ids, excluded_result_ids_size, id)) { results.push_back(id); count++; } @@ -1375,11 +1371,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } for (uint32_t count = 0; count < n && is_valid;) { - while (excluded_result_index < excluded_result_ids_size && excluded_result_ids[excluded_result_index] < seq_id) { - excluded_result_index++; - } - - if (excluded_result_index >= excluded_result_ids_size || excluded_result_ids[excluded_result_index] != seq_id) { + if (!ArrayUtils::skip_index_to_id(excluded_result_index, excluded_result_ids, excluded_result_ids_size, seq_id)) { results.push_back(seq_id); count++; } diff --git a/src/index.cpp b/src/index.cpp index 986b31cf..694ab9bb 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4989,7 +4989,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const auto parent_search_begin = search_begin_us; const auto parent_search_stop_ms = search_stop_us; auto parent_search_cutoff = search_cutoff; - size_t excluded_result_index = 0; + uint32_t excluded_result_index = 0; for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { std::vector batch_result_ids; From 321b4034e5bfb5dd5ae02b65426aab8f37070b9a Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 14:27:50 +0530 Subject: [PATCH 35/93] Add tests for `ArrayUtils::skip_index_to_id`. --- src/array_utils.cpp | 6 +++++- test/array_utils_test.cpp | 42 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/array_utils.cpp b/src/array_utils.cpp index ad22a85f..e034c4d6 100644 --- a/src/array_utils.cpp +++ b/src/array_utils.cpp @@ -153,11 +153,15 @@ size_t ArrayUtils::exclude_scalar(const uint32_t *A, const size_t lenA, bool ArrayUtils::skip_index_to_id(uint32_t& curr_index, uint32_t const* const array, const uint32_t& array_len, const uint32_t& id) { + if (curr_index >= array_len) { + return false; + } + if (id <= array[curr_index]) { return id == array[curr_index]; } - long start = curr_index, mid, end = array_len; + long start = curr_index, mid, end = array_len - 1; while (start <= end) { mid = start + (end - start) / 2; diff --git a/test/array_utils_test.cpp b/test/array_utils_test.cpp index 0fa4622a..2a961296 100644 --- a/test/array_utils_test.cpp +++ b/test/array_utils_test.cpp @@ -172,4 +172,46 @@ TEST(SortedArrayTest, FilterArray) { delete[] arr2; delete[] arr1; delete[] results; +} + +TEST(SortedArrayTest, SkipToID) { + std::vector array; + for (uint32_t i = 0; i < 10; i++) { + array.push_back(i * 3); + } + + uint32_t index = 0; + bool found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 15); + ASSERT_TRUE(found); + ASSERT_EQ(5, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 3); + ASSERT_FALSE(found); + ASSERT_EQ(4, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 12); + ASSERT_TRUE(found); + ASSERT_EQ(4, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 24); + ASSERT_TRUE(found); + ASSERT_EQ(8, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 25); + ASSERT_FALSE(found); + ASSERT_EQ(9, index); + + index = 4; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 30); + ASSERT_FALSE(found); + ASSERT_EQ(10, index); + + index = 12; + found = ArrayUtils::skip_index_to_id(index, array.data(), array.size(), 30); + ASSERT_FALSE(found); + ASSERT_EQ(12, index); } \ No newline at end of file From 173e6436df938afe8bab5c4fdfc1a85dbfef07a5 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 25 Apr 2023 16:54:12 +0530 Subject: [PATCH 36/93] Fix failing tests. --- include/index.h | 2 +- src/index.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/index.h b/include/index.h index 4f689c33..7f54e094 100644 --- a/include/index.h +++ b/include/index.h @@ -239,7 +239,7 @@ public: explicit VectorFilterFunctor(filter_result_iterator_t* const filter_result_iterator) : filter_result_iterator(filter_result_iterator) {} - bool operator()(unsigned int id) { + bool operator()(hnswlib::labeltype id) override { filter_result_iterator->reset(); return filter_result_iterator->valid(id) == 1; } diff --git a/src/index.cpp b/src/index.cpp index 694ab9bb..4fb91d9b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2882,6 +2882,7 @@ Option Index::search(std::vector& field_query_tokens, cons while (!no_filters_provided && filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { auto seq_id = filter_result_iterator.seq_id; + filter_result_iterator.next(); std::vector values; try { @@ -2903,7 +2904,6 @@ Option Index::search(std::vector& field_query_tokens, cons } dist_labels.emplace_back(dist, seq_id); - filter_result_iterator.next(); filter_id_count++; } From 6acc7d8557d6c1914fd6c4082a4818d92501ef4c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 26 Apr 2023 20:26:45 +0530 Subject: [PATCH 37/93] Fix memory leaks: * Handle deletion of `filter_tree_root` in `sort_fields_guard_t`. * Handle `filter_tree_root` being updated in `Index::static_filter_query_eval`. * Handle deletion of `phrase_result_ids` in `Index::search`. --- include/field.h | 2 +- include/index.h | 2 +- src/collection.cpp | 7 +++++++ src/index.cpp | 9 +++++++-- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/field.h b/include/field.h index e2f9729f..cbd62af2 100644 --- a/include/field.h +++ b/include/field.h @@ -555,7 +555,7 @@ struct sort_by { }; struct eval_t { - filter_node_t* filter_tree_root; + filter_node_t* filter_tree_root = nullptr; uint32_t* ids = nullptr; uint32_t size = 0; }; diff --git a/include/index.h b/include/index.h index 7f54e094..386490cd 100644 --- a/include/index.h +++ b/include/index.h @@ -642,7 +642,7 @@ public: Option search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t*& filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, diff --git a/src/collection.cpp b/src/collection.cpp index 21f2d187..f62a1753 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -28,6 +28,8 @@ struct sort_fields_guard_t { ~sort_fields_guard_t() { for(auto& sort_by_clause: sort_fields_std) { + delete sort_by_clause.eval.filter_tree_root; + if(sort_by_clause.eval.ids) { delete [] sort_by_clause.eval.ids; sort_by_clause.eval.ids = nullptr; @@ -1542,6 +1544,11 @@ Option Collection::search(std::string raw_query, std::unique_ptr search_params_guard(search_params); auto search_op = index->run_search(search_params, name); + + // filter_tree_root might be updated in Index::static_filter_query_eval. + filter_tree_root_guard.release(); + filter_tree_root_guard.reset(filter_tree_root); + if (!search_op.ok()) { return Option(search_op.code(), search_op.error()); } diff --git a/src/index.cpp b/src/index.cpp index 4fb91d9b..f5c5ae40 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2348,7 +2348,7 @@ bool Index::static_filter_query_eval(const override_t* override, if (filter_tree_root == nullptr) { filter_tree_root = new_filter_tree_root; } else { - filter_node_t* root = new filter_node_t(AND, filter_tree_root, + auto root = new filter_node_t(AND, filter_tree_root, new_filter_tree_root); filter_tree_root = root; } @@ -2711,7 +2711,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name Option Index::search(std::vector& field_query_tokens, const std::vector& the_fields, const text_match_type_t match_type, - filter_node_t* filter_tree_root, std::vector& facets, facet_query_t& facet_query, + filter_node_t*& filter_tree_root, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, @@ -2797,6 +2797,8 @@ Option Index::search(std::vector& field_query_tokens, cons // handle phrase searches uint32_t* phrase_result_ids = nullptr; uint32_t phrase_result_count = 0; + std::unique_ptr phrase_result_ids_guard; + if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, @@ -2805,6 +2807,9 @@ Option Index::search(std::vector& field_query_tokens, cons excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, included_ids_map, is_wildcard_query, phrase_result_ids, phrase_result_count); + + phrase_result_ids_guard.reset(phrase_result_ids); + if (phrase_result_count == 0) { goto process_search_results; } From 5c3333058d82e03e214edeae5296f553294c7a31 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 26 Apr 2023 20:36:09 +0530 Subject: [PATCH 38/93] Remove `Index::do_filtering`. Using `filter_result_t` instead. --- include/index.h | 19 -- src/index.cpp | 517 +++--------------------------------------------- 2 files changed, 25 insertions(+), 511 deletions(-) diff --git a/include/index.h b/include/index.h index 386490cd..f2f8a4b8 100644 --- a/include/index.h +++ b/include/index.h @@ -472,31 +472,12 @@ private: bool field_is_indexed(const std::string& field_name) const; - Option do_filtering(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name = "", - const uint32_t& context_ids_length = 0, - uint32_t* const& context_ids = nullptr) const; - void aproximate_numerical_match(num_tree_t* const num_tree, const NUM_COMPARATOR& comparator, const int64_t& value, const int64_t& range_end_value, uint32_t& filter_ids_length) const; - /// Traverses through filter tree to get the filter_result. - /// - /// \param filter_tree_root - /// \param filter_result - /// \param collection_name Name of the collection to which current index belongs. Used to find the reference field in other collection. - /// \param context_ids_length Number of docs matching the search query. - /// \param context_ids Array of doc ids matching the search query. - Option recursive_filter(filter_node_t* const filter_tree_root, - filter_result_t& filter_result, - const std::string& collection_name = "", - const uint32_t& context_ids_length = 0, - uint32_t* const& context_ids = nullptr) const; - void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; diff --git a/src/index.cpp b/src/index.cpp index f5c5ae40..d1ca8688 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1539,446 +1539,6 @@ bool Index::field_is_indexed(const std::string& field_name) const { geopoint_index.count(field_name) != 0; } -Option Index::do_filtering(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name, - const uint32_t& context_ids_length, - uint32_t* const& context_ids) const { - // auto begin = std::chrono::high_resolution_clock::now(); - const filter a_filter = root->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. - auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(a_filter.referenced_collection_name); - if (collection == nullptr) { - return Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); - } - - filter_result_t reference_filter_result; - auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, - reference_filter_result, - collection_name); - if (!reference_filter_op.ok()) { - return Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name - + "` collection: " + reference_filter_op.error()); - } - - if (context_ids_length != 0) { - std::vector include_indexes; - include_indexes.reserve(std::min(context_ids_length, reference_filter_result.count)); - - size_t context_index = 0, reference_result_index = 0; - while (context_index < context_ids_length && reference_result_index < reference_filter_result.count) { - if (context_ids[context_index] == reference_filter_result.docs[reference_result_index]) { - include_indexes.push_back(reference_result_index); - context_index++; - reference_result_index++; - } else if (context_ids[context_index] < reference_filter_result.docs[reference_result_index]) { - context_index++; - } else { - reference_result_index++; - } - } - - result.count = include_indexes.size(); - result.docs = new uint32_t[include_indexes.size()]; - auto& result_references = result.reference_filter_results[a_filter.referenced_collection_name]; - result_references = new reference_filter_result_t[include_indexes.size()]; - - for (uint32_t i = 0; i < include_indexes.size(); i++) { - result.docs[i] = reference_filter_result.docs[include_indexes[i]]; - result_references[i] = reference_filter_result.reference_filter_results[a_filter.referenced_collection_name][include_indexes[i]]; - } - - return Option(true); - } - - result = std::move(reference_filter_result); - return Option(true); - } - - if (a_filter.field_name == "id") { - // we handle `ids` separately - std::vector result_ids; - for (const auto& id_str : a_filter.values) { - result_ids.push_back(std::stoul(id_str)); - } - - std::sort(result_ids.begin(), result_ids.end()); - - auto result_array = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), result_array); - - if (context_ids_length != 0) { - uint32_t* out = nullptr; - result.count = ArrayUtils::and_scalar(context_ids, context_ids_length, - result_array, result_ids.size(), &out); - - delete[] result_array; - - result.docs = out; - return Option(true); - } - - result.docs = result_array; - result.count = result_ids.size(); - return Option(true); - } - - if (!field_is_indexed(a_filter.field_name)) { - return Option(true); - } - - field f = search_schema.at(a_filter.field_name); - - uint32_t* result_ids = nullptr; - size_t result_ids_len = 0; - - if (f.is_integer()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - int64_t value = (int64_t)std::stol(filter_value); - - if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi + 1]; - auto const range_end_value = (int64_t)std::stol(next_filter_value); - - if (context_ids_length != 0) { - num_tree->range_inclusive_contains(value, range_end_value, context_ids_length, context_ids, - result_ids_len, result_ids); - } else { - num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len); - } - - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[fi], value, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len); - } - } - } - } else if (f.is_float()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - float value = (float)std::atof(filter_value.c_str()); - int64_t float_int64 = float_to_int64_t(value); - - if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str())); - - if (context_ids_length != 0) { - num_tree->range_inclusive_contains(float_int64, range_end_value, context_ids_length, context_ids, - result_ids_len, result_ids); - } else { - num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len); - } - - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, float_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[fi], float_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len); - } - } - } - } else if (f.is_bool()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - size_t value_index = 0; - for (const std::string& filter_value : a_filter.values) { - int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - if (a_filter.comparators[value_index] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, bool_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[value_index], bool_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len); - } - } - - value_index++; - } - } else if (f.is_geopoint()) { - for (const std::string& filter_value : a_filter.values) { - std::vector geo_result_ids; - - std::vector filter_value_parts; - StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points - - bool is_polygon = StringUtils::is_float(filter_value_parts.back()); - S2Region* query_region; - - if (is_polygon) { - const int num_verts = int(filter_value_parts.size()) / 2; - std::vector vertices; - double sum = 0.0; - - for (size_t point_index = 0; point_index < size_t(num_verts); - point_index++) { - double lat = std::stod(filter_value_parts[point_index * 2]); - double lon = std::stod(filter_value_parts[point_index * 2 + 1]); - S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint(); - vertices.emplace_back(vertex); - } - - auto loop = new S2Loop(vertices, S2Debug::DISABLE); - loop->Normalize(); // if loop is not CCW but CW, change to CCW. - - S2Error error; - if (loop->FindValidationError(&error)) { - LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error; - delete loop; - continue; - } else { - query_region = loop; - } - } else { - double radius = std::stof(filter_value_parts[2]); - const auto& unit = filter_value_parts[3]; - - if (unit == "km") { - radius *= 1000; - } else { - // assume "mi" (validated upstream) - radius *= 1609.34; - } - - S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); - double query_lat = std::stod(filter_value_parts[0]); - double query_lng = std::stod(filter_value_parts[1]); - S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); - query_region = new S2Cap(center, query_radius); - } - - S2RegionTermIndexer::Options options; - options.set_index_contains_points_only(true); - S2RegionTermIndexer indexer(options); - - for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { - auto geo_index = geopoint_index.at(a_filter.field_name); - const auto& ids_it = geo_index->find(term); - if(ids_it != geo_index->end()) { - geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); - } - } - - gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); - geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); - - // `geo_result_ids` will contain all IDs that are within approximately within query radius - // we still need to do another round of exact filtering on them - - if (context_ids_length != 0) { - uint32_t *out = nullptr; - uint32_t count = ArrayUtils::and_scalar(context_ids, context_ids_length, - &geo_result_ids[0], geo_result_ids.size(), &out); - - geo_result_ids = std::vector(out, out + count); - } - - std::vector exact_geo_result_ids; - - if (f.is_single_geopoint()) { - spp::sparse_hash_map* sort_field_index = sort_index.at(f.name); - - for (auto result_id : geo_result_ids) { - // no need to check for existence of `result_id` because of indexer based pre-filtering above - int64_t lat_lng = sort_field_index->at(result_id); - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); - if (query_region->Contains(s2_lat_lng.ToPoint())) { - exact_geo_result_ids.push_back(result_id); - } - } - } else { - spp::sparse_hash_map* geo_field_index = geo_array_index.at(f.name); - - for (auto result_id : geo_result_ids) { - int64_t* lat_lngs = geo_field_index->at(result_id); - - bool point_found = false; - - // any one point should exist - for (size_t li = 0; li < lat_lngs[0]; li++) { - int64_t lat_lng = lat_lngs[li + 1]; - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); - if (query_region->Contains(s2_lat_lng.ToPoint())) { - point_found = true; - break; - } - } - - if (point_found) { - exact_geo_result_ids.push_back(result_id); - } - } - } - - uint32_t* out = nullptr; - result_ids_len = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(), - result_ids, result_ids_len, &out); - - delete[] result_ids; - result_ids = out; - - delete query_region; - } - } else if (f.is_string()) { - art_tree* t = search_index.at(a_filter.field_name); - - uint32_t* or_ids = nullptr; - size_t or_ids_size = 0; - - // aggregates IDs across array of filter values and reduces excessive ORing - std::vector f_id_buff; - - for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; - - // there could be multiple tokens in a filter value, which we have to treat as ANDs - // e.g. country: South Africa - Tokenizer tokenizer(filter_value, true, false, f.locale, symbols_to_index, token_separators); - - std::string str_token; - size_t token_index = 0; - std::vector str_tokens; - - while (tokenizer.next(str_token, token_index)) { - str_tokens.push_back(str_token); - - art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), - str_token.length()+1); - if (leaf == nullptr) { - continue; - } - - posting_lists.push_back(leaf->values); - } - - if (posting_lists.size() != str_tokens.size()) { - continue; - } - - if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) { - // needs intersection + exact matching (unlike CONTAINS) - std::vector result_id_vec; - posting_t::intersect(posting_lists, result_id_vec, context_ids_length, context_ids); - - if (result_id_vec.empty()) { - continue; - } - - // need to do exact match - uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()]; - size_t exact_str_ids_size = 0; - std::unique_ptr exact_str_ids_guard(exact_str_ids); - - posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(), - exact_str_ids, exact_str_ids_size); - - if (exact_str_ids_size == 0) { - continue; - } - - for (size_t ei = 0; ei < exact_str_ids_size; ei++) { - f_id_buff.push_back(exact_str_ids[ei]); - } - } else { - // CONTAINS - size_t before_size = f_id_buff.size(); - posting_t::intersect(posting_lists, f_id_buff, context_ids_length, context_ids); - if (f_id_buff.size() == before_size) { - continue; - } - } - - if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) { - gfx::timsort(f_id_buff.begin(), f_id_buff.end()); - f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); - - uint32_t* out = nullptr; - or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); - delete[] or_ids; - or_ids = out; - std::vector().swap(f_id_buff); // clears out memory - } - } - - if (!f_id_buff.empty()) { - gfx::timsort(f_id_buff.begin(), f_id_buff.end()); - f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); - - uint32_t* out = nullptr; - or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); - delete[] or_ids; - or_ids = out; - std::vector().swap(f_id_buff); // clears out memory - } - - result_ids = or_ids; - result_ids_len = or_ids_size; - } - - if (a_filter.apply_not_equals) { - auto all_ids = seq_ids->uncompress(); - auto all_ids_size = seq_ids->num_ids(); - - uint32_t* to_include_ids = nullptr; - size_t to_include_ids_len = 0; - - to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, result_ids, - result_ids_len, &to_include_ids); - - delete[] all_ids; - delete[] result_ids; - - result_ids = to_include_ids; - result_ids_len = to_include_ids_len; - - if (context_ids_length != 0) { - uint32_t *out = nullptr; - result.count = ArrayUtils::and_scalar(context_ids, context_ids_length, - result_ids, result_ids_len, &out); - - delete[] result_ids; - - result.docs = out; - return Option(true); - } - } - - result.docs = result_ids; - result.count = result_ids_len; - - return Option(true); - /*long long int timeMillis = - std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - - begin).count(); - - LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ -} - void Index::aproximate_numerical_match(num_tree_t* const num_tree, const NUM_COMPARATOR& comparator, const int64_t& value, @@ -2149,54 +1709,19 @@ Option Index::rearrange_filter_tree(filter_node_t* const root, return Option(true); } -Option Index::recursive_filter(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name, - const uint32_t& context_ids_length, - uint32_t* const& context_ids) const { - if (root == nullptr) { - return Option(true); - } - - if (root->isOperator) { - filter_result_t l_result; - if (root->left != nullptr) { - auto filter_op = recursive_filter(root->left, l_result , collection_name, context_ids_length, context_ids); - if (!filter_op.ok()) { - return filter_op; - } - } - - filter_result_t r_result; - if (root->right != nullptr) { - auto filter_op = recursive_filter(root->right, r_result , collection_name, context_ids_length, context_ids); - if (!filter_op.ok()) { - return filter_op; - } - } - - if (root->filter_operator == AND) { - filter_result_t::and_filter_results(l_result, r_result, result); - } else { - filter_result_t::or_filter_results(l_result, r_result, result); - } - - return Option(true); - } - - return do_filtering(root, result, collection_name, context_ids_length, context_ids); -} - Option Index::do_filtering_with_lock(filter_node_t* const filter_tree_root, filter_result_t& filter_result, const std::string& collection_name) const { std::shared_lock lock(mutex); - auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } + filter_result.count = filter_result_iterator.to_filter_id_array(filter_result.docs); + return Option(true); } @@ -2206,16 +1731,20 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); - filter_result_t reference_filter_result; - auto filter_op = recursive_filter(filter_tree_root, reference_filter_result); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } + uint32_t* reference_docs = nullptr; + uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs); + std::unique_ptr docs_guard(reference_docs); + // doc id -> reference doc ids std::map> reference_map; - for (uint32_t i = 0; i < reference_filter_result.count; i++) { - auto reference_doc_id = reference_filter_result.docs[i]; + for (uint32_t i = 0; i < count; i++) { + auto reference_doc_id = reference_docs[i]; auto doc_id = sort_index.at(reference_helper_field_name)->at(reference_doc_id); reference_map[doc_id].push_back(reference_doc_id); @@ -5102,11 +4631,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - filter_result_t result; - recursive_filter(sort_fields_std[i].eval.filter_tree_root, result); - sort_fields_std[i].eval.ids = result.docs; - sort_fields_std[i].eval.size = result.count; - result.docs = nullptr; + + auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return; + } + + sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids); + } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); From 9aa226a461be9054ff169f42441d58670d85cf55 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 27 Apr 2023 14:40:53 +0530 Subject: [PATCH 39/93] Refactor string filter iteration. --- include/filter_result_iterator.h | 10 +- src/filter_result_iterator.cpp | 171 ++++++++++++------------------- 2 files changed, 69 insertions(+), 112 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 259b93f0..bc8c4c23 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -85,7 +85,6 @@ struct filter_result_t { static void or_filter_results(const filter_result_t& a, const filter_result_t& b, filter_result_t& result); }; - class filter_result_iterator_t { private: std::string collection_name; @@ -106,6 +105,7 @@ private: /// for each token. /// /// Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_lists; std::vector> posting_list_iterators; std::vector expanded_plists; @@ -121,11 +121,11 @@ private: /// Advance all the token iterators that are at seq_id. void advance_string_filter_token_iterators(); - /// Finds the next match for a filter on string field. - void doc_matching_string_filter(bool field_is_array); + /// Finds the first match for a filter on string field. + void get_string_filter_first_match(const bool& field_is_array); - /// Returns true when doc and reference hold valid values. Used in conjunction with next() and skip_to(id). - [[nodiscard]] bool valid(); + /// Finds the next match for a filter on string field. + void get_string_filter_next_match(const bool& field_is_array); public: uint32_t seq_id = 0; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 63a76f73..34228916 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -291,7 +291,7 @@ void filter_result_iterator_t::advance_string_filter_token_iterators() { } } -void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) { +void filter_result_iterator_t::get_string_filter_next_match(const bool& field_is_array) { // If none of the filter value iterators are valid, mark this node as invalid. bool one_is_valid = false; @@ -412,7 +412,7 @@ void filter_result_iterator_t::next() { do { previous_match = seq_id; advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); } while (is_valid && previous_match + 1 == seq_id); if (!is_valid) { @@ -433,7 +433,7 @@ void filter_result_iterator_t::next() { } advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); return; } @@ -474,6 +474,50 @@ void apply_not_equals(uint32_t*&& all_ids, result_ids_len = to_include_ids_len; } +void filter_result_iterator_t::get_string_filter_first_match(const bool& field_is_array) { + get_string_filter_next_match(field_is_array); + + if (filter_node->filter_exp.apply_not_equals) { + // filter didn't match any id. So by applying not equals, every id in the index is a match. + if (!is_valid) { + is_valid = true; + seq_id = 0; + result_index = index->seq_ids->last_id() + 1; + return; + } + + // [0, seq_id) are a match for not equals. + if (seq_id > 0) { + result_index = seq_id; + seq_id = 0; + return; + } + + // Keep ignoring the consecutive matches. + uint32_t previous_match; + do { + previous_match = seq_id; + advance_string_filter_token_iterators(); + get_string_filter_next_match(field_is_array); + } while (is_valid && previous_match + 1 == seq_id); + + if (!is_valid) { + // filter matched all the ids in the index. So for not equals, there's no match. + if (previous_match >= index->seq_ids->last_id()) { + return; + } + + is_valid = true; + result_index = index->seq_ids->last_id() + 1; + seq_id = previous_match + 1; + return; + } + + result_index = seq_id; + seq_id = previous_match + 1; + } +} + void filter_result_iterator_t::init() { if (filter_node == nullptr) { return; @@ -807,7 +851,7 @@ void filter_result_iterator_t::init() { art_tree* t = index->search_index.at(a_filter.field_name); for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; + std::vector raw_posting_lists; // there could be multiple tokens in a filter value, which we have to treat as ANDs // e.g. country: South Africa @@ -826,119 +870,29 @@ void filter_result_iterator_t::init() { continue; } - posting_lists.push_back(leaf->values); + raw_posting_lists.push_back(leaf->values); } - if (posting_lists.size() != str_tokens.size()) { + if (raw_posting_lists.size() != str_tokens.size()) { continue; } std::vector plists; - posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + posting_t::to_expanded_plists(raw_posting_lists, plists, expanded_plists); + posting_lists.push_back(plists); posting_list_iterators.emplace_back(std::vector()); - for (auto const& plist: plists) { posting_list_iterators.back().push_back(plist->new_iterator()); } } - doc_matching_string_filter(f.is_array()); - - if (filter_node->filter_exp.apply_not_equals) { - // filter didn't match any id. So by applying not equals, every id in the index is a match. - if (!is_valid) { - is_valid = true; - seq_id = 0; - result_index = index->seq_ids->last_id() + 1; - return; - } - - // [0, seq_id) are a match for not equals. - if (seq_id > 0) { - result_index = seq_id; - seq_id = 0; - return; - } - - // Keep ignoring the consecutive matches. - uint32_t previous_match; - do { - previous_match = seq_id; - advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); - } while (is_valid && previous_match + 1 == seq_id); - - if (!is_valid) { - // filter matched all the ids in the index. So for not equals, there's no match. - if (previous_match >= index->seq_ids->last_id()) { - return; - } - - is_valid = true; - result_index = index->seq_ids->last_id() + 1; - seq_id = previous_match + 1; - return; - } - - result_index = seq_id; - seq_id = previous_match + 1; - } + get_string_filter_first_match(f.is_array()); return; } } -bool filter_result_iterator_t::valid() { - if (!is_valid) { - return false; - } - - if (filter_node->isOperator) { - if (filter_node->filter_operator == AND) { - is_valid = left_it->valid() && right_it->valid(); - return is_valid; - } else { - is_valid = left_it->valid() || right_it->valid(); - return is_valid; - } - } - - if (is_filter_result_initialized) { - is_valid = result_index < filter_result.count; - return is_valid; - } - - const filter a_filter = filter_node->filter_exp; - - if (!index->field_is_indexed(a_filter.field_name)) { - is_valid = false; - return is_valid; - } - - field f = index->search_schema.at(a_filter.field_name); - - if (f.is_string()) { - if (filter_node->filter_exp.apply_not_equals) { - return seq_id < result_index; - } - - bool one_is_valid = false; - for (auto& filter_value_tokens: posting_list_iterators) { - posting_list_t::intersect(filter_value_tokens, one_is_valid); - - if (one_is_valid) { - break; - } - } - - is_valid = one_is_valid; - return is_valid; - } - - return false; -} - void filter_result_iterator_t::skip_to(uint32_t id) { if (!is_valid) { return; @@ -1003,7 +957,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { do { previous_match = seq_id; advance_string_filter_token_iterators(); - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); } while (is_valid && previous_match + 1 == seq_id); } while (is_valid && seq_id <= id); @@ -1047,7 +1001,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { } } - doc_matching_string_filter(f.is_array()); + get_string_filter_next_match(f.is_array()); return; } } @@ -1190,13 +1144,16 @@ void filter_result_iterator_t::reset() { field f = index->search_schema.at(a_filter.field_name); if (f.is_string()) { - posting_list_iterators.clear(); - for(auto expanded_plist: expanded_plists) { - delete expanded_plist; - } - expanded_plists.clear(); + for (uint32_t i = 0; i < posting_lists.size(); i++) { + auto const& plists = posting_lists[i]; - init(); + posting_list_iterators[i].clear(); + for (auto const& plist: plists) { + posting_list_iterators[i].push_back(plist->new_iterator()); + } + } + + get_string_filter_first_match(f.is_array()); return; } } From 329e16652d68953673515a190a12eee660879145 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 28 Apr 2023 13:31:25 +0530 Subject: [PATCH 40/93] Add convenience methods in `result_iter_state_t`. --- include/or_iterator.h | 36 ++++++++++++++++++------------------ include/posting_list.h | 6 ++++++ src/posting_list.cpp | 29 +++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/include/or_iterator.h b/include/or_iterator.h index 67fd5ddf..c4f27518 100644 --- a/include/or_iterator.h +++ b/include/or_iterator.h @@ -62,8 +62,8 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state case 0: break; case 1: - if(istate.filter_ids_length != 0) { - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); + if(istate.is_filter_provided() && istate.is_filter_valid()) { + its[0].skip_to(istate.get_filter_id()); } while(its.size() == it_size && its[0].valid()) { @@ -79,10 +79,10 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state func(id, its); } - if(istate.filter_ids_length != 0 && !is_excluded) { - if(istate.filter_ids_index < istate.filter_ids_length) { + if(istate.is_filter_provided() && !is_excluded) { + if(istate.is_filter_valid()) { // skip iterator till next id available in filter - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); + its[0].skip_to(istate.get_filter_id()); } else { break; } @@ -92,9 +92,9 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } break; case 2: - if(istate.filter_ids_length != 0) { - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); - its[1].skip_to(istate.filter_ids[istate.filter_ids_index]); + if(istate.is_filter_provided() && istate.is_filter_valid()) { + its[0].skip_to(istate.get_filter_id()); + its[1].skip_to(istate.get_filter_id()); } while(its.size() == it_size && !at_end2(its)) { @@ -111,11 +111,11 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state func(id, its); } - if(istate.filter_ids_length != 0 && !is_excluded) { - if(istate.filter_ids_index < istate.filter_ids_length) { + if(istate.is_filter_provided() != 0 && !is_excluded) { + if(istate.is_filter_valid()) { // skip iterator till next id available in filter - its[0].skip_to(istate.filter_ids[istate.filter_ids_index]); - its[1].skip_to(istate.filter_ids[istate.filter_ids_index]); + its[0].skip_to(istate.get_filter_id()); + its[1].skip_to(istate.get_filter_id()); } else { break; } @@ -128,9 +128,9 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } break; default: - if(istate.filter_ids_length != 0) { + if(istate.is_filter_provided() && istate.is_filter_valid()) { for(auto& it: its) { - it.skip_to(istate.filter_ids[istate.filter_ids_index]); + it.skip_to(istate.get_filter_id()); } } @@ -148,11 +148,11 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state func(id, its); } - if(istate.filter_ids_length != 0 && !is_excluded) { - if(istate.filter_ids_index < istate.filter_ids_length) { + if(istate.is_filter_provided() && !is_excluded) { + if(istate.is_filter_valid()) { // skip iterator till next id available in filter for(auto& it: its) { - it.skip_to(istate.filter_ids[istate.filter_ids_index]); + it.skip_to(istate.get_filter_id()); } } else { break; @@ -167,4 +167,4 @@ bool or_iterator_t::intersect(std::vector& its, result_iter_state } return true; -} +} \ No newline at end of file diff --git a/include/posting_list.h b/include/posting_list.h index 11ed9c91..a0df81fa 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -33,6 +33,12 @@ struct result_iter_state_t { filter_result_iterator_t* fit) : excluded_result_ids(excluded_result_ids), excluded_result_ids_size(excluded_result_ids_size), fit(fit){} + + [[nodiscard]] bool is_filter_provided() const; + + [[nodiscard]] bool is_filter_valid() const; + + [[nodiscard]] uint32_t get_filter_id() const; }; /* diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 1b5a3323..dc82fb69 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -2,6 +2,7 @@ #include #include "for.h" #include "array_utils.h" +#include "filter_result_iterator.h" /* block_t operations */ @@ -1781,3 +1782,31 @@ posting_list_t::iterator_t posting_list_t::iterator_t::clone() const { uint32_t posting_list_t::iterator_t::get_field_id() const { return field_id; } + +bool result_iter_state_t::is_filter_provided() const { + return filter_ids_length > 0 || (fit != nullptr && fit->approx_filter_ids_length > 0); +} + +bool result_iter_state_t::is_filter_valid() const { + if (filter_ids_length > 0) { + return filter_ids_index < filter_ids_length; + } + + if (fit != nullptr) { + return fit->is_valid; + } + + return false; +} + +uint32_t result_iter_state_t::get_filter_id() const { + if (filter_ids_length > 0 && filter_ids_index < filter_ids_length) { + return filter_ids[filter_ids_index]; + } + + if (fit != nullptr && fit->is_valid) { + return fit->seq_id; + } + + return 0; +} From cf7086c295fa15b459b8393cc8b4777e5b63acd8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 28 Apr 2023 13:35:07 +0530 Subject: [PATCH 41/93] Refactor `or_iterator_t::take_id`. --- src/or_iterator.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index 8dd9d487..6384cba2 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -209,7 +209,16 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e } if (istate.fit != nullptr && istate.fit->approx_filter_ids_length > 0) { - return (istate.fit->valid(id) == 1); + if (istate.fit->valid(id) == -1) { + return false; + } + + if (istate.fit->seq_id == id) { + istate.fit->next(); + return true; + } + + return false; } return true; @@ -245,5 +254,3 @@ or_iterator_t::~or_iterator_t() noexcept { it.reset_cache(); } } - - From 56de9b126508acc075c9af2d6193d7e1a445a2b4 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 28 Apr 2023 14:55:39 +0530 Subject: [PATCH 42/93] Fix `filter_result_iterator_t::valid(uint32_t id)` not updating `seq_id` in case of complex filter. --- src/filter_result_iterator.cpp | 8 ++++++++ test/filter_test.cpp | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 34228916..f155d955 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1014,6 +1014,14 @@ int filter_result_iterator_t::valid(uint32_t id) { if (filter_node->isOperator) { auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); + if (left_it->is_valid && right_it->is_valid) { + seq_id = std::min(left_it->seq_id, right_it->seq_id); + } else if (left_it->is_valid) { + seq_id = left_it->seq_id; + } else if (right_it->is_valid) { + seq_id = right_it->seq_id; + } + if (filter_node->filter_operator == AND) { is_valid = left_it->is_valid && right_it->is_valid; diff --git a/test/filter_test.cpp b/test/filter_test.cpp index 50e7d6d6..d3aa3f31 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -284,10 +284,11 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); - std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}; + std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 3, 4, 5, 5}; expected = {1, 0, 1, 0, 1, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test.seq_id); } delete filter_tree_root; @@ -301,9 +302,11 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_TRUE(iter_validate_ids_not_equals_filter_test.init_status().ok()); validate_ids = {0, 1, 2, 3, 4, 5, 6}; + seq_ids = {1, 1, 3, 3, 5, 5, 5}; expected = {0, 1, 0, 1, 0, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_not_equals_filter_test.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_not_equals_filter_test.seq_id); } delete filter_tree_root; From 85de14c8c345e81240df1d3c753f8dc120d28ee8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 1 May 2023 14:53:19 +0530 Subject: [PATCH 43/93] Refactor `art_fuzzy_search_i`. --- src/art.cpp | 263 ++++++++++++++++++++++++++++------------------------ 1 file changed, 144 insertions(+), 119 deletions(-) diff --git a/src/art.cpp b/src/art.cpp index 7a7d5d5b..8c3b99ad 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -973,36 +973,6 @@ const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, return prev_token_doc_ids; } -const uint32_t* get_allowed_doc_ids(art_tree *t, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, - size_t& prev_token_doc_ids_len) { - - art_leaf* prev_leaf = static_cast( - art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) - ); - - uint32_t* prev_token_doc_ids = nullptr; - - if(prev_token.empty() || !prev_leaf) { - prev_token_doc_ids_len = filter_result_iterator.to_filter_id_array(prev_token_doc_ids); - return prev_token_doc_ids; - } - - std::vector prev_leaf_ids; - posting_t::merge({prev_leaf->values}, prev_leaf_ids); - - if(filter_result_iterator.is_valid) { - prev_token_doc_ids_len = filter_result_iterator.and_scalar(prev_leaf_ids.data(), prev_leaf_ids.size(), - prev_token_doc_ids); - } else { - prev_token_doc_ids_len = prev_leaf_ids.size(); - prev_token_doc_ids = new uint32_t[prev_token_doc_ids_len]; - std::copy(prev_leaf_ids.begin(), prev_leaf_ids.end(), prev_token_doc_ids); - } - - return prev_token_doc_ids; -} - bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::string& prev_token, const uint32_t* allowed_doc_ids, const size_t allowed_doc_ids_len, std::set& exclude_leaves, const art_leaf* exact_leaf, @@ -1030,6 +1000,52 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str return true; } +bool validate_and_add_leaf(art_leaf* leaf, + const std::string& prev_token, const art_leaf* prev_leaf, + const art_leaf* exact_leaf, + filter_result_iterator_t& filter_result_iterator, + std::set& exclude_leaves, + std::vector& results) { + if(leaf == exact_leaf) { + return false; + } + + std::string tok(reinterpret_cast(leaf->key), leaf->key_len - 1); + if(exclude_leaves.count(tok) != 0) { + return false; + } + + if(prev_token.empty() || !prev_leaf) { + if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) { + return false; + } + } else if (!filter_result_iterator.is_valid) { + std::vector prev_leaf_ids; + posting_t::merge({prev_leaf->values}, prev_leaf_ids); + + if (!posting_t::contains_atleast_one(leaf->values, prev_leaf_ids.data(), prev_leaf_ids.size())) { + return false; + } + } else { + std::vector leaf_ids; + posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); + + bool found = false; + for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) { + found = (filter_result_iterator.valid(leaf_ids[i]) == 1); + } + + if (!found) { + return false; + } + } + + exclude_leaves.emplace(tok); + results.push_back(leaf); + + return true; +} + int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, const art_leaf* exact_leaf, const bool last_token, const std::string& prev_token, @@ -1126,6 +1142,101 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r return 0; } +int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, + const art_leaf* exact_leaf, + const bool last_token, const std::string& prev_token, + filter_result_iterator_t& filter_result_iterator, + const art_tree* t, std::set& exclude_leaves, std::vector& results) { + + printf("INSIDE art_topk_iter: root->type: %d\n", root->type); + + auto prev_leaf = static_cast( + art_search(t, reinterpret_cast(prev_token.c_str()), prev_token.size() + 1) + ); + + std::priority_queue, + decltype(&compare_art_node_score_pq)> q(compare_art_node_score_pq); + + if(token_order == FREQUENCY) { + q = std::priority_queue, + decltype(&compare_art_node_frequency_pq)>(compare_art_node_frequency_pq); + } + + q.push(root); + + size_t num_processed = 0; + + while(!q.empty() && results.size() < max_results*4) { + art_node *n = (art_node *) q.top(); + q.pop(); + + if (!n) continue; + if (IS_LEAF(n)) { + art_leaf *l = (art_leaf *) LEAF_RAW(n); + //LOG(INFO) << "END LEAF SCORE: " << l->max_score; + + validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator, + exclude_leaves, results); + filter_result_iterator.reset(); + + if (++num_processed % 1024 == 0 && (microseconds( + std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { + search_cutoff = true; + break; + } + + continue; + } + + int idx; + switch (n->type) { + case NODE4: + //LOG(INFO) << "NODE4, SCORE: " << n->max_score; + for (int i=0; i < n->num_children; i++) { + art_node* child = ((art_node4*)n)->children[i]; + q.push(child); + } + break; + + case NODE16: + //LOG(INFO) << "NODE16, SCORE: " << n->max_score; + for (int i=0; i < n->num_children; i++) { + q.push(((art_node16*)n)->children[i]); + } + break; + + case NODE48: + //LOG(INFO) << "NODE48, SCORE: " << n->max_score; + for (int i=0; i < 256; i++) { + idx = ((art_node48*)n)->keys[i]; + if (!idx) continue; + art_node *child = ((art_node48*)n)->children[idx - 1]; + q.push(child); + } + break; + + case NODE256: + //LOG(INFO) << "NODE256, SCORE: " << n->max_score; + for (int i=0; i < 256; i++) { + if (!((art_node256*)n)->children[i]) continue; + q.push(((art_node256*)n)->children[i]); + } + break; + + default: + printf("ABORTING BECAUSE OF UNKNOWN NODE TYPE: %d\n", n->type); + abort(); + } + } + + /*LOG(INFO) << "leaf results.size: " << results.size() + << ", filter_ids_length: " << filter_ids_length + << ", num_large_lists: " << num_large_lists;*/ + + printf("OUTSIDE art_topk_iter: results size: %d\n", results.size()); + return 0; +} + // Recursively iterates over the tree static int recursive_iter(art_node *n, art_callback cb, void *data) { // Handle base cases @@ -1689,14 +1800,10 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; - // documents that contain the previous token and/or filter ids - size_t allowed_doc_ids_len = 0; - const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); - filter_result_iterator.reset(); - for(auto node: nodes) { art_topk_iter(node, token_order, max_words, exact_leaf, - last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, + last_token, prev_token, + filter_result_iterator, t, exclude_leaves, results); } @@ -1722,91 +1829,9 @@ int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_le if(time_micro > 1000) { LOG(INFO) << "Time taken for art_topk_iter: " << time_micro << "us, size of nodes: " << nodes.size() - << ", filter_ids_length: " << filter_ids_length; + << ", filter_ids_length: " << filter_result_iterator.approx_filter_ids_length; }*/ - delete [] allowed_doc_ids; - - return 0; -} - -int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, - const size_t max_words, const token_ordering token_order, const bool prefix, - bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, - std::vector &results, std::set& exclude_leaves) { - - std::vector nodes; - int irow[term_len + 1]; - int jrow[term_len + 1]; - for (int i = 0; i <= term_len; i++){ - irow[i] = jrow[i] = i; - } - - //auto begin = std::chrono::high_resolution_clock::now(); - - if(IS_LEAF(t->root)) { - art_leaf *l = (art_leaf *) LEAF_RAW(t->root); - art_fuzzy_recurse(0, l->key[0], t->root, 0, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); - } else { - if(t->root == nullptr) { - return 0; - } - - // send depth as -1 to indicate that this is a root node - art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); - } - - //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); - //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); - - //auto begin = std::chrono::high_resolution_clock::now(); - - size_t key_len = prefix ? term_len + 1 : term_len; - art_leaf* exact_leaf = (art_leaf *) art_search(t, term, key_len); - //LOG(INFO) << "exact_leaf: " << exact_leaf << ", term: " << term << ", term_len: " << term_len; - - // documents that contain the previous token and/or filter ids - size_t allowed_doc_ids_len = 0; - const uint32_t* allowed_doc_ids = get_allowed_doc_ids(t, prev_token, filter_result_iterator, allowed_doc_ids_len); - - for(auto node: nodes) { - art_topk_iter(node, token_order, max_words, exact_leaf, - last_token, prev_token, allowed_doc_ids, allowed_doc_ids_len, - t, exclude_leaves, results); - } - - if(token_order == FREQUENCY) { - std::sort(results.begin(), results.end(), compare_art_leaf_frequency); - } else { - std::sort(results.begin(), results.end(), compare_art_leaf_score); - } - - if(exact_leaf && min_cost == 0) { - std::string tok(reinterpret_cast(exact_leaf->key), exact_leaf->key_len - 1); - if(exclude_leaves.count(tok) == 0) { - results.insert(results.begin(), exact_leaf); - exclude_leaves.emplace(tok); - } - } - - if(results.size() > max_words) { - results.resize(max_words); - } - - /*auto time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); - - if(time_micro > 1000) { - LOG(INFO) << "Time taken for art_topk_iter: " << time_micro - << "us, size of nodes: " << nodes.size() - << ", filter_ids_length: " << filter_ids_length; - }*/ - -// TODO: Figure out this edge case. -// if(allowed_doc_ids != filter_ids) { -// delete [] allowed_doc_ids; -// } - return 0; } From 91c1c321dc8c9113a0ea66f3f522155363fa717a Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 2 May 2023 10:49:03 +0530 Subject: [PATCH 44/93] Add test for prefix search with filter. --- src/art.cpp | 2 +- test/collection_filtering_test.cpp | 42 ++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/art.cpp b/src/art.cpp index 8c3b99ad..ee0fb53a 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1031,7 +1031,7 @@ bool validate_and_add_leaf(art_leaf* leaf, posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); bool found = false; - for (uint32_t i = 0; i < leaf_ids.size() && !found; i++) { + for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator.is_valid && !found; i++) { found = (filter_result_iterator.valid(leaf_ids[i]) == 1); } diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 1628c2dc..c644d547 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -2628,4 +2628,46 @@ TEST_F(CollectionFilteringTest, ComplexFilterQuery) { } collectionManager.drop_collection("ComplexFilterQueryCollection"); +} + +TEST_F(CollectionFilteringTest, PrefixSearchWithFilter) { + std::ifstream infile(std::string(ROOT_DIR)+"test/documents.jsonl"); + std::vector search_fields = { + field("title", field_types::STRING, false), + field("points", field_types::INT32, false) + }; + + query_fields = {"title"}; + sort_fields = { sort_by(sort_field_const::text_match, "DESC"), sort_by("points", "DESC") }; + + auto collection = collectionManager.create_collection("collection", 4, search_fields, "points").get(); + + std::string json_line; + + // dummy record for record id 0: to make the test record IDs to match with line numbers + json_line = "{\"points\":10,\"title\":\"z\"}"; + collection->add(json_line); + + while (std::getline(infile, json_line)) { + collection->add(json_line); + } + + infile.close(); + + std::vector facets; + auto results = collection->search("what ex", query_fields, "points: >10", facets, sort_fields, {0}, 10, 1, MAX_SCORE, {true}, 10, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10).get(); + ASSERT_EQ(7, results["hits"].size()); + std::vector ids = {"6", "12", "19", "22", "13", "8", "15"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + collectionManager.drop_collection("collection"); } \ No newline at end of file From 9362c5a5e0f349f90239934d8a940199efbaab78 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 3 May 2023 18:42:31 +0530 Subject: [PATCH 45/93] Fix phrase search. --- include/art.h | 2 +- include/filter_result_iterator.h | 9 +++ include/index.h | 20 ++--- src/art.cpp | 16 ++-- src/filter_result_iterator.cpp | 45 +++++++++++ src/index.cpp | 129 ++++++++++++++++--------------- test/filter_test.cpp | 14 ++++ 7 files changed, 155 insertions(+), 80 deletions(-) diff --git a/include/art.h b/include/art.h index 11f57a68..92e043e3 100644 --- a/include/art.h +++ b/include/art.h @@ -279,7 +279,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, const size_t max_words, const token_ordering token_order, const bool prefix, bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::vector &results, std::set& exclude_leaves); void encode_int32(int32_t n, unsigned char *chars); diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index bc8c4c23..b3b12555 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -109,6 +109,8 @@ private: std::vector> posting_list_iterators; std::vector expanded_plists; + bool delete_filter_node = false; + /// Initializes the state of iterator node after it's creation. void init(); @@ -127,6 +129,8 @@ private: /// Finds the next match for a filter on string field. void get_string_filter_next_match(const bool& field_is_array); + explicit filter_result_iterator_t(uint32_t approx_filter_ids_length); + public: uint32_t seq_id = 0; /// Collection name -> references @@ -143,6 +147,8 @@ public: /// iterator reaching it's end. (is_valid would be false in both these cases) uint32_t approx_filter_ids_length; + explicit filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count); + explicit filter_result_iterator_t(const std::string collection_name, Index const* const index, filter_node_t const* const filter_node, uint32_t approx_filter_ids_length = UINT32_MAX); @@ -193,4 +199,7 @@ public: /// Performs AND with the contents of A and allocates a new array of results. /// \return size of the results array uint32_t and_scalar(const uint32_t* A, const uint32_t& lenA, uint32_t*& results); + + static void add_phrase_ids(filter_result_iterator_t*& filter_result_iterator, + uint32_t* phrase_result_ids, const uint32_t& phrase_result_count); }; diff --git a/include/index.h b/include/index.h index f2f8a4b8..45ce8172 100644 --- a/include/index.h +++ b/include/index.h @@ -408,7 +408,7 @@ private: void search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -723,7 +723,7 @@ public: const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, @@ -764,7 +764,7 @@ public: std::vector>& searched_queries, const size_t group_limit, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -795,7 +795,7 @@ public: spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -812,6 +812,7 @@ public: std::array*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + filter_result_iterator_t*& filter_result_iterator, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed, const std::set& curated_ids, @@ -819,8 +820,7 @@ public: const std::unordered_set& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, - bool is_wildcard_query, - uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const; + bool is_wildcard_query) const; void fuzzy_search_fields(const std::vector& the_fields, const std::vector& query_tokens, @@ -828,7 +828,7 @@ public: const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -857,7 +857,7 @@ public: const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, @@ -879,7 +879,7 @@ public: const std::vector& group_by_fields, bool prioritize_exact_match, const bool search_all_candidates, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, @@ -931,7 +931,7 @@ public: void process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, const bool filter_curated_hits, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, diff --git a/src/art.cpp b/src/art.cpp index ee0fb53a..65d3ca63 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -1003,7 +1003,7 @@ bool validate_and_add_leaf(art_leaf* leaf, const bool last_token, const std::str bool validate_and_add_leaf(art_leaf* leaf, const std::string& prev_token, const art_leaf* prev_leaf, const art_leaf* exact_leaf, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& exclude_leaves, std::vector& results) { if(leaf == exact_leaf) { @@ -1016,10 +1016,10 @@ bool validate_and_add_leaf(art_leaf* leaf, } if(prev_token.empty() || !prev_leaf) { - if (filter_result_iterator.is_valid && !filter_result_iterator.contains_atleast_one(leaf->values)) { + if (filter_result_iterator->is_valid && !filter_result_iterator->contains_atleast_one(leaf->values)) { return false; } - } else if (!filter_result_iterator.is_valid) { + } else if (!filter_result_iterator->is_valid) { std::vector prev_leaf_ids; posting_t::merge({prev_leaf->values}, prev_leaf_ids); @@ -1031,8 +1031,8 @@ bool validate_and_add_leaf(art_leaf* leaf, posting_t::merge({prev_leaf->values, leaf->values}, leaf_ids); bool found = false; - for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator.is_valid && !found; i++) { - found = (filter_result_iterator.valid(leaf_ids[i]) == 1); + for (uint32_t i = 0; i < leaf_ids.size() && filter_result_iterator->is_valid && !found; i++) { + found = (filter_result_iterator->valid(leaf_ids[i]) == 1); } if (!found) { @@ -1145,7 +1145,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_results, const art_leaf* exact_leaf, const bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const art_tree* t, std::set& exclude_leaves, std::vector& results) { printf("INSIDE art_topk_iter: root->type: %d\n", root->type); @@ -1177,7 +1177,7 @@ int art_topk_iter(const art_node *root, token_ordering token_order, size_t max_r validate_and_add_leaf(l, prev_token, prev_leaf, exact_leaf, filter_result_iterator, exclude_leaves, results); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if (++num_processed % 1024 == 0 && (microseconds( std::chrono::system_clock::now().time_since_epoch()).count() - search_begin_us) > search_stop_us) { @@ -1767,7 +1767,7 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, int art_fuzzy_search_i(art_tree *t, const unsigned char *term, const int term_len, const int min_cost, const int max_cost, const size_t max_words, const token_ordering token_order, const bool prefix, bool last_token, const std::string& prev_token, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::vector &results, std::set& exclude_leaves) { std::vector nodes; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f155d955..c4dcd0b5 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1252,6 +1252,10 @@ filter_result_iterator_t::~filter_result_iterator_t() { delete expanded_plist; } + if (delete_filter_node) { + delete filter_node; + } + delete left_it; delete right_it; } @@ -1343,3 +1347,44 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, next(); } } + +filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_length) : + approx_filter_ids_length(approx_filter_ids_length) { + filter_node = new filter_node_t(AND, nullptr, nullptr); + delete_filter_node = true; +} + +filter_result_iterator_t::filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count) { + filter_result.count = approx_filter_ids_length = ids_count; + filter_result.docs = ids; + is_valid = ids_count > 0; + + if (is_valid) { + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + filter_node = new filter_node_t({"dummy", {}, {}}); + delete_filter_node = true; + } +} + +void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_result_iterator, + uint32_t* phrase_result_ids, const uint32_t& phrase_result_count) { + auto root_iterator = new filter_result_iterator_t(std::min(phrase_result_count, filter_result_iterator->approx_filter_ids_length)); + root_iterator->left_it = new filter_result_iterator_t(phrase_result_ids, phrase_result_count); + root_iterator->right_it = filter_result_iterator; + + auto& left_it = root_iterator->left_it; + auto& right_it = root_iterator->right_it; + + while (left_it->is_valid && right_it->is_valid && left_it->seq_id != right_it->seq_id) { + if (left_it->seq_id < right_it->seq_id) { + left_it->skip_to(right_it->seq_id); + } else { + right_it->skip_to(left_it->seq_id); + } + } + + root_iterator->is_valid = left_it->is_valid && right_it->is_valid; + root_iterator->seq_id = left_it->seq_id; + filter_result_iterator = root_iterator; +} diff --git a/src/index.cpp b/src/index.cpp index d1ca8688..b117d4bd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1286,7 +1286,7 @@ void Index::aggregate_topster(Topster* agg_topster, Topster* index_topster) { void Index::search_all_candidates(const size_t num_search_fields, const text_match_type_t match_type, const std::vector& the_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, const std::vector& sort_fields, @@ -2270,14 +2270,16 @@ Option Index::search(std::vector& field_query_tokens, cons return rearrange_op; } - auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root, + auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root, approx_filter_ids_length); - auto filter_init_op = filter_result_iterator.init_status(); + std::unique_ptr filter_iterator_guard(filter_result_iterator); + + auto filter_init_op = filter_result_iterator->init_status(); if (!filter_init_op.ok()) { return filter_init_op; } - if (filter_tree_root != nullptr && !filter_result_iterator.is_valid) { + if (filter_tree_root != nullptr && !filter_result_iterator->is_valid) { return Option(true); } @@ -2291,7 +2293,7 @@ Option Index::search(std::vector& field_query_tokens, cons process_curated_ids(included_ids, excluded_ids, group_by_fields, group_limit, filter_curated_hits, filter_result_iterator, curated_ids, included_ids_map, included_ids_vec, excluded_group_ids); - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector curated_ids_sorted(curated_ids.begin(), curated_ids.end()); std::sort(curated_ids_sorted.begin(), curated_ids_sorted.end()); @@ -2322,24 +2324,19 @@ Option Index::search(std::vector& field_query_tokens, cons field_query_tokens[0].q_include_tokens[0].value == "*"; - // TODO: Do AND with phrase ids at last // handle phrase searches - uint32_t* phrase_result_ids = nullptr; - uint32_t phrase_result_count = 0; - std::unique_ptr phrase_result_ids_guard; - if (!field_query_tokens[0].q_phrases.empty()) { do_phrase_search(num_search_fields, the_fields, field_query_tokens, sort_fields_std, searched_queries, group_limit, group_by_fields, topster, sort_order, field_values, geopoint_indices, curated_ids_sorted, - all_result_ids, all_result_ids_len, groups_processed, curated_ids, + filter_result_iterator, all_result_ids, all_result_ids_len, groups_processed, curated_ids, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, curated_topster, - included_ids_map, is_wildcard_query, - phrase_result_ids, phrase_result_count); + included_ids_map, is_wildcard_query); - phrase_result_ids_guard.reset(phrase_result_ids); + filter_iterator_guard.release(); + filter_iterator_guard.reset(filter_result_iterator); - if (phrase_result_count == 0) { + if (filter_result_iterator->approx_filter_ids_length == 0) { goto process_search_results; } } @@ -2347,7 +2344,7 @@ Option Index::search(std::vector& field_query_tokens, cons // for phrase query, parser will set field_query_tokens to "*", need to handle that if (is_wildcard_query && field_query_tokens[0].q_phrases.empty()) { const uint8_t field_id = (uint8_t)(FIELD_LIMIT_NUM - 0); - bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator.is_valid); + bool no_filters_provided = (filter_tree_root == nullptr && !filter_result_iterator->is_valid); if(no_filters_provided && facets.empty() && curated_ids.empty() && vector_query.field_name.empty() && sort_fields_std.size() == 1 && sort_fields_std[0].name == sort_field_const::seq_id && @@ -2395,8 +2392,10 @@ Option Index::search(std::vector& field_query_tokens, cons Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, store, doc_id_prefix, filter_tree_root); - filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); - approx_filter_ids_length = filter_result_iterator.approx_filter_ids_length; + filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); + filter_iterator_guard.reset(filter_result_iterator); + + approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -2414,9 +2413,9 @@ Option Index::search(std::vector& field_query_tokens, cons uint32_t filter_id_count = 0; while (!no_filters_provided && - filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator.is_valid) { - auto seq_id = filter_result_iterator.seq_id; - filter_result_iterator.next(); + filter_id_count < vector_query.flat_search_cutoff && filter_result_iterator->is_valid) { + auto seq_id = filter_result_iterator->seq_id; + filter_result_iterator->next(); std::vector values; try { @@ -2440,12 +2439,13 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels.emplace_back(dist, seq_id); filter_id_count++; } + filter_result_iterator->reset(); if(no_filters_provided || - (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator.is_valid)) { + (filter_id_count >= vector_query.flat_search_cutoff && filter_result_iterator->is_valid)) { dist_labels.clear(); - VectorFilterFunctor filterFunctor(&filter_result_iterator); + VectorFilterFunctor filterFunctor(filter_result_iterator); if(field_vector_index->distance_type == cosine) { std::vector normalized_q(vector_query.values.size()); @@ -2455,8 +2455,7 @@ Option Index::search(std::vector& field_query_tokens, cons dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } } - - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector nearest_ids; @@ -2511,7 +2510,7 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, filter_result_iterator, approx_filter_ids_length, concurrency, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } // filter tree was initialized to have all sequence ids in this flow. @@ -2572,7 +2571,7 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -2609,7 +2608,7 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } } @@ -2625,7 +2624,7 @@ Option Index::search(std::vector& field_query_tokens, cons filter_result_iterator, query_hashes, sort_order, field_values, geopoint_indices, qtoken_set); - filter_result_iterator.reset(); + filter_result_iterator->reset(); // gather up both original query and synonym queries and do drop tokens @@ -2682,7 +2681,7 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); - filter_result_iterator.reset(); + filter_result_iterator->reset(); } else { break; @@ -2699,7 +2698,7 @@ Option Index::search(std::vector& field_query_tokens, cons sort_order, field_values, geopoint_indices, curated_ids_sorted, excluded_group_ids, all_result_ids, all_result_ids_len, groups_processed); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if(!vector_query.field_name.empty()) { // check at least one of sort fields is text match @@ -2716,7 +2715,7 @@ Option Index::search(std::vector& field_query_tokens, cons constexpr float TEXT_MATCH_WEIGHT = 0.7; constexpr float VECTOR_SEARCH_WEIGHT = 1.0 - TEXT_MATCH_WEIGHT; - VectorFilterFunctor filterFunctor(&filter_result_iterator); + VectorFilterFunctor filterFunctor(filter_result_iterator); auto& field_vector_index = vector_index.at(vector_query.field_name); std::vector> dist_labels; auto k = std::max(vector_query.k, fetch_size); @@ -2728,7 +2727,7 @@ Option Index::search(std::vector& field_query_tokens, cons } else { dist_labels = field_vector_index->vecdex->searchKnnCloserFirst(vector_query.values.data(), k, &filterFunctor); } - filter_result_iterator.reset(); + filter_result_iterator->reset(); std::vector> vec_results; for (const auto& dist_label : dist_labels) { @@ -2938,7 +2937,7 @@ Option Index::search(std::vector& field_query_tokens, cons void Index::process_curated_ids(const std::vector>& included_ids, const std::vector& excluded_ids, const std::vector& group_by_fields, const size_t group_limit, - const bool filter_curated_hits, filter_result_iterator_t& filter_result_iterator, + const bool filter_curated_hits, filter_result_iterator_t* const filter_result_iterator, std::set& curated_ids, std::map>& included_ids_map, std::vector& included_ids_vec, @@ -2961,9 +2960,9 @@ void Index::process_curated_ids(const std::vector> // if `filter_curated_hits` is enabled, we will remove curated hits that don't match filter condition std::set included_ids_set; - if(filter_result_iterator.is_valid && filter_curated_hits) { + if(filter_result_iterator->is_valid && filter_curated_hits) { for (const auto &included_id: included_ids_vec) { - auto result = filter_result_iterator.valid(included_id); + auto result = filter_result_iterator->valid(included_id); if (result == -1) { break; @@ -3030,7 +3029,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, const text_match_type_t match_type, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const std::vector& curated_ids, const std::unordered_set& excluded_group_ids, const std::vector & sort_fields, @@ -3176,7 +3175,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, last_token, prev_token, filter_result_iterator, field_leaves, unique_tokens); - filter_result_iterator.reset(); + filter_result_iterator->reset(); /*auto timeMillis = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - begin).count(); @@ -3207,7 +3206,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, token_candidates_vec.back().candidates[0], the_fields, num_search_fields, filter_result_iterator, exclude_token_ids, exclude_token_ids_size, prev_token_doc_ids, popular_field_ids); - filter_result_iterator.reset(); + filter_result_iterator->reset(); for(size_t field_id: query_field_ids) { auto& the_field = the_fields[field_id]; @@ -3230,7 +3229,7 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, art_fuzzy_search_i(search_index.at(the_field.name), (const unsigned char *) token.c_str(), token_len, costs[token_index], costs[token_index], max_candidates, token_order, prefix_search, false, "", filter_result_iterator, field_leaves, unique_tokens); - filter_result_iterator.reset(); + filter_result_iterator->reset(); if(field_leaves.empty()) { // look at the next field @@ -3294,7 +3293,6 @@ void Index::fuzzy_search_fields(const std::vector& the_fields, exhaustive_search, max_candidates, syn_orig_num_tokens, sort_order, field_values, geopoint_indices, query_hashes, id_buff); - filter_result_iterator.reset(); if(id_buff.size() > 1) { gfx::timsort(id_buff.begin(), id_buff.end()); @@ -3355,7 +3353,7 @@ void Index::find_across_fields(const token_t& previous_token, const std::string& previous_token_str, const std::vector& the_fields, const size_t num_search_fields, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, std::vector& prev_token_doc_ids, std::vector& top_prefix_field_ids) const { @@ -3366,7 +3364,7 @@ void Index::find_across_fields(const token_t& previous_token, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator); const bool prefix_search = previous_token.is_prefix_searched; const uint32_t token_num_typos = previous_token.num_typos; @@ -3447,7 +3445,7 @@ void Index::search_across_fields(const std::vector& query_tokens, const std::vector& group_by_fields, const bool prioritize_exact_match, const bool prioritize_token_position, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const uint32_t total_cost, const int syn_orig_num_tokens, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, @@ -3508,7 +3506,7 @@ void Index::search_across_fields(const std::vector& query_tokens, // used to track plists that must be destructed once done std::vector expanded_plists; - result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, &filter_result_iterator); + result_iter_state_t istate(exclude_token_ids, exclude_token_ids_size, filter_result_iterator); // for each token, find the posting lists across all query_by fields for(size_t ti = 0; ti < query_tokens.size(); ti++) { @@ -3970,6 +3968,7 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector*, 3> field_values, const std::vector& geopoint_indices, const std::vector& curated_ids_sorted, + filter_result_iterator_t*& filter_result_iterator, uint32_t*& all_result_ids, size_t& all_result_ids_len, spp::sparse_hash_map& groups_processed, const std::set& curated_ids, @@ -3977,9 +3976,10 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector& excluded_group_ids, Topster* curated_topster, const std::map>& included_ids_map, - bool is_wildcard_query, - uint32_t*& phrase_result_ids, uint32_t& phrase_result_count) const { + bool is_wildcard_query) const { + uint32_t* phrase_result_ids = nullptr; + uint32_t phrase_result_count = 0; std::map phrase_match_id_scores; for(size_t i = 0; i < num_search_fields; i++) { @@ -4068,12 +4068,19 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vectoris_valid) { + filter_result_iterator_t::add_phrase_ids(filter_result_iterator, phrase_result_ids, phrase_result_count); + } else { + delete filter_result_iterator; + filter_result_iterator = new filter_result_iterator_t(phrase_result_ids, phrase_result_count); + } + size_t filter_index = 0; if(is_wildcard_query) { - all_result_ids = new uint32_t[phrase_result_count]; - std::copy(phrase_result_ids, phrase_result_ids + phrase_result_count, all_result_ids); - all_result_ids_len = phrase_result_count; + all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids); + filter_result_iterator->reset(); } else { // this means that the there are non-phrase tokens in the query // so we cannot directly copy to the all_result_ids array @@ -4081,8 +4088,8 @@ void Index::do_phrase_search(const size_t num_search_fields, const std::vector(10000, phrase_result_count); i++) { - auto seq_id = phrase_result_ids[i]; + for(size_t i = 0; i < std::min(10000, all_result_ids_len); i++) { + auto seq_id = all_result_ids[i]; int64_t match_score = phrase_match_id_scores[seq_id]; int64_t scores[3] = {0}; @@ -4135,7 +4142,7 @@ void Index::do_synonym_search(const std::vector& the_fields, spp::sparse_hash_map& groups_processed, std::vector>& searched_queries, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, std::set& query_hashes, const int* sort_order, std::array*, 3>& field_values, @@ -4163,7 +4170,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& group_by_fields, const size_t max_extra_prefix, const size_t max_extra_suffix, const std::vector& query_tokens, Topster* actual_topster, - filter_result_iterator_t& filter_result_iterator, + filter_result_iterator_t* const filter_result_iterator, const int sort_order[3], std::array*, 3> field_values, const std::vector& geopoint_indices, @@ -4197,10 +4204,10 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vectoris_valid) { uint32_t *filtered_raw_infix_ids = nullptr; - raw_infix_ids_length = filter_result_iterator.and_scalar(raw_infix_ids, raw_infix_ids_length, + raw_infix_ids_length = filter_result_iterator->and_scalar(raw_infix_ids, raw_infix_ids_length, filtered_raw_infix_ids); if(raw_infix_ids != &infix_ids[0]) { delete [] raw_infix_ids; @@ -4495,7 +4502,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t& filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, @@ -4525,11 +4532,11 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, auto parent_search_cutoff = search_cutoff; uint32_t excluded_result_index = 0; - for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator.is_valid; thread_id++) { + for(size_t thread_id = 0; thread_id < num_threads && filter_result_iterator->is_valid; thread_id++) { std::vector batch_result_ids; batch_result_ids.reserve(window_size); - filter_result_iterator.get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, + filter_result_iterator->get_n_ids(window_size, excluded_result_index, exclude_token_ids, exclude_token_ids_size, batch_result_ids); num_queued++; @@ -4611,8 +4618,8 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::chrono::high_resolution_clock::now() - beginF).count(); LOG(INFO) << "Time for raw scoring: " << timeMillisF;*/ - filter_result_iterator.reset(); - all_result_ids_len = filter_result_iterator.to_filter_id_array(all_result_ids); + filter_result_iterator->reset(); + all_result_ids_len = filter_result_iterator->to_filter_id_array(all_result_ids); } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, diff --git a/test/filter_test.cpp b/test/filter_test.cpp index d3aa3f31..ac6efdb4 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -482,5 +482,19 @@ TEST_F(FilterTest, FilterTreeIterator) { ASSERT_EQ(6, iter_skip_test4.seq_id); ASSERT_TRUE(iter_skip_test4.is_valid); + auto iter_add_phrase_ids_test = new filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + std::unique_ptr filter_iter_guard(iter_add_phrase_ids_test); + ASSERT_TRUE(iter_add_phrase_ids_test->init_status().ok()); + + auto phrase_ids = new uint32_t[4]; + for (uint32_t i = 0; i < 4; i++) { + phrase_ids[i] = i * 2; + } + filter_result_iterator_t::add_phrase_ids(iter_add_phrase_ids_test, phrase_ids, 4); + filter_iter_guard.reset(iter_add_phrase_ids_test); + + ASSERT_TRUE(iter_add_phrase_ids_test->is_valid); + ASSERT_EQ(6, iter_add_phrase_ids_test->seq_id); + delete filter_tree_root; } \ No newline at end of file From 1b5a47181d34d870243071fbc8c9e2c6b02af655 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 4 May 2023 16:32:52 +0530 Subject: [PATCH 46/93] Fix alloc-dealloc-mismatch. --- src/index.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index.cpp b/src/index.cpp index b117d4bd..ddfd75f8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1739,7 +1739,7 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter uint32_t* reference_docs = nullptr; uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs); - std::unique_ptr docs_guard(reference_docs); + std::unique_ptr docs_guard(reference_docs); // doc id -> reference doc ids std::map> reference_map; From d32b73270121a3dfacaf4124efa15ef83a72600e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 5 May 2023 12:34:17 +0530 Subject: [PATCH 47/93] Fix ASAN issues. --- src/collection.cpp | 2 +- test/filter_test.cpp | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/collection.cpp b/src/collection.cpp index f62a1753..4dfdc00b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -421,7 +421,7 @@ Option Collection::update_matching_filter(const std::string& fil } const auto& dirty_values = parse_dirty_values_option(req_dirty_values); - size_t docs_updated_count; + size_t docs_updated_count = 0; nlohmann::json update_document, dummy; try { diff --git a/test/filter_test.cpp b/test/filter_test.cpp index ac6efdb4..afce0209 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -423,7 +423,7 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_to_array_test.is_valid); - delete filter_ids; + delete[] filter_ids; auto iter_and_scalar_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_and_scalar_test.init_status().ok()); @@ -440,7 +440,7 @@ TEST_F(FilterTest, FilterTreeIterator) { } ASSERT_FALSE(iter_and_scalar_test.is_valid); - delete and_result; + delete[] and_result; delete filter_tree_root; doc = R"({ @@ -491,6 +491,7 @@ TEST_F(FilterTest, FilterTreeIterator) { phrase_ids[i] = i * 2; } filter_result_iterator_t::add_phrase_ids(iter_add_phrase_ids_test, phrase_ids, 4); + filter_iter_guard.release(); filter_iter_guard.reset(iter_add_phrase_ids_test); ASSERT_TRUE(iter_add_phrase_ids_test->is_valid); From 1520be463ba2de848a65054623d83aeca0c382be Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 10:28:36 +0530 Subject: [PATCH 48/93] Fix failing join tests. --- src/index.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index ddfd75f8..207c44bb 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1362,6 +1362,7 @@ void Index::search_all_candidates(const size_t num_search_fields, id_buff, all_result_ids, all_result_ids_len); query_hashes.insert(qhash); + filter_result_iterator->reset(); } } @@ -2571,7 +2572,6 @@ Option Index::search(std::vector& field_query_tokens, cons typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator->reset(); // try split/joining tokens if no results are found if(split_join_tokens == always || (all_result_ids_len == 0 && split_join_tokens == fallback)) { @@ -2608,7 +2608,6 @@ Option Index::search(std::vector& field_query_tokens, cons all_result_ids, all_result_ids_len, group_limit, group_by_fields, prioritize_exact_match, prioritize_token_position, query_hashes, token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, syn_orig_num_tokens, sort_order, field_values, geopoint_indices); - filter_result_iterator->reset(); } } @@ -2681,7 +2680,6 @@ Option Index::search(std::vector& field_query_tokens, cons token_order, prefixes, typo_tokens_threshold, exhaustive_search, max_candidates, min_len_1typo, min_len_2typo, -1, sort_order, field_values, geopoint_indices); - filter_result_iterator->reset(); } else { break; From fdd643b563d0b10cec9d345ea9128bfad733d199 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 12:13:16 +0530 Subject: [PATCH 49/93] Remove `SEQ_IDS_FILTER` logic. --- include/index.h | 4 ---- src/filter.cpp | 3 --- src/filter_result_iterator.cpp | 12 ------------ src/index.cpp | 13 +------------ 4 files changed, 1 insertion(+), 31 deletions(-) diff --git a/include/index.h b/include/index.h index 45ce8172..3a511ba0 100644 --- a/include/index.h +++ b/include/index.h @@ -540,10 +540,6 @@ public: // in the query that have the least individual hits one by one until enough results are found. static const int DROP_TOKENS_THRESHOLD = 1; - // "_all_" is a special field that maps to all the ids in the index. - static constexpr const char* SEQ_IDS_FIELD = "_all_"; - static constexpr const char* SEQ_IDS_FILTER = "_all_: 1"; - Index() = delete; Index(const std::string& name, diff --git a/src/filter.cpp b/src/filter.cpp index 18348ed6..95fbfefc 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -283,9 +283,6 @@ Option toFilter(const std::string expression, } } return Option(true); - } else if (field_name == Index::SEQ_IDS_FIELD) { - filter_exp = {field_name, {}, {}}; - return Option(true); } auto field_it = search_schema.find(field_name); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index c4dcd0b5..88485aa4 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -588,18 +588,6 @@ void filter_result_iterator_t::init() { filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); - seq_id = filter_result.docs[result_index]; - is_filter_result_initialized = true; - return; - } else if (a_filter.field_name == Index::SEQ_IDS_FIELD) { - if (index->seq_ids->num_ids() == 0) { - is_valid = false; - return; - } - - approx_filter_ids_length = filter_result.count = index->seq_ids->num_ids(); - filter_result.docs = index->seq_ids->uncompress(); - seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; return; diff --git a/src/index.cpp b/src/index.cpp index 207c44bb..e8634a3b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2389,11 +2389,7 @@ Option Index::search(std::vector& field_query_tokens, cons // if filters were not provided, use the seq_ids index to generate the list of all document ids if (no_filters_provided) { - const std::string doc_id_prefix = std::to_string(collection_id) + "_" + Collection::DOC_ID_PREFIX + "_"; - Option parse_filter_op = filter::parse_filter_query(SEQ_IDS_FILTER, search_schema, - store, doc_id_prefix, filter_tree_root); - - filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root); + filter_result_iterator = new filter_result_iterator_t(seq_ids->uncompress(), seq_ids->num_ids()); filter_iterator_guard.reset(filter_result_iterator); approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; @@ -2514,13 +2510,6 @@ Option Index::search(std::vector& field_query_tokens, cons filter_result_iterator->reset(); } - // filter tree was initialized to have all sequence ids in this flow. - if (no_filters_provided) { - delete filter_tree_root; - filter_tree_root = nullptr; - approx_filter_ids_length = 0; - } - uint32_t _all_result_ids_len = all_result_ids_len; curate_filtered_ids(curated_ids, excluded_result_ids, excluded_result_ids_size, all_result_ids, _all_result_ids_len, curated_ids_sorted); From 5bf8746d0cb07fa655012dbea549a14ab771df3a Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 13:38:18 +0530 Subject: [PATCH 50/93] Fix `HybridSearchRankFusionTest`. --- include/index.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/index.h b/include/index.h index 3a511ba0..53c4c561 100644 --- a/include/index.h +++ b/include/index.h @@ -240,6 +240,10 @@ public: filter_result_iterator(filter_result_iterator) {} bool operator()(hnswlib::labeltype id) override { + if (filter_result_iterator->approx_filter_ids_length == 0) { + return true; + } + filter_result_iterator->reset(); return filter_result_iterator->valid(id) == 1; } From 928150008a7ab04b2c553c434d074ddcebea5057 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 8 May 2023 19:12:58 +0530 Subject: [PATCH 51/93] Try merging id lists using priority queue --- include/num_tree.h | 5 +++++ src/filter_result_iterator.cpp | 34 +++++++++++++++++++++++-------- src/num_tree.cpp | 37 ++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/include/num_tree.h b/include/num_tree.h index 444f6266..38c57c8a 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -65,4 +65,9 @@ public: uint32_t* const& context_ids, size_t& result_ids_len, uint32_t*& result_ids) const; + + void merge_id_list_iterators(std::vector& id_list_iterators, + const NUM_COMPARATOR &comparator, + uint32_t*& result_ids, + uint32_t& result_ids_len) const; }; \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 88485aa4..482c7475 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -606,22 +606,40 @@ void filter_result_iterator_t::init() { for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); + std::vector id_list_iterators; + std::vector expanded_id_lists; - size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); + num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); } else { - num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], + value, id_list_iterators, expanded_id_lists); } - filter_result.count = result_size; + uint32_t* filter_match_ids = nullptr; + uint32_t filter_ids_length; + num_tree->merge_id_list_iterators(id_list_iterators, a_filter.comparators[fi], + filter_match_ids, filter_ids_length); + + if (a_filter.comparators[fi] == NOT_EQUALS) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_match_ids, filter_ids_length); + } + + uint32_t *out = nullptr; + filter_result.count = ArrayUtils::or_scalar(filter_match_ids, filter_ids_length, + filter_result.docs, filter_result.count, &out); + + delete [] filter_match_ids; + delete [] filter_result.docs; + filter_result.docs = out; + + for(id_list_t* expanded_id_list: expanded_id_lists) { + delete expanded_id_list; + } } if (a_filter.apply_not_equals) { diff --git a/src/num_tree.cpp b/src/num_tree.cpp index 89c5e3a0..d254ca43 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -429,3 +429,40 @@ num_tree_t::~num_tree_t() { ids_t::destroy_list(kv.second); } } + +void num_tree_t::merge_id_list_iterators(std::vector& id_list_iterators, + const NUM_COMPARATOR &comparator, + uint32_t*& result_ids, + uint32_t& result_ids_len) const { + struct comp { + bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { + return lhs->id() > rhs->id(); + } + }; + + std::priority_queue, comp> iter_queue; + for (auto& id_list_iterator: id_list_iterators) { + if (id_list_iterator.valid()) { + iter_queue.push(&id_list_iterator); + } + } + + std::vector consolidated_ids; + while (!iter_queue.empty()) { + id_list_t::iterator_t* iter = iter_queue.top(); + iter_queue.pop(); + + consolidated_ids.push_back(iter->id()); + iter->next(); + + if (iter->valid()) { + iter_queue.push(iter); + } + } + + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + result_ids_len = consolidated_ids.size(); + result_ids = new uint32_t[consolidated_ids.size()]; + std::copy(consolidated_ids.begin(), consolidated_ids.end(), result_ids); +} From 548aee4f992b4bcfa9723c9a6e93c3ec434c5858 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 9 May 2023 15:59:36 +0530 Subject: [PATCH 52/93] Optimize string filtering. --- src/filter_result_iterator.cpp | 105 +++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 30 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 482c7475..e58c9c17 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -299,48 +299,93 @@ void filter_result_iterator_t::get_string_filter_next_match(const bool& field_is uint32_t lowest_id = UINT32_MAX; if (filter_node->filter_exp.comparators[0] == EQUALS || filter_node->filter_exp.comparators[0] == NOT_EQUALS) { - for (auto& filter_value_tokens : posting_list_iterators) { - bool tokens_iter_is_valid, exact_match = false; - while(true) { - // Perform AND between tokens of a filter value. - posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + bool exact_match_found = false; + switch (posting_list_iterators.size()) { + case 1: + while(true) { + // Perform AND between tokens of a filter value. + posting_list_t::intersect(posting_list_iterators[0], one_is_valid); - if (!tokens_iter_is_valid) { - break; + if (!one_is_valid) { + break; + } + + if (posting_list_t::has_exact_match(posting_list_iterators[0], field_is_array)) { + exact_match_found = true; + break; + } else { + // Keep advancing token iterators till exact match is not found. + for (auto& iter: posting_list_iterators[0]) { + if (!iter.valid()) { + break; + } + + iter.next(); + } + } } - if (posting_list_t::has_exact_match(filter_value_tokens, field_is_array)) { - exact_match = true; - break; - } else { - // Keep advancing token iterators till exact match is not found. - for (auto &iter: filter_value_tokens) { - if (!iter.valid()) { + if (one_is_valid && exact_match_found) { + lowest_id = posting_list_iterators[0][0].id(); + } + break; + + default : + for (auto& filter_value_tokens : posting_list_iterators) { + bool tokens_iter_is_valid; + while(true) { + // Perform AND between tokens of a filter value. + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + if (!tokens_iter_is_valid) { break; } - iter.next(); + if (posting_list_t::has_exact_match(filter_value_tokens, field_is_array)) { + exact_match_found = true; + break; + } else { + // Keep advancing token iterators till exact match is not found. + for (auto &iter: filter_value_tokens) { + if (!iter.valid()) { + break; + } + + iter.next(); + } + } + } + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && exact_match_found && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); } } - } - - one_is_valid = tokens_iter_is_valid || one_is_valid; - - if (tokens_iter_is_valid && exact_match && filter_value_tokens[0].id() < lowest_id) { - lowest_id = filter_value_tokens[0].id(); - } } } else { - for (auto& filter_value_tokens : posting_list_iterators) { - // Perform AND between tokens of a filter value. - bool tokens_iter_is_valid; - posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + switch (posting_list_iterators.size()) { + case 1: + // Perform AND between tokens of a filter value. + posting_list_t::intersect(posting_list_iterators[0], one_is_valid); - one_is_valid = tokens_iter_is_valid || one_is_valid; + if (one_is_valid) { + lowest_id = posting_list_iterators[0][0].id(); + } + break; - if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { - lowest_id = filter_value_tokens[0].id(); - } + default: + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } } } From ef77b58f2b6215de20d8caaa8a9ac8f8bac52179 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 16 May 2023 18:16:43 +0530 Subject: [PATCH 53/93] Undo id list merge using priority queue. --- include/num_tree.h | 15 ----- src/filter_result_iterator.cpp | 36 +++-------- src/num_tree.cpp | 115 --------------------------------- 3 files changed, 9 insertions(+), 157 deletions(-) diff --git a/include/num_tree.h b/include/num_tree.h index 38c57c8a..2170a30e 100644 --- a/include/num_tree.h +++ b/include/num_tree.h @@ -30,11 +30,6 @@ public: void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len); - void range_inclusive_search_iterators(int64_t start, - int64_t end, - std::vector& id_list_iterators, - std::vector& expanded_id_lists); - void approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len); void range_inclusive_contains(const int64_t& start, const int64_t& end, @@ -47,11 +42,6 @@ public: void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len); - void search_iterators(NUM_COMPARATOR comparator, - int64_t value, - std::vector& id_list_iterators, - std::vector& expanded_id_lists); - void approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len); void remove(uint64_t value, uint32_t id); @@ -65,9 +55,4 @@ public: uint32_t* const& context_ids, size_t& result_ids_len, uint32_t*& result_ids) const; - - void merge_id_list_iterators(std::vector& id_list_iterators, - const NUM_COMPARATOR &comparator, - uint32_t*& result_ids, - uint32_t& result_ids_len) const; }; \ No newline at end of file diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index e58c9c17..99231679 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -651,40 +651,22 @@ void filter_result_iterator_t::init() { for (size_t fi = 0; fi < a_filter.values.size(); fi++) { const std::string& filter_value = a_filter.values[fi]; int64_t value = (int64_t)std::stol(filter_value); - std::vector id_list_iterators; - std::vector expanded_id_lists; + size_t result_size = filter_result.count; if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { const std::string& next_filter_value = a_filter.values[fi + 1]; auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search_iterators(value, range_end_value, id_list_iterators, expanded_id_lists); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); } else { - num_tree->search_iterators(a_filter.comparators[fi] == NOT_EQUALS ? EQUALS : a_filter.comparators[fi], - value, id_list_iterators, expanded_id_lists); + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); } - uint32_t* filter_match_ids = nullptr; - uint32_t filter_ids_length; - num_tree->merge_id_list_iterators(id_list_iterators, a_filter.comparators[fi], - filter_match_ids, filter_ids_length); - - if (a_filter.comparators[fi] == NOT_EQUALS) { - apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_match_ids, filter_ids_length); - } - - uint32_t *out = nullptr; - filter_result.count = ArrayUtils::or_scalar(filter_match_ids, filter_ids_length, - filter_result.docs, filter_result.count, &out); - - delete [] filter_match_ids; - delete [] filter_result.docs; - filter_result.docs = out; - - for(id_list_t* expanded_id_list: expanded_id_lists) { - delete expanded_id_list; - } + filter_result.count = result_size; } if (a_filter.apply_not_equals) { @@ -1400,7 +1382,7 @@ void filter_result_iterator_t::get_n_ids(const uint32_t& n, } filter_result_iterator_t::filter_result_iterator_t(uint32_t approx_filter_ids_length) : - approx_filter_ids_length(approx_filter_ids_length) { + approx_filter_ids_length(approx_filter_ids_length) { filter_node = new filter_node_t(AND, nullptr, nullptr); delete_filter_node = true; } diff --git a/src/num_tree.cpp b/src/num_tree.cpp index d254ca43..c59cb008 100644 --- a/src/num_tree.cpp +++ b/src/num_tree.cpp @@ -43,30 +43,6 @@ void num_tree_t::range_inclusive_search(int64_t start, int64_t end, uint32_t** i *ids = out; } -void num_tree_t::range_inclusive_search_iterators(int64_t start, - int64_t end, - std::vector& id_list_iterators, - std::vector& expanded_id_lists) { - if (int64map.empty()) { - return; - } - - auto it_start = int64map.lower_bound(start); // iter values will be >= start - - std::vector raw_id_lists; - while (it_start != int64map.end() && it_start->first <= end) { - raw_id_lists.push_back(it_start->second); - it_start++; - } - - std::vector id_lists; - ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); - - for (const auto &id_list: id_lists) { - id_list_iterators.emplace_back(id_list->new_iterator()); - } -} - void num_tree_t::approx_range_inclusive_search_count(int64_t start, int64_t end, uint32_t& ids_len) { if (int64map.empty()) { return; @@ -211,60 +187,6 @@ void num_tree_t::search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids } } -void num_tree_t::search_iterators(NUM_COMPARATOR comparator, - int64_t value, - std::vector& id_list_iterators, - std::vector& expanded_id_lists) { - if (int64map.empty()) { - return ; - } - - std::vector raw_id_lists; - if (comparator == EQUALS) { - const auto& it = int64map.find(value); - if (it != int64map.end()) { - raw_id_lists.emplace_back(it->second); - } - } else if (comparator == GREATER_THAN || comparator == GREATER_THAN_EQUALS) { - // iter entries will be >= value, or end() if all entries are before value - auto iter_ge_value = int64map.lower_bound(value); - - if(iter_ge_value == int64map.end()) { - return ; - } - - if(comparator == GREATER_THAN && iter_ge_value->first == value) { - iter_ge_value++; - } - - while(iter_ge_value != int64map.end()) { - raw_id_lists.emplace_back(iter_ge_value->second); - iter_ge_value++; - } - } else if(comparator == LESS_THAN || comparator == LESS_THAN_EQUALS) { - // iter entries will be >= value, or end() if all entries are before value - auto iter_ge_value = int64map.lower_bound(value); - - auto it = int64map.begin(); - while(it != iter_ge_value) { - raw_id_lists.emplace_back(it->second); - it++; - } - - // for LESS_THAN_EQUALS, check if last iter entry is equal to value - if(it != int64map.end() && comparator == LESS_THAN_EQUALS && it->first == value) { - raw_id_lists.emplace_back(it->second); - } - } - - std::vector id_lists; - ids_t::to_expanded_id_lists(raw_id_lists, id_lists, expanded_id_lists); - - for (const auto &id_list: id_lists) { - id_list_iterators.emplace_back(id_list->new_iterator()); - } -} - void num_tree_t::approx_search_count(NUM_COMPARATOR comparator, int64_t value, uint32_t& ids_len) { if (int64map.empty()) { return; @@ -429,40 +351,3 @@ num_tree_t::~num_tree_t() { ids_t::destroy_list(kv.second); } } - -void num_tree_t::merge_id_list_iterators(std::vector& id_list_iterators, - const NUM_COMPARATOR &comparator, - uint32_t*& result_ids, - uint32_t& result_ids_len) const { - struct comp { - bool operator()(const id_list_t::iterator_t *lhs, const id_list_t::iterator_t *rhs) const { - return lhs->id() > rhs->id(); - } - }; - - std::priority_queue, comp> iter_queue; - for (auto& id_list_iterator: id_list_iterators) { - if (id_list_iterator.valid()) { - iter_queue.push(&id_list_iterator); - } - } - - std::vector consolidated_ids; - while (!iter_queue.empty()) { - id_list_t::iterator_t* iter = iter_queue.top(); - iter_queue.pop(); - - consolidated_ids.push_back(iter->id()); - iter->next(); - - if (iter->valid()) { - iter_queue.push(iter); - } - } - - consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); - - result_ids_len = consolidated_ids.size(); - result_ids = new uint32_t[consolidated_ids.size()]; - std::copy(consolidated_ids.begin(), consolidated_ids.end(), result_ids); -} From be9975ac1d24f95da2564e65513247ece13771d7 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 16 May 2023 18:18:29 +0530 Subject: [PATCH 54/93] Optimize `filter_result_iterator_t::and_filter_iterators`. --- src/filter_result_iterator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 99231679..a55a01de 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -180,7 +180,7 @@ void filter_result_t::or_filter_results(const filter_result_t& a, const filter_r void filter_result_iterator_t::and_filter_iterators() { while (left_it->is_valid && right_it->is_valid) { while (left_it->seq_id < right_it->seq_id) { - left_it->next(); + left_it->skip_to(right_it->seq_id); if (!left_it->is_valid) { is_valid = false; return; @@ -188,7 +188,7 @@ void filter_result_iterator_t::and_filter_iterators() { } while (left_it->seq_id > right_it->seq_id) { - right_it->next(); + right_it->skip_to(left_it->seq_id); if (!right_it->is_valid) { is_valid = false; return; From 90c2b4f7e9014cbd70c58220b44f4a8f425615f8 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 17 May 2023 14:44:27 +0530 Subject: [PATCH 55/93] Fix `filter_result_iterator_t::valid` not updating `seq_id` properly for complex filter expressions. --- src/filter_result_iterator.cpp | 23 +++++++++++++++-------- src/or_iterator.cpp | 6 +----- test/filter_test.cpp | 2 +- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index a55a01de..f0145b72 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1047,14 +1047,6 @@ int filter_result_iterator_t::valid(uint32_t id) { if (filter_node->isOperator) { auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); - if (left_it->is_valid && right_it->is_valid) { - seq_id = std::min(left_it->seq_id, right_it->seq_id); - } else if (left_it->is_valid) { - seq_id = left_it->seq_id; - } else if (right_it->is_valid) { - seq_id = right_it->seq_id; - } - if (filter_node->filter_operator == AND) { is_valid = left_it->is_valid && right_it->is_valid; @@ -1063,9 +1055,20 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } + // id did not match the filter but both of the sub-iterators are still valid. + // Updating seq_id to the next potential match. + if (left_valid == 0 && right_valid == 0) { + seq_id = std::max(left_it->seq_id, right_it->seq_id); + } else if (left_valid == 0) { + seq_id = left_it->seq_id; + } else if (right_valid == 0) { + seq_id = right_it->seq_id; + } + return 0; } + seq_id = id; return 1; } else { is_valid = left_it->is_valid || right_it->is_valid; @@ -1075,9 +1078,13 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } + // id did not match the filter but both of the sub-iterators are still valid. + // Next seq_id match would be the minimum of the two. + seq_id = std::min(left_it->seq_id, right_it->seq_id); return 0; } + seq_id = id; return 1; } } diff --git a/src/or_iterator.cpp b/src/or_iterator.cpp index 6384cba2..5a88d68e 100644 --- a/src/or_iterator.cpp +++ b/src/or_iterator.cpp @@ -209,11 +209,7 @@ bool or_iterator_t::take_id(result_iter_state_t& istate, uint32_t id, bool& is_e } if (istate.fit != nullptr && istate.fit->approx_filter_ids_length > 0) { - if (istate.fit->valid(id) == -1) { - return false; - } - - if (istate.fit->seq_id == id) { + if (istate.fit->valid(id) == 1) { istate.fit->next(); return true; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index afce0209..aa10b97d 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -284,7 +284,7 @@ TEST_F(FilterTest, FilterTreeIterator) { auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); - std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 3, 4, 5, 5}; + std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 4, 4, 5, 5}; expected = {1, 0, 1, 0, 1, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); From 2d7fc818d55de02167cca3c5dffaadb27cae0836 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 18 May 2023 07:50:37 +0530 Subject: [PATCH 56/93] Add tests for `filter_result_iterator_t::valid`. --- src/filter_result_iterator.cpp | 15 +++++++++---- test/filter_test.cpp | 40 ++++++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index f0145b72..707f2f69 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1061,7 +1061,7 @@ int filter_result_iterator_t::valid(uint32_t id) { seq_id = std::max(left_it->seq_id, right_it->seq_id); } else if (left_valid == 0) { seq_id = left_it->seq_id; - } else if (right_valid == 0) { + } else { seq_id = right_it->seq_id; } @@ -1078,9 +1078,16 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } - // id did not match the filter but both of the sub-iterators are still valid. - // Next seq_id match would be the minimum of the two. - seq_id = std::min(left_it->seq_id, right_it->seq_id); + // id did not match the filter; both of the sub-iterators or one of them might be valid. + // Updating seq_id to the next match. + if (left_valid == 0 && right_valid == 0) { + seq_id = std::min(left_it->seq_id, right_it->seq_id); + } else if (left_valid == 0) { + seq_id = left_it->seq_id; + } else { + seq_id = right_it->seq_id; + } + return 0; } diff --git a/test/filter_test.cpp b/test/filter_test.cpp index aa10b97d..e1a6aaf7 100644 --- a/test/filter_test.cpp +++ b/test/filter_test.cpp @@ -281,14 +281,46 @@ TEST_F(FilterTest, FilterTreeIterator) { filter_tree_root); ASSERT_TRUE(filter_op.ok()); - auto iter_validate_ids_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); - ASSERT_TRUE(iter_validate_ids_test.init_status().ok()); + auto iter_validate_ids_test1 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test1.init_status().ok()); std::vector validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 2, 2, 4, 4, 5, 5}; expected = {1, 0, 1, 0, 1, 1, -1}; for (uint32_t i = 0; i < validate_ids.size(); i++) { - ASSERT_EQ(expected[i], iter_validate_ids_test.valid(validate_ids[i])); - ASSERT_EQ(seq_ids[i], iter_validate_ids_test.seq_id); + ASSERT_EQ(expected[i], iter_validate_ids_test1.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test1.seq_id); + } + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: platinum || name: James", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_test2 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test2.init_status().ok()); + + validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {1, 1, 5, 5, 5, 5, 5}; + expected = {0, 1, 0, 0, 0, 1, -1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_test2.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test2.seq_id); + } + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold && rating: < 6", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_validate_ids_test3 = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root); + ASSERT_TRUE(iter_validate_ids_test3.init_status().ok()); + + validate_ids = {0, 1, 2, 3, 4, 5, 6}, seq_ids = {0, 3, 3, 4, 4, 4, 4}; + expected = {1, 0, 0, 0, 1, -1, -1}; + for (uint32_t i = 0; i < validate_ids.size(); i++) { + ASSERT_EQ(expected[i], iter_validate_ids_test3.valid(validate_ids[i])); + ASSERT_EQ(seq_ids[i], iter_validate_ids_test3.seq_id); } delete filter_tree_root; From 87159b556fd9628ea765608e49b261b32178d77c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 19 May 2023 18:52:20 +0530 Subject: [PATCH 57/93] Add `filter::parse_geopoint_filter_value`. --- include/filter.h | 14 ++++ src/filter.cpp | 151 ++++++++++++++++++++++++++++++++++++++++++- src/string_utils.cpp | 38 +++++++++++ 3 files changed, 202 insertions(+), 1 deletion(-) diff --git a/include/filter.h b/include/filter.h index 1961b0ac..0efd5900 100644 --- a/include/filter.h +++ b/include/filter.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "store.h" enum FILTER_OPERATOR { @@ -26,6 +27,15 @@ struct filter { // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` std::string referenced_collection_name = ""; + std::vector params; + + /// For searching places within a given radius of a given latlong (mi for miles and km for kilometers) + static constexpr const char* GEO_FILTER_RADIUS = "radius"; + + /// Radius threshold beyond which exact filtering on geo_result_ids will not be done. + static constexpr const char* EXACT_GEO_FILTER_RADIUS = "exact_filter_radius"; + static constexpr const char* DEFAULT_EXACT_GEO_FILTER_RADIUS = "10km"; + static const std::string RANGE_OPERATOR() { return ".."; } @@ -39,6 +49,10 @@ struct filter { std::string& processed_filter_val, NUM_COMPARATOR& num_comparator); + static Option parse_geopoint_filter_value(std::string& raw_value, + const std::string& format_err_msg, + filter& filter_exp); + static Option parse_filter_query(const std::string& filter_query, const tsl::htrie_map& search_schema, const Store* store, diff --git a/src/filter.cpp b/src/filter.cpp index 95fbfefc..c7cc73ca 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -143,6 +143,142 @@ Option filter::parse_geopoint_filter_value(std::string& raw_value, return Option(true); } +Option filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) { + // FORMAT: + // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ] + + // Every open parenthesis represent a geo filter value. + auto open_parenthesis_count = std::count(raw_value.begin(), raw_value.end(), '('); + if (open_parenthesis_count < 1) { + return Option(400, format_err_msg); + } + + filter_exp.comparators.push_back(LESS_THAN_EQUALS); + bool is_multivalued = raw_value[0] == '['; + size_t i = is_multivalued; + + // Adding polygonal values at last since they don't have any parameters associated with them. + std::vector polygons; + for (auto j = 0; j < open_parenthesis_count; j++) { + if (is_multivalued) { + auto pos = raw_value.find('(', i); + if (pos == std::string::npos) { + return Option(400, format_err_msg); + } + i = pos; + } + + i++; + if (i >= raw_value.size()) { + return Option(400, format_err_msg); + } + + auto value_end_index = raw_value.find(')', i); + if (value_end_index == std::string::npos) { + return Option(400, format_err_msg); + } + + // [48.853, 2.344], radius: 1km, exact_filter_radius: 100km + // [48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469] + std::string value_str = raw_value.substr(i, value_end_index - i); + StringUtils::trim(value_str); + + if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) { + return Option(400, format_err_msg); + } + + auto points_str = value_str.substr(1, value_str.find(']', 1) - 1); + std::vector geo_points; + StringUtils::split(points_str, geo_points, ","); + + for (const auto& geo_point: geo_points) { + if(!StringUtils::is_float(geo_point)) { + return Option(400, format_err_msg); + } + } + + bool is_polygon = value_str.back() == ']'; + if (is_polygon) { + polygons.push_back(points_str); + continue; + } + + // Handle options. + // , radius: 1km, exact_filter_radius: 100km + i = raw_value.find(']', i); + i++; + + std::vector options; + StringUtils::split(raw_value.substr(i, value_end_index - i), options, ","); + + if (options.empty()) { + // Missing radius option + return Option(400, format_err_msg); + } + + bool is_radius_present = false; + for (auto const& option: options) { + if (option.empty()) { + continue; + } + + std::vector key_value; + StringUtils::split(option, key_value, ":"); + + if (key_value.size() < 2) { + continue; + } + + if (key_value[0] == GEO_FILTER_RADIUS) { + is_radius_present = true; + auto& value = key_value[1]; + + if(value.size() < 2) { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::string unit = value.substr(value.size() - 2, 2); + + if(unit != "km" && unit != "mi") { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::vector dist_values; + StringUtils::split(value, dist_values, unit); + + if(dist_values.size() != 1) { + return Option(400, format_err_msg); + } + + if(!StringUtils::is_float(dist_values[0])) { + return Option(400, format_err_msg); + } + + filter_exp.values.push_back(points_str + ", " + dist_values[0] + ", " + unit); + } else if (key_value[0] == EXACT_GEO_FILTER_RADIUS) { + nlohmann::json param; + param[EXACT_GEO_FILTER_RADIUS] = key_value[1]; + filter_exp.params.push_back(param); + } + } + + if (!is_radius_present) { + return Option(400, format_err_msg); + } + if (filter_exp.params.empty()) { + nlohmann::json param; + param[EXACT_GEO_FILTER_RADIUS] = DEFAULT_EXACT_GEO_FILTER_RADIUS; + filter_exp.params.push_back(param); + } + } + + for (auto const& polygon: polygons) { + filter_exp.values.push_back(polygon); + } + + return Option(true); +} + bool isOperator(const std::string& expression) { return expression == "&&" || expression == "||"; } @@ -383,10 +519,23 @@ Option toFilter(const std::string expression, } } else if (_field.is_geopoint()) { filter_exp = {field_name, {}, {}}; + NUM_COMPARATOR num_comparator; + + if ((raw_value[0] == '(' && std::count(raw_value.begin(), raw_value.end(), '[') > 0) || + std::count(raw_value.begin(), raw_value.end(), '[') > 1 || + std::count(raw_value.begin(), raw_value.end(), ':') > 0) { + + const std::string& format_err_msg = "Value of filter field `" + _field.name + "`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format."; + + auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, filter_exp); + return parse_op; + } + const std::string& format_err_msg = "Value of filter field `" + _field.name + "`: must be in the `(-44.50, 170.29, 0.75 km)` or " "(56.33, -65.97, 23.82, -127.82) format."; - NUM_COMPARATOR num_comparator; // could be a single value or a list if (raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') { std::vector filter_values; diff --git a/src/string_utils.cpp b/src/string_utils.cpp index b2caf445..f205d774 100644 --- a/src/string_utils.cpp +++ b/src/string_utils.cpp @@ -349,6 +349,35 @@ size_t StringUtils::get_num_chars(const std::string& s) { return j; } +Option parse_multi_valued_geopoint_filter(const std::string& filter_query, std::string& tokens, size_t& index) { + // Multi-valued geopoint filter. + // field_name:[ ([points], options), ([points]) ] + auto error = Option(400, "Could not parse the geopoint filter."); + if (filter_query[index] != '[') { + return error; + } + + size_t start_index = index; + auto size = filter_query.size(); + + // Individual geopoint filters have square brackets inside them. + int square_bracket_count = 1; + while (++index < size && square_bracket_count > 0) { + if (filter_query[index] == '[') { + square_bracket_count++; + } else if (filter_query[index] == ']') { + square_bracket_count--; + } + } + + if (square_bracket_count != 0) { + return error; + } + + tokens = filter_query.substr(start_index, index - start_index); + return Option(true); +} + Option parse_reference_filter(const std::string& filter_query, std::queue& tokens, size_t& index) { auto error = Option(400, "Could not parse the reference filter."); if (filter_query[index] != '$') { @@ -440,6 +469,15 @@ Option StringUtils::tokenize_filter_query(const std::string& filter_query, if (preceding_colon && c == '(') { is_geo_value = true; preceding_colon = false; + } else if (preceding_colon && c == '[') { + std::string value; + auto op = parse_multi_valued_geopoint_filter(filter_query, value, i); + if (!op.ok()) { + return op; + } + + ss << value; + break; } else if (preceding_colon && c != ' ') { preceding_colon = false; } From de42309059b79ed5e5d753a1844682f5d1c0dd5b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 22 May 2023 12:16:00 +0530 Subject: [PATCH 58/93] Add logic to skip exact geo filtering beyond threshold. --- include/filter.h | 8 +-- src/filter.cpp | 82 +++++++++++++++------- src/filter_result_iterator.cpp | 19 ++++- test/collection_filtering_test.cpp | 107 ++++++++++++++++++++++++++++- 4 files changed, 186 insertions(+), 30 deletions(-) diff --git a/include/filter.h b/include/filter.h index 0efd5900..f9f15c1d 100644 --- a/include/filter.h +++ b/include/filter.h @@ -25,16 +25,16 @@ struct filter { bool apply_not_equals = false; // Would store `Foo` in case of a filter expression like `$Foo(bar := baz)` - std::string referenced_collection_name = ""; + std::string referenced_collection_name; std::vector params; /// For searching places within a given radius of a given latlong (mi for miles and km for kilometers) - static constexpr const char* GEO_FILTER_RADIUS = "radius"; + static constexpr const char* GEO_FILTER_RADIUS_KEY = "radius"; /// Radius threshold beyond which exact filtering on geo_result_ids will not be done. - static constexpr const char* EXACT_GEO_FILTER_RADIUS = "exact_filter_radius"; - static constexpr const char* DEFAULT_EXACT_GEO_FILTER_RADIUS = "10km"; + static constexpr const char* EXACT_GEO_FILTER_RADIUS_KEY = "exact_filter_radius"; + static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000; static const std::string RANGE_OPERATOR() { return ".."; diff --git a/src/filter.cpp b/src/filter.cpp index c7cc73ca..3b13b400 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -143,6 +143,33 @@ Option filter::parse_geopoint_filter_value(std::string& raw_value, return Option(true); } +Option validate_geofilter_distance(std::string& raw_value, const string& format_err_msg, + std::string& distance, std::string& unit) { + if (raw_value.size() < 2) { + return Option(400, "Unit must be either `km` or `mi`."); + } + + unit = raw_value.substr(raw_value.size() - 2, 2); + + if (unit != "km" && unit != "mi") { + return Option(400, "Unit must be either `km` or `mi`."); + } + + std::vector dist_values; + StringUtils::split(raw_value, dist_values, unit); + + if (dist_values.size() != 1) { + return Option(400, format_err_msg); + } + + if (!StringUtils::is_float(dist_values[0])) { + return Option(400, format_err_msg); + } + + distance = std::string(dist_values[0]); + return Option(true); +} + Option filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) { // FORMAT: // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ] @@ -185,19 +212,27 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) { return Option(400, format_err_msg); + } else { + std::vector filter_values; + StringUtils::split(value_str, filter_values, ","); + + if(filter_values.size() < 3) { + return Option(400, format_err_msg); + } } auto points_str = value_str.substr(1, value_str.find(']', 1) - 1); std::vector geo_points; StringUtils::split(points_str, geo_points, ","); + bool is_polygon = value_str.back() == ']'; for (const auto& geo_point: geo_points) { - if(!StringUtils::is_float(geo_point)) { + if (!StringUtils::is_float(geo_point) || + (!is_polygon && (geo_point == "nan" || geo_point == "NaN"))) { return Option(400, format_err_msg); } } - bool is_polygon = value_str.back() == ']'; if (is_polygon) { polygons.push_back(points_str); continue; @@ -229,35 +264,34 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string continue; } - if (key_value[0] == GEO_FILTER_RADIUS) { + if (key_value[0] == GEO_FILTER_RADIUS_KEY) { is_radius_present = true; - auto& value = key_value[1]; - if(value.size() < 2) { - return Option(400, "Unit must be either `km` or `mi`."); + std::string distance, unit; + auto validate_op = validate_geofilter_distance(key_value[1], format_err_msg, distance, unit); + if (!validate_op.ok()) { + return validate_op; } - std::string unit = value.substr(value.size() - 2, 2); - - if(unit != "km" && unit != "mi") { - return Option(400, "Unit must be either `km` or `mi`."); + filter_exp.values.push_back(points_str + ", " + distance + ", " + unit); + } else if (key_value[0] == EXACT_GEO_FILTER_RADIUS_KEY) { + std::string distance, unit; + auto validate_op = validate_geofilter_distance(key_value[1], format_err_msg, distance, unit); + if (!validate_op.ok()) { + return validate_op; } - std::vector dist_values; - StringUtils::split(value, dist_values, unit); + double exact_under_radius = std::stof(distance); - if(dist_values.size() != 1) { - return Option(400, format_err_msg); + if (unit == "km") { + exact_under_radius *= 1000; + } else { + // assume "mi" (validated upstream) + exact_under_radius *= 1609.34; } - if(!StringUtils::is_float(dist_values[0])) { - return Option(400, format_err_msg); - } - - filter_exp.values.push_back(points_str + ", " + dist_values[0] + ", " + unit); - } else if (key_value[0] == EXACT_GEO_FILTER_RADIUS) { nlohmann::json param; - param[EXACT_GEO_FILTER_RADIUS] = key_value[1]; + param[EXACT_GEO_FILTER_RADIUS_KEY] = exact_under_radius; filter_exp.params.push_back(param); } } @@ -265,9 +299,11 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string if (!is_radius_present) { return Option(400, format_err_msg); } - if (filter_exp.params.empty()) { + + // EXACT_GEO_FILTER_RADIUS_KEY was not present. + if (filter_exp.params.size() < filter_exp.values.size()) { nlohmann::json param; - param[EXACT_GEO_FILTER_RADIUS] = DEFAULT_EXACT_GEO_FILTER_RADIUS; + param[EXACT_GEO_FILTER_RADIUS_KEY] = DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE; filter_exp.params.push_back(param); } } diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 707f2f69..115c22c2 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -754,7 +754,9 @@ void filter_result_iterator_t::init() { is_filter_result_initialized = true; return; } else if (f.is_geopoint()) { - for (const std::string& filter_value : a_filter.values) { + for (uint32_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + std::vector geo_result_ids; std::vector filter_value_parts; @@ -763,6 +765,7 @@ void filter_result_iterator_t::init() { bool is_polygon = StringUtils::is_float(filter_value_parts.back()); S2Region* query_region; + double radius; if (is_polygon) { const int num_verts = int(filter_value_parts.size()) / 2; std::vector vertices; @@ -788,7 +791,7 @@ void filter_result_iterator_t::init() { query_region = loop; } } else { - double radius = std::stof(filter_value_parts[2]); + radius = std::stof(filter_value_parts[2]); const auto& unit = filter_value_parts[3]; if (unit == "km") { @@ -820,6 +823,18 @@ void filter_result_iterator_t::init() { gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); + // Skip exact filtering step if query radius is greater than the threshold. + if (!is_polygon && fi < a_filter.params.size() && + radius > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get()) { + uint32_t* out = nullptr; + filter_result.count = ArrayUtils::or_scalar(geo_result_ids.data(), geo_result_ids.size(), + filter_result.docs, filter_result.count, &out); + + delete[] filter_result.docs; + filter_result.docs = out; + continue; + } + // `geo_result_ids` will contain all IDs that are within approximately within query radius // we still need to do another round of exact filtering on them diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index c644d547..b8439b0f 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1049,7 +1049,7 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) { collectionManager.drop_collection("coll_array_fields"); } -TEST_F(CollectionFilteringTest, GeoPointFiltering) { +TEST_F(CollectionFilteringTest, GeoPointFilteringV1) { Collection *coll1; std::vector fields = {field("title", field_types::STRING, false), @@ -1192,6 +1192,111 @@ TEST_F(CollectionFilteringTest, GeoPointFiltering) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionFilteringTest, GeoPointFilteringV2) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(10, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + // when geo query does not send radius key + gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) { Collection *coll1; From 4c2811f46d7acd87485b95cb4ad4a06542617956 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 22 May 2023 15:32:34 +0530 Subject: [PATCH 59/93] Add `geo_filtering_test.cpp`. --- src/filter_result_iterator.cpp | 3 +- test/collection_filtering_test.cpp | 107 +------ test/geo_filtering_test.cpp | 474 +++++++++++++++++++++++++++++ 3 files changed, 476 insertions(+), 108 deletions(-) create mode 100644 test/geo_filtering_test.cpp diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 115c22c2..d9b56683 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -807,6 +807,7 @@ void filter_result_iterator_t::init() { S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); query_region = new S2Cap(center, query_radius); } + std::unique_ptr query_region_guard(query_region); S2RegionTermIndexer::Options options; options.set_index_contains_points_only(true); @@ -883,8 +884,6 @@ void filter_result_iterator_t::init() { delete[] filter_result.docs; filter_result.docs = out; - - delete query_region; } if (filter_result.count == 0) { diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index b8439b0f..c644d547 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1049,7 +1049,7 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) { collectionManager.drop_collection("coll_array_fields"); } -TEST_F(CollectionFilteringTest, GeoPointFilteringV1) { +TEST_F(CollectionFilteringTest, GeoPointFiltering) { Collection *coll1; std::vector fields = {field("title", field_types::STRING, false), @@ -1192,111 +1192,6 @@ TEST_F(CollectionFilteringTest, GeoPointFilteringV1) { collectionManager.drop_collection("coll1"); } -TEST_F(CollectionFilteringTest, GeoPointFilteringV2) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - - results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - - // pick location close to none of the spots - results = coll1->search("*", - {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(0, results["found"].get()); - - // pick a large radius covering all points - - results = coll1->search("*", - {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(10, results["found"].get()); - - // 1 mile radius - - results = coll1->search("*", - {}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - - ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); - - // when geo query had NaN - auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY); - - ASSERT_FALSE(gop.ok()); - ASSERT_EQ("Value of filter field `loc`: must be in the " - "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); - - // when geo query does not send radius key - gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", - {}, {}, {0}, 10, 1, FREQUENCY); - - ASSERT_FALSE(gop.ok()); - ASSERT_EQ("Value of filter field `loc`: must be in the " - "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); - - collectionManager.drop_collection("coll1"); -} - TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) { Collection *coll1; diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp new file mode 100644 index 00000000..9db67c0b --- /dev/null +++ b/test/geo_filtering_test.cpp @@ -0,0 +1,474 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class GeoFilteringTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(GeoFilteringTest, GeoPointFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(10, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + // when geo query does not send radius key + gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector>> records = { + { {"Alpha Inc", "Ennore", "13.22112, 80.30511"}, + {"Alpha Inc", "Velachery", "12.98973, 80.23095"} + }, + + { + {"Veera Inc", "Thiruvallur", "13.12752, 79.90136"}, + }, + + { + {"B1 Inc", "Bengaluru", "12.98246, 77.5847"}, + {"B1 Inc", "Hosur", "12.74147, 77.82915"}, + {"B1 Inc", "Vellore", "12.91866, 79.13075"}, + }, + + { + {"M Inc", "Nashik", "20.11282, 73.79458"}, + {"M Inc", "Pune", "18.56309, 73.855"}, + } + }; + + for(size_t i=0; i> lat_lngs; + for(size_t k = 0; k < records[i].size(); k++) { + std::vector lat_lng_str; + StringUtils::split(records[i][k][2], lat_lng_str, ", "); + + std::vector lat_lng = { + std::stod(lat_lng_str[0]), + std::stod(lat_lng_str[1]) + }; + + lat_lngs.push_back(lat_lng); + } + + doc["loc"] = lat_lngs; + auto add_op = coll1->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + } + + // pick a location close to Chennai + auto results = coll1->search("*", + {}, "loc: ([13.12631, 80.20252], radius: 100km, exact_filter_radius: 100km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + // Default value of exact_filter_radius is 10km, exact filtering is not performed. + results = coll1->search("*", + {}, "loc: ([13.12631, 80.20252], radius: 100km,)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + ASSERT_STREQ("2", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: ([13.62601, 79.39559], radius: 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: ([21.20714729927276, 78.99153966917213], radius: 1000 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(4, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: ([12.98941, 80.23073], radius: 1mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPointRemoval) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc1", field_types::GEOPOINT, false), + field("loc2", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "Palais Garnier"; + doc["loc1"] = {48.872576479306765, 2.332291112241466}; + doc["loc2"] = nlohmann::json::array(); + doc["loc2"][0] = {48.84620987789056, 2.345152755563131}; + doc["points"] = 100; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("*", + {}, "loc1: ([48.87491151802846, 2.343945883701618], radius: 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: ([48.87491151802846, 2.343945883701618], radius: 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + // remove the document, index another document and try querying again + coll1->remove("0"); + doc["id"] = "1"; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + results = coll1->search("*", + {}, "loc1: ([48.87491151802846, 2.343945883701618], radius: 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: ([48.87491151802846, 2.343945883701618], radius: 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} + +TEST_F(GeoFilteringTest, GeoPolygonFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.875223042424125,2.323509661928681, " + "48.85745408145392, 2.3267084486160856, " + "48.859636574404355,2.351469427048221, " + "48.87756059389807, 2.3443610121873206])", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + + // should work even if points of polygon are clockwise + + results = coll1->search("*", + {}, "loc: ([48.87756059389807, 2.3443610121873206, " + "48.859636574404355,2.351469427048221, " + "48.85745408145392, 2.3267084486160856, " + "48.875223042424125,2.323509661928681])", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPolygonFilteringSouthAmerica) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"North of Equator", "4.48615, -71.38049"}, + {"South of Equator", "-8.48587, -71.02892"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a polygon that covers both points + + auto results = coll1->search("*", + {}, "loc: ([13.3163, -82.3585, " + "-29.134, -82.3585, " + "-29.134, -59.8528, " + "13.3163, -59.8528])", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringTest, GeoPointFilteringWithNonSortableLocationField) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string", "sort": false}, + {"name": "loc", "type": "geopoint", "sort": false}, + {"name": "points", "type": "int32", "sort": false} + ] + })"_json; + + auto coll_op = collectionManager.create_collection(schema); + ASSERT_TRUE(coll_op.ok()); + Collection* coll1 = coll_op.get(); + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: ([48.90615915923891, 2.3435897727061175], radius:3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} From b825ce752a3593e3be7c509f12ee7b0bc790cccb Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 10:09:28 +0530 Subject: [PATCH 60/93] Skip exact filtering beyond threshold in case of geo polygon. --- src/filter.cpp | 61 +++++++++++++++++----------------- src/filter_result_iterator.cpp | 18 +++++----- test/geo_filtering_test.cpp | 44 +++++++++++++++--------- 3 files changed, 69 insertions(+), 54 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index 3b13b400..5946e5d1 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -172,7 +172,8 @@ Option validate_geofilter_distance(std::string& raw_value, const string& f Option filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) { // FORMAT: - // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ] + // [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), + // ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469], exact_filter_radius: 100km) ] // Every open parenthesis represent a geo filter value. auto open_parenthesis_count = std::count(raw_value.begin(), raw_value.end(), '('); @@ -181,11 +182,9 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } filter_exp.comparators.push_back(LESS_THAN_EQUALS); - bool is_multivalued = raw_value[0] == '['; + bool is_multivalued = open_parenthesis_count > 1; size_t i = is_multivalued; - // Adding polygonal values at last since they don't have any parameters associated with them. - std::vector polygons; for (auto j = 0; j < open_parenthesis_count; j++) { if (is_multivalued) { auto pos = raw_value.find('(', i); @@ -206,57 +205,55 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } // [48.853, 2.344], radius: 1km, exact_filter_radius: 100km - // [48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469] + // [48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469], exact_filter_radius: 100km std::string value_str = raw_value.substr(i, value_end_index - i); StringUtils::trim(value_str); if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) { return Option(400, format_err_msg); - } else { - std::vector filter_values; - StringUtils::split(value_str, filter_values, ","); - - if(filter_values.size() < 3) { - return Option(400, format_err_msg); - } } auto points_str = value_str.substr(1, value_str.find(']', 1) - 1); std::vector geo_points; StringUtils::split(points_str, geo_points, ","); - bool is_polygon = value_str.back() == ']'; + if (geo_points.size() < 2 || geo_points.size() % 2) { + return Option(400, format_err_msg); + } + + bool is_polygon = geo_points.size() > 2; for (const auto& geo_point: geo_points) { - if (!StringUtils::is_float(geo_point) || - (!is_polygon && (geo_point == "nan" || geo_point == "NaN"))) { + if (geo_point == "nan" || geo_point == "NaN" || !StringUtils::is_float(geo_point)) { return Option(400, format_err_msg); } } if (is_polygon) { - polygons.push_back(points_str); - continue; + filter_exp.values.push_back(points_str); } // Handle options. // , radius: 1km, exact_filter_radius: 100km - i = raw_value.find(']', i); - i++; + i = raw_value.find(']', i) + 1; std::vector options; StringUtils::split(raw_value.substr(i, value_end_index - i), options, ","); if (options.empty()) { - // Missing radius option - return Option(400, format_err_msg); + if (!is_polygon) { + // Missing radius option + return Option(400, format_err_msg); + } + + nlohmann::json param; + param[EXACT_GEO_FILTER_RADIUS_KEY] = DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE; + filter_exp.params.push_back(param); + + continue; } bool is_radius_present = false; for (auto const& option: options) { - if (option.empty()) { - continue; - } - std::vector key_value; StringUtils::split(option, key_value, ":"); @@ -264,7 +261,7 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string continue; } - if (key_value[0] == GEO_FILTER_RADIUS_KEY) { + if (key_value[0] == GEO_FILTER_RADIUS_KEY && !is_polygon) { is_radius_present = true; std::string distance, unit; @@ -293,10 +290,16 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string nlohmann::json param; param[EXACT_GEO_FILTER_RADIUS_KEY] = exact_under_radius; filter_exp.params.push_back(param); + + // Only EXACT_GEO_FILTER_RADIUS_KEY option would be present for a polygon. We can also stop if we've + // parsed the radius in case of a single geopoint since there are only two options. + if (is_polygon || is_radius_present) { + break; + } } } - if (!is_radius_present) { + if (!is_radius_present && !is_polygon) { return Option(400, format_err_msg); } @@ -308,10 +311,6 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } } - for (auto const& polygon: polygons) { - filter_exp.values.push_back(polygon); - } - return Option(true); } diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index d9b56683..a6b61134 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -765,7 +765,7 @@ void filter_result_iterator_t::init() { bool is_polygon = StringUtils::is_float(filter_value_parts.back()); S2Region* query_region; - double radius; + double query_radius_meters; if (is_polygon) { const int num_verts = int(filter_value_parts.size()) / 2; std::vector vertices; @@ -790,22 +790,24 @@ void filter_result_iterator_t::init() { } else { query_region = loop; } + + query_radius_meters = S2Earth::RadiansToMeters(query_region->GetCapBound().GetRadius().radians()); } else { - radius = std::stof(filter_value_parts[2]); + query_radius_meters = std::stof(filter_value_parts[2]); const auto& unit = filter_value_parts[3]; if (unit == "km") { - radius *= 1000; + query_radius_meters *= 1000; } else { // assume "mi" (validated upstream) - radius *= 1609.34; + query_radius_meters *= 1609.34; } - S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); + S1Angle query_radius_radians = S1Angle::Radians(S2Earth::MetersToRadians(query_radius_meters)); double query_lat = std::stod(filter_value_parts[0]); double query_lng = std::stod(filter_value_parts[1]); S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); - query_region = new S2Cap(center, query_radius); + query_region = new S2Cap(center, query_radius_radians); } std::unique_ptr query_region_guard(query_region); @@ -825,8 +827,8 @@ void filter_result_iterator_t::init() { geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); // Skip exact filtering step if query radius is greater than the threshold. - if (!is_polygon && fi < a_filter.params.size() && - radius > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get()) { + if (fi < a_filter.params.size() && + query_radius_meters > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get()) { uint32_t* out = nullptr; filter_result.count = ArrayUtils::or_scalar(geo_result_ids.data(), geo_result_ids.size(), filter_result.docs, filter_result.count, &out); diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index 9db67c0b..ab200262 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -85,7 +85,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_EQ(1, results["found"].get()); ASSERT_EQ(1, results["hits"].size()); - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", {}, {}, {0}, 10, 1, FREQUENCY).get(); @@ -115,9 +115,9 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_EQ(3, results["found"].get()); - ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_EQ("6", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("5", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("3", results["hits"][2]["document"]["id"].get()); // when geo query had NaN auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)", @@ -206,8 +206,8 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ(2, results["found"].get()); ASSERT_EQ(2, results["hits"].size()); - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); // Default value of exact_filter_radius is 10km, exact filtering is not performed. results = coll1->search("*", @@ -217,9 +217,9 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); - ASSERT_STREQ("2", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("1", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_EQ("2", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); // pick location close to none of the spots results = coll1->search("*", @@ -244,7 +244,7 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ(1, results["found"].get()); - ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); collectionManager.drop_collection("coll1"); } @@ -355,9 +355,9 @@ TEST_F(GeoFilteringTest, GeoPolygonFiltering) { ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); - ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + ASSERT_EQ("8", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("4", results["hits"][1]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][2]["document"]["id"].get()); // should work even if points of polygon are clockwise @@ -389,6 +389,8 @@ TEST_F(GeoFilteringTest, GeoPolygonFilteringSouthAmerica) { std::vector> records = { {"North of Equator", "4.48615, -71.38049"}, {"South of Equator", "-8.48587, -71.02892"}, + {"North of Equator, outside polygon", "4.13377, -56.00459"}, + {"South of Equator, outside polygon", "-4.5041, -57.34523"}, }; for(size_t i=0; iadd(doc.dump()).ok()); } - // pick a polygon that covers both points - + // polygon only covers 2 points but all points are returned since exact filtering is not performed. auto results = coll1->search("*", {}, "loc: ([13.3163, -82.3585, " "-29.134, -82.3585, " @@ -417,9 +418,22 @@ TEST_F(GeoFilteringTest, GeoPolygonFilteringSouthAmerica) { "13.3163, -59.8528])", {}, {}, {0}, 10, 1, FREQUENCY).get(); + ASSERT_EQ(4, results["found"].get()); + ASSERT_EQ(4, results["hits"].size()); + + results = coll1->search("*", + {}, "loc: ([13.3163, -82.3585, " + "-29.134, -82.3585, " + "-29.134, -59.8528, " + "13.3163, -59.8528], exact_filter_radius: 2703km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + ASSERT_EQ(2, results["found"].get()); ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + collectionManager.drop_collection("coll1"); } From 2af35a36249b13c1412d5f240471101f25935c23 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:02:47 +0530 Subject: [PATCH 61/93] Add test cases. --- src/filter.cpp | 2 +- test/geo_filtering_test.cpp | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index 5946e5d1..f14edb60 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -562,7 +562,7 @@ Option toFilter(const std::string expression, const std::string& format_err_msg = "Value of filter field `" + _field.name + "`: must be in the " "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format."; + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format."; auto parse_op = filter::parse_geopoint_filter_value(raw_value, format_err_msg, filter_exp); return parse_op; diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index ab200262..f8c187f9 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -126,7 +126,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_FALSE(gop.ok()); ASSERT_EQ("Value of filter field `loc`: must be in the " "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); // when geo query does not send radius key gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])", @@ -135,7 +135,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_FALSE(gop.ok()); ASSERT_EQ("Value of filter field `loc`: must be in the " "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " - "([56.33, -65.97, 23.82, -127.82]) format.", gop.error()); + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); collectionManager.drop_collection("coll1"); } @@ -371,6 +371,21 @@ TEST_F(GeoFilteringTest, GeoPolygonFiltering) { ASSERT_EQ(3, results["found"].get()); ASSERT_EQ(3, results["hits"].size()); + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: ([48.87756059389807, 2.3443610121873206, NaN, nan])", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the " + "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " + "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); + + gop = coll1->search("*", {}, "loc: ([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7k)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Unit must be either `km` or `mi`.", gop.error()); + collectionManager.drop_collection("coll1"); } From ccf1f971939ee568ea60269b5072495c61925a23 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:07:24 +0530 Subject: [PATCH 62/93] Add comment. --- include/filter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/filter.h b/include/filter.h index f9f15c1d..f52b086f 100644 --- a/include/filter.h +++ b/include/filter.h @@ -34,7 +34,7 @@ struct filter { /// Radius threshold beyond which exact filtering on geo_result_ids will not be done. static constexpr const char* EXACT_GEO_FILTER_RADIUS_KEY = "exact_filter_radius"; - static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000; + static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000; // meters static const std::string RANGE_OPERATOR() { return ".."; From 61b813ecf726d3a847cb7bb845b564b83d00ccc1 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:14:59 +0530 Subject: [PATCH 63/93] Allow even a single value with square bracket notation. --- src/filter.cpp | 2 +- test/geo_filtering_test.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index f14edb60..c152d77f 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -182,7 +182,7 @@ Option filter::parse_geopoint_filter_value(string& raw_value, const string } filter_exp.comparators.push_back(LESS_THAN_EQUALS); - bool is_multivalued = open_parenthesis_count > 1; + bool is_multivalued = raw_value[0] == '['; size_t i = is_multivalued; for (auto j = 0; j < open_parenthesis_count; j++) { diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index f8c187f9..0cd45d69 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -87,6 +87,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + // Multiple queries can be clubbed using square brackets [ filterA, filterB, ... ] results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]", {}, {}, {0}, 10, 1, FREQUENCY).get(); @@ -94,7 +95,7 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { // pick location close to none of the spots results = coll1->search("*", - {}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)", + {}, "loc: [([48.910544830985785, 2.337218333651177], radius: 2 km)]", {}, {}, {0}, 10, 1, FREQUENCY).get(); ASSERT_EQ(0, results["found"].get()); From fedf8f4ec19a9c6446b3738bab3091d22936bf1c Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 23 May 2023 11:40:26 +0530 Subject: [PATCH 64/93] Add `geo_filtering_old_test.cpp`. --- test/collection_filtering_test.cpp | 508 --------------------------- test/geo_filtering_old_test.cpp | 544 +++++++++++++++++++++++++++++ test/geo_filtering_test.cpp | 90 +++++ 3 files changed, 634 insertions(+), 508 deletions(-) create mode 100644 test/geo_filtering_old_test.cpp diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index c644d547..988b035a 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1049,514 +1049,6 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) { collectionManager.drop_collection("coll_array_fields"); } -TEST_F(CollectionFilteringTest, GeoPointFiltering) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - - - results = coll1->search("*", {}, "loc: (48.90615, 2.34358, 1 km) || " - "loc: (48.8462, 2.34515, 1 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - - // pick location close to none of the spots - results = coll1->search("*", - {}, "loc: (48.910544830985785, 2.337218333651177, 2 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(0, results["found"].get()); - - // pick a large radius covering all points - - results = coll1->search("*", - {}, "loc: (48.910544830985785, 2.337218333651177, 20 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(10, results["found"].get()); - - // 1 mile radius - - results = coll1->search("*", - {}, "loc: (48.85825332869331, 2.303816427653377, 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - - ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); - - // when geo query had NaN - auto gop = coll1->search("*", {}, "loc: (NaN, nan, 1 mi)", - {}, {}, {0}, 10, 1, FREQUENCY); - - ASSERT_FALSE(gop.ok()); - ASSERT_EQ("Value of filter field `loc`: must be in the `(-44.50, 170.29, 0.75 km)` or " - "(56.33, -65.97, 23.82, -127.82) format.", gop.error()); - - // when geo field is formatted as string, show meaningful error - nlohmann::json bad_doc; - bad_doc["id"] = "1000"; - bad_doc["title"] = "Test record"; - bad_doc["loc"] = {"48.91", "2.33"}; - bad_doc["points"] = 1000; - - auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - bad_doc["loc"] = "foobar"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); - - bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2k)"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); - - bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2)"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); - - bad_doc["loc"] = {"foo", "bar"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - bad_doc["loc"] = {"2.33", "bar"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - bad_doc["loc"] = {"foo", "2.33"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); - - // under coercion mode, it should work - bad_doc["loc"] = {"48.91", "2.33"}; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_TRUE(add_op.ok()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT_ARRAY, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector>> records = { - { {"Alpha Inc", "Ennore", "13.22112, 80.30511"}, - {"Alpha Inc", "Velachery", "12.98973, 80.23095"} - }, - - { - {"Veera Inc", "Thiruvallur", "13.12752, 79.90136"}, - }, - - { - {"B1 Inc", "Bengaluru", "12.98246, 77.5847"}, - {"B1 Inc", "Hosur", "12.74147, 77.82915"}, - {"B1 Inc", "Vellore", "12.91866, 79.13075"}, - }, - - { - {"M Inc", "Nashik", "20.11282, 73.79458"}, - {"M Inc", "Pune", "18.56309, 73.855"}, - } - }; - - for(size_t i=0; i> lat_lngs; - for(size_t k = 0; k < records[i].size(); k++) { - std::vector lat_lng_str; - StringUtils::split(records[i][k][2], lat_lng_str, ", "); - - std::vector lat_lng = { - std::stod(lat_lng_str[0]), - std::stod(lat_lng_str[1]) - }; - - lat_lngs.push_back(lat_lng); - } - - doc["loc"] = lat_lngs; - auto add_op = coll1->add(doc.dump()); - ASSERT_TRUE(add_op.ok()); - } - - // pick a location close to Chennai - auto results = coll1->search("*", - {}, "loc: (13.12631, 80.20252, 100km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - ASSERT_EQ(2, results["hits"].size()); - - ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); - - // pick location close to none of the spots - results = coll1->search("*", - {}, "loc: (13.62601, 79.39559, 10 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(0, results["found"].get()); - - // pick a large radius covering all points - - results = coll1->search("*", - {}, "loc: (21.20714729927276, 78.99153966917213, 1000 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(4, results["found"].get()); - - // 1 mile radius - - results = coll1->search("*", - {}, "loc: (12.98941, 80.23073, 1mi)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - - ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); - - // when geo field is formatted badly, show meaningful error - nlohmann::json bad_doc; - bad_doc["id"] = "1000"; - bad_doc["title"] = "Test record"; - bad_doc["loc"] = {"48.91", "2.33"}; - bad_doc["points"] = 1000; - - auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must contain 2 element arrays: [ [lat, lng],... ].", add_op.error()); - - bad_doc["loc"] = "foobar"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array.", add_op.error()); - - bad_doc["loc"] = nlohmann::json::array(); - nlohmann::json points = nlohmann::json::array(); - points.push_back("foo"); - points.push_back("bar"); - bad_doc["loc"].push_back(points); - - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); - - bad_doc["loc"][0][0] = "2.33"; - bad_doc["loc"][0][1] = "bar"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); - - bad_doc["loc"][0][0] = "foo"; - bad_doc["loc"][0][1] = "2.33"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_FALSE(add_op.ok()); - ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); - - // under coercion mode, it should work - bad_doc["loc"][0][0] = "48.91"; - bad_doc["loc"][0][1] = "2.33"; - add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); - ASSERT_TRUE(add_op.ok()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPointRemoval) { - std::vector fields = {field("title", field_types::STRING, false), - field("loc1", field_types::GEOPOINT, false), - field("loc2", field_types::GEOPOINT_ARRAY, false), - field("points", field_types::INT32, false),}; - - Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - - nlohmann::json doc; - doc["id"] = "0"; - doc["title"] = "Palais Garnier"; - doc["loc1"] = {48.872576479306765, 2.332291112241466}; - doc["loc2"] = nlohmann::json::array(); - doc["loc2"][0] = {48.84620987789056, 2.345152755563131}; - doc["points"] = 100; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - - auto results = coll1->search("*", - {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - results = coll1->search("*", - {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - // remove the document, index another document and try querying again - coll1->remove("0"); - doc["id"] = "1"; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - - results = coll1->search("*", - {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); - - results = coll1->search("*", - {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); -} - -TEST_F(CollectionFilteringTest, GeoPolygonFiltering) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: (48.875223042424125,2.323509661928681, " - "48.85745408145392, 2.3267084486160856, " - "48.859636574404355,2.351469427048221, " - "48.87756059389807, 2.3443610121873206)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - ASSERT_EQ(3, results["hits"].size()); - - ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); - ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); - ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); - - // should work even if points of polygon are clockwise - - results = coll1->search("*", - {}, "loc: (48.87756059389807, 2.3443610121873206, " - "48.859636574404355,2.351469427048221, " - "48.85745408145392, 2.3267084486160856, " - "48.875223042424125,2.323509661928681)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(3, results["found"].get()); - ASSERT_EQ(3, results["hits"].size()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPolygonFilteringSouthAmerica) { - Collection *coll1; - - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - coll1 = collectionManager.get_collection("coll1").get(); - if(coll1 == nullptr) { - coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); - } - - std::vector> records = { - {"North of Equator", "4.48615, -71.38049"}, - {"South of Equator", "-8.48587, -71.02892"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a polygon that covers both points - - auto results = coll1->search("*", - {}, "loc: (13.3163, -82.3585, " - "-29.134, -82.3585, " - "-29.134, -59.8528, " - "13.3163, -59.8528)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(2, results["found"].get()); - ASSERT_EQ(2, results["hits"].size()); - - collectionManager.drop_collection("coll1"); -} - -TEST_F(CollectionFilteringTest, GeoPointFilteringWithNonSortableLocationField) { - std::vector fields = {field("title", field_types::STRING, false), - field("loc", field_types::GEOPOINT, false), - field("points", field_types::INT32, false),}; - - nlohmann::json schema = R"({ - "name": "coll1", - "fields": [ - {"name": "title", "type": "string", "sort": false}, - {"name": "loc", "type": "geopoint", "sort": false}, - {"name": "points", "type": "int32", "sort": false} - ] - })"_json; - - auto coll_op = collectionManager.create_collection(schema); - ASSERT_TRUE(coll_op.ok()); - Collection* coll1 = coll_op.get(); - - std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - }; - - for(size_t i=0; i lat_lng; - StringUtils::split(records[i][1], lat_lng, ", "); - - double lat = std::stod(lat_lng[0]); - double lng = std::stod(lat_lng[1]); - - doc["id"] = std::to_string(i); - doc["title"] = records[i][0]; - doc["loc"] = {lat, lng}; - doc["points"] = i; - - ASSERT_TRUE(coll1->add(doc.dump()).ok()); - } - - // pick a location close to only the Sacre Coeur - auto results = coll1->search("*", - {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", - {}, {}, {0}, 10, 1, FREQUENCY).get(); - - ASSERT_EQ(1, results["found"].get()); - ASSERT_EQ(1, results["hits"].size()); -} - TEST_F(CollectionFilteringTest, FilteringWithPrefixSearch) { Collection *coll1; diff --git a/test/geo_filtering_old_test.cpp b/test/geo_filtering_old_test.cpp new file mode 100644 index 00000000..f5f37684 --- /dev/null +++ b/test/geo_filtering_old_test.cpp @@ -0,0 +1,544 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class GeoFilteringOldTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(GeoFilteringOldTest, GeoPointFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + + + results = coll1->search("*", {}, "loc: (48.90615, 2.34358, 1 km) || " + "loc: (48.8462, 2.34515, 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: (48.910544830985785, 2.337218333651177, 2 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: (48.910544830985785, 2.337218333651177, 20 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(10, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: (48.85825332869331, 2.303816427653377, 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + + ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get().c_str()); + + // when geo query had NaN + auto gop = coll1->search("*", {}, "loc: (NaN, nan, 1 mi)", + {}, {}, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(gop.ok()); + ASSERT_EQ("Value of filter field `loc`: must be in the `(-44.50, 170.29, 0.75 km)` or " + "(56.33, -65.97, 23.82, -127.82) format.", gop.error()); + + // when geo field is formatted as string, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2k)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = {"foo", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"2.33", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"foo", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"] = {"48.91", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPointArrayFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector>> records = { + { {"Alpha Inc", "Ennore", "13.22112, 80.30511"}, + {"Alpha Inc", "Velachery", "12.98973, 80.23095"} + }, + + { + {"Veera Inc", "Thiruvallur", "13.12752, 79.90136"}, + }, + + { + {"B1 Inc", "Bengaluru", "12.98246, 77.5847"}, + {"B1 Inc", "Hosur", "12.74147, 77.82915"}, + {"B1 Inc", "Vellore", "12.91866, 79.13075"}, + }, + + { + {"M Inc", "Nashik", "20.11282, 73.79458"}, + {"M Inc", "Pune", "18.56309, 73.855"}, + } + }; + + for(size_t i=0; i> lat_lngs; + for(size_t k = 0; k < records[i].size(); k++) { + std::vector lat_lng_str; + StringUtils::split(records[i][k][2], lat_lng_str, ", "); + + std::vector lat_lng = { + std::stod(lat_lng_str[0]), + std::stod(lat_lng_str[1]) + }; + + lat_lngs.push_back(lat_lng); + } + + doc["loc"] = lat_lngs; + auto add_op = coll1->add(doc.dump()); + ASSERT_TRUE(add_op.ok()); + } + + // pick a location close to Chennai + auto results = coll1->search("*", + {}, "loc: (13.12631, 80.20252, 100km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][1]["document"]["id"].get().c_str()); + + // pick location close to none of the spots + results = coll1->search("*", + {}, "loc: (13.62601, 79.39559, 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(0, results["found"].get()); + + // pick a large radius covering all points + + results = coll1->search("*", + {}, "loc: (21.20714729927276, 78.99153966917213, 1000 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(4, results["found"].get()); + + // 1 mile radius + + results = coll1->search("*", + {}, "loc: (12.98941, 80.23073, 1mi)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + + ASSERT_STREQ("0", results["hits"][0]["document"]["id"].get().c_str()); + + // when geo field is formatted badly, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must contain 2 element arrays: [ [lat, lng],... ].", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array.", add_op.error()); + + bad_doc["loc"] = nlohmann::json::array(); + nlohmann::json points = nlohmann::json::array(); + points.push_back("foo"); + points.push_back("bar"); + bad_doc["loc"].push_back(points); + + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "2.33"; + bad_doc["loc"][0][1] = "bar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "foo"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"][0][0] = "48.91"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPointRemoval) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc1", field_types::GEOPOINT, false), + field("loc2", field_types::GEOPOINT_ARRAY, false), + field("points", field_types::INT32, false),}; + + Collection* coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + + nlohmann::json doc; + doc["id"] = "0"; + doc["title"] = "Palais Garnier"; + doc["loc1"] = {48.872576479306765, 2.332291112241466}; + doc["loc2"] = nlohmann::json::array(); + doc["loc2"][0] = {48.84620987789056, 2.345152755563131}; + doc["points"] = 100; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + auto results = coll1->search("*", + {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + // remove the document, index another document and try querying again + coll1->remove("0"); + doc["id"] = "1"; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + + results = coll1->search("*", + {}, "loc1: (48.87491151802846, 2.343945883701618, 1 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + + results = coll1->search("*", + {}, "loc2: (48.87491151802846, 2.343945883701618, 10 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} + +TEST_F(GeoFilteringOldTest, GeoPolygonFiltering) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: (48.875223042424125,2.323509661928681, " + "48.85745408145392, 2.3267084486160856, " + "48.859636574404355,2.351469427048221, " + "48.87756059389807, 2.3443610121873206)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + ASSERT_STREQ("8", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("4", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("0", results["hits"][2]["document"]["id"].get().c_str()); + + // should work even if points of polygon are clockwise + + results = coll1->search("*", + {}, "loc: (48.87756059389807, 2.3443610121873206, " + "48.859636574404355,2.351469427048221, " + "48.85745408145392, 2.3267084486160856, " + "48.875223042424125,2.323509661928681)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPolygonFilteringSouthAmerica) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"North of Equator", "4.48615, -71.38049"}, + {"South of Equator", "-8.48587, -71.02892"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a polygon that covers both points + + auto results = coll1->search("*", + {}, "loc: (13.3163, -82.3585, " + "-29.134, -82.3585, " + "-29.134, -59.8528, " + "13.3163, -59.8528)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(2, results["found"].get()); + ASSERT_EQ(2, results["hits"].size()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(GeoFilteringOldTest, GeoPointFilteringWithNonSortableLocationField) { + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + nlohmann::json schema = R"({ + "name": "coll1", + "fields": [ + {"name": "title", "type": "string", "sort": false}, + {"name": "loc", "type": "geopoint", "sort": false}, + {"name": "points", "type": "int32", "sort": false} + ] + })"_json; + + auto coll_op = collectionManager.create_collection(schema); + ASSERT_TRUE(coll_op.ok()); + Collection* coll1 = coll_op.get(); + + std::vector> records = { + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + }; + + for(size_t i=0; i lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + // pick a location close to only the Sacre Coeur + auto results = coll1->search("*", + {}, "loc: (48.90615915923891, 2.3435897727061175, 3 km)", + {}, {}, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); +} \ No newline at end of file diff --git a/test/geo_filtering_test.cpp b/test/geo_filtering_test.cpp index 0cd45d69..526d44ae 100644 --- a/test/geo_filtering_test.cpp +++ b/test/geo_filtering_test.cpp @@ -138,6 +138,52 @@ TEST_F(GeoFilteringTest, GeoPointFiltering) { "`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or " "([56.33, -65.97, 23.82, -127.82], exact_filter_radius: 7 km) format.", gop.error()); + // when geo field is formatted as string, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2k)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = "loc: (48.910544830985785, 2.337218333651177, 2)"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a 2 element array: [lat, lng].", add_op.error()); + + bad_doc["loc"] = {"foo", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"2.33", "bar"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + bad_doc["loc"] = {"foo", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be a geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"] = {"48.91", "2.33"}; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + collectionManager.drop_collection("coll1"); } @@ -247,6 +293,50 @@ TEST_F(GeoFilteringTest, GeoPointArrayFiltering) { ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + // when geo field is formatted badly, show meaningful error + nlohmann::json bad_doc; + bad_doc["id"] = "1000"; + bad_doc["title"] = "Test record"; + bad_doc["loc"] = {"48.91", "2.33"}; + bad_doc["points"] = 1000; + + auto add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must contain 2 element arrays: [ [lat, lng],... ].", add_op.error()); + + bad_doc["loc"] = "foobar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array.", add_op.error()); + + bad_doc["loc"] = nlohmann::json::array(); + nlohmann::json points = nlohmann::json::array(); + points.push_back("foo"); + points.push_back("bar"); + bad_doc["loc"].push_back(points); + + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "2.33"; + bad_doc["loc"][0][1] = "bar"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + bad_doc["loc"][0][0] = "foo"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_FALSE(add_op.ok()); + ASSERT_EQ("Field `loc` must be an array of geopoint.", add_op.error()); + + // under coercion mode, it should work + bad_doc["loc"][0][0] = "48.91"; + bad_doc["loc"][0][1] = "2.33"; + add_op = coll1->add(bad_doc.dump(), CREATE, "", DIRTY_VALUES::COERCE_OR_REJECT); + ASSERT_TRUE(add_op.ok()); + collectionManager.drop_collection("coll1"); } From c9180a0541aadb618a6b6369e131ef747ec9b981 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 26 May 2023 16:00:22 +0530 Subject: [PATCH 65/93] Add `NumericTrie`. --- include/numeric_range_trie_test.h | 60 +++++++++++++++ src/numeric_range_trie.cpp | 122 ++++++++++++++++++++++++++++++ test/numeric_range_trie_test.cpp | 41 ++++++++++ 3 files changed, 223 insertions(+) create mode 100644 include/numeric_range_trie_test.h create mode 100644 src/numeric_range_trie.cpp create mode 100644 test/numeric_range_trie_test.cpp diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h new file mode 100644 index 00000000..ce1da83d --- /dev/null +++ b/include/numeric_range_trie_test.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include "sorted_array.h" + +constexpr char MAX_LEVEL = 4; +constexpr short EXPANSE = 256; + +class NumericTrieNode { + NumericTrieNode** children = nullptr; + sorted_array seq_ids; + + void insert(const int32_t& value, const uint32_t& seq_id, char& level); + + void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); +public: + + ~NumericTrieNode() { + if (children != nullptr) { + for (auto i = 0; i < EXPANSE; i++) { + delete children[i]; + } + } + + delete [] children; + } + + void insert(const int32_t& value, const uint32_t& seq_id); + + void search_range(const int32_t& low,const int32_t& high, + uint32_t*& ids, uint32_t& ids_length); + + void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + + void search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); +}; + +class NumericTrie { + NumericTrieNode* negative_trie = nullptr; + NumericTrieNode* positive_trie = nullptr; + +public: + + ~NumericTrie() { + delete negative_trie; + delete positive_trie; + } + + void insert(const int32_t& value, const uint32_t& seq_id); + + void search_range(const int32_t& low, const bool& low_inclusive, + const int32_t& high, const bool& high_inclusive, + uint32_t*& ids, uint32_t& ids_length); + + void search_lesser(const int32_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); + + void search_greater(const int32_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); +}; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp new file mode 100644 index 00000000..c62765a5 --- /dev/null +++ b/src/numeric_range_trie.cpp @@ -0,0 +1,122 @@ +#include "numeric_range_trie_test.h" +#include "array_utils.h" + +void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { + if (value < 0) { + if (negative_trie == nullptr) { + negative_trie = new NumericTrieNode(); + } + + negative_trie->insert(std::abs(value), seq_id); + } else { + if (positive_trie == nullptr) { + positive_trie = new NumericTrieNode(); + } + + positive_trie->insert(value, seq_id); + } +} + +void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, + const int32_t& high, const bool& high_inclusive, + uint32_t*& ids, uint32_t& ids_length) { + if (low < 0 && high >= 0) { + // Have to combine the results of >low from negative_trie and low from negative_trie. + negative_trie->search_lesser(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + } + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + if (!(high == 0 && !high_inclusive)) { + positive_trie->search_lesser(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); + } + + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + + delete [] negative_ids; + delete [] positive_ids; + return; + } +} + +void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { + char level = 0; + return insert(value, seq_id, level); +} + +inline int get_index(const int32_t& value, char& level) { + return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; +} + +void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& level) { + if (level > MAX_LEVEL) { + return; + } + + if (!seq_ids.contains(seq_id)) { + seq_ids.append(seq_id); + } + + if (++level <= MAX_LEVEL) { + if (children == nullptr) { + children = new NumericTrieNode* [EXPANSE]{nullptr}; + } + + auto index = get_index(value, level); + if (children[index] == nullptr) { + children[index] = new NumericTrieNode(); + } + + return children[index]->insert(value, seq_id, level); + } +} + +void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { + char level = 0; + std::vector matches; + search_lesser_helper(value, level, matches); + + for (auto const& match: matches) { + uint32_t* out = nullptr; + auto const& m_seq_ids = match->seq_ids.uncompress(); + ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + + delete [] m_seq_ids; + delete [] ids; + ids = out; + } +} + +void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, std::vector& matches) { + if (level > MAX_LEVEL) { + return; + } else if (level == MAX_LEVEL) { + matches.push_back(this); + return; + } + + if (children == nullptr) { + return; + } + + auto index = get_index(value, ++level); + if (children[index] == nullptr) { + return; + } + + children[index]->search_lesser_helper(value, level, matches); + + while (--index >= 0) { + if (children[index] != nullptr) { + matches.push_back(children[index]); + } + } + + --level; +} diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp new file mode 100644 index 00000000..b21b659a --- /dev/null +++ b/test/numeric_range_trie_test.cpp @@ -0,0 +1,41 @@ +#include +#include "numeric_range_trie_test.h" + +class NumericRangeTrieTest : public ::testing::Test { +protected: + + virtual void SetUp() {} + + virtual void TearDown() {} +}; + +TEST_F(NumericRangeTrieTest, Insert) { + auto trie = new NumericTrie(); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32768, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_range(-32768, true, 32768, true, ids, ids_length); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + delete [] ids; + delete trie; +} \ No newline at end of file From 8ec5bd2efe3d24e17803ff11e451662f2204791d Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 29 May 2023 12:36:13 +0530 Subject: [PATCH 66/93] Add test cases. --- src/numeric_range_trie.cpp | 18 +++----- test/numeric_range_trie_test.cpp | 78 ++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index c62765a5..d7476a28 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -25,7 +25,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, uint32_t* negative_ids = nullptr; uint32_t negative_ids_length = 0; - if (!(low == -1 && !low_inclusive)) { + if (!(low == -1 && !low_inclusive)) { // No need to search for (-1, ... auto abs_low = std::abs(low); // Since we store absolute values, search_lesser would yield result for >low from negative_trie. negative_trie->search_lesser(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); @@ -33,7 +33,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - if (!(high == 0 && !high_inclusive)) { + if (!(high == 0 && !high_inclusive)) { // No need to search for ..., 0) positive_trie->search_lesser(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); } @@ -94,24 +94,18 @@ void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32 } void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, std::vector& matches) { - if (level > MAX_LEVEL) { - return; - } else if (level == MAX_LEVEL) { + if (level == MAX_LEVEL) { matches.push_back(this); return; - } - - if (children == nullptr) { + } else if (level > MAX_LEVEL || children == nullptr) { return; } auto index = get_index(value, ++level); - if (children[index] == nullptr) { - return; + if (children[index] != nullptr) { + children[index]->search_lesser_helper(value, level, matches); } - children[index]->search_lesser_helper(value, level, matches); - while (--index >= 0) { if (children[index] != nullptr) { matches.push_back(children[index]); diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index b21b659a..526cdb32 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -9,8 +9,9 @@ protected: virtual void TearDown() {} }; -TEST_F(NumericRangeTrieTest, Insert) { +TEST_F(NumericRangeTrieTest, SearchRange) { auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); std::vector> pairs = { {-8192, 8}, {-16384, 32}, @@ -30,12 +31,81 @@ TEST_F(NumericRangeTrieTest, Insert) { uint32_t ids_length = 0; trie->search_range(-32768, true, 32768, true, ids, ids_length); + std::unique_ptr ids_guard(ids); ASSERT_EQ(pairs.size(), ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } - delete [] ids; - delete trie; -} \ No newline at end of file + trie->search_range(-32768, true, 32768, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(pairs.size() - 1, ids_length); + for (uint32_t i = 0; i < pairs.size() - 1; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_range(-32768, true, 134217728, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_range(-32768, true, 0, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < 4; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_range(-32768, true, 0, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < 4; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_range(-32768, false, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(pairs.size() - 1, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + trie->search_range(-134217728, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_range(-1, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_range(-1, false, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_range(-1, true, 0, true, ids, ids_length); + ASSERT_EQ(0, ids_length); + + trie->search_range(-1, false, 0, false, ids, ids_length); + ASSERT_EQ(0, ids_length); +} From 8520e785cc1ae4bc2efa4dbd4ee6876b99254b79 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 29 May 2023 14:44:55 +0530 Subject: [PATCH 67/93] Add `NumericTrie::search_greater`. --- include/numeric_range_trie_test.h | 5 ++ src/numeric_range_trie.cpp | 90 +++++++++++++++++++++++++++++++ test/numeric_range_trie_test.cpp | 80 +++++++++++++++++++++++++++ 3 files changed, 175 insertions(+) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index ce1da83d..b3d4436c 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -13,6 +13,9 @@ class NumericTrieNode { void insert(const int32_t& value, const uint32_t& seq_id, char& level); void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); + + void search_greater_helper(const int32_t& value, char& level, std::vector& matches); + public: ~NumericTrieNode() { @@ -27,6 +30,8 @@ public: void insert(const int32_t& value, const uint32_t& seq_id); + void get_all_ids(uint32_t*& ids, uint32_t& ids_length); + void search_range(const int32_t& low,const int32_t& high, uint32_t*& ids, uint32_t& ids_length); diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index d7476a28..50cf3315 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -39,6 +39,41 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + delete [] negative_ids; + delete [] positive_ids; + return; + } else if (low >= 0) { + // Search only in positive_trie + } else { + // Search only in negative_trie + } +} + +void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { + if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) + positive_trie->get_all_ids(ids, ids_length); + return; + } + + if (value >= 0) { + uint32_t* positive_ids = nullptr; + positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, ids_length); + ids = positive_ids; + } else { + // Have to combine the results of >value from negative_trie and all the ids in positive_trie + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + auto abs_low = std::abs(value); + // Since we store absolute values, search_lesser would yield result for >low from negative_trie. + negative_trie->search_lesser(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->get_all_ids(positive_ids, positive_ids_length); + + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + delete [] negative_ids; delete [] positive_ids; return; @@ -51,6 +86,13 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { } inline int get_index(const int32_t& value, char& level) { + // Values are index considering higher order of the bytes first. + // 0x01020408 (16909320) would be indexed in the trie as follows: + // Level Index + // 1 1 + // 2 2 + // 3 4 + // 4 8 return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; } @@ -59,6 +101,7 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& return; } + // Root node contains all the sequence ids present in the tree. if (!seq_ids.contains(seq_id)) { seq_ids.append(seq_id); } @@ -77,6 +120,11 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& } } +void NumericTrieNode::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { + ids = seq_ids.uncompress(); + ids_length = seq_ids.getLength(); +} + void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; @@ -114,3 +162,45 @@ void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, st --level; } + +void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { + +} + +void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { + char level = 0; + std::vector matches; + search_greater_helper(value, level, matches); + + for (auto const& match: matches) { + uint32_t* out = nullptr; + auto const& m_seq_ids = match->seq_ids.uncompress(); + ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + + delete [] m_seq_ids; + delete [] ids; + ids = out; + } +} + +void NumericTrieNode::search_greater_helper(const int32_t& value, char& level, std::vector& matches) { + if (level == MAX_LEVEL) { + matches.push_back(this); + return; + } else if (level > MAX_LEVEL || children == nullptr) { + return; + } + + auto index = get_index(value, ++level); + if (children[index] != nullptr) { + children[index]->search_greater_helper(value, level, matches); + } + + while (++index < EXPANSE) { + if (children[index] != nullptr) { + matches.push_back(children[index]); + } + } + + --level; +} diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 526cdb32..690209a8 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -109,3 +109,83 @@ TEST_F(NumericRangeTrieTest, SearchRange) { trie->search_range(-1, false, 0, false, ids, ids_length); ASSERT_EQ(0, ids_length); } + +TEST_F(NumericRangeTrieTest, SearchGreater) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32768, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_greater(0, true, ids, ids_length); + std::unique_ptr ids_guard(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(-1, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(-1, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(-24576, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(7, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + trie->search_greater(-32768, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(7, ids_length); + for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { + if (i == 3) continue; // id for -32768 would not be present + ASSERT_EQ(pairs[i].second, ids[j++]); + } + + trie->search_greater(8192, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_greater(8192, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(3, ids_length); + for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } +} From 06b46ab96100cba9b99022ded1daa5e713efdb55 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 29 May 2023 16:30:12 +0530 Subject: [PATCH 68/93] Add test case. --- src/numeric_range_trie.cpp | 5 ++++- test/numeric_range_trie_test.cpp | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 50cf3315..53090b5d 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -57,7 +57,10 @@ void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, ui if (value >= 0) { uint32_t* positive_ids = nullptr; - positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, ids_length); + uint32_t positive_ids_length = 0; + positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, positive_ids_length); + + ids_length = positive_ids_length; ids = positive_ids; } else { // Have to combine the results of >value from negative_trie and all the ids in positive_trie diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 690209a8..bc47f35a 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -188,4 +188,9 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + + trie->search_greater(100000, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); } From 64ec856097a9a37f15d4397ea57686c09024cb16 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 29 May 2023 17:32:48 +0530 Subject: [PATCH 69/93] Add `NumericTrie::search_lesser`. --- src/numeric_range_trie.cpp | 66 ++++++++++++-- test/numeric_range_trie_test.cpp | 146 ++++++++++++++++++++++++++++++- 2 files changed, 202 insertions(+), 10 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 53090b5d..254bffd5 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -25,7 +25,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, uint32_t* negative_ids = nullptr; uint32_t negative_ids_length = 0; - if (!(low == -1 && !low_inclusive)) { // No need to search for (-1, ... + if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ... auto abs_low = std::abs(low); // Since we store absolute values, search_lesser would yield result for >low from negative_trie. negative_trie->search_lesser(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); @@ -33,7 +33,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - if (!(high == 0 && !high_inclusive)) { // No need to search for ..., 0) + if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) positive_trie->search_lesser(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); } @@ -51,14 +51,18 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) - positive_trie->get_all_ids(ids, ids_length); + if (positive_trie != nullptr) { + positive_trie->get_all_ids(ids, ids_length); + } return; } if (value >= 0) { uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, positive_ids_length); + if (positive_trie != nullptr) { + positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, positive_ids_length); + } ids_length = positive_ids_length; ids = positive_ids; @@ -67,13 +71,17 @@ void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, ui uint32_t* negative_ids = nullptr; uint32_t negative_ids_length = 0; - auto abs_low = std::abs(value); - // Since we store absolute values, search_lesser would yield result for >low from negative_trie. - negative_trie->search_lesser(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + negative_trie->search_lesser(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + } uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->get_all_ids(positive_ids, positive_ids_length); + if (positive_trie != nullptr) { + positive_trie->get_all_ids(positive_ids, positive_ids_length); + } ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); @@ -83,6 +91,48 @@ void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, ui } } +void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { + if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] + if (negative_trie != nullptr) { + negative_trie->get_all_ids(ids, ids_length); + } + return; + } + + if (value < 0) { + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + // Since we store absolute values, search_greater would yield result for search_greater(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); + } + + ids_length = negative_ids_length; + ids = negative_ids; + } else { + // Have to combine the results of search_lesser(inclusive ? value : value - 1, positive_ids, positive_ids_length); + } + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + if (negative_trie != nullptr) { + negative_trie->get_all_ids(negative_ids, negative_ids_length); + } + + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, negative_ids, negative_ids_length, &ids); + + delete [] negative_ids; + delete [] positive_ids; + return; + } +} + void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { char level = 0; return insert(value, seq_id, level); diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index bc47f35a..6b6fa45f 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -23,7 +23,7 @@ TEST_F(NumericRangeTrieTest, SearchRange) { {32768, 91} }; - for (auto const pair: pairs) { + for (auto const& pair: pairs) { trie->insert(pair.first, pair.second); } @@ -124,7 +124,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { {32768, 91} }; - for (auto const pair: pairs) { + for (auto const& pair: pairs) { trie->insert(pair.first, pair.second); } @@ -193,4 +193,146 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + + trie->search_greater(-100000, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(8, ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } +} + +TEST_F(NumericRangeTrieTest, SearchLesser) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_lesser(0, true, ids, ids_length); + std::unique_ptr ids_guard(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_lesser(0, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_lesser(-1, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_lesser(-16384, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(3, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_lesser(-16384, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(2, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_lesser(8192, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_lesser(8192, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + + trie->search_lesser(-100000, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_lesser(100000, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(8, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } +} + +TEST_F(NumericRangeTrieTest, Validation) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_range(-32768, true, 32768, true, ids, ids_length); + std::unique_ptr ids_guard(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_greater(0, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_greater(15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_greater(-15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_lesser(0, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_lesser(-15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_lesser(15, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); } From 596f77898e8967061a13a8b75b39c3aaffd94932 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 30 May 2023 09:56:09 +0530 Subject: [PATCH 70/93] Add test case. --- src/numeric_range_trie.cpp | 4 ++++ test/numeric_range_trie_test.cpp | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 254bffd5..be0c7f2c 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -20,6 +20,10 @@ void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, const int32_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length) { + if (low >= high) { + return; + } + if (low < 0 && high >= 0) { // Have to combine the results of >low from negative_trie and search_range(-32768, true, 32768, true, ids, ids_length); + trie->search_range(32768, true, -32768, true, ids, ids_length); std::unique_ptr ids_guard(ids); + ASSERT_EQ(0, ids_length); + + trie->search_range(-32768, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + ASSERT_EQ(pairs.size(), ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); From 29c58cee0fe4ddc428982cb8b314236347a2fc41 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 30 May 2023 18:04:55 +0530 Subject: [PATCH 71/93] Add `NumericTrieNode::search_range`. --- include/numeric_range_trie_test.h | 6 +- src/numeric_range_trie.cpp | 94 +++++++++++++++++++++++++++++-- test/numeric_range_trie_test.cpp | 66 ++++++++++++++++++++-- 3 files changed, 156 insertions(+), 10 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index b3d4436c..2144c8a9 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -10,7 +10,9 @@ class NumericTrieNode { NumericTrieNode** children = nullptr; sorted_array seq_ids; - void insert(const int32_t& value, const uint32_t& seq_id, char& level); + void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level); + + void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); @@ -32,7 +34,7 @@ public: void get_all_ids(uint32_t*& ids, uint32_t& ids_length); - void search_range(const int32_t& low,const int32_t& high, + void search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length); void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index be0c7f2c..f975df7b 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -20,7 +20,7 @@ void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, const int32_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length) { - if (low >= high) { + if (low > high) { return; } @@ -48,8 +48,30 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, return; } else if (low >= 0) { // Search only in positive_trie + + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + if (positive_trie != nullptr) { + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, + positive_ids, positive_ids_length); + } + + ids = positive_ids; + ids_length = positive_ids_length; } else { // Search only in negative_trie + + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + if (negative_trie != nullptr) { + // Since we store absolute values, switching low and high would produce the correct result. + auto abs_high = std::abs(high), abs_low = std::abs(low); + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + negative_ids, negative_ids_length); + } + + ids = negative_ids; + ids_length = negative_ids_length; } } @@ -139,7 +161,7 @@ void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uin void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { char level = 0; - return insert(value, seq_id, level); + return insert_helper(value, seq_id, level); } inline int get_index(const int32_t& value, char& level) { @@ -153,7 +175,7 @@ inline int get_index(const int32_t& value, char& level) { return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; } -void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& level) { +void NumericTrieNode::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) { if (level > MAX_LEVEL) { return; } @@ -173,7 +195,7 @@ void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id, char& children[index] = new NumericTrieNode(); } - return children[index]->insert(value, seq_id, level); + return children[index]->insert_helper(value, seq_id, level); } } @@ -221,7 +243,71 @@ void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, st } void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { + if (low > high) { + return; + } + std::vector matches; + search_range_helper(low, high, matches); + for (auto const& match: matches) { + uint32_t* out = nullptr; + auto const& m_seq_ids = match->seq_ids.uncompress(); + ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + + delete [] m_seq_ids; + delete [] ids; + ids = out; + } +} + +void NumericTrieNode::search_range_helper(const int32_t& low, const int32_t& high, + std::vector& matches) { + // Segregating the nodes into matching low, in-between, and matching high. + + NumericTrieNode* root = this; + char level = 1; + auto low_index = get_index(low, level), high_index = get_index(high, level); + + // Keep updating the root while the range is contained within a single child node. + while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) { + if (root->children[low_index] == nullptr) { + return; + } + + root = root->children[low_index]; + level++; + low_index = get_index(low, level); + high_index = get_index(high, level); + } + + if (root->children == nullptr) { + return; + } else if (low_index == high_index) { // low and high are equal + if (root->children[low_index] != nullptr) { + matches.push_back(root->children[low_index]); + } + return; + } + + if (root->children[low_index] != nullptr) { + // Collect all the sub-nodes that are greater than low. + root->children[low_index]->search_greater_helper(low, level, matches); + } + + auto index = low_index + 1; + // All the nodes in-between low and high are a match by default. + while (index < std::min(high_index, (int)EXPANSE)) { + if (root->children[index] != nullptr) { + matches.push_back(root->children[index]); + } + + index++; + } + + if (index < EXPANSE && index == high_index && root->children[index] != nullptr) { + // Collect all the sub-nodes that are lesser than high. + root->children[index]->search_lesser_helper(high, level, matches); + } } void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index ff297aec..875ed544 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -92,6 +92,14 @@ TEST_F(NumericRangeTrieTest, SearchRange) { ASSERT_EQ(pairs[i].second, ids[i]); } + trie->search_range(-134217728, true, 134217728, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(pairs.size(), ids_length); + for (uint32_t i = 0; i < pairs.size(); i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } + trie->search_range(-1, true, 32768, true, ids, ids_length); ids_guard.reset(ids); @@ -113,6 +121,56 @@ TEST_F(NumericRangeTrieTest, SearchRange) { trie->search_range(-1, false, 0, false, ids, ids_length); ASSERT_EQ(0, ids_length); + + trie->search_range(8192, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_range(8192, true, 0x2000000, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { + ASSERT_EQ(pairs[i].second, ids[j]); + } + + trie->search_range(16384, true, 16384, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(56, ids[0]); + + trie->search_range(16384, true, 16384, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(16384, false, 16384, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(16383, true, 16383, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(8193, true, 16383, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(-32768, true, -8192, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(pairs[i].second, ids[i]); + } } TEST_F(NumericRangeTrieTest, SearchGreater) { @@ -194,12 +252,12 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(100000, false, ids, ids_length); + trie->search_greater(1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_greater(-100000, false, ids, ids_length); + trie->search_greater(-1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(8, ids_length); @@ -285,12 +343,12 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(-100000, false, ids, ids_length); + trie->search_lesser(-1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_lesser(100000, true, ids, ids_length); + trie->search_lesser(1000000, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(8, ids_length); From 377665b6710e6f856f6af48aef8d813ea151c4d6 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 31 May 2023 10:19:19 +0530 Subject: [PATCH 72/93] Add `MultivalueData` test. --- test/numeric_range_trie_test.cpp | 111 ++++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 875ed544..ae59589b 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -357,7 +357,106 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { } } -TEST_F(NumericRangeTrieTest, Validation) { +TEST_F(NumericRangeTrieTest, MultivalueData) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-0x202020, 32}, + {-32768, 5}, + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {0, 43}, + {0, 49}, + {1, 8}, + {256, 91}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91}, + {0x202020, 35}, + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_lesser(0, false, ids, ids_length); + std::unique_ptr ids_guard(ids); + + std::vector expected = {5, 8, 32, 35, 43}; + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->search_lesser(-16380, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(4, ids_length); + + expected = {5, 8, 32, 35}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->search_lesser(16384, false, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(7, ids_length); + + expected = {5, 8, 32, 35, 43, 49, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->search_greater(0, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(7, ids_length); + + expected = {8, 35, 43, 49, 56, 58, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->search_greater(256, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(5, ids_length); + + expected = {35, 49, 56, 58, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->search_greater(-32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(9, ids_length); + + expected = {5, 8, 32, 35, 43, 49, 56, 58, 91}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->search_range(-32768, true, 0, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(6, ids_length); + + expected = {5, 8, 32, 35, 43, 49}; + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } +} + +TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); @@ -369,6 +468,16 @@ TEST_F(NumericRangeTrieTest, Validation) { ASSERT_EQ(0, ids_length); + trie->search_range(-32768, true, -1, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_range(1, true, 32768, true, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); + trie->search_greater(0, true, ids, ids_length); ids_guard.reset(ids); From 4e240cfb145a0529fb7d142b1cf9cff281111143 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 31 May 2023 12:23:45 +0530 Subject: [PATCH 73/93] Refactor `NumericTrie`. --- include/numeric_range_trie_test.h | 50 +++++++++++++++---------------- src/numeric_range_trie.cpp | 36 +++++++++++----------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index 2144c8a9..7511a712 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -6,45 +6,45 @@ constexpr char MAX_LEVEL = 4; constexpr short EXPANSE = 256; -class NumericTrieNode { - NumericTrieNode** children = nullptr; - sorted_array seq_ids; +class NumericTrie { + class Node { + Node** children = nullptr; + sorted_array seq_ids; - void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level); + void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level); - void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); + void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); - void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); + void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); - void search_greater_helper(const int32_t& value, char& level, std::vector& matches); + void search_greater_helper(const int32_t& value, char& level, std::vector& matches); -public: + public: - ~NumericTrieNode() { - if (children != nullptr) { - for (auto i = 0; i < EXPANSE; i++) { - delete children[i]; + ~Node() { + if (children != nullptr) { + for (auto i = 0; i < EXPANSE; i++) { + delete children[i]; + } } + + delete [] children; } - delete [] children; - } + void insert(const int32_t& value, const uint32_t& seq_id); - void insert(const int32_t& value, const uint32_t& seq_id); + void get_all_ids(uint32_t*& ids, uint32_t& ids_length); - void get_all_ids(uint32_t*& ids, uint32_t& ids_length); + void search_range(const int32_t& low, const int32_t& high, + uint32_t*& ids, uint32_t& ids_length); - void search_range(const int32_t& low, const int32_t& high, - uint32_t*& ids, uint32_t& ids_length); + void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); - void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + }; - void search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); -}; - -class NumericTrie { - NumericTrieNode* negative_trie = nullptr; - NumericTrieNode* positive_trie = nullptr; + Node* negative_trie = nullptr; + Node* positive_trie = nullptr; public: diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index f975df7b..ce7e6b98 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -4,13 +4,13 @@ void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { if (value < 0) { if (negative_trie == nullptr) { - negative_trie = new NumericTrieNode(); + negative_trie = new NumericTrie::Node(); } negative_trie->insert(std::abs(value), seq_id); } else { if (positive_trie == nullptr) { - positive_trie = new NumericTrieNode(); + positive_trie = new NumericTrie::Node(); } positive_trie->insert(value, seq_id); @@ -159,7 +159,7 @@ void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uin } } -void NumericTrieNode::insert(const int32_t& value, const uint32_t& seq_id) { +void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) { char level = 0; return insert_helper(value, seq_id, level); } @@ -175,7 +175,7 @@ inline int get_index(const int32_t& value, char& level) { return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; } -void NumericTrieNode::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) { +void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) { if (level > MAX_LEVEL) { return; } @@ -187,26 +187,26 @@ void NumericTrieNode::insert_helper(const int32_t& value, const uint32_t& seq_id if (++level <= MAX_LEVEL) { if (children == nullptr) { - children = new NumericTrieNode* [EXPANSE]{nullptr}; + children = new NumericTrie::Node* [EXPANSE]{nullptr}; } auto index = get_index(value, level); if (children[index] == nullptr) { - children[index] = new NumericTrieNode(); + children[index] = new NumericTrie::Node(); } return children[index]->insert_helper(value, seq_id, level); } } -void NumericTrieNode::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { ids = seq_ids.uncompress(); ids_length = seq_ids.getLength(); } -void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { char level = 0; - std::vector matches; + std::vector matches; search_lesser_helper(value, level, matches); for (auto const& match: matches) { @@ -220,7 +220,7 @@ void NumericTrieNode::search_lesser(const int32_t& value, uint32_t*& ids, uint32 } } -void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, std::vector& matches) { +void NumericTrie::Node::search_lesser_helper(const int32_t& value, char& level, std::vector& matches) { if (level == MAX_LEVEL) { matches.push_back(this); return; @@ -242,11 +242,11 @@ void NumericTrieNode::search_lesser_helper(const int32_t& value, char& level, st --level; } -void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { if (low > high) { return; } - std::vector matches; + std::vector matches; search_range_helper(low, high, matches); for (auto const& match: matches) { @@ -260,11 +260,11 @@ void NumericTrieNode::search_range(const int32_t& low, const int32_t& high, uint } } -void NumericTrieNode::search_range_helper(const int32_t& low, const int32_t& high, - std::vector& matches) { +void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& high, + std::vector& matches) { // Segregating the nodes into matching low, in-between, and matching high. - NumericTrieNode* root = this; + NumericTrie::Node* root = this; char level = 1; auto low_index = get_index(low, level), high_index = get_index(high, level); @@ -310,9 +310,9 @@ void NumericTrieNode::search_range_helper(const int32_t& low, const int32_t& hig } } -void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { char level = 0; - std::vector matches; + std::vector matches; search_greater_helper(value, level, matches); for (auto const& match: matches) { @@ -326,7 +326,7 @@ void NumericTrieNode::search_greater(const int32_t& value, uint32_t*& ids, uint3 } } -void NumericTrieNode::search_greater_helper(const int32_t& value, char& level, std::vector& matches) { +void NumericTrie::Node::search_greater_helper(const int32_t& value, char& level, std::vector& matches) { if (level == MAX_LEVEL) { matches.push_back(this); return; From 3c0f597b520bc0175d2dd9d563b934faebe73f8d Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 31 May 2023 13:57:18 +0530 Subject: [PATCH 74/93] Rename methods. --- include/numeric_range_trie_test.h | 16 ++++----- src/numeric_range_trie.cpp | 36 +++++++++---------- test/numeric_range_trie_test.cpp | 60 +++++++++++++++---------------- 3 files changed, 56 insertions(+), 56 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index 7511a712..cc980e3e 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -15,9 +15,9 @@ class NumericTrie { void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); - void search_lesser_helper(const int32_t& value, char& level, std::vector& matches); + void search_less_than_helper(const int32_t& value, char& level, std::vector& matches); - void search_greater_helper(const int32_t& value, char& level, std::vector& matches); + void search_greater_than_helper(const int32_t& value, char& level, std::vector& matches); public: @@ -38,9 +38,9 @@ class NumericTrie { void search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length); - void search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); - void search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); }; Node* negative_trie = nullptr; @@ -59,9 +59,9 @@ public: const int32_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_lesser(const int32_t& value, const bool& inclusive, - uint32_t*& ids, uint32_t& ids_length); + void search_less_than(const int32_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); - void search_greater(const int32_t& value, const bool& inclusive, - uint32_t*& ids, uint32_t& ids_length); + void search_greater_than(const int32_t& value, const bool& inclusive, + uint32_t*& ids, uint32_t& ids_length); }; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index ce7e6b98..0c985765 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -32,13 +32,13 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, if (negative_trie != nullptr && !(low == -1 && !low_inclusive)) { // No need to search for (-1, ... auto abs_low = std::abs(low); // Since we store absolute values, search_lesser would yield result for >low from negative_trie. - negative_trie->search_lesser(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); } uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) - positive_trie->search_lesser(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); + positive_trie->search_less_than(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); } ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); @@ -75,7 +75,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, } } -void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { positive_trie->get_all_ids(ids, ids_length); @@ -87,7 +87,7 @@ void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, ui uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; if (positive_trie != nullptr) { - positive_trie->search_greater(inclusive ? value : value + 1, positive_ids, positive_ids_length); + positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); } ids_length = positive_ids_length; @@ -100,7 +100,7 @@ void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, ui // Since we store absolute values, search_lesser would yield result for >value from negative_trie. if (negative_trie != nullptr) { auto abs_low = std::abs(value); - negative_trie->search_lesser(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); } uint32_t* positive_ids = nullptr; @@ -117,7 +117,7 @@ void NumericTrie::search_greater(const int32_t& value, const bool& inclusive, ui } } -void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] if (negative_trie != nullptr) { negative_trie->get_all_ids(ids, ids_length); @@ -131,7 +131,7 @@ void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uin // Since we store absolute values, search_greater would yield result for search_greater(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); + negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); } ids_length = negative_ids_length; @@ -142,7 +142,7 @@ void NumericTrie::search_lesser(const int32_t& value, const bool& inclusive, uin uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; if (positive_trie != nullptr) { - positive_trie->search_lesser(inclusive ? value : value - 1, positive_ids, positive_ids_length); + positive_trie->search_less_than(inclusive ? value : value - 1, positive_ids, positive_ids_length); } uint32_t* negative_ids = nullptr; @@ -204,10 +204,10 @@ void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { ids_length = seq_ids.getLength(); } -void NumericTrie::Node::search_lesser(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; - search_lesser_helper(value, level, matches); + search_less_than_helper(value, level, matches); for (auto const& match: matches) { uint32_t* out = nullptr; @@ -220,7 +220,7 @@ void NumericTrie::Node::search_lesser(const int32_t& value, uint32_t*& ids, uint } } -void NumericTrie::Node::search_lesser_helper(const int32_t& value, char& level, std::vector& matches) { +void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& level, std::vector& matches) { if (level == MAX_LEVEL) { matches.push_back(this); return; @@ -230,7 +230,7 @@ void NumericTrie::Node::search_lesser_helper(const int32_t& value, char& level, auto index = get_index(value, ++level); if (children[index] != nullptr) { - children[index]->search_lesser_helper(value, level, matches); + children[index]->search_less_than_helper(value, level, matches); } while (--index >= 0) { @@ -291,7 +291,7 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h if (root->children[low_index] != nullptr) { // Collect all the sub-nodes that are greater than low. - root->children[low_index]->search_greater_helper(low, level, matches); + root->children[low_index]->search_greater_than_helper(low, level, matches); } auto index = low_index + 1; @@ -306,14 +306,14 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h if (index < EXPANSE && index == high_index && root->children[index] != nullptr) { // Collect all the sub-nodes that are lesser than high. - root->children[index]->search_lesser_helper(high, level, matches); + root->children[index]->search_less_than_helper(high, level, matches); } } -void NumericTrie::Node::search_greater(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; - search_greater_helper(value, level, matches); + search_greater_than_helper(value, level, matches); for (auto const& match: matches) { uint32_t* out = nullptr; @@ -326,7 +326,7 @@ void NumericTrie::Node::search_greater(const int32_t& value, uint32_t*& ids, uin } } -void NumericTrie::Node::search_greater_helper(const int32_t& value, char& level, std::vector& matches) { +void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& level, std::vector& matches) { if (level == MAX_LEVEL) { matches.push_back(this); return; @@ -336,7 +336,7 @@ void NumericTrie::Node::search_greater_helper(const int32_t& value, char& level, auto index = get_index(value, ++level); if (children[index] != nullptr) { - children[index]->search_greater_helper(value, level, matches); + children[index]->search_greater_than_helper(value, level, matches); } while (++index < EXPANSE) { diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index ae59589b..28344882 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -194,7 +194,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { uint32_t* ids = nullptr; uint32_t ids_length = 0; - trie->search_greater(0, true, ids, ids_length); + trie->search_greater_than(0, true, ids, ids_length); std::unique_ptr ids_guard(ids); ASSERT_EQ(4, ids_length); @@ -202,7 +202,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(-1, false, ids, ids_length); + trie->search_greater_than(-1, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -210,7 +210,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(-1, true, ids, ids_length); + trie->search_greater_than(-1, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -218,7 +218,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(-24576, true, ids, ids_length); + trie->search_greater_than(-24576, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(7, ids_length); @@ -227,7 +227,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j++]); } - trie->search_greater(-32768, false, ids, ids_length); + trie->search_greater_than(-32768, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(7, ids_length); @@ -236,7 +236,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j++]); } - trie->search_greater(8192, true, ids, ids_length); + trie->search_greater_than(8192, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -244,7 +244,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(8192, false, ids, ids_length); + trie->search_greater_than(8192, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(3, ids_length); @@ -252,12 +252,12 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_greater(1000000, false, ids, ids_length); + trie->search_greater_than(1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_greater(-1000000, false, ids, ids_length); + trie->search_greater_than(-1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(8, ids_length); @@ -287,7 +287,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { uint32_t* ids = nullptr; uint32_t ids_length = 0; - trie->search_lesser(0, true, ids, ids_length); + trie->search_less_than(0, true, ids, ids_length); std::unique_ptr ids_guard(ids); ASSERT_EQ(4, ids_length); @@ -295,7 +295,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[j]); } - trie->search_lesser(0, false, ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -303,7 +303,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(-1, true, ids, ids_length); + trie->search_less_than(-1, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -311,7 +311,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(-16384, true, ids, ids_length); + trie->search_less_than(-16384, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(3, ids_length); @@ -319,7 +319,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(-16384, false, ids, ids_length); + trie->search_less_than(-16384, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(2, ids_length); @@ -327,7 +327,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(8192, true, ids, ids_length); + trie->search_less_than(8192, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(5, ids_length); @@ -335,7 +335,7 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(8192, false, ids, ids_length); + trie->search_less_than(8192, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -343,12 +343,12 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { ASSERT_EQ(pairs[i].second, ids[i]); } - trie->search_lesser(-1000000, false, ids, ids_length); + trie->search_less_than(-1000000, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_lesser(1000000, true, ids, ids_length); + trie->search_less_than(1000000, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(8, ids_length); @@ -385,7 +385,7 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { uint32_t* ids = nullptr; uint32_t ids_length = 0; - trie->search_lesser(0, false, ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); std::unique_ptr ids_guard(ids); std::vector expected = {5, 8, 32, 35, 43}; @@ -395,7 +395,7 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } - trie->search_lesser(-16380, false, ids, ids_length); + trie->search_less_than(-16380, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -405,7 +405,7 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } - trie->search_lesser(16384, false, ids, ids_length); + trie->search_less_than(16384, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(7, ids_length); @@ -415,7 +415,7 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } - trie->search_greater(0, true, ids, ids_length); + trie->search_greater_than(0, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(7, ids_length); @@ -425,7 +425,7 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } - trie->search_greater(256, true, ids, ids_length); + trie->search_greater_than(256, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(5, ids_length); @@ -435,7 +435,7 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } - trie->search_greater(-32768, true, ids, ids_length); + trie->search_greater_than(-32768, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(9, ids_length); @@ -478,32 +478,32 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { ASSERT_EQ(0, ids_length); - trie->search_greater(0, true, ids, ids_length); + trie->search_greater_than(0, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_greater(15, true, ids, ids_length); + trie->search_greater_than(15, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_greater(-15, true, ids, ids_length); + trie->search_greater_than(-15, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_lesser(0, false, ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_lesser(-15, true, ids, ids_length); + trie->search_less_than(-15, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); - trie->search_lesser(15, true, ids, ids_length); + trie->search_less_than(15, true, ids, ids_length); ids_guard.reset(ids); ASSERT_EQ(0, ids_length); From 8daca328313a4761e06c33952b243f3a0caea5e0 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 31 May 2023 15:19:06 +0530 Subject: [PATCH 75/93] Add `NumericTrie::search_equal_to`. --- include/numeric_range_trie_test.h | 4 +++ src/numeric_range_trie.cpp | 30 +++++++++++++++++ test/numeric_range_trie_test.cpp | 55 +++++++++++++++++++++++++++++-- 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index cc980e3e..737cf5f9 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -41,6 +41,8 @@ class NumericTrie { void search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); void search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + + void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); }; Node* negative_trie = nullptr; @@ -64,4 +66,6 @@ public: void search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); + + void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); }; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 0c985765..cfa9e7e4 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -159,6 +159,19 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, } } +void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { + uint32_t* match_ids = nullptr; + uint32_t match_ids_length = 0; + if (value < 0 && negative_trie != nullptr) { + negative_trie->search_equal_to(std::abs(value), match_ids, match_ids_length); + } else if (value >= 0 && positive_trie != nullptr) { + positive_trie->search_equal_to(value, match_ids, match_ids_length); + } + + ids = match_ids; + ids_length = match_ids_length; +} + void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) { char level = 0; return insert_helper(value, seq_id, level); @@ -347,3 +360,20 @@ void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& l --level; } + +void NumericTrie::Node::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { + char level = 1; + Node* root = this; + auto index = get_index(value, level); + + while (level <= MAX_LEVEL) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level); + } + + root->get_all_ids(ids, ids_length); +} diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 28344882..baad5ef7 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -173,7 +173,7 @@ TEST_F(NumericRangeTrieTest, SearchRange) { } } -TEST_F(NumericRangeTrieTest, SearchGreater) { +TEST_F(NumericRangeTrieTest, SearchGreaterThan) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); std::vector> pairs = { @@ -266,7 +266,7 @@ TEST_F(NumericRangeTrieTest, SearchGreater) { } } -TEST_F(NumericRangeTrieTest, SearchLesser) { +TEST_F(NumericRangeTrieTest, SearchLessThan) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); std::vector> pairs = { @@ -357,6 +357,52 @@ TEST_F(NumericRangeTrieTest, SearchLesser) { } } +TEST_F(NumericRangeTrieTest, SearchEqualTo) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32769, 41}, + {-32768, 43}, + {-32767, 45}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_equal_to(0, ids, ids_length); + std::unique_ptr ids_guard(ids); + + ASSERT_EQ(0, ids_length); + + trie->search_equal_to(-32768, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(43, ids[0]); + + trie->search_equal_to(24576, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(58, ids[0]); + + trie->search_equal_to(0x202020, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); +} + TEST_F(NumericRangeTrieTest, MultivalueData) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); @@ -507,4 +553,9 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + + trie->search_equal_to(15, ids, ids_length); + ids_guard.reset(ids); + + ASSERT_EQ(0, ids_length); } From aeb473deb2c0aa995624fc9050413e8e25ee6ac5 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 1 Jun 2023 11:03:21 +0530 Subject: [PATCH 76/93] Refactor `NumericTrie::search_greater_than`. --- src/numeric_range_trie.cpp | 58 +++++++++++++++++++++++--------- test/numeric_range_trie_test.cpp | 25 +++++++++----- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index cfa9e7e4..3c4e7cd1 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -78,42 +78,68 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { - positive_trie->get_all_ids(ids, ids_length); + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->get_all_ids(positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; } return; } if (value >= 0) { - uint32_t* positive_ids = nullptr; - uint32_t positive_ids_length = 0; - if (positive_trie != nullptr) { - positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); + if (positive_trie == nullptr) { + return; } - ids_length = positive_ids_length; - ids = positive_ids; + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; + positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; } else { // Have to combine the results of >value from negative_trie and all the ids in positive_trie - uint32_t* negative_ids = nullptr; - uint32_t negative_ids_length = 0; - // Since we store absolute values, search_lesser would yield result for >value from negative_trie. if (negative_trie != nullptr) { + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; auto abs_low = std::abs(value); + + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; + } + + if (positive_trie == nullptr) { + return; } uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - if (positive_trie != nullptr) { - positive_trie->get_all_ids(positive_ids, positive_ids_length); - } + positive_trie->get_all_ids(positive_ids, positive_ids_length); - ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); - delete [] negative_ids; delete [] positive_ids; - return; + delete [] ids; + ids = out; } } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index baad5ef7..3f591dfc 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -9,6 +9,12 @@ protected: virtual void TearDown() {} }; +void reset(uint32_t*& ids, uint32_t& ids_length) { + delete [] ids; + ids = nullptr; + ids_length = 0; +} + TEST_F(NumericRangeTrieTest, SearchRange) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); @@ -195,31 +201,30 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) { uint32_t ids_length = 0; trie->search_greater_than(0, true, ids, ids_length); - std::unique_ptr ids_guard(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(-1, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(-1, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(-24576, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(7, ids_length); for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { @@ -227,8 +232,8 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) { ASSERT_EQ(pairs[i].second, ids[j++]); } + reset(ids, ids_length); trie->search_greater_than(-32768, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(7, ids_length); for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { @@ -236,34 +241,36 @@ TEST_F(NumericRangeTrieTest, SearchGreaterThan) { ASSERT_EQ(pairs[i].second, ids[j++]); } + reset(ids, ids_length); trie->search_greater_than(8192, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(8192, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(3, ids_length); for (uint32_t i = 5, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_greater_than(1000000, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_greater_than(-1000000, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(8, ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + + reset(ids, ids_length); } TEST_F(NumericRangeTrieTest, SearchLessThan) { From 25e36eea0fcd3419e697553a592a84ecc5b0b4fa Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 1 Jun 2023 11:21:31 +0530 Subject: [PATCH 77/93] Refactor `NumericTrie::search_equal_to`. --- src/numeric_range_trie.cpp | 25 +++++++++++++++++-------- test/numeric_range_trie_test.cpp | 7 +++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 3c4e7cd1..5ce81287 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -186,16 +186,25 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, } void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { - uint32_t* match_ids = nullptr; - uint32_t match_ids_length = 0; - if (value < 0 && negative_trie != nullptr) { - negative_trie->search_equal_to(std::abs(value), match_ids, match_ids_length); - } else if (value >= 0 && positive_trie != nullptr) { - positive_trie->search_equal_to(value, match_ids, match_ids_length); + if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { + return; } - ids = match_ids; - ids_length = match_ids_length; + uint32_t* equal_ids = nullptr; + uint32_t equal_ids_length = 0; + + if (value < 0) { + negative_trie->search_equal_to(std::abs(value), equal_ids, equal_ids_length); + } else { + positive_trie->search_equal_to(value, equal_ids, equal_ids_length); + } + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(equal_ids, equal_ids_length, ids, ids_length, &out); + + delete [] equal_ids; + delete [] ids; + ids = out; } void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) { diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 3f591dfc..6a45054a 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -388,24 +388,23 @@ TEST_F(NumericRangeTrieTest, SearchEqualTo) { uint32_t ids_length = 0; trie->search_equal_to(0, ids, ids_length); - std::unique_ptr ids_guard(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_equal_to(-32768, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(1, ids_length); ASSERT_EQ(43, ids[0]); + reset(ids, ids_length); trie->search_equal_to(24576, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(1, ids_length); ASSERT_EQ(58, ids[0]); + reset(ids, ids_length); trie->search_equal_to(0x202020, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); } From f2e27500c4f6d312ca0156b9b44d5b8fb4f64942 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 1 Jun 2023 11:42:26 +0530 Subject: [PATCH 78/93] Refactor `NumericTrie::search_less_than`. --- src/numeric_range_trie.cpp | 60 +++++++++++++++++++++++--------- test/numeric_range_trie_test.cpp | 19 +++++----- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 5ce81287..5b397716 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -146,42 +146,68 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] if (negative_trie != nullptr) { - negative_trie->get_all_ids(ids, ids_length); + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + negative_trie->get_all_ids(negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; } return; } if (value < 0) { - uint32_t* negative_ids = nullptr; - uint32_t negative_ids_length = 0; - // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); + if (negative_trie == nullptr) { + return; } - ids_length = negative_ids_length; - ids = negative_ids; + uint32_t* negative_ids = nullptr; + uint32_t negative_ids_length = 0; + auto abs_low = std::abs(value); + + // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; } else { // Have to combine the results of search_less_than(inclusive ? value : value - 1, positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; + } + + if (negative_trie == nullptr) { + return; } uint32_t* negative_ids = nullptr; uint32_t negative_ids_length = 0; - if (negative_trie != nullptr) { - negative_trie->get_all_ids(negative_ids, negative_ids_length); - } + negative_trie->get_all_ids(negative_ids, negative_ids_length); - ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, negative_ids, negative_ids_length, &ids); + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); delete [] negative_ids; - delete [] positive_ids; - return; + delete [] ids; + ids = out; } } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 6a45054a..a4e5cb8a 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -295,73 +295,74 @@ TEST_F(NumericRangeTrieTest, SearchLessThan) { uint32_t ids_length = 0; trie->search_less_than(0, true, ids, ids_length); - std::unique_ptr ids_guard(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_less_than(0, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_less_than(-1, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_less_than(-16384, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(3, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_less_than(-16384, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(2, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_less_than(8192, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(5, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_less_than(8192, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_less_than(-1000000, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_less_than(1000000, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(8, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + + reset(ids, ids_length); } TEST_F(NumericRangeTrieTest, SearchEqualTo) { From 5e36ea85f0fe13a452942971e203c727a524071e Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 1 Jun 2023 12:52:09 +0530 Subject: [PATCH 79/93] Refactor `NumericTrie::search_range`. --- src/numeric_range_trie.cpp | 67 ++++++++++++++++++++------------ test/numeric_range_trie_test.cpp | 56 +++++++++++++------------- 2 files changed, 73 insertions(+), 50 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 5b397716..d5f93e36 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -27,51 +27,70 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, if (low < 0 && high >= 0) { // Have to combine the results of >low from negative_trie and low from negative_trie. negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; } - uint32_t* positive_ids = nullptr; - uint32_t positive_ids_length = 0; if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) + uint32_t* positive_ids = nullptr; + uint32_t positive_ids_length = 0; positive_trie->search_less_than(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; } - - ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, positive_ids, positive_ids_length, &ids); - - delete [] negative_ids; - delete [] positive_ids; - return; } else if (low >= 0) { // Search only in positive_trie + if (positive_trie == nullptr) { + return; + } uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - if (positive_trie != nullptr) { - positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, - positive_ids, positive_ids_length); - } + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, + positive_ids, positive_ids_length); - ids = positive_ids; - ids_length = positive_ids_length; + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); + + delete [] positive_ids; + delete [] ids; + ids = out; } else { // Search only in negative_trie + if (negative_trie == nullptr) { + return; + } uint32_t* negative_ids = nullptr; uint32_t negative_ids_length = 0; - if (negative_trie != nullptr) { - // Since we store absolute values, switching low and high would produce the correct result. - auto abs_high = std::abs(high), abs_low = std::abs(low); - negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, - negative_ids, negative_ids_length); - } + // Since we store absolute values, switching low and high would produce the correct result. + auto abs_high = std::abs(high), abs_low = std::abs(low); + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + negative_ids, negative_ids_length); - ids = negative_ids; - ids_length = negative_ids_length; + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); + + delete [] negative_ids; + delete [] ids; + ids = out; } } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index a4e5cb8a..b8ba9186 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -37,52 +37,51 @@ TEST_F(NumericRangeTrieTest, SearchRange) { uint32_t ids_length = 0; trie->search_range(32768, true, -32768, true, ids, ids_length); - std::unique_ptr ids_guard(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(-32768, true, 32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(pairs.size(), ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-32768, true, 32768, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(pairs.size() - 1, ids_length); for (uint32_t i = 0; i < pairs.size() - 1; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-32768, true, 134217728, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(pairs.size(), ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-32768, true, 0, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 0; i < 4; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-32768, true, 0, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 0; i < 4; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-32768, false, 32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(pairs.size() - 1, ids_length); for (uint32_t i = 0, j = 0; i < pairs.size(); i++) { @@ -90,93 +89,97 @@ TEST_F(NumericRangeTrieTest, SearchRange) { ASSERT_EQ(pairs[i].second, ids[j++]); } + reset(ids, ids_length); trie->search_range(-134217728, true, 32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(pairs.size(), ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-134217728, true, 134217728, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(pairs.size(), ids_length); for (uint32_t i = 0; i < pairs.size(); i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + reset(ids, ids_length); trie->search_range(-1, true, 32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_range(-1, false, 32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < pairs.size(); i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_range(-1, true, 0, true, ids, ids_length); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(-1, false, 0, false, ids, ids_length); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(8192, true, 32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_range(8192, true, 0x2000000, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 4, j = 0; i < ids_length; i++, j++) { ASSERT_EQ(pairs[i].second, ids[j]); } + reset(ids, ids_length); trie->search_range(16384, true, 16384, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(1, ids_length); ASSERT_EQ(56, ids[0]); + reset(ids, ids_length); trie->search_range(16384, true, 16384, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(16384, false, 16384, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(16383, true, 16383, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(8193, true, 16383, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + reset(ids, ids_length); trie->search_range(-32768, true, -8192, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(pairs[i].second, ids[i]); } + + reset(ids, ids_length); } TEST_F(NumericRangeTrieTest, SearchGreaterThan) { @@ -439,7 +442,6 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { uint32_t ids_length = 0; trie->search_less_than(0, false, ids, ids_length); - std::unique_ptr ids_guard(ids); std::vector expected = {5, 8, 32, 35, 43}; @@ -448,8 +450,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } + reset(ids, ids_length); trie->search_less_than(-16380, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(4, ids_length); @@ -458,8 +460,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } + reset(ids, ids_length); trie->search_less_than(16384, false, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(7, ids_length); @@ -468,8 +470,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } + reset(ids, ids_length); trie->search_greater_than(0, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(7, ids_length); @@ -478,8 +480,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } + reset(ids, ids_length); trie->search_greater_than(256, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(5, ids_length); @@ -488,8 +490,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } + reset(ids, ids_length); trie->search_greater_than(-32768, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(9, ids_length); @@ -498,8 +500,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { ASSERT_EQ(expected[i], ids[i]); } + reset(ids, ids_length); trie->search_range(-32768, true, 0, true, ids, ids_length); - ids_guard.reset(ids); ASSERT_EQ(6, ids_length); @@ -507,6 +509,8 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { for (uint32_t i = 0; i < ids_length; i++) { ASSERT_EQ(expected[i], ids[i]); } + + reset(ids, ids_length); } TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { From bc8a5fc96dad779f987b233a9e29bc81d48988f5 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 1 Jun 2023 16:46:04 +0530 Subject: [PATCH 80/93] Add `range_index` property. --- include/field.h | 9 ++- include/index.h | 3 + src/field.cpp | 19 ++++++- src/filter_result_iterator.cpp | 69 +++++++++++++++++------ src/index.cpp | 29 +++++++++- test/numeric_range_trie_test.cpp | 94 +++++++++++++++++++++++++++++++- 6 files changed, 200 insertions(+), 23 deletions(-) diff --git a/include/field.h b/include/field.h index cbd62af2..a0eca2af 100644 --- a/include/field.h +++ b/include/field.h @@ -54,6 +54,7 @@ namespace fields { static const std::string from = "from"; static const std::string embed_from = "embed_from"; static const std::string model_name = "model_name"; + static const std::string range_index = "range_index"; // Some models require additional parameters to be passed to the model during indexing/querying // For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying @@ -93,13 +94,17 @@ struct field { std::string reference; // Foo.bar (reference to bar field in Foo collection). + bool range_index; + field() {} field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false, - int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) : + int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, + std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale), - nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) { + nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), + embed(embed), range_index(range_index) { set_computed_defaults(sort, infix); } diff --git a/include/index.h b/include/index.h index 53c4c561..3554d2c5 100644 --- a/include/index.h +++ b/include/index.h @@ -30,6 +30,7 @@ #include "vector_query_ops.h" #include "hnswlib/hnswlib.h" #include "filter.h" +#include "numeric_range_trie_test.h" static constexpr size_t ARRAY_FACET_DIM = 4; using facet_map_t = spp::sparse_hash_map; @@ -302,6 +303,8 @@ private: spp::sparse_hash_map numerical_index; + spp::sparse_hash_map range_index; + spp::sparse_hash_map>*> geopoint_index; // geo_array_field => (seq_id => values) used for exact filtering of geo array records diff --git a/src/field.cpp b/src/field.cpp index f48ecf50..3882b45f 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -75,6 +75,23 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::reference] = ""; } + if (field_json.count(fields::range_index) != 0) { + if (!field_json.at(fields::range_index).is_boolean()) { + return Option(400, std::string("The `range_index` property of the field `") + + field_json[fields::name].get() + + std::string("` should be a boolean.")); + } + + auto const& type = field_json["type"]; + if (type != field_types::INT32 && type != field_types::INT32_ARRAY && + type != field_types::INT64 && type != field_types::INT64_ARRAY && + type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) { + return Option(400, std::string("The `range_index` property is only allowed for the numerical fields`")); + } + } else { + field_json[fields::range_index] = false; + } + if(field_json["name"] == ".*") { if(field_json.count(fields::facet) == 0) { field_json[fields::facet] = false; @@ -297,7 +314,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::optional], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix], field_json[fields::nested], field_json[fields::nested_array], field_json[fields::num_dim], vec_dist, - field_json[fields::reference], field_json[fields::embed]) + field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index]) ); if (!field_json[fields::reference].get().empty()) { diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index a6b61134..8a3c6d89 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -646,27 +646,62 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); if (f.is_integer()) { - auto num_tree = index->numerical_index.at(a_filter.field_name); + if (f.is_int32() && f.range_index) { + auto const& trie = index->range_index.at(a_filter.field_name); - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - int64_t value = (int64_t)std::stol(filter_value); + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + auto const& value = (int32_t)std::stoi(filter_value); - size_t result_size = filter_result.count; - if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi + 1]; - auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); - } else { - num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const& range_end_value = (int32_t)std::stoi(next_filter_value); + trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count); + fi++; + } else if (a_filter.comparators[fi] == EQUALS) { + trie->search_equal_to(value, filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + uint32_t to_exclude_ids_len = 0; + trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len); + + auto all_ids = index->seq_ids->uncompress(); + filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(), + to_exclude_ids, to_exclude_ids_len, &filter_result.docs); + + delete[] all_ids; + delete[] to_exclude_ids; + } else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) { + trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS, + filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) { + trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS, + filter_result.docs, filter_result.count); + } } + } else { + auto num_tree = index->numerical_index.at(a_filter.field_name); - filter_result.count = result_size; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + int64_t value = (int64_t)std::stol(filter_value); + + size_t result_size = filter_result.count; + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const range_end_value = (int64_t)std::stol(next_filter_value); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); + fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + } + + filter_result.count = result_size; + } } if (a_filter.apply_not_equals) { diff --git a/src/index.cpp b/src/index.cpp index e8634a3b..f146564a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -88,6 +88,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* } else { num_tree_t* num_tree = new num_tree_t; numerical_index.emplace(a_field.name, num_tree); + + if (a_field.range_index) { + auto trie = new NumericTrie(); + range_index.emplace(a_field.name, trie); + } } if(a_field.sort) { @@ -161,6 +166,13 @@ Index::~Index() { numerical_index.clear(); + for(auto & name_tree: range_index) { + delete name_tree.second; + name_tree.second = nullptr; + } + + range_index.clear(); + for(auto & name_map: sort_index) { delete name_map.second; name_map.second = nullptr; @@ -738,6 +750,15 @@ void Index::index_field_in_memory(const field& afield, std::vector if(!afield.is_string()) { if (afield.type == field_types::INT32) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + int32_t value = record.doc[afield.name].get(); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -900,13 +921,19 @@ void Index::index_field_in_memory(const field& afield, std::vector // all other numerical arrays auto num_tree = numerical_index.at(afield.name); - iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] + auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr; + iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie] (const index_record& record, uint32_t seq_id) { for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) { const auto& arr_value = record.doc[afield.name][arr_i]; if(afield.type == field_types::INT32_ARRAY) { const int32_t value = arr_value; + + if (afield.range_index) { + trie->insert(value, seq_id); + } + num_tree->insert(value, seq_id); } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index b8ba9186..5d9cca7d 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -1,12 +1,35 @@ #include +#include +#include "collection.h" #include "numeric_range_trie_test.h" class NumericRangeTrieTest : public ::testing::Test { protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; - virtual void SetUp() {} + std::vector query_fields; + std::vector sort_fields; - virtual void TearDown() {} + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } }; void reset(uint32_t*& ids, uint32_t& ids_length) { @@ -570,3 +593,70 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { ASSERT_EQ(0, ids_length); } + +TEST_F(NumericRangeTrieTest, Integration) { + Collection *coll_array_fields; + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::vector fields = { + field("name", field_types::STRING, false), + field("rating", field_types::FLOAT, false), + field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(), + true), // Setting range index true. + field("years", field_types::INT32_ARRAY, false), + field("timestamps", field_types::INT64_ARRAY, false), + field("tags", field_types::STRING_ARRAY, true) + }; + + std::vector sort_fields = { sort_by("age", "DESC") }; + + coll_array_fields = collectionManager.get_collection("coll_array_fields").get(); + if(coll_array_fields == nullptr) { + // ensure that default_sorting_field is a non-array numerical field + auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years"); + ASSERT_EQ(false, coll_op.ok()); + ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str()); + + // let's try again properly + coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age"); + coll_array_fields = coll_op.get(); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + auto add_op = coll_array_fields->add(json_line); + LOG(INFO) << add_op.error(); + ASSERT_TRUE(add_op.ok()); + } + + infile.close(); + + // Plain search with no filters - results should be sorted by rank fields + query_fields = {"name"}; + std::vector facets; + nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(5, results["hits"].size()); + + std::vector ids = {"3", "1", "4", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // Searching on an int32 field + results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); + + ids = {"3", "1", "4"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } +} From 93fd3373174a0dd0fcdb2b2b690eabb14652b83f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 2 Jun 2023 11:13:58 +0530 Subject: [PATCH 81/93] Optimize methods. --- src/numeric_range_trie.cpp | 55 +++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index d5f93e36..6ee805d0 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -1,3 +1,4 @@ +#include #include "numeric_range_trie_test.h" #include "array_utils.h" @@ -302,15 +303,25 @@ void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, u std::vector matches; search_less_than_helper(value, level, matches); + std::vector consolidated_ids; for (auto const& match: matches) { - uint32_t* out = nullptr; auto const& m_seq_ids = match->seq_ids.uncompress(); - ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + consolidated_ids.push_back(m_seq_ids[i]); + } delete [] m_seq_ids; - delete [] ids; - ids = out; } + + gfx::timsort(consolidated_ids.begin(), consolidated_ids.end()); + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(), + ids, ids_length, &out); + + delete [] ids; + ids = out; } void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& level, std::vector& matches) { @@ -342,15 +353,25 @@ void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, ui std::vector matches; search_range_helper(low, high, matches); + std::vector consolidated_ids; for (auto const& match: matches) { - uint32_t* out = nullptr; auto const& m_seq_ids = match->seq_ids.uncompress(); - ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + consolidated_ids.push_back(m_seq_ids[i]); + } delete [] m_seq_ids; - delete [] ids; - ids = out; } + + gfx::timsort(consolidated_ids.begin(), consolidated_ids.end()); + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(), + ids, ids_length, &out); + + delete [] ids; + ids = out; } void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& high, @@ -408,15 +429,25 @@ void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids std::vector matches; search_greater_than_helper(value, level, matches); + std::vector consolidated_ids; for (auto const& match: matches) { - uint32_t* out = nullptr; auto const& m_seq_ids = match->seq_ids.uncompress(); - ids_length = ArrayUtils::or_scalar(m_seq_ids, match->seq_ids.getLength(), ids, ids_length, &out); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + consolidated_ids.push_back(m_seq_ids[i]); + } delete [] m_seq_ids; - delete [] ids; - ids = out; } + + gfx::timsort(consolidated_ids.begin(), consolidated_ids.end()); + consolidated_ids.erase(unique(consolidated_ids.begin(), consolidated_ids.end()), consolidated_ids.end()); + + uint32_t* out = nullptr; + ids_length = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(), + ids, ids_length, &out); + + delete [] ids; + ids = out; } void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& level, std::vector& matches) { From 03edc150270091735f89fe80d1e967846db7b817 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Fri, 2 Jun 2023 19:14:11 +0530 Subject: [PATCH 82/93] Add support for `int64` and `float` fields in `NumericTrie`. --- include/numeric_range_trie_test.h | 40 ++++++---- src/filter_result_iterator.cpp | 74 +++++++++++++----- src/index.cpp | 32 +++++++- src/numeric_range_trie.cpp | 120 ++++++++++++++++-------------- test/numeric_range_trie_test.cpp | 37 ++++----- 5 files changed, 194 insertions(+), 109 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index 737cf5f9..f5d6add2 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -3,21 +3,25 @@ #include #include "sorted_array.h" -constexpr char MAX_LEVEL = 4; constexpr short EXPANSE = 256; class NumericTrie { + char max_level = 4; + class Node { Node** children = nullptr; sorted_array seq_ids; - void insert_helper(const int32_t& value, const uint32_t& seq_id, char& level); + void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level); - void search_range_helper(const int32_t& low,const int32_t& high, std::vector& matches); + void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, + std::vector& matches); - void search_less_than_helper(const int32_t& value, char& level, std::vector& matches); + void search_less_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches); - void search_greater_than_helper(const int32_t& value, char& level, std::vector& matches); + void search_greater_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches); public: @@ -31,18 +35,18 @@ class NumericTrie { delete [] children; } - void insert(const int32_t& value, const uint32_t& seq_id); + void insert(const int64_t& value, const uint32_t& seq_id, const char& max_level); void get_all_ids(uint32_t*& ids, uint32_t& ids_length); - void search_range(const int32_t& low, const int32_t& high, + void search_range(const int64_t& low, const int64_t& high, const char& max_level, uint32_t*& ids, uint32_t& ids_length); - void search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); - void search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); - void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); }; Node* negative_trie = nullptr; @@ -50,22 +54,26 @@ class NumericTrie { public: + explicit NumericTrie(char num_bits = 32) { + max_level = num_bits / 8; + } + ~NumericTrie() { delete negative_trie; delete positive_trie; } - void insert(const int32_t& value, const uint32_t& seq_id); + void insert(const int64_t& value, const uint32_t& seq_id); - void search_range(const int32_t& low, const bool& low_inclusive, - const int32_t& high, const bool& high_inclusive, + void search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_less_than(const int32_t& value, const bool& inclusive, + void search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_greater_than(const int32_t& value, const bool& inclusive, + void search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); - void search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length); + void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length); }; diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 8a3c6d89..0e0a8b9a 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -646,7 +646,7 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); if (f.is_integer()) { - if (f.is_int32() && f.range_index) { + if (f.range_index) { auto const& trie = index->range_index.at(a_filter.field_name); for (size_t fi = 0; fi < a_filter.values.size(); fi++) { @@ -718,28 +718,64 @@ void filter_result_iterator_t::init() { is_filter_result_initialized = true; return; } else if (f.is_float()) { - auto num_tree = index->numerical_index.at(a_filter.field_name); + if (f.range_index) { + auto const& trie = index->range_index.at(a_filter.field_name); - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - float value = (float)std::atof(filter_value.c_str()); - int64_t float_int64 = Index::float_to_int64_t(value); + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); - size_t result_size = filter_result.count; - if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); - num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, float_int64, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); - } else { - num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + trie->search_range(float_int64, true, range_end_value, true, filter_result.docs, filter_result.count); + fi++; + } else if (a_filter.comparators[fi] == EQUALS) { + trie->search_equal_to(float_int64, filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + uint32_t to_exclude_ids_len = 0; + trie->search_equal_to(float_int64, to_exclude_ids, to_exclude_ids_len); + + auto all_ids = index->seq_ids->uncompress(); + filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(), + to_exclude_ids, to_exclude_ids_len, &filter_result.docs); + + delete[] all_ids; + delete[] to_exclude_ids; + } else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) { + trie->search_greater_than(float_int64, a_filter.comparators[fi] == GREATER_THAN_EQUALS, + filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) { + trie->search_less_than(float_int64, a_filter.comparators[fi] == LESS_THAN_EQUALS, + filter_result.docs, filter_result.count); + } } + } else { + auto num_tree = index->numerical_index.at(a_filter.field_name); - filter_result.count = result_size; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + float value = (float)std::atof(filter_value.c_str()); + int64_t float_int64 = Index::float_to_int64_t(value); + + size_t result_size = filter_result.count; + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi+1]; + int64_t range_end_value = Index::float_to_int64_t((float) std::atof(next_filter_value.c_str())); + num_tree->range_inclusive_search(float_int64, range_end_value, &filter_result.docs, result_size); + fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, float_int64, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[fi], float_int64, &filter_result.docs, result_size); + } + + filter_result.count = result_size; + } } if (a_filter.apply_not_equals) { diff --git a/src/index.cpp b/src/index.cpp index f146564a..84d1803a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -90,7 +90,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* numerical_index.emplace(a_field.name, num_tree); if (a_field.range_index) { - auto trie = new NumericTrie(); + auto trie = a_field.is_int32() ? new NumericTrie() : new NumericTrie(64); range_index.emplace(a_field.name, trie); } } @@ -768,6 +768,15 @@ void Index::index_field_in_memory(const field& afield, std::vector } else if(afield.type == field_types::INT64) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + int64_t value = record.doc[afield.name].get(); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -777,6 +786,16 @@ void Index::index_field_in_memory(const field& afield, std::vector } else if(afield.type == field_types::FLOAT) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + float fvalue = record.doc[afield.name].get(); + int64_t value = float_to_int64_t(fvalue); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -929,23 +948,30 @@ void Index::index_field_in_memory(const field& afield, std::vector if(afield.type == field_types::INT32_ARRAY) { const int32_t value = arr_value; + num_tree->insert(value, seq_id); if (afield.range_index) { trie->insert(value, seq_id); } - - num_tree->insert(value, seq_id); } else if(afield.type == field_types::INT64_ARRAY) { const int64_t value = arr_value; num_tree->insert(value, seq_id); + + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::FLOAT_ARRAY) { const float fvalue = arr_value; int64_t value = float_to_int64_t(fvalue); num_tree->insert(value, seq_id); + + if (afield.range_index) { + trie->insert(value, seq_id); + } } else if(afield.type == field_types::BOOL_ARRAY) { diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 6ee805d0..3076f873 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -2,24 +2,24 @@ #include "numeric_range_trie_test.h" #include "array_utils.h" -void NumericTrie::insert(const int32_t& value, const uint32_t& seq_id) { +void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) { if (value < 0) { if (negative_trie == nullptr) { negative_trie = new NumericTrie::Node(); } - negative_trie->insert(std::abs(value), seq_id); + negative_trie->insert(std::abs(value), seq_id, max_level); } else { if (positive_trie == nullptr) { positive_trie = new NumericTrie::Node(); } - positive_trie->insert(value, seq_id); + positive_trie->insert(value, seq_id, max_level); } } -void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, - const int32_t& high, const bool& high_inclusive, +void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length) { if (low > high) { return; @@ -34,7 +34,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, auto abs_low = std::abs(low); // Since we store absolute values, search_lesser would yield result for >low from negative_trie. - negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, + negative_ids, negative_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); @@ -47,7 +48,8 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_less_than(high_inclusive ? high : high - 1, positive_ids, positive_ids_length); + positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, + positive_ids, positive_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); @@ -64,7 +66,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, positive_ids, positive_ids_length); uint32_t* out = nullptr; @@ -84,6 +86,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, // Since we store absolute values, switching low and high would produce the correct result. auto abs_high = std::abs(high), abs_low = std::abs(low); negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + max_level, negative_ids, negative_ids_length); uint32_t* out = nullptr; @@ -95,7 +98,7 @@ void NumericTrie::search_range(const int32_t& low, const bool& low_inclusive, } } -void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { uint32_t* positive_ids = nullptr; @@ -119,7 +122,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_greater_than(inclusive ? value : value + 1, positive_ids, positive_ids_length); + positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, positive_ids, positive_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); @@ -136,7 +139,8 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv auto abs_low = std::abs(value); // Since we store absolute values, search_lesser would yield result for >value from negative_trie. - negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, negative_ids, negative_ids_length); + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, + negative_ids, negative_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); @@ -163,7 +167,7 @@ void NumericTrie::search_greater_than(const int32_t& value, const bool& inclusiv } } -void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] if (negative_trie != nullptr) { uint32_t* negative_ids = nullptr; @@ -190,7 +194,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, auto abs_low = std::abs(value); // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, negative_ids, negative_ids_length); + negative_trie->search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, + negative_ids, negative_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(negative_ids, negative_ids_length, ids, ids_length, &out); @@ -204,7 +209,8 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, if (positive_trie != nullptr) { uint32_t* positive_ids = nullptr; uint32_t positive_ids_length = 0; - positive_trie->search_less_than(inclusive ? value : value - 1, positive_ids, positive_ids_length); + positive_trie->search_less_than(inclusive ? value : value - 1, max_level, + positive_ids, positive_ids_length); uint32_t* out = nullptr; ids_length = ArrayUtils::or_scalar(positive_ids, positive_ids_length, ids, ids_length, &out); @@ -231,7 +237,7 @@ void NumericTrie::search_less_than(const int32_t& value, const bool& inclusive, } } -void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) { if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { return; } @@ -240,9 +246,9 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t uint32_t equal_ids_length = 0; if (value < 0) { - negative_trie->search_equal_to(std::abs(value), equal_ids, equal_ids_length); + negative_trie->search_equal_to(std::abs(value), max_level, equal_ids, equal_ids_length); } else { - positive_trie->search_equal_to(value, equal_ids, equal_ids_length); + positive_trie->search_equal_to(value, max_level, equal_ids, equal_ids_length); } uint32_t* out = nullptr; @@ -253,12 +259,12 @@ void NumericTrie::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t ids = out; } -void NumericTrie::Node::insert(const int32_t& value, const uint32_t& seq_id) { +void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) { char level = 0; - return insert_helper(value, seq_id, level); + return insert_helper(value, seq_id, level, max_level); } -inline int get_index(const int32_t& value, char& level) { +inline int get_index(const int64_t& value, const char& level, const char& max_level) { // Values are index considering higher order of the bytes first. // 0x01020408 (16909320) would be indexed in the trie as follows: // Level Index @@ -266,11 +272,11 @@ inline int get_index(const int32_t& value, char& level) { // 2 2 // 3 4 // 4 8 - return (value >> (8 * (MAX_LEVEL - level))) & 0xFF; + return (value >> (8 * (max_level - level))) & 0xFF; } -void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_id, char& level) { - if (level > MAX_LEVEL) { +void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { + if (level > max_level) { return; } @@ -279,17 +285,17 @@ void NumericTrie::Node::insert_helper(const int32_t& value, const uint32_t& seq_ seq_ids.append(seq_id); } - if (++level <= MAX_LEVEL) { + if (++level <= max_level) { if (children == nullptr) { children = new NumericTrie::Node* [EXPANSE]{nullptr}; } - auto index = get_index(value, level); + auto index = get_index(value, level, max_level); if (children[index] == nullptr) { children[index] = new NumericTrie::Node(); } - return children[index]->insert_helper(value, seq_id, level); + return children[index]->insert_helper(value, seq_id, level, max_level); } } @@ -298,10 +304,11 @@ void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { ids_length = seq_ids.getLength(); } -void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; - search_less_than_helper(value, level, matches); + search_less_than_helper(value, level, max_level, matches); std::vector consolidated_ids; for (auto const& match: matches) { @@ -324,17 +331,18 @@ void NumericTrie::Node::search_less_than(const int32_t& value, uint32_t*& ids, u ids = out; } -void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& level, std::vector& matches) { - if (level == MAX_LEVEL) { +void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches) { + if (level == max_level) { matches.push_back(this); return; - } else if (level > MAX_LEVEL || children == nullptr) { + } else if (level > max_level || children == nullptr) { return; } - auto index = get_index(value, ++level); + auto index = get_index(value, ++level, max_level); if (children[index] != nullptr) { - children[index]->search_less_than_helper(value, level, matches); + children[index]->search_less_than_helper(value, level, max_level, matches); } while (--index >= 0) { @@ -346,12 +354,13 @@ void NumericTrie::Node::search_less_than_helper(const int32_t& value, char& leve --level; } -void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { if (low > high) { return; } std::vector matches; - search_range_helper(low, high, matches); + search_range_helper(low, high, max_level, matches); std::vector consolidated_ids; for (auto const& match: matches) { @@ -374,24 +383,24 @@ void NumericTrie::Node::search_range(const int32_t& low, const int32_t& high, ui ids = out; } -void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& high, - std::vector& matches) { +void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, + std::vector& matches) { // Segregating the nodes into matching low, in-between, and matching high. NumericTrie::Node* root = this; char level = 1; - auto low_index = get_index(low, level), high_index = get_index(high, level); + auto low_index = get_index(low, level, max_level), high_index = get_index(high, level, max_level); // Keep updating the root while the range is contained within a single child node. - while (root->children != nullptr && low_index == high_index && level < MAX_LEVEL) { + while (root->children != nullptr && low_index == high_index && level < max_level) { if (root->children[low_index] == nullptr) { return; } root = root->children[low_index]; level++; - low_index = get_index(low, level); - high_index = get_index(high, level); + low_index = get_index(low, level, max_level); + high_index = get_index(high, level, max_level); } if (root->children == nullptr) { @@ -405,7 +414,7 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h if (root->children[low_index] != nullptr) { // Collect all the sub-nodes that are greater than low. - root->children[low_index]->search_greater_than_helper(low, level, matches); + root->children[low_index]->search_greater_than_helper(low, level, max_level, matches); } auto index = low_index + 1; @@ -420,14 +429,15 @@ void NumericTrie::Node::search_range_helper(const int32_t& low, const int32_t& h if (index < EXPANSE && index == high_index && root->children[index] != nullptr) { // Collect all the sub-nodes that are lesser than high. - root->children[index]->search_less_than_helper(high, level, matches); + root->children[index]->search_less_than_helper(high, level, max_level, matches); } } -void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { char level = 0; std::vector matches; - search_greater_than_helper(value, level, matches); + search_greater_than_helper(value, level, max_level, matches); std::vector consolidated_ids; for (auto const& match: matches) { @@ -450,17 +460,18 @@ void NumericTrie::Node::search_greater_than(const int32_t& value, uint32_t*& ids ids = out; } -void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& level, std::vector& matches) { - if (level == MAX_LEVEL) { +void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level, + std::vector& matches) { + if (level == max_level) { matches.push_back(this); return; - } else if (level > MAX_LEVEL || children == nullptr) { + } else if (level > max_level || children == nullptr) { return; } - auto index = get_index(value, ++level); + auto index = get_index(value, ++level, max_level); if (children[index] != nullptr) { - children[index]->search_greater_than_helper(value, level, matches); + children[index]->search_greater_than_helper(value, level, max_level, matches); } while (++index < EXPANSE) { @@ -472,18 +483,19 @@ void NumericTrie::Node::search_greater_than_helper(const int32_t& value, char& l --level; } -void NumericTrie::Node::search_equal_to(const int32_t& value, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, + uint32_t*& ids, uint32_t& ids_length) { char level = 1; Node* root = this; - auto index = get_index(value, level); + auto index = get_index(value, level, max_level); - while (level <= MAX_LEVEL) { + while (level <= max_level) { if (root->children == nullptr || root->children[index] == nullptr) { return; } root = root->children[index]; - index = get_index(value, ++level); + index = get_index(value, ++level, max_level); } root->get_all_ids(ids, ids_length); diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 5d9cca7d..29dff68a 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -604,7 +604,8 @@ TEST_F(NumericRangeTrieTest, Integration) { field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(), true), // Setting range index true. field("years", field_types::INT32_ARRAY, false), - field("timestamps", field_types::INT64_ARRAY, false), + field("timestamps", field_types::INT64_ARRAY, false, false, true, "", -1, -1, false, 0, 0, cosine, "", + nlohmann::json(), true), field("tags", field_types::STRING_ARRAY, true) }; @@ -626,32 +627,18 @@ TEST_F(NumericRangeTrieTest, Integration) { while (std::getline(infile, json_line)) { auto add_op = coll_array_fields->add(json_line); - LOG(INFO) << add_op.error(); ASSERT_TRUE(add_op.ok()); } infile.close(); - // Plain search with no filters - results should be sorted by rank fields query_fields = {"name"}; std::vector facets; - nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); - ASSERT_EQ(5, results["hits"].size()); - - std::vector ids = {"3", "1", "4", "0", "2"}; - - for(size_t i = 0; i < results["hits"].size(); i++) { - nlohmann::json result = results["hits"].at(i); - std::string result_id = result["document"]["id"]; - std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); - } - // Searching on an int32 field - results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); ASSERT_EQ(3, results["hits"].size()); - ids = {"3", "1", "4"}; + std::vector ids = {"3", "1", "4"}; for(size_t i = 0; i < results["hits"].size(); i++) { nlohmann::json result = results["hits"].at(i); @@ -659,4 +646,20 @@ TEST_F(NumericRangeTrieTest, Integration) { std::string id = ids.at(i); ASSERT_STREQ(id.c_str(), result_id.c_str()); } + + // searching on an int64 array field - also ensure that padded space causes no issues + results = coll_array_fields->search("Jeremy", query_fields, "timestamps : > 475205222", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(4, results["hits"].size()); + + ids = {"1", "4", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); } From cb9f7c7507f21f5db98c21989bff22362091746d Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 5 Jun 2023 13:23:09 +0530 Subject: [PATCH 83/93] Fix wildcard search with geo-filter producing maximum 100 results. --- include/index.h | 2 +- src/filter_result_iterator.cpp | 8 ++++++++ src/index.cpp | 14 ++++++++------ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/include/index.h b/include/index.h index 3554d2c5..3db9cea4 100644 --- a/include/index.h +++ b/include/index.h @@ -726,7 +726,7 @@ public: const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 0e0a8b9a..ee326b32 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -571,8 +571,10 @@ void filter_result_iterator_t::init() { if (filter_node->isOperator) { if (filter_node->filter_operator == AND) { and_filter_iterators(); + approx_filter_ids_length = std::min(left_it->approx_filter_ids_length, right_it->approx_filter_ids_length); } else { or_filter_iterators(); + approx_filter_ids_length = std::max(left_it->approx_filter_ids_length, right_it->approx_filter_ids_length); } return; @@ -612,6 +614,7 @@ void filter_result_iterator_t::init() { } is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; return; } @@ -635,6 +638,7 @@ void filter_result_iterator_t::init() { seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; return; } @@ -716,6 +720,7 @@ void filter_result_iterator_t::init() { seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; return; } else if (f.is_float()) { if (f.range_index) { @@ -790,6 +795,7 @@ void filter_result_iterator_t::init() { seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; return; } else if (f.is_bool()) { auto num_tree = index->numerical_index.at(a_filter.field_name); @@ -823,6 +829,7 @@ void filter_result_iterator_t::init() { seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; return; } else if (f.is_geopoint()) { for (uint32_t fi = 0; fi < a_filter.values.size(); fi++) { @@ -966,6 +973,7 @@ void filter_result_iterator_t::init() { seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; return; } else if (f.is_string()) { art_tree* t = index->search_index.at(a_filter.field_name); diff --git a/src/index.cpp b/src/index.cpp index 84d1803a..66370c99 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1693,7 +1693,9 @@ Option Index::_approximate_filter_ids(const filter& a_filter, value_index++; } } else if (f.is_geopoint()) { - filter_ids_length = 100; + // Optimistically setting a value greater than 0. Exact count would be found during initialization of + // filter_result_iterator. + filter_ids_length = 1; } else if (f.is_string()) { art_tree* t = search_index.at(a_filter.field_name); @@ -2325,7 +2327,7 @@ Option Index::search(std::vector& field_query_tokens, cons } auto filter_result_iterator = new filter_result_iterator_t(collection_name, this, filter_tree_root, - approx_filter_ids_length); + approx_filter_ids_length); std::unique_ptr filter_iterator_guard(filter_result_iterator); auto filter_init_op = filter_result_iterator->init_status(); @@ -2444,8 +2446,6 @@ Option Index::search(std::vector& field_query_tokens, cons if (no_filters_provided) { filter_result_iterator = new filter_result_iterator_t(seq_ids->uncompress(), seq_ids->num_ids()); filter_iterator_guard.reset(filter_result_iterator); - - approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; } collate_included_ids({}, included_ids_map, curated_topster, searched_queries); @@ -2558,7 +2558,7 @@ Option Index::search(std::vector& field_query_tokens, cons curated_ids, curated_ids_sorted, excluded_result_ids, excluded_result_ids_size, excluded_group_ids, all_result_ids, all_result_ids_len, - filter_result_iterator, approx_filter_ids_length, concurrency, + filter_result_iterator, concurrency, sort_order, field_values, geopoint_indices); filter_result_iterator->reset(); } @@ -4542,12 +4542,14 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, const std::vector& curated_ids_sorted, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::unordered_set& excluded_group_ids, uint32_t*& all_result_ids, size_t& all_result_ids_len, - filter_result_iterator_t* const filter_result_iterator, const uint32_t& approx_filter_ids_length, + filter_result_iterator_t* const filter_result_iterator, const size_t concurrency, const int* sort_order, std::array*, 3>& field_values, const std::vector& geopoint_indices) const { + auto const& approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; + uint32_t token_bits = 0; const bool check_for_circuit_break = (approx_filter_ids_length > 1000000); From 3e78413e1f26cb1475c11f0bf7e316a26554ab41 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 5 Jun 2023 13:33:00 +0530 Subject: [PATCH 84/93] Update test. --- test/collection_specific_more_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/collection_specific_more_test.cpp b/test/collection_specific_more_test.cpp index 65a94a71..b8419b37 100644 --- a/test/collection_specific_more_test.cpp +++ b/test/collection_specific_more_test.cpp @@ -2329,7 +2329,7 @@ TEST_F(CollectionSpecificMoreTest, ApproxFilterMatchCount) { ASSERT_TRUE(filter_op.ok()); coll->_get_index()->_approximate_filter_ids(filter_tree_root->filter_exp, approx_count); - ASSERT_EQ(approx_count, 100); + ASSERT_EQ(approx_count, 1); delete filter_tree_root; filter_op = filter::parse_filter_query("years:>2000 && ((age:<30 && rating:>5) || (age:>50 && rating:<5))", From 1d2dd365e21cb39c72b08141ed18226bcc8d9f83 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Mon, 5 Jun 2023 14:17:03 +0530 Subject: [PATCH 85/93] Fix failing tests. --- src/field.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/field.cpp b/src/field.cpp index 3882b45f..b62ddd77 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -83,7 +83,8 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso } auto const& type = field_json["type"]; - if (type != field_types::INT32 && type != field_types::INT32_ARRAY && + if (field_json[fields::range_index] && + type != field_types::INT32 && type != field_types::INT32_ARRAY && type != field_types::INT64 && type != field_types::INT64_ARRAY && type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) { return Option(400, std::string("The `range_index` property is only allowed for the numerical fields`")); From 257e1189fef7d55a55fb9acdd5dd76cd101fdf41 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 7 Jun 2023 14:42:46 +0530 Subject: [PATCH 86/93] Add `NumericTrie::iterator_t`. --- include/numeric_range_trie_test.h | 54 +++++++ src/numeric_range_trie.cpp | 226 ++++++++++++++++++++++++++++++ test/numeric_range_trie_test.cpp | 65 +++++++++ 3 files changed, 345 insertions(+) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index f5d6add2..a8422524 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -42,11 +42,19 @@ class NumericTrie { void search_range(const int64_t& low, const int64_t& high, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + void search_range(const int64_t& low, const int64_t& high, const char& max_level, std::vector& matches); + void search_less_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + void search_less_than(const int64_t& value, const char& max_level, std::vector& matches); + void search_greater_than(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + void search_greater_than(const int64_t& value, const char& max_level, std::vector& matches); + void search_equal_to(const int64_t& value, const char& max_level, uint32_t*& ids, uint32_t& ids_length); + + void search_equal_to(const int64_t& value, const char& max_level, std::vector& matches); }; Node* negative_trie = nullptr; @@ -63,17 +71,63 @@ public: delete positive_trie; } + class iterator_t { + struct match_state { + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + uint32_t index = 0; + + explicit match_state(uint32_t*& ids, uint32_t& ids_length) : ids(ids), ids_length(ids_length) {} + + ~match_state() { + delete [] ids; + } + }; + + std::vector matches; + + void set_seq_id(); + + public: + + explicit iterator_t(std::vector& matches); + + ~iterator_t() { + for (auto& match: matches) { + delete match; + } + } + + iterator_t& operator=(iterator_t&& obj) noexcept; + + uint32_t seq_id = 0; + bool is_valid = true; + + void next(); + void skip_to(uint32_t id); + void reset(); + }; + void insert(const int64_t& value, const uint32_t& seq_id); void search_range(const int64_t& low, const bool& low_inclusive, const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length); + iterator_t search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive); + void search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); + iterator_t search_less_than(const int64_t& value, const bool& inclusive); + void search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length); + iterator_t search_greater_than(const int64_t& value, const bool& inclusive); + void search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length); + + iterator_t search_equal_to(const int64_t& value); }; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 3076f873..9d9f4aa0 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -98,6 +98,47 @@ void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, } } +NumericTrie::iterator_t NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, + const int64_t& high, const bool& high_inclusive) { + std::vector matches; + if (low > high) { + return NumericTrie::iterator_t(matches); + } + + if (low < 0 && high >= 0) { + // Have to combine the results of >low from negative_trie and low from negative_trie. + negative_trie->search_less_than(low_inclusive ? abs_low : abs_low - 1, max_level, matches); + } + + if (positive_trie != nullptr && !(high == 0 && !high_inclusive)) { // No need to search for ..., 0) + positive_trie->search_less_than(high_inclusive ? high : high - 1, max_level, matches); + } + } else if (low >= 0) { + // Search only in positive_trie + if (positive_trie == nullptr) { + return NumericTrie::iterator_t(matches); + } + + positive_trie->search_range(low_inclusive ? low : low + 1, high_inclusive ? high : high - 1, max_level, matches); + } else { + // Search only in negative_trie + if (negative_trie == nullptr) { + return NumericTrie::iterator_t(matches); + } + + auto abs_high = std::abs(high), abs_low = std::abs(low); + // Since we store absolute values, switching low and high would produce the correct result. + negative_trie->search_range(high_inclusive ? abs_high : abs_high + 1, low_inclusive ? abs_low : abs_low - 1, + max_level, matches); + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) if (positive_trie != nullptr) { @@ -167,6 +208,35 @@ void NumericTrie::search_greater_than(const int64_t& value, const bool& inclusiv } } +NumericTrie::iterator_t NumericTrie::search_greater_than(const int64_t& value, const bool& inclusive) { + std::vector matches; + + if ((value == 0 && inclusive) || (value == -1 && !inclusive)) { // [0, ∞), (-1, ∞) + if (positive_trie != nullptr) { + matches.push_back(positive_trie); + } + return NumericTrie::iterator_t(matches); + } + + if (value >= 0) { + if (positive_trie != nullptr) { + positive_trie->search_greater_than(inclusive ? value : value + 1, max_level, matches); + } + } else { + // Have to combine the results of >value from negative_trie and all the ids in positive_trie + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + // Since we store absolute values, search_lesser would yield result for >value from negative_trie. + negative_trie->search_less_than(inclusive ? abs_low : abs_low - 1, max_level, matches); + } + if (positive_trie != nullptr) { + matches.push_back(positive_trie); + } + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, uint32_t*& ids, uint32_t& ids_length) { if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] if (negative_trie != nullptr) { @@ -237,6 +307,35 @@ void NumericTrie::search_less_than(const int64_t& value, const bool& inclusive, } } +NumericTrie::iterator_t NumericTrie::search_less_than(const int64_t& value, const bool& inclusive) { + std::vector matches; + + if ((value == 0 && !inclusive) || (value == -1 && inclusive)) { // (-∞, 0), (-∞, -1] + if (negative_trie != nullptr) { + matches.push_back(negative_trie); + } + return NumericTrie::iterator_t(matches); + } + + if (value < 0) { + if (negative_trie != nullptr) { + auto abs_low = std::abs(value); + // Since we store absolute values, search_greater would yield result for search_greater_than(inclusive ? abs_low : abs_low + 1, max_level, matches); + } + } else { + // Have to combine the results of search_less_than(inclusive ? value : value - 1, max_level, matches); + } + if (negative_trie != nullptr) { + matches.push_back(negative_trie); + } + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t& ids_length) { if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { return; @@ -259,6 +358,17 @@ void NumericTrie::search_equal_to(const int64_t& value, uint32_t*& ids, uint32_t ids = out; } +NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) { + std::vector matches; + if (value < 0 && negative_trie != nullptr) { + negative_trie->search_equal_to(std::abs(value), max_level, matches); + } else if (value >= 0 && positive_trie != nullptr) { + positive_trie->search_equal_to(value, max_level, matches); + } + + return NumericTrie::iterator_t(matches); +} + void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) { char level = 0; return insert_helper(value, seq_id, level, max_level); @@ -331,6 +441,11 @@ void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_l ids = out; } +void NumericTrie::Node::search_less_than(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 0; + search_less_than_helper(value, level, max_level, matches); +} + void NumericTrie::Node::search_less_than_helper(const int64_t& value, char& level, const char& max_level, std::vector& matches) { if (level == max_level) { @@ -383,6 +498,15 @@ void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, co ids = out; } +void NumericTrie::Node::search_range(const int64_t& low, const int64_t& high, const char& max_level, + std::vector& matches) { + if (low > high) { + return; + } + + search_range_helper(low, high, max_level, matches); +} + void NumericTrie::Node::search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, std::vector& matches) { // Segregating the nodes into matching low, in-between, and matching high. @@ -460,6 +584,11 @@ void NumericTrie::Node::search_greater_than(const int64_t& value, const char& ma ids = out; } +void NumericTrie::Node::search_greater_than(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 0; + search_greater_than_helper(value, level, max_level, matches); +} + void NumericTrie::Node::search_greater_than_helper(const int64_t& value, char& level, const char& max_level, std::vector& matches) { if (level == max_level) { @@ -500,3 +629,100 @@ void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_le root->get_all_ids(ids, ids_length); } + +void NumericTrie::Node::search_equal_to(const int64_t& value, const char& max_level, std::vector& matches) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level <= max_level) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + matches.push_back(root); +} + +void NumericTrie::iterator_t::reset() { + for (auto& match: matches) { + match->index = 0; + } + + is_valid = true; + set_seq_id(); +} + +void NumericTrie::iterator_t::skip_to(uint32_t id) { + for (auto& match: matches) { + ArrayUtils::skip_index_to_id(match->index, match->ids, match->ids_length, id); + } + + set_seq_id(); +} + +void NumericTrie::iterator_t::next() { + // Advance all the matches at seq_id. + for (auto& match: matches) { + if (match->index < match->ids_length && match->ids[match->index] == seq_id) { + match->index++; + } + } + + set_seq_id(); +} + +NumericTrie::iterator_t::iterator_t(std::vector& node_matches) { + for (auto const& node_match: node_matches) { + uint32_t* ids = nullptr; + uint32_t ids_length; + node_match->get_all_ids(ids, ids_length); + if (ids_length > 0) { + matches.emplace_back(new match_state(ids, ids_length)); + } + } + + set_seq_id(); +} + +void NumericTrie::iterator_t::set_seq_id() { + // Find the lowest id of all the matches and update the seq_id. + bool one_is_valid = false; + uint32_t lowest_id = UINT32_MAX; + + for (auto& match: matches) { + if (match->index < match->ids_length) { + one_is_valid = true; + + if (match->ids[match->index] < lowest_id) { + lowest_id = match->ids[match->index]; + } + } + } + + if (one_is_valid) { + seq_id = lowest_id; + } + + is_valid = one_is_valid; +} + +NumericTrie::iterator_t& NumericTrie::iterator_t::operator=(NumericTrie::iterator_t&& obj) noexcept { + if (&obj == this) + return *this; + + for (auto& match: matches) { + delete match; + } + matches.clear(); + + matches = std::move(obj.matches); + seq_id = obj.seq_id; + is_valid = obj.is_valid; + + return *this; +} + diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index 29dff68a..d2fc6e16 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -436,6 +436,71 @@ TEST_F(NumericRangeTrieTest, SearchEqualTo) { ASSERT_EQ(0, ids_length); } +TEST_F(NumericRangeTrieTest, IterateSearchEqualTo) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-8192, 8}, + {-16384, 32}, + {-24576, 35}, + {-32769, 41}, + {-32768, 43}, + {-32767, 45}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {24576, 60}, + {32768, 91} + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + auto iterator = trie->search_equal_to(0); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(0x202020); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(-32768); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(43, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(false, iterator.is_valid); + + iterator = trie->search_equal_to(24576); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(60, iterator.seq_id); + + iterator.next(); + ASSERT_EQ(false, iterator.is_valid); + + + iterator.reset(); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.skip_to(4); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(58, iterator.seq_id); + + iterator.skip_to(59); + ASSERT_EQ(true, iterator.is_valid); + ASSERT_EQ(60, iterator.seq_id); + + iterator.skip_to(66); + ASSERT_EQ(false, iterator.is_valid); +} + TEST_F(NumericRangeTrieTest, MultivalueData) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); From b45e7c07d49619f3a82e7c210d99fa42ac3a5885 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Tue, 13 Jun 2023 16:00:25 +0530 Subject: [PATCH 87/93] Add `geo_range_index`. --- include/index.h | 4 +- include/numeric_range_trie_test.h | 16 ++++- src/filter_result_iterator.cpp | 16 +++-- src/index.cpp | 46 +++++------- src/numeric_range_trie.cpp | 116 +++++++++++++++++++++++++++++- 5 files changed, 161 insertions(+), 37 deletions(-) diff --git a/include/index.h b/include/index.h index 3db9cea4..e179f271 100644 --- a/include/index.h +++ b/include/index.h @@ -305,7 +305,9 @@ private: spp::sparse_hash_map range_index; - spp::sparse_hash_map>*> geopoint_index; + spp::sparse_hash_map geo_range_index; + +// spp::sparse_hash_map>*> geopoint_index; // geo_array_field => (seq_id => values) used for exact filtering of geo array records spp::sparse_hash_map*> geo_array_index; diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index a8422524..ed695a70 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -14,6 +14,8 @@ class NumericTrie { void insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level); + void insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, const char& max_level); + void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, std::vector& matches); @@ -35,7 +37,13 @@ class NumericTrie { delete [] children; } - void insert(const int64_t& value, const uint32_t& seq_id, const char& max_level); + void insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level); + + void search_geopoint(const uint64_t& cell_id, const char& max_index_level, uint32_t*& ids, uint32_t& ids_length); + + void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level); void get_all_ids(uint32_t*& ids, uint32_t& ids_length); @@ -110,6 +118,12 @@ public: void insert(const int64_t& value, const uint32_t& seq_id); + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id); + + void search_geopoint(const uint64_t& cell_id, uint32_t*& ids, uint32_t& ids_length); + + void delete_geopoint(const uint64_t& cell_id, uint32_t id); + void search_range(const int64_t& low, const bool& low_inclusive, const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index ee326b32..5fb432f2 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -892,13 +892,21 @@ void filter_result_iterator_t::init() { S2RegionTermIndexer::Options options; options.set_index_contains_points_only(true); S2RegionTermIndexer indexer(options); + auto const& geo_range_index = index->geo_range_index.at(a_filter.field_name); for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { - auto geo_index = index->geopoint_index.at(a_filter.field_name); - const auto& ids_it = geo_index->find(term); - if(ids_it != geo_index->end()) { - geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); + auto cell = S2CellId::FromToken(term); + uint32_t* geo_ids = nullptr; + uint32_t geo_ids_length = 0; + + geo_range_index->search_geopoint(cell.id(), geo_ids, geo_ids_length); + + geo_result_ids.reserve(geo_result_ids.size() + geo_ids_length); + for (uint32_t i = 0; i < geo_ids_length; i++) { + geo_result_ids.push_back(geo_ids[i]); } + + delete [] geo_ids; } gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); diff --git a/src/index.cpp b/src/index.cpp index 7d90c7b8..6a5a642e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -78,8 +78,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* art_tree_init(t); search_index.emplace(a_field.name, t); } else if(a_field.is_geopoint()) { - auto field_geo_index = new spp::sparse_hash_map>(); - geopoint_index.emplace(a_field.name, field_geo_index); + geo_range_index.emplace(a_field.name, new NumericTrie(64)); if(!a_field.is_single_geopoint()) { spp::sparse_hash_map * doc_to_geos = new spp::sparse_hash_map(); @@ -141,12 +140,12 @@ Index::~Index() { search_index.clear(); - for(auto & name_index: geopoint_index) { + for(auto & name_index: geo_range_index) { delete name_index.second; name_index.second = nullptr; } - geopoint_index.clear(); + geo_range_index.clear(); for(auto& name_index: geo_array_index) { for(auto& kv: *name_index.second) { @@ -811,10 +810,10 @@ void Index::index_field_in_memory(const field& afield, std::vector num_tree->insert(value, seq_id); }); } else if(afield.type == field_types::GEOPOINT || afield.type == field_types::GEOPOINT_ARRAY) { - auto geo_index = geopoint_index.at(afield.name); + auto geopoint_range_index = geo_range_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, - [&afield, &geo_array_index=geo_array_index, geo_index](const index_record& record, uint32_t seq_id) { + [&afield, &geo_array_index=geo_array_index, geopoint_range_index](const index_record& record, uint32_t seq_id) { // nested geopoint value inside an array of object will be a simple array so must be treated as geopoint bool nested_obj_arr_geopoint = (afield.nested && afield.type == field_types::GEOPOINT_ARRAY && !record.doc[afield.name].empty() && record.doc[afield.name][0].is_number()); @@ -828,9 +827,8 @@ void Index::index_field_in_memory(const field& afield, std::vector S2RegionTermIndexer indexer(options); S2Point point = S2LatLng::FromDegrees(latlongs[li], latlongs[li+1]).ToPoint(); - for(const auto& term: indexer.GetIndexTerms(point, "")) { - (*geo_index)[term].push_back(seq_id); - } + auto cell = S2CellId(point); + geopoint_range_index->insert_geopoint(cell.id(), seq_id); } if(nested_obj_arr_geopoint) { @@ -858,9 +856,9 @@ void Index::index_field_in_memory(const field& afield, std::vector for(size_t li = 0; li < latlongs.size(); li++) { auto& latlong = latlongs[li]; S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint(); - for(const auto& term: indexer.GetIndexTerms(point, "")) { - (*geo_index)[term].push_back(seq_id); - } + + auto cell = S2CellId(point); + geopoint_range_index->insert_geopoint(cell.id(), seq_id); int64_t packed_latlong = GeoPoint::pack_lat_lng(latlong[0], latlong[1]); packed_latlongs[li + 1] = packed_latlong; @@ -1590,7 +1588,7 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree, bool Index::field_is_indexed(const std::string& field_name) const { return search_index.count(field_name) != 0 || numerical_index.count(field_name) != 0 || - geopoint_index.count(field_name) != 0; + geo_range_index.count(field_name) != 0; } void Index::aproximate_numerical_match(num_tree_t* const num_tree, @@ -5468,7 +5466,7 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const } } } else if(search_field.is_geopoint()) { - auto geo_index = geopoint_index[field_name]; + auto geopoint_range_index = geo_range_index[field_name]; S2RegionTermIndexer::Options options; options.set_index_contains_points_only(true); S2RegionTermIndexer indexer(options); @@ -5479,17 +5477,8 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const for(const std::vector& latlong: latlongs) { S2Point point = S2LatLng::FromDegrees(latlong[0], latlong[1]).ToPoint(); - for(const auto& term: indexer.GetIndexTerms(point, "")) { - auto term_it = geo_index->find(term); - if(term_it == geo_index->end()) { - continue; - } - std::vector& ids = term_it->second; - ids.erase(std::remove(ids.begin(), ids.end(), seq_id), ids.end()); - if(ids.empty()) { - geo_index->erase(term); - } - } + auto cell = S2CellId(point); + geopoint_range_index->delete_geopoint(cell.id(), seq_id); } if(!search_field.is_single_geopoint()) { @@ -5641,8 +5630,7 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec art_tree_init(t); search_index.emplace(new_field.name, t); } else if(new_field.is_geopoint()) { - auto field_geo_index = new spp::sparse_hash_map>(); - geopoint_index.emplace(new_field.name, field_geo_index); + geo_range_index.emplace(new_field.name, new NumericTrie(64)); if(!new_field.is_single_geopoint()) { auto geo_array_map = new spp::sparse_hash_map(); geo_array_index.emplace(new_field.name, geo_array_map); @@ -5692,8 +5680,8 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec delete search_index[del_field.name]; search_index.erase(del_field.name); } else if(del_field.is_geopoint()) { - delete geopoint_index[del_field.name]; - geopoint_index.erase(del_field.name); + delete geo_range_index[del_field.name]; + geo_range_index.erase(del_field.name); if(!del_field.is_single_geopoint()) { spp::sparse_hash_map* geo_array_map = geo_array_index[del_field.name]; diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 9d9f4aa0..86090304 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -18,6 +18,30 @@ void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) { } } +void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id) { + if (positive_trie == nullptr) { + positive_trie = new NumericTrie::Node(); + } + + positive_trie->insert_geopoint(cell_id, seq_id, max_level); +} + +void NumericTrie::search_geopoint(const uint64_t& cell_id, uint32_t*& ids, uint32_t& ids_length) { + if (positive_trie == nullptr) { + return; + } + + positive_trie->search_geopoint(cell_id, max_level, ids, ids_length); +} + +void NumericTrie::delete_geopoint(const uint64_t& cell_id, uint32_t id) { + if (positive_trie == nullptr) { + return; + } + + positive_trie->delete_geopoint(cell_id, id, max_level); +} + void NumericTrie::search_range(const int64_t& low, const bool& low_inclusive, const int64_t& high, const bool& high_inclusive, uint32_t*& ids, uint32_t& ids_length) { @@ -369,9 +393,14 @@ NumericTrie::iterator_t NumericTrie::search_equal_to(const int64_t& value) { return NumericTrie::iterator_t(matches); } -void NumericTrie::Node::insert(const int64_t& value, const uint32_t& seq_id, const char& max_level) { +void NumericTrie::Node::insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level) { char level = 0; - return insert_helper(value, seq_id, level, max_level); + return insert_helper(cell_id, seq_id, level, max_level); +} + +void NumericTrie::Node::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level) { + char level = 0; + return insert_geopoint_helper(cell_id, seq_id, level, max_level); } inline int get_index(const int64_t& value, const char& level, const char& max_level) { @@ -385,6 +414,10 @@ inline int get_index(const int64_t& value, const char& level, const char& max_le return (value >> (8 * (max_level - level))) & 0xFF; } +inline int get_geopoint_index(const uint64_t& cell_id, const char& level, const char& max_level) { + return (cell_id >> (8 * (max_level - level))) & 0xFF; +} + void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { if (level > max_level) { return; @@ -409,6 +442,85 @@ void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_ } } +void NumericTrie::Node::insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, + const char& max_level) { + if (level > max_level) { + return; + } + + // Root node contains all the sequence ids present in the tree. + if (!seq_ids.contains(seq_id)) { + seq_ids.append(seq_id); + } + + if (++level <= max_level) { + if (children == nullptr) { + children = new NumericTrie::Node* [EXPANSE]{nullptr}; + } + + auto index = get_geopoint_index(cell_id, level, max_level); + if (children[index] == nullptr) { + children[index] = new NumericTrie::Node(); + } + + return children[index]->insert_geopoint_helper(cell_id, seq_id, level, max_level); + } +} + +char get_max_search_level(const uint64_t& cell_id, const char& max_level) { + // For cell id 0x47E66C3000000000, we only have to prefix match the top four bytes since rest of the bytes are 0. + // So the max search level would be 4 in this case. + + uint64_t mask = 0xff; + char i = max_level; + while (((cell_id & mask) == 0) && --i > 0) { + mask <<= 8; + } + + return i; +} + +void NumericTrie::Node::search_geopoint(const uint64_t& cell_id, const char& max_index_level, + uint32_t*& ids, uint32_t& ids_length) { + char level = 1; + Node* root = this; + auto index = get_geopoint_index(cell_id, level, max_index_level); + auto max_search_level = get_max_search_level(cell_id, max_index_level); + + while (level < max_search_level) { + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_geopoint_index(cell_id, ++level, max_index_level); + } + + root->get_all_ids(ids, ids_length); +} + +void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level) { + char level = 1; + Node* root = this; + auto index = get_geopoint_index(cell_id, level, max_level); + + while (level < max_level) { + root->seq_ids.remove_value(id); + + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_geopoint_index(cell_id, ++level, max_level); + } + + if (root->children != nullptr || root->children[index] != nullptr) { + delete root->children[index]; + root->children[index] = nullptr; + } +} + void NumericTrie::Node::get_all_ids(uint32_t*& ids, uint32_t& ids_length) { ids = seq_ids.uncompress(); ids_length = seq_ids.getLength(); From a7af973338c8a8fc739077fc2c7fd67ed4118925 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 14 Jun 2023 11:34:19 +0530 Subject: [PATCH 88/93] Optimize geo filtering. --- include/numeric_range_trie_test.h | 7 +++++-- src/filter_result_iterator.cpp | 16 +++------------- src/numeric_range_trie.cpp | 31 ++++++++++++++++++++++++++----- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index ed695a70..3b6f8f68 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -16,6 +16,8 @@ class NumericTrie { void insert_geopoint_helper(const uint64_t& cell_id, const uint32_t& seq_id, char& level, const char& max_level); + void search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level, std::set& matches); + void search_range_helper(const int64_t& low,const int64_t& high, const char& max_level, std::vector& matches); @@ -41,7 +43,8 @@ class NumericTrie { void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level); - void search_geopoint(const uint64_t& cell_id, const char& max_index_level, uint32_t*& ids, uint32_t& ids_length); + void search_geopoints(const std::vector& cell_ids, const char& max_index_level, + std::vector& geo_result_ids); void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level); @@ -120,7 +123,7 @@ public: void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id); - void search_geopoint(const uint64_t& cell_id, uint32_t*& ids, uint32_t& ids_length); + void search_geopoints(const std::vector& cell_ids, std::vector& geo_result_ids); void delete_geopoint(const uint64_t& cell_id, uint32_t id); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 5fb432f2..4794e030 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -894,23 +894,13 @@ void filter_result_iterator_t::init() { S2RegionTermIndexer indexer(options); auto const& geo_range_index = index->geo_range_index.at(a_filter.field_name); + std::vector cell_ids; for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { auto cell = S2CellId::FromToken(term); - uint32_t* geo_ids = nullptr; - uint32_t geo_ids_length = 0; - - geo_range_index->search_geopoint(cell.id(), geo_ids, geo_ids_length); - - geo_result_ids.reserve(geo_result_ids.size() + geo_ids_length); - for (uint32_t i = 0; i < geo_ids_length; i++) { - geo_result_ids.push_back(geo_ids[i]); - } - - delete [] geo_ids; + cell_ids.push_back(cell.id()); } - gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); - geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); + geo_range_index->search_geopoints(cell_ids, geo_result_ids); // Skip exact filtering step if query radius is greater than the threshold. if (fi < a_filter.params.size() && diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 86090304..71970894 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -1,4 +1,5 @@ #include +#include #include "numeric_range_trie_test.h" #include "array_utils.h" @@ -26,12 +27,12 @@ void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_i positive_trie->insert_geopoint(cell_id, seq_id, max_level); } -void NumericTrie::search_geopoint(const uint64_t& cell_id, uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::search_geopoints(const std::vector& cell_ids, std::vector& geo_result_ids) { if (positive_trie == nullptr) { return; } - positive_trie->search_geopoint(cell_id, max_level, ids, ids_length); + positive_trie->search_geopoints(cell_ids, max_level, geo_result_ids); } void NumericTrie::delete_geopoint(const uint64_t& cell_id, uint32_t id) { @@ -480,8 +481,8 @@ char get_max_search_level(const uint64_t& cell_id, const char& max_level) { return i; } -void NumericTrie::Node::search_geopoint(const uint64_t& cell_id, const char& max_index_level, - uint32_t*& ids, uint32_t& ids_length) { +void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const char& max_index_level, + std::set& matches) { char level = 1; Node* root = this; auto index = get_geopoint_index(cell_id, level, max_index_level); @@ -496,7 +497,27 @@ void NumericTrie::Node::search_geopoint(const uint64_t& cell_id, const char& max index = get_geopoint_index(cell_id, ++level, max_index_level); } - root->get_all_ids(ids, ids_length); + matches.insert(root); +} + +void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, const char& max_index_level, + std::vector& geo_result_ids) { + std::set matches; + for (const auto &cell_id: cell_ids) { + search_geopoints_helper(cell_id, max_index_level, matches); + } + + for (auto const& match: matches) { + auto const& m_seq_ids = match->seq_ids.uncompress(); + for (uint32_t i = 0; i < match->seq_ids.getLength(); i++) { + geo_result_ids.push_back(m_seq_ids[i]); + } + + delete [] m_seq_ids; + } + + gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); + geo_result_ids.erase(unique(geo_result_ids.begin(), geo_result_ids.end()), geo_result_ids.end()); } void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level) { From 9695a0b4d62ec0b86e2cc83c96527baf98505770 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 14 Jun 2023 19:14:08 +0530 Subject: [PATCH 89/93] Update `geo_range_index` to be a 32 bit trie. --- include/index.h | 2 -- src/index.cpp | 4 ++-- src/numeric_range_trie.cpp | 18 ++++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/index.h b/include/index.h index e179f271..38c02238 100644 --- a/include/index.h +++ b/include/index.h @@ -307,8 +307,6 @@ private: spp::sparse_hash_map geo_range_index; -// spp::sparse_hash_map>*> geopoint_index; - // geo_array_field => (seq_id => values) used for exact filtering of geo array records spp::sparse_hash_map*> geo_array_index; diff --git a/src/index.cpp b/src/index.cpp index 6a5a642e..cf5ec95a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -78,7 +78,7 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* art_tree_init(t); search_index.emplace(a_field.name, t); } else if(a_field.is_geopoint()) { - geo_range_index.emplace(a_field.name, new NumericTrie(64)); + geo_range_index.emplace(a_field.name, new NumericTrie()); if(!a_field.is_single_geopoint()) { spp::sparse_hash_map * doc_to_geos = new spp::sparse_hash_map(); @@ -5630,7 +5630,7 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec art_tree_init(t); search_index.emplace(new_field.name, t); } else if(new_field.is_geopoint()) { - geo_range_index.emplace(new_field.name, new NumericTrie(64)); + geo_range_index.emplace(new_field.name, new NumericTrie()); if(!new_field.is_single_geopoint()) { auto geo_array_map = new spp::sparse_hash_map(); geo_array_index.emplace(new_field.name, geo_array_map); diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 71970894..7ac88590 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -415,8 +415,9 @@ inline int get_index(const int64_t& value, const char& level, const char& max_le return (value >> (8 * (max_level - level))) & 0xFF; } -inline int get_geopoint_index(const uint64_t& cell_id, const char& level, const char& max_level) { - return (cell_id >> (8 * (max_level - level))) & 0xFF; +inline int get_geopoint_index(const uint64_t& cell_id, const char& level) { + // Doing 8-level since cell_id is a 64 bit number. + return (cell_id >> (8 * (8 - level))) & 0xFF; } void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { @@ -459,7 +460,7 @@ void NumericTrie::Node::insert_geopoint_helper(const uint64_t& cell_id, const ui children = new NumericTrie::Node* [EXPANSE]{nullptr}; } - auto index = get_geopoint_index(cell_id, level, max_level); + auto index = get_geopoint_index(cell_id, level); if (children[index] == nullptr) { children[index] = new NumericTrie::Node(); } @@ -472,7 +473,7 @@ char get_max_search_level(const uint64_t& cell_id, const char& max_level) { // For cell id 0x47E66C3000000000, we only have to prefix match the top four bytes since rest of the bytes are 0. // So the max search level would be 4 in this case. - uint64_t mask = 0xff; + auto mask = (uint64_t) 0xFF << (8 * (8 - max_level)); // We're only indexing top 8-max_level bytes. char i = max_level; while (((cell_id & mask) == 0) && --i > 0) { mask <<= 8; @@ -485,7 +486,7 @@ void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const c std::set& matches) { char level = 1; Node* root = this; - auto index = get_geopoint_index(cell_id, level, max_index_level); + auto index = get_geopoint_index(cell_id, level); auto max_search_level = get_max_search_level(cell_id, max_index_level); while (level < max_search_level) { @@ -494,7 +495,7 @@ void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const c } root = root->children[index]; - index = get_geopoint_index(cell_id, ++level, max_index_level); + index = get_geopoint_index(cell_id, ++level); } matches.insert(root); @@ -523,7 +524,7 @@ void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level) { char level = 1; Node* root = this; - auto index = get_geopoint_index(cell_id, level, max_level); + auto index = get_geopoint_index(cell_id, level); while (level < max_level) { root->seq_ids.remove_value(id); @@ -533,9 +534,10 @@ void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, co } root = root->children[index]; - index = get_geopoint_index(cell_id, ++level, max_level); + index = get_geopoint_index(cell_id, ++level); } + root->seq_ids.remove_value(id); if (root->children != nullptr || root->children[index] != nullptr) { delete root->children[index]; root->children[index] = nullptr; From 101e064884b7bc1870c5a11de42f9a55493cd31f Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 15 Jun 2023 14:25:37 +0530 Subject: [PATCH 90/93] Add `NumericTrie::remove`. --- include/numeric_range_trie_test.h | 6 ++- src/index.cpp | 27 +++++++++++- src/numeric_range_trie.cpp | 55 ++++++++++++++++++++--- test/numeric_range_trie_test.cpp | 72 +++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 7 deletions(-) diff --git a/include/numeric_range_trie_test.h b/include/numeric_range_trie_test.h index 3b6f8f68..8b7bd22c 100644 --- a/include/numeric_range_trie_test.h +++ b/include/numeric_range_trie_test.h @@ -41,9 +41,11 @@ class NumericTrie { void insert(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + void remove(const int64_t& cell_id, const uint32_t& seq_id, const char& max_level); + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id, const char& max_level); - void search_geopoints(const std::vector& cell_ids, const char& max_index_level, + void search_geopoints(const std::vector& cell_ids, const char& max_level, std::vector& geo_result_ids); void delete_geopoint(const uint64_t& cell_id, uint32_t id, const char& max_level); @@ -121,6 +123,8 @@ public: void insert(const int64_t& value, const uint32_t& seq_id); + void remove(const int64_t& value, const uint32_t& seq_id); + void insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id); void search_geopoints(const std::vector& cell_ids, std::vector& geo_result_ids); diff --git a/src/index.cpp b/src/index.cpp index cf5ec95a..3bf363e8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -5420,6 +5420,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const std::vector{document[field_name].get()} : document[field_name].get>(); for(int32_t value: values) { + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(value, seq_id); + } + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(value, seq_id); if(search_field.facet) { @@ -5431,6 +5436,11 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const std::vector{document[field_name].get()} : document[field_name].get>(); for(int64_t value: values) { + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(value, seq_id); + } + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(value, seq_id); if(search_field.facet) { @@ -5445,8 +5455,14 @@ void Index::remove_field(uint32_t seq_id, const nlohmann::json& document, const document[field_name].get>(); for(float value: values) { - num_tree_t* num_tree = numerical_index.at(field_name); int64_t fintval = float_to_int64_t(value); + + if (search_field.range_index) { + auto const& trie = range_index.at(search_field.name); + trie->remove(fintval, seq_id); + } + + num_tree_t* num_tree = numerical_index.at(field_name); num_tree->remove(fintval, seq_id); if(search_field.facet) { remove_facet_token(search_field, search_index, StringUtils::float_to_str(value), seq_id); @@ -5638,6 +5654,10 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec } else { num_tree_t* num_tree = new num_tree_t; numerical_index.emplace(new_field.name, num_tree); + + if (new_field.range_index) { + range_index.emplace(new_field.name, new NumericTrie(new_field.is_int32() ? 32 : 64)); + } } } @@ -5694,6 +5714,11 @@ void Index::refresh_schemas(const std::vector& new_fields, const std::vec } else { delete numerical_index[del_field.name]; numerical_index.erase(del_field.name); + + if (del_field.range_index) { + delete range_index[del_field.name]; + range_index.erase(del_field.name); + } } if(del_field.is_sortable()) { diff --git a/src/numeric_range_trie.cpp b/src/numeric_range_trie.cpp index 7ac88590..f70de113 100644 --- a/src/numeric_range_trie.cpp +++ b/src/numeric_range_trie.cpp @@ -19,6 +19,18 @@ void NumericTrie::insert(const int64_t& value, const uint32_t& seq_id) { } } +void NumericTrie::remove(const int64_t& value, const uint32_t& seq_id) { + if ((value < 0 && negative_trie == nullptr) || (value >= 0 && positive_trie == nullptr)) { + return; + } + + if (value < 0) { + negative_trie->remove(std::abs(value), seq_id, max_level); + } else { + positive_trie->remove(value, seq_id, max_level); + } +} + void NumericTrie::insert_geopoint(const uint64_t& cell_id, const uint32_t& seq_id) { if (positive_trie == nullptr) { positive_trie = new NumericTrie::Node(); @@ -420,6 +432,34 @@ inline int get_geopoint_index(const uint64_t& cell_id, const char& level) { return (cell_id >> (8 * (8 - level))) & 0xFF; } +void NumericTrie::Node::remove(const int64_t& value, const uint32_t& id, const char& max_level) { + char level = 1; + Node* root = this; + auto index = get_index(value, level, max_level); + + while (level < max_level) { + root->seq_ids.remove_value(id); + + if (root->children == nullptr || root->children[index] == nullptr) { + return; + } + + root = root->children[index]; + index = get_index(value, ++level, max_level); + } + + root->seq_ids.remove_value(id); + if (root->children != nullptr && root->children[index] != nullptr) { + auto& child = root->children[index]; + + child->seq_ids.remove_value(id); + if (child->seq_ids.getLength() == 0) { + delete child; + child = nullptr; + } + } +} + void NumericTrie::Node::insert_helper(const int64_t& value, const uint32_t& seq_id, char& level, const char& max_level) { if (level > max_level) { return; @@ -501,11 +541,11 @@ void NumericTrie::Node::search_geopoints_helper(const uint64_t& cell_id, const c matches.insert(root); } -void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, const char& max_index_level, +void NumericTrie::Node::search_geopoints(const std::vector& cell_ids, const char& max_level, std::vector& geo_result_ids) { std::set matches; for (const auto &cell_id: cell_ids) { - search_geopoints_helper(cell_id, max_index_level, matches); + search_geopoints_helper(cell_id, max_level, matches); } for (auto const& match: matches) { @@ -538,9 +578,14 @@ void NumericTrie::Node::delete_geopoint(const uint64_t& cell_id, uint32_t id, co } root->seq_ids.remove_value(id); - if (root->children != nullptr || root->children[index] != nullptr) { - delete root->children[index]; - root->children[index] = nullptr; + if (root->children != nullptr && root->children[index] != nullptr) { + auto& child = root->children[index]; + + child->seq_ids.remove_value(id); + if (child->seq_ids.getLength() == 0) { + delete child; + child = nullptr; + } } } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index d2fc6e16..2412b5a5 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -601,6 +601,75 @@ TEST_F(NumericRangeTrieTest, MultivalueData) { reset(ids, ids_length); } +TEST_F(NumericRangeTrieTest, Remove) { + auto trie = new NumericTrie(); + std::unique_ptr trie_guard(trie); + std::vector> pairs = { + {-0x202020, 32}, + {-32768, 5}, + {-32768, 8}, + {-24576, 32}, + {-16384, 35}, + {-8192, 43}, + {0, 2}, + {0, 49}, + {1, 8}, + {256, 91}, + {8192, 49}, + {16384, 56}, + {24576, 58}, + {32768, 91}, + {0x202020, 35}, + }; + + for (auto const& pair: pairs) { + trie->insert(pair.first, pair.second); + } + + uint32_t* ids = nullptr; + uint32_t ids_length = 0; + + trie->search_less_than(0, false, ids, ids_length); + + std::vector expected = {5, 8, 32, 35, 43}; + + ASSERT_EQ(5, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->remove(-24576, 32); + trie->remove(-0x202020, 32); + + reset(ids, ids_length); + trie->search_less_than(0, false, ids, ids_length); + + expected = {5, 8, 35, 43}; + ASSERT_EQ(4, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + reset(ids, ids_length); + trie->search_equal_to(0, ids, ids_length); + + expected = {2, 49}; + ASSERT_EQ(2, ids_length); + for (uint32_t i = 0; i < ids_length; i++) { + ASSERT_EQ(expected[i], ids[i]); + } + + trie->remove(0, 2); + + reset(ids, ids_length); + trie->search_equal_to(0, ids, ids_length); + + ASSERT_EQ(1, ids_length); + ASSERT_EQ(49, ids[0]); + + reset(ids, ids_length); +} + TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { auto trie = new NumericTrie(); std::unique_ptr trie_guard(trie); @@ -657,6 +726,9 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { ids_guard.reset(ids); ASSERT_EQ(0, ids_length); + + trie->remove(15, 0); + trie->remove(-15, 0); } TEST_F(NumericRangeTrieTest, Integration) { From c7ff2e708c85ff03b9f28e4d11f2ba2791200c53 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Thu, 15 Jun 2023 18:59:53 +0530 Subject: [PATCH 91/93] Add `filter_result_iterator_t::compute_result`. --- include/filter_result_iterator.h | 3 + src/filter_result_iterator.cpp | 128 +++++++++++++++++++++++++++++++ src/index.cpp | 1 + 3 files changed, 132 insertions(+) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index b3b12555..0e06efe4 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -160,6 +160,9 @@ public: /// Returns the status of the initialization of iterator tree. Option init_status(); + /// Recursively computes the result of each node and stores the final result in root node. + void compute_result(); + /// Returns a tri-state: /// 0: id is not valid /// 1: id is valid diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 4794e030..2355e871 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -1528,3 +1528,131 @@ void filter_result_iterator_t::add_phrase_ids(filter_result_iterator_t*& filter_ root_iterator->seq_id = left_it->seq_id; filter_result_iterator = root_iterator; } + +void filter_result_iterator_t::compute_result() { + if (filter_node->isOperator) { + left_it->compute_result(); + right_it->compute_result(); + + if (filter_node->filter_operator == AND) { + filter_result_t::and_filter_results(left_it->filter_result, right_it->filter_result, filter_result); + } else { + filter_result_t::or_filter_results(left_it->filter_result, right_it->filter_result, filter_result); + } + + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; + return; + } + + // Only string field filter needs to be evaluated. + if (is_filter_result_initialized || index->search_index.count(filter_node->filter_exp.field_name) == 0) { + return; + } + + auto const& a_filter = filter_node->filter_exp; + auto const& f = index->search_schema.at(a_filter.field_name); + art_tree* t = index->search_index.at(a_filter.field_name); + + uint32_t* or_ids = nullptr; + size_t or_ids_size = 0; + + // aggregates IDs across array of filter values and reduces excessive ORing + std::vector f_id_buff; + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) { + // needs intersection + exact matching (unlike CONTAINS) + std::vector result_id_vec; + posting_t::intersect(posting_lists, result_id_vec); + + if (result_id_vec.empty()) { + continue; + } + + // need to do exact match + uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()]; + size_t exact_str_ids_size = 0; + std::unique_ptr exact_str_ids_guard(exact_str_ids); + + posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(), + exact_str_ids, exact_str_ids_size); + + if (exact_str_ids_size == 0) { + continue; + } + + for (size_t ei = 0; ei < exact_str_ids_size; ei++) { + f_id_buff.push_back(exact_str_ids[ei]); + } + } else { + // CONTAINS + size_t before_size = f_id_buff.size(); + posting_t::intersect(posting_lists, f_id_buff); + if (f_id_buff.size() == before_size) { + continue; + } + } + + if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) { + gfx::timsort(f_id_buff.begin(), f_id_buff.end()); + f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); + + uint32_t* out = nullptr; + or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); + delete[] or_ids; + or_ids = out; + std::vector().swap(f_id_buff); // clears out memory + } + } + + if (!f_id_buff.empty()) { + gfx::timsort(f_id_buff.begin(), f_id_buff.end()); + f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); + + uint32_t* out = nullptr; + or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); + delete[] or_ids; + or_ids = out; + std::vector().swap(f_id_buff); // clears out memory + } + + filter_result.docs = or_ids; + filter_result.count = or_ids_size; + + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), filter_result.docs, filter_result.count); + } + + result_index = 0; + seq_id = filter_result.docs[result_index]; + is_filter_result_initialized = true; + approx_filter_ids_length = filter_result.count; +} diff --git a/src/index.cpp b/src/index.cpp index 3bf363e8..8f2eeb1b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4546,6 +4546,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root, std::array*, 3>& field_values, const std::vector& geopoint_indices) const { + filter_result_iterator->compute_result(); auto const& approx_filter_ids_length = filter_result_iterator->approx_filter_ids_length; uint32_t token_bits = 0; From 91dd04add209356a48ddf81a69e056895edcc851 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Sat, 17 Jun 2023 10:43:06 +0530 Subject: [PATCH 92/93] Refactor `filter_result_iterator_t` methods. --- include/filter_result_iterator.h | 2 +- src/filter_result_iterator.cpp | 97 +++++++++++++++++--------------- 2 files changed, 54 insertions(+), 45 deletions(-) diff --git a/include/filter_result_iterator.h b/include/filter_result_iterator.h index 0e06efe4..d74cb523 100644 --- a/include/filter_result_iterator.h +++ b/include/filter_result_iterator.h @@ -160,7 +160,7 @@ public: /// Returns the status of the initialization of iterator tree. Option init_status(); - /// Recursively computes the result of each node and stores the final result in root node. + /// Recursively computes the result of each node and stores the final result in the root node. void compute_result(); /// Returns a tri-state: diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 2355e871..25cfd2c3 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -401,6 +401,22 @@ void filter_result_iterator_t::next() { return; } + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. + if (is_filter_result_initialized) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + seq_id = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + if (filter_node->isOperator) { // Advance the subtrees and then apply operators to arrive at the next valid doc. if (filter_node->filter_operator == AND) { @@ -423,21 +439,6 @@ void filter_result_iterator_t::next() { return; } - if (is_filter_result_initialized) { - if (++result_index >= filter_result.count) { - is_valid = false; - return; - } - - seq_id = filter_result.docs[result_index]; - reference.clear(); - for (auto const& item: filter_result.reference_filter_results) { - reference[item.first] = item.second[result_index]; - } - - return; - } - const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { @@ -1024,20 +1025,7 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } - if (filter_node->isOperator) { - // Skip the subtrees to id and then apply operators to arrive at the next valid doc. - left_it->skip_to(id); - right_it->skip_to(id); - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); - } - - return; - } - + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. if (is_filter_result_initialized) { ArrayUtils::skip_index_to_id(result_index, filter_result.docs, filter_result.count, id); @@ -1055,6 +1043,20 @@ void filter_result_iterator_t::skip_to(uint32_t id) { return; } + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + left_it->skip_to(id); + right_it->skip_to(id); + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { @@ -1137,6 +1139,12 @@ int filter_result_iterator_t::valid(uint32_t id) { return -1; } + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. + if (is_filter_result_initialized) { + skip_to(id); + return is_valid ? (seq_id == id ? 1 : 0) : -1; + } + if (filter_node->isOperator) { auto left_valid = left_it->valid(id), right_valid = right_it->valid(id); @@ -1250,21 +1258,7 @@ void filter_result_iterator_t::reset() { return; } - if (filter_node->isOperator) { - // Reset the subtrees then apply operators to arrive at the first valid doc. - left_it->reset(); - right_it->reset(); - is_valid = true; - - if (filter_node->filter_operator == AND) { - and_filter_iterators(); - } else { - or_filter_iterators(); - } - - return; - } - + // No need to traverse iterator tree if there's only one filter or compute_result() has been called. if (is_filter_result_initialized) { if (filter_result.count == 0) { is_valid = false; @@ -1283,6 +1277,21 @@ void filter_result_iterator_t::reset() { return; } + if (filter_node->isOperator) { + // Reset the subtrees then apply operators to arrive at the first valid doc. + left_it->reset(); + right_it->reset(); + is_valid = true; + + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + const filter a_filter = filter_node->filter_exp; if (!index->field_is_indexed(a_filter.field_name)) { From efac704f1bf16ab63631343d7d9036b608e53294 Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Sat, 17 Jun 2023 12:39:18 +0530 Subject: [PATCH 93/93] Support id != --- include/filter.h | 2 +- src/filter.cpp | 5 ++++- src/filter_result_iterator.cpp | 20 ++++++++++++++----- test/collection_filtering_test.cpp | 32 ++++++++++++++++++++---------- 4 files changed, 42 insertions(+), 17 deletions(-) diff --git a/include/filter.h b/include/filter.h index f52b086f..01469d5e 100644 --- a/include/filter.h +++ b/include/filter.h @@ -19,7 +19,7 @@ struct filter { std::string field_name; std::vector values; std::vector comparators; - // Would be set when `field: != ...` is encountered with a string field or `field: != [ ... ]` is encountered in the + // Would be set when `field: != ...` is encountered with id/string field or `field: != [ ... ]` is encountered in the // case of int and float fields. During filtering, all the results of matching the field against the values are // aggregated and then this flag is checked if negation on the aggregated result is required. bool apply_not_equals = false; diff --git a/src/filter.cpp b/src/filter.cpp index c152d77f..5d94c66c 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -422,7 +422,10 @@ Option toFilter(const std::string expression, id_comparator = EQUALS; while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); } else if (raw_value.size() >= 2 && raw_value[0] == '!' && raw_value[1] == '=') { - return Option(400, "Not equals filtering is not supported on the `id` field."); + id_comparator = NOT_EQUALS; + filter_exp.apply_not_equals = true; + filter_value_index++; + while (++filter_value_index < raw_value.size() && raw_value[filter_value_index] == ' '); } if (filter_value_index != 0) { raw_value = raw_value.substr(filter_value_index); diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index 25cfd2c3..42816867 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -620,11 +620,6 @@ void filter_result_iterator_t::init() { } if (a_filter.field_name == "id") { - if (a_filter.values.empty()) { - is_valid = false; - return; - } - // we handle `ids` separately std::vector result_ids; for (const auto& id_str : a_filter.values) { @@ -637,6 +632,16 @@ void filter_result_iterator_t::init() { filter_result.docs = new uint32_t[result_ids.size()]; std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + if (a_filter.apply_not_equals) { + apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, filter_result.count); + } + + if (filter_result.count == 0) { + is_valid = false; + return; + } + seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; approx_filter_ids_length = filter_result.count; @@ -1660,6 +1665,11 @@ void filter_result_iterator_t::compute_result() { apply_not_equals(index->seq_ids->uncompress(), index->seq_ids->num_ids(), filter_result.docs, filter_result.count); } + if (filter_result.count == 0) { + is_valid = false; + return; + } + result_index = 0; seq_id = filter_result.docs[result_index]; is_filter_result_initialized = true; diff --git a/test/collection_filtering_test.cpp b/test/collection_filtering_test.cpp index 988b035a..b3a5e600 100644 --- a/test/collection_filtering_test.cpp +++ b/test/collection_filtering_test.cpp @@ -1231,6 +1231,16 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) { ASSERT_EQ(1, results["hits"].size()); ASSERT_STREQ("123", results["hits"][0]["document"]["id"].get().c_str()); + results = coll1->search("*", + {}, "id: != 123", + {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(3, results["found"].get()); + ASSERT_EQ(3, results["hits"].size()); + ASSERT_STREQ("125", results["hits"][0]["document"]["id"].get().c_str()); + ASSERT_STREQ("127", results["hits"][1]["document"]["id"].get().c_str()); + ASSERT_STREQ("129", results["hits"][2]["document"]["id"].get().c_str()); + // single ID with backtick results = coll1->search("*", @@ -1283,6 +1293,14 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) { ASSERT_STREQ("125", results["hits"][1]["document"]["id"].get().c_str()); ASSERT_STREQ("127", results["hits"][2]["document"]["id"].get().c_str()); + results = coll1->search("*", + {}, "id:!= [123,125] && num_employees: <300", + {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(1, results["found"].get()); + ASSERT_EQ(1, results["hits"].size()); + ASSERT_STREQ("127", results["hits"][0]["document"]["id"].get().c_str()); + // empty id list not allowed auto res_op = coll1->search("*", {}, "id:=", {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}); ASSERT_FALSE(res_op.ok()); @@ -1296,13 +1314,6 @@ TEST_F(CollectionFilteringTest, FilteringViaDocumentIds) { ASSERT_FALSE(res_op.ok()); ASSERT_EQ("Error with filter field `id`: Filter value cannot be empty.", res_op.error()); - // not equals is not supported yet - res_op = coll1->search("*", - {}, "id:!= [123,125] && num_employees: <300", - {}, sort_fields, {0}, 10, 1, FREQUENCY, {true}); - ASSERT_FALSE(res_op.ok()); - ASSERT_EQ("Not equals filtering is not supported on the `id` field.", res_op.error()); - // when no IDs exist results = coll1->search("*", {}, "id: [1000] && num_employees: <300", @@ -1397,9 +1408,10 @@ TEST_F(CollectionFilteringTest, NumericalFilteringWithArray) { TEST_F(CollectionFilteringTest, NegationOperatorBasics) { Collection *coll1; - std::vector fields = {field("title", field_types::STRING, false), - field("artist", field_types::STRING, false), - field("points", field_types::INT32, false),}; + std::vector fields = { + field("title", field_types::STRING, false), + field("artist", field_types::STRING, false), + field("points", field_types::INT32, false),}; coll1 = collectionManager.get_collection("coll1").get(); if(coll1 == nullptr) {