diff --git a/include/filter.h b/include/filter.h new file mode 100644 index 00000000..239a7e77 --- /dev/null +++ b/include/filter.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include "index.h" + +class filter_result_iterator_t { +private: + std::string collection_name; + const Index* index; + filter_node_t* filter_node; + filter_result_iterator_t* left_it = nullptr; + filter_result_iterator_t* right_it = nullptr; + + // Used in case of id and reference filter. + uint32_t result_index = 0; + + // Stores the result of the filters that cannot be iterated. + filter_result_t filter_result; + + // Initialized in case of filter on string field. + // Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator + // for each token. + // + // Multiple filter values: Multiple tokens: posting list iterator + std::vector> posting_list_iterators; + std::vector expanded_plists; + + // Set to false when this iterator or it's subtree becomes invalid. + bool is_valid = true; + + /// Initializes the state of iterator node after it's creation. + void init(); + + /// Performs AND on the subtrees of operator. + void and_filter_iterators(); + + /// Performs OR on the subtrees of operator. + void or_filter_iterators(); + + /// Finds the next match for a filter on string field. + void doc_matching_string_filter(); + +public: + uint32_t doc; + // Collection name -> references + std::map reference; + Option status; + + explicit filter_result_iterator_t(const std::string& collection_name, + const Index* index, filter_node_t* filter_node, + Option& status) : + collection_name(collection_name), + index(index), + filter_node(filter_node), + status(status) { + // Generate the iterator tree and then initialize each node. + if (filter_node->isOperator) { + left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, status); + right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, status); + } + + init(); + } + + ~filter_result_iterator_t() { + // In case the filter was on string field. + for(auto expanded_plist: expanded_plists) { + delete expanded_plist; + } + + delete left_it; + delete right_it; + } + + [[nodiscard]] bool valid(); + + void next(); + + void skip_to(uint32_t id); +}; diff --git a/include/index.h b/include/index.h index de591a0b..09b29fed 100644 --- a/include/index.h +++ b/include/index.h @@ -958,6 +958,8 @@ public: std::unordered_set& excluded_group_ids) const; int64_t get_doc_val_from_sort_index(sort_index_iterator it, uint32_t doc_seq_id) const; + + friend class filter_result_iterator_t; }; template diff --git a/include/option.h b/include/option.h index 0f8c49e0..ced54ae9 100644 --- a/include/option.h +++ b/include/option.h @@ -31,6 +31,18 @@ public: error_code = obj.error_code; } + Option& operator=(Option&& obj) noexcept { + if (&obj == this) + return *this; + + value = obj.value; + is_ok = obj.is_ok; + error_msg = obj.error_msg; + error_code = obj.error_code; + + return *this; + } + bool ok() const { return is_ok; } diff --git a/include/posting_list.h b/include/posting_list.h index dea742f7..16f42ea7 100644 --- a/include/posting_list.h +++ b/include/posting_list.h @@ -164,6 +164,8 @@ public: static void intersect(const std::vector& posting_lists, std::vector& result_ids); + static void intersect(std::vector& posting_list_iterators, bool& is_valid); + template static bool block_intersect( std::vector& its, diff --git a/src/filter.cpp b/src/filter.cpp new file mode 100644 index 00000000..804e19a4 --- /dev/null +++ b/src/filter.cpp @@ -0,0 +1,430 @@ +#include +#include +#include +#include "filter.h" + +void filter_result_iterator_t::and_filter_iterators() { + while (left_it->valid() && right_it->valid()) { + while (left_it->doc < right_it->doc) { + left_it->next(); + if (!left_it->valid()) { + is_valid = false; + return; + } + } + + while (left_it->doc > right_it->doc) { + right_it->next(); + if (!right_it->valid()) { + is_valid = false; + return; + } + } + + if (left_it->doc == right_it->doc) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + } + + is_valid = false; +} + +void filter_result_iterator_t::or_filter_iterators() { + if (left_it->valid() && right_it->valid()) { + if (left_it->doc < right_it->doc) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->doc > right_it->doc) { + doc = right_it->doc; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (left_it->valid()) { + doc = left_it->doc; + reference.clear(); + + for (const auto& item: left_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + if (right_it->valid()) { + doc = right_it->doc; + reference.clear(); + + for (const auto& item: right_it->reference) { + reference[item.first] = item.second; + } + + return; + } + + is_valid = false; +} + +void filter_result_iterator_t::doc_matching_string_filter() { + // If none of the filter value iterators are valid, mark this node as invalid. + bool one_is_valid = false; + + // Since we do OR between filter values, the lowest doc id from all is selected. + uint32_t lowest_id = UINT32_MAX; + + for (auto& filter_value_tokens : posting_list_iterators) { + // Perform AND between tokens of a filter value. + bool tokens_iter_is_valid; + posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid); + + one_is_valid = tokens_iter_is_valid || one_is_valid; + + if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) { + lowest_id = filter_value_tokens[0].id(); + } + } + + if (one_is_valid) { + doc = lowest_id; + } + + is_valid = one_is_valid; +} + +void filter_result_iterator_t::next() { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + if (++result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Advance all the filter values that are at doc. Then find the next one. + std::vector doc_matching_indexes; + for (uint32_t i = 0; i < posting_list_iterators.size(); i++) { + const auto& filter_value_tokens = posting_list_iterators[i]; + + if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) { + doc_matching_indexes.push_back(i); + } + } + + for (const auto &lowest_id_index: doc_matching_indexes) { + for (auto &iter: posting_list_iterators[lowest_id_index]) { + iter.next(); + } + } + + doc_matching_string_filter(); + return; + } +} + +void filter_result_iterator_t::init() { + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + and_filter_iterators(); + } else { + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. + auto& cm = CollectionManager::get_instance(); + auto collection = cm.get_collection(a_filter.referenced_collection_name); + if (collection == nullptr) { + status = Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); + is_valid = false; + return; + } + + auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, + filter_result, + collection_name); + if (!reference_filter_op.ok()) { + status = Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name + + "` collection: " + reference_filter_op.error()); + is_valid = false; + return; + } + + is_valid = filter_result.count > 0; + return; + } + + if (a_filter.field_name == "id") { + if (a_filter.values.empty()) { + is_valid = false; + return; + } + + // we handle `ids` separately + std::vector result_ids; + for (const auto& id_str : a_filter.values) { + result_ids.push_back(std::stoul(id_str)); + } + + std::sort(result_ids.begin(), result_ids.end()); + + filter_result.count = result_ids.size(); + filter_result.docs = new uint32_t[result_ids.size()]; + std::copy(result_ids.begin(), result_ids.end(), filter_result.docs); + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + art_tree* t = index->search_index.at(a_filter.field_name); + + for (const std::string& filter_value : a_filter.values) { + std::vector posting_lists; + + // there could be multiple tokens in a filter value, which we have to treat as ANDs + // e.g. country: South Africa + Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators); + + std::string str_token; + size_t token_index = 0; + std::vector str_tokens; + + while (tokenizer.next(str_token, token_index)) { + str_tokens.push_back(str_token); + + art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), + str_token.length()+1); + if (leaf == nullptr) { + continue; + } + + posting_lists.push_back(leaf->values); + } + + if (posting_lists.size() != str_tokens.size()) { + continue; + } + + std::vector plists; + posting_t::to_expanded_plists(posting_lists, plists, expanded_plists); + + posting_list_iterators.emplace_back(std::vector()); + + for (auto const& plist: plists) { + posting_list_iterators.back().push_back(plist->new_iterator()); + } + } + + doc_matching_string_filter(); + return; + } +} + +bool filter_result_iterator_t::valid() { + if (!is_valid) { + return false; + } + + if (filter_node->isOperator) { + if (filter_node->filter_operator == AND) { + is_valid = left_it->valid() && right_it->valid(); + return is_valid; + } else { + is_valid = left_it->valid() || right_it->valid(); + return is_valid; + } + } + + const filter a_filter = filter_node->filter_exp; + + if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") { + is_valid = result_index < filter_result.count; + return is_valid; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return is_valid; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + bool one_is_valid = false; + for (auto& filter_value_tokens: posting_list_iterators) { + posting_list_t::intersect(filter_value_tokens, one_is_valid); + + if (one_is_valid) { + break; + } + } + + is_valid = one_is_valid; + return is_valid; + } + + return true; +} + +void filter_result_iterator_t::skip_to(uint32_t id) { + if (!is_valid) { + return; + } + + if (filter_node->isOperator) { + // Skip the subtrees to id and then apply operators to arrive at the next valid doc. + if (filter_node->filter_operator == AND) { + left_it->skip_to(id); + and_filter_iterators(); + } else { + right_it->skip_to(id); + or_filter_iterators(); + } + + return; + } + + const filter a_filter = filter_node->filter_exp; + + bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); + if (is_referenced_filter) { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + reference.clear(); + for (auto const& item: filter_result.reference_filter_results) { + reference[item.first] = item.second[result_index]; + } + + return; + } + + if (a_filter.field_name == "id") { + while (filter_result.docs[result_index] < id && ++result_index < filter_result.count); + + if (result_index >= filter_result.count) { + is_valid = false; + return; + } + + doc = filter_result.docs[result_index]; + return; + } + + if (!index->field_is_indexed(a_filter.field_name)) { + is_valid = false; + return; + } + + field f = index->search_schema.at(a_filter.field_name); + + if (f.is_string()) { + // Skip all the token iterators and find a new match. + for (auto& filter_value_tokens : posting_list_iterators) { + for (auto& token: filter_value_tokens) { + // We perform AND on tokens. Short-circuiting here. + if (!token.valid()) { + break; + } + + token.skip_to(id); + } + } + + doc_matching_string_filter(); + return; + } +} diff --git a/src/posting_list.cpp b/src/posting_list.cpp index 23f17e8b..39f3ac00 100644 --- a/src/posting_list.cpp +++ b/src/posting_list.cpp @@ -754,6 +754,42 @@ void posting_list_t::intersect(const std::vector& posting_lists } } +void posting_list_t::intersect(std::vector& posting_list_iterators, bool& is_valid) { + if (posting_list_iterators.empty()) { + is_valid = false; + return; + } + + if (posting_list_iterators.size() == 1) { + is_valid = posting_list_iterators.front().valid(); + return; + } + + switch (posting_list_iterators.size()) { + case 2: + while(!at_end2(posting_list_iterators)) { + if(equals2(posting_list_iterators)) { + is_valid = true; + return; + } else { + advance_non_largest2(posting_list_iterators); + } + } + is_valid = false; + break; + default: + while(!at_end(posting_list_iterators)) { + if(equals(posting_list_iterators)) { + is_valid = true; + return; + } else { + advance_non_largest(posting_list_iterators); + } + } + is_valid = false; + } +} + bool posting_list_t::take_id(result_iter_state_t& istate, uint32_t id) { // decide if this result id should be excluded if(istate.excluded_result_ids_size != 0) { diff --git a/test/filter_test.cpp b/test/filter_test.cpp new file mode 100644 index 00000000..60cff669 --- /dev/null +++ b/test/filter_test.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include +#include +#include +#include "collection.h" + +class FilterTest : public ::testing::Test { +protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; + + std::vector query_fields; + std::vector sort_fields; + + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_join"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } +}; + +TEST_F(FilterTest, FilterTreeIterator) { + nlohmann::json schema = + R"({ + "name": "Collection", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int32"}, + {"name": "years", "type": "int32[]"}, + {"name": "rating", "type": "float"}, + {"name": "tags", "type": "string[]"} + ] + })"_json; + + Collection* coll = collectionManager.create_collection(schema).get(); + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::string json_line; + while (std::getline(infile, json_line)) { + auto add_op = coll->add(json_line); + ASSERT_TRUE(add_op.ok()); + } + infile.close(); + + const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_"; + filter_node_t* filter_tree_root = nullptr; + Option filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + Option iter_op(true); + auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_FALSE(iter_no_match_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: [foo bar, baz]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_FALSE(iter_no_match_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: Jeremy", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_contains_test.valid()); + ASSERT_EQ(i, iter_contains_test.doc); + iter_contains_test.next(); + } + ASSERT_FALSE(iter_contains_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name: [Jeremy, Howard, Richard]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_contains_multi_test.valid()); + ASSERT_EQ(i, iter_contains_multi_test.doc); + iter_contains_multi_test.next(); + } + ASSERT_FALSE(iter_contains_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("name:= Jeremy Howard", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + for (uint32_t i = 0; i < 5; i++) { + ASSERT_TRUE(iter_exact_match_test.valid()); + ASSERT_EQ(i, iter_exact_match_test.doc); + iter_exact_match_test.next(); + } + ASSERT_FALSE(iter_exact_match_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags:= [gold, silver]", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + std::vector expected = {0, 2, 3, 4}; + for (auto const& i : expected) { + ASSERT_TRUE(iter_exact_match_multi_test.valid()); + ASSERT_EQ(i, iter_exact_match_multi_test.doc); + iter_exact_match_multi_test.next(); + } + ASSERT_FALSE(iter_exact_match_multi_test.valid()); + ASSERT_TRUE(iter_op.ok()); + +// delete filter_tree_root; +// filter_tree_root = nullptr; +// filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix, +// filter_tree_root); +// ASSERT_TRUE(filter_op.ok()); +// +// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); +// +// std::vector expected = {1, 3}; +// for (auto const& i : expected) { +// ASSERT_TRUE(iter_not_equals_test.valid()); +// ASSERT_EQ(i, iter_not_equals_test.doc); +// iter_not_equals_test.next(); +// } +// +// ASSERT_FALSE(iter_not_equals_test.valid()); +// ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; + filter_tree_root = nullptr; + filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix, + filter_tree_root); + ASSERT_TRUE(filter_op.ok()); + + auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op); + + ASSERT_TRUE(iter_skip_test.valid()); + iter_skip_test.skip_to(3); + ASSERT_TRUE(iter_skip_test.valid()); + ASSERT_EQ(4, iter_skip_test.doc); + iter_skip_test.next(); + + ASSERT_FALSE(iter_skip_test.valid()); + ASSERT_TRUE(iter_op.ok()); + + delete filter_tree_root; +} \ No newline at end of file