Add filter_result_iterator_t.

This commit is contained in:
Harpreet Sangar 2023-03-18 17:25:25 +05:30
parent fbbcc0b64d
commit c662a655b0
7 changed files with 745 additions and 0 deletions

81
include/filter.h Normal file
View File

@ -0,0 +1,81 @@
#pragma once
#include <string>
#include <map>
#include "index.h"
class filter_result_iterator_t {
private:
std::string collection_name;
const Index* index;
filter_node_t* filter_node;
filter_result_iterator_t* left_it = nullptr;
filter_result_iterator_t* right_it = nullptr;
// Used in case of id and reference filter.
uint32_t result_index = 0;
// Stores the result of the filters that cannot be iterated.
filter_result_t filter_result;
// Initialized in case of filter on string field.
// Sample filter values: ["foo bar", "baz"]. Each filter value is split into tokens. We get posting list iterator
// for each token.
//
// Multiple filter values: Multiple tokens: posting list iterator
std::vector<std::vector<posting_list_t::iterator_t>> posting_list_iterators;
std::vector<posting_list_t*> expanded_plists;
// Set to false when this iterator or it's subtree becomes invalid.
bool is_valid = true;
/// Initializes the state of iterator node after it's creation.
void init();
/// Performs AND on the subtrees of operator.
void and_filter_iterators();
/// Performs OR on the subtrees of operator.
void or_filter_iterators();
/// Finds the next match for a filter on string field.
void doc_matching_string_filter();
public:
uint32_t doc;
// Collection name -> references
std::map<std::string, reference_filter_result_t> reference;
Option<bool> status;
explicit filter_result_iterator_t(const std::string& collection_name,
const Index* index, filter_node_t* filter_node,
Option<bool>& status) :
collection_name(collection_name),
index(index),
filter_node(filter_node),
status(status) {
// Generate the iterator tree and then initialize each node.
if (filter_node->isOperator) {
left_it = new filter_result_iterator_t(collection_name, index, filter_node->left, status);
right_it = new filter_result_iterator_t(collection_name, index, filter_node->right, status);
}
init();
}
~filter_result_iterator_t() {
// In case the filter was on string field.
for(auto expanded_plist: expanded_plists) {
delete expanded_plist;
}
delete left_it;
delete right_it;
}
[[nodiscard]] bool valid();
void next();
void skip_to(uint32_t id);
};

View File

@ -958,6 +958,8 @@ public:
std::unordered_set<uint32_t>& excluded_group_ids) const;
int64_t get_doc_val_from_sort_index(sort_index_iterator it, uint32_t doc_seq_id) const;
friend class filter_result_iterator_t;
};
template<class T>

View File

@ -31,6 +31,18 @@ public:
error_code = obj.error_code;
}
Option& operator=(Option&& obj) noexcept {
if (&obj == this)
return *this;
value = obj.value;
is_ok = obj.is_ok;
error_msg = obj.error_msg;
error_code = obj.error_code;
return *this;
}
bool ok() const {
return is_ok;
}

View File

@ -164,6 +164,8 @@ public:
static void intersect(const std::vector<posting_list_t*>& posting_lists, std::vector<uint32_t>& result_ids);
static void intersect(std::vector<posting_list_t::iterator_t>& posting_list_iterators, bool& is_valid);
template<class T>
static bool block_intersect(
std::vector<posting_list_t::iterator_t>& its,

430
src/filter.cpp Normal file
View File

@ -0,0 +1,430 @@
#include <collection_manager.h>
#include <posting.h>
#include <timsort.hpp>
#include "filter.h"
void filter_result_iterator_t::and_filter_iterators() {
while (left_it->valid() && right_it->valid()) {
while (left_it->doc < right_it->doc) {
left_it->next();
if (!left_it->valid()) {
is_valid = false;
return;
}
}
while (left_it->doc > right_it->doc) {
right_it->next();
if (!right_it->valid()) {
is_valid = false;
return;
}
}
if (left_it->doc == right_it->doc) {
doc = left_it->doc;
reference.clear();
for (const auto& item: left_it->reference) {
reference[item.first] = item.second;
}
for (const auto& item: right_it->reference) {
reference[item.first] = item.second;
}
return;
}
}
is_valid = false;
}
void filter_result_iterator_t::or_filter_iterators() {
if (left_it->valid() && right_it->valid()) {
if (left_it->doc < right_it->doc) {
doc = left_it->doc;
reference.clear();
for (const auto& item: left_it->reference) {
reference[item.first] = item.second;
}
return;
}
if (left_it->doc > right_it->doc) {
doc = right_it->doc;
reference.clear();
for (const auto& item: right_it->reference) {
reference[item.first] = item.second;
}
return;
}
doc = left_it->doc;
reference.clear();
for (const auto& item: left_it->reference) {
reference[item.first] = item.second;
}
for (const auto& item: right_it->reference) {
reference[item.first] = item.second;
}
return;
}
if (left_it->valid()) {
doc = left_it->doc;
reference.clear();
for (const auto& item: left_it->reference) {
reference[item.first] = item.second;
}
return;
}
if (right_it->valid()) {
doc = right_it->doc;
reference.clear();
for (const auto& item: right_it->reference) {
reference[item.first] = item.second;
}
return;
}
is_valid = false;
}
void filter_result_iterator_t::doc_matching_string_filter() {
// If none of the filter value iterators are valid, mark this node as invalid.
bool one_is_valid = false;
// Since we do OR between filter values, the lowest doc id from all is selected.
uint32_t lowest_id = UINT32_MAX;
for (auto& filter_value_tokens : posting_list_iterators) {
// Perform AND between tokens of a filter value.
bool tokens_iter_is_valid;
posting_list_t::intersect(filter_value_tokens, tokens_iter_is_valid);
one_is_valid = tokens_iter_is_valid || one_is_valid;
if (tokens_iter_is_valid && filter_value_tokens[0].id() < lowest_id) {
lowest_id = filter_value_tokens[0].id();
}
}
if (one_is_valid) {
doc = lowest_id;
}
is_valid = one_is_valid;
}
void filter_result_iterator_t::next() {
if (!is_valid) {
return;
}
if (filter_node->isOperator) {
if (filter_node->filter_operator == AND) {
and_filter_iterators();
} else {
or_filter_iterators();
}
return;
}
const filter a_filter = filter_node->filter_exp;
bool is_referenced_filter = !a_filter.referenced_collection_name.empty();
if (is_referenced_filter) {
if (++result_index >= filter_result.count) {
is_valid = false;
return;
}
doc = filter_result.docs[result_index];
reference.clear();
for (auto const& item: filter_result.reference_filter_results) {
reference[item.first] = item.second[result_index];
}
return;
}
if (a_filter.field_name == "id") {
if (++result_index >= filter_result.count) {
is_valid = false;
return;
}
doc = filter_result.docs[result_index];
return;
}
if (!index->field_is_indexed(a_filter.field_name)) {
is_valid = false;
return;
}
field f = index->search_schema.at(a_filter.field_name);
if (f.is_string()) {
// Advance all the filter values that are at doc. Then find the next one.
std::vector<uint32_t> doc_matching_indexes;
for (uint32_t i = 0; i < posting_list_iterators.size(); i++) {
const auto& filter_value_tokens = posting_list_iterators[i];
if (filter_value_tokens[0].valid() && filter_value_tokens[0].id() == doc) {
doc_matching_indexes.push_back(i);
}
}
for (const auto &lowest_id_index: doc_matching_indexes) {
for (auto &iter: posting_list_iterators[lowest_id_index]) {
iter.next();
}
}
doc_matching_string_filter();
return;
}
}
void filter_result_iterator_t::init() {
if (filter_node->isOperator) {
if (filter_node->filter_operator == AND) {
and_filter_iterators();
} else {
or_filter_iterators();
}
return;
}
const filter a_filter = filter_node->filter_exp;
bool is_referenced_filter = !a_filter.referenced_collection_name.empty();
if (is_referenced_filter) {
// Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents.
auto& cm = CollectionManager::get_instance();
auto collection = cm.get_collection(a_filter.referenced_collection_name);
if (collection == nullptr) {
status = Option<bool>(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found.");
is_valid = false;
return;
}
auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name,
filter_result,
collection_name);
if (!reference_filter_op.ok()) {
status = Option<bool>(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name
+ "` collection: " + reference_filter_op.error());
is_valid = false;
return;
}
is_valid = filter_result.count > 0;
return;
}
if (a_filter.field_name == "id") {
if (a_filter.values.empty()) {
is_valid = false;
return;
}
// we handle `ids` separately
std::vector<uint32_t> result_ids;
for (const auto& id_str : a_filter.values) {
result_ids.push_back(std::stoul(id_str));
}
std::sort(result_ids.begin(), result_ids.end());
filter_result.count = result_ids.size();
filter_result.docs = new uint32_t[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), filter_result.docs);
}
if (!index->field_is_indexed(a_filter.field_name)) {
is_valid = false;
return;
}
field f = index->search_schema.at(a_filter.field_name);
if (f.is_string()) {
art_tree* t = index->search_index.at(a_filter.field_name);
for (const std::string& filter_value : a_filter.values) {
std::vector<void*> posting_lists;
// there could be multiple tokens in a filter value, which we have to treat as ANDs
// e.g. country: South Africa
Tokenizer tokenizer(filter_value, true, false, f.locale, index->symbols_to_index, index->token_separators);
std::string str_token;
size_t token_index = 0;
std::vector<std::string> str_tokens;
while (tokenizer.next(str_token, token_index)) {
str_tokens.push_back(str_token);
art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(),
str_token.length()+1);
if (leaf == nullptr) {
continue;
}
posting_lists.push_back(leaf->values);
}
if (posting_lists.size() != str_tokens.size()) {
continue;
}
std::vector<posting_list_t*> plists;
posting_t::to_expanded_plists(posting_lists, plists, expanded_plists);
posting_list_iterators.emplace_back(std::vector<posting_list_t::iterator_t>());
for (auto const& plist: plists) {
posting_list_iterators.back().push_back(plist->new_iterator());
}
}
doc_matching_string_filter();
return;
}
}
bool filter_result_iterator_t::valid() {
if (!is_valid) {
return false;
}
if (filter_node->isOperator) {
if (filter_node->filter_operator == AND) {
is_valid = left_it->valid() && right_it->valid();
return is_valid;
} else {
is_valid = left_it->valid() || right_it->valid();
return is_valid;
}
}
const filter a_filter = filter_node->filter_exp;
if (!a_filter.referenced_collection_name.empty() || a_filter.field_name == "id") {
is_valid = result_index < filter_result.count;
return is_valid;
}
if (!index->field_is_indexed(a_filter.field_name)) {
is_valid = false;
return is_valid;
}
field f = index->search_schema.at(a_filter.field_name);
if (f.is_string()) {
bool one_is_valid = false;
for (auto& filter_value_tokens: posting_list_iterators) {
posting_list_t::intersect(filter_value_tokens, one_is_valid);
if (one_is_valid) {
break;
}
}
is_valid = one_is_valid;
return is_valid;
}
return true;
}
void filter_result_iterator_t::skip_to(uint32_t id) {
if (!is_valid) {
return;
}
if (filter_node->isOperator) {
// Skip the subtrees to id and then apply operators to arrive at the next valid doc.
if (filter_node->filter_operator == AND) {
left_it->skip_to(id);
and_filter_iterators();
} else {
right_it->skip_to(id);
or_filter_iterators();
}
return;
}
const filter a_filter = filter_node->filter_exp;
bool is_referenced_filter = !a_filter.referenced_collection_name.empty();
if (is_referenced_filter) {
while (filter_result.docs[result_index] < id && ++result_index < filter_result.count);
if (result_index >= filter_result.count) {
is_valid = false;
return;
}
doc = filter_result.docs[result_index];
reference.clear();
for (auto const& item: filter_result.reference_filter_results) {
reference[item.first] = item.second[result_index];
}
return;
}
if (a_filter.field_name == "id") {
while (filter_result.docs[result_index] < id && ++result_index < filter_result.count);
if (result_index >= filter_result.count) {
is_valid = false;
return;
}
doc = filter_result.docs[result_index];
return;
}
if (!index->field_is_indexed(a_filter.field_name)) {
is_valid = false;
return;
}
field f = index->search_schema.at(a_filter.field_name);
if (f.is_string()) {
// Skip all the token iterators and find a new match.
for (auto& filter_value_tokens : posting_list_iterators) {
for (auto& token: filter_value_tokens) {
// We perform AND on tokens. Short-circuiting here.
if (!token.valid()) {
break;
}
token.skip_to(id);
}
}
doc_matching_string_filter();
return;
}
}

View File

@ -754,6 +754,42 @@ void posting_list_t::intersect(const std::vector<posting_list_t*>& posting_lists
}
}
void posting_list_t::intersect(std::vector<posting_list_t::iterator_t>& posting_list_iterators, bool& is_valid) {
if (posting_list_iterators.empty()) {
is_valid = false;
return;
}
if (posting_list_iterators.size() == 1) {
is_valid = posting_list_iterators.front().valid();
return;
}
switch (posting_list_iterators.size()) {
case 2:
while(!at_end2(posting_list_iterators)) {
if(equals2(posting_list_iterators)) {
is_valid = true;
return;
} else {
advance_non_largest2(posting_list_iterators);
}
}
is_valid = false;
break;
default:
while(!at_end(posting_list_iterators)) {
if(equals(posting_list_iterators)) {
is_valid = true;
return;
} else {
advance_non_largest(posting_list_iterators);
}
}
is_valid = false;
}
}
bool posting_list_t::take_id(result_iter_state_t& istate, uint32_t id) {
// decide if this result id should be excluded
if(istate.excluded_result_ids_size != 0) {

182
test/filter_test.cpp Normal file
View File

@ -0,0 +1,182 @@
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <fstream>
#include <collection_manager.h>
#include <filter.h>
#include "collection.h"
class FilterTest : public ::testing::Test {
protected:
Store *store;
CollectionManager & collectionManager = CollectionManager::get_instance();
std::atomic<bool> quit = false;
std::vector<std::string> query_fields;
std::vector<sort_by> sort_fields;
void setupCollection() {
std::string state_dir_path = "/tmp/typesense_test/collection_join";
LOG(INFO) << "Truncating and creating: " << state_dir_path;
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
store = new Store(state_dir_path);
collectionManager.init(store, 1.0, "auth_key", quit);
collectionManager.load(8, 1000);
}
virtual void SetUp() {
setupCollection();
}
virtual void TearDown() {
collectionManager.dispose();
delete store;
}
};
TEST_F(FilterTest, FilterTreeIterator) {
nlohmann::json schema =
R"({
"name": "Collection",
"fields": [
{"name": "name", "type": "string"},
{"name": "age", "type": "int32"},
{"name": "years", "type": "int32[]"},
{"name": "rating", "type": "float"},
{"name": "tags", "type": "string[]"}
]
})"_json;
Collection* coll = collectionManager.create_collection(schema).get();
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::string json_line;
while (std::getline(infile, json_line)) {
auto add_op = coll->add(json_line);
ASSERT_TRUE(add_op.ok());
}
infile.close();
const std::string doc_id_prefix = std::to_string(coll->get_collection_id()) + "_" + Collection::DOC_ID_PREFIX + "_";
filter_node_t* filter_tree_root = nullptr;
Option<bool> filter_op = filter::parse_filter_query("name: foo", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
Option<bool> iter_op(true);
auto iter_no_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
ASSERT_FALSE(iter_no_match_test.valid());
ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
filter_tree_root = nullptr;
filter_op = filter::parse_filter_query("name: [foo bar, baz]", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
auto iter_no_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
ASSERT_FALSE(iter_no_match_multi_test.valid());
ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
filter_tree_root = nullptr;
filter_op = filter::parse_filter_query("name: Jeremy", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
auto iter_contains_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
for (uint32_t i = 0; i < 5; i++) {
ASSERT_TRUE(iter_contains_test.valid());
ASSERT_EQ(i, iter_contains_test.doc);
iter_contains_test.next();
}
ASSERT_FALSE(iter_contains_test.valid());
ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
filter_tree_root = nullptr;
filter_op = filter::parse_filter_query("name: [Jeremy, Howard, Richard]", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
auto iter_contains_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
for (uint32_t i = 0; i < 5; i++) {
ASSERT_TRUE(iter_contains_multi_test.valid());
ASSERT_EQ(i, iter_contains_multi_test.doc);
iter_contains_multi_test.next();
}
ASSERT_FALSE(iter_contains_multi_test.valid());
ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
filter_tree_root = nullptr;
filter_op = filter::parse_filter_query("name:= Jeremy Howard", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
auto iter_exact_match_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
for (uint32_t i = 0; i < 5; i++) {
ASSERT_TRUE(iter_exact_match_test.valid());
ASSERT_EQ(i, iter_exact_match_test.doc);
iter_exact_match_test.next();
}
ASSERT_FALSE(iter_exact_match_test.valid());
ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
filter_tree_root = nullptr;
filter_op = filter::parse_filter_query("tags:= [gold, silver]", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
auto iter_exact_match_multi_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
std::vector<uint32_t> expected = {0, 2, 3, 4};
for (auto const& i : expected) {
ASSERT_TRUE(iter_exact_match_multi_test.valid());
ASSERT_EQ(i, iter_exact_match_multi_test.doc);
iter_exact_match_multi_test.next();
}
ASSERT_FALSE(iter_exact_match_multi_test.valid());
ASSERT_TRUE(iter_op.ok());
// delete filter_tree_root;
// filter_tree_root = nullptr;
// filter_op = filter::parse_filter_query("tags:!= gold", coll->get_schema(), store, doc_id_prefix,
// filter_tree_root);
// ASSERT_TRUE(filter_op.ok());
//
// auto iter_not_equals_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
//
// std::vector<uint32_t> expected = {1, 3};
// for (auto const& i : expected) {
// ASSERT_TRUE(iter_not_equals_test.valid());
// ASSERT_EQ(i, iter_not_equals_test.doc);
// iter_not_equals_test.next();
// }
//
// ASSERT_FALSE(iter_not_equals_test.valid());
// ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
filter_tree_root = nullptr;
filter_op = filter::parse_filter_query("tags: gold", coll->get_schema(), store, doc_id_prefix,
filter_tree_root);
ASSERT_TRUE(filter_op.ok());
auto iter_skip_test = filter_result_iterator_t(coll->get_name(), coll->_get_index(), filter_tree_root, iter_op);
ASSERT_TRUE(iter_skip_test.valid());
iter_skip_test.skip_to(3);
ASSERT_TRUE(iter_skip_test.valid());
ASSERT_EQ(4, iter_skip_test.doc);
iter_skip_test.next();
ASSERT_FALSE(iter_skip_test.valid());
ASSERT_TRUE(iter_op.ok());
delete filter_tree_root;
}