Add Index::rearranging_recursive_filter.

This commit is contained in:
Harpreet Sangar 2023-01-27 12:57:13 +05:30
parent 6c5662bc95
commit 5c5f43195c
4 changed files with 188 additions and 46 deletions

View File

@ -536,6 +536,7 @@ struct filter_node_t {
bool isOperator;
filter_node_t* left;
filter_node_t* right;
std::pair<uint32_t, uint32_t*> match_index_ids;
filter_node_t(filter filter_exp)
: filter_exp(std::move(filter_exp)),
@ -552,6 +553,7 @@ struct filter_node_t {
right(right) {}
~filter_node_t() {
delete[] match_index_ids.second;
delete left;
delete right;
}

View File

@ -99,7 +99,7 @@ struct search_args {
std::vector<query_tokens_t> field_query_tokens;
std::vector<search_field_t> search_fields;
const text_match_type_t match_type;
const filter_node_t* filter_tree_root;
filter_node_t* filter_tree_root;
std::vector<facet>& facets;
std::vector<std::pair<uint32_t, uint32_t>>& included_ids;
std::vector<uint32_t> excluded_ids;
@ -484,14 +484,16 @@ private:
uint32_t*& ids,
size_t& ids_len) const;
void do_filtering(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const root) const;
void do_filtering(filter_node_t* const root) const;
void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const;
void recursive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const root,
const bool enable_short_circuit) const;
filter_node_t* const root,
const bool enable_short_circuit = false) const;
void get_filter_matches(filter_node_t* const root, std::vector<std::pair<uint32_t, filter_node_t*>>& vec) const;
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;
@ -653,7 +655,7 @@ public:
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
filter_node_t const* const& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
filter_node_t* filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
@ -713,10 +715,10 @@ public:
void do_filtering_with_lock(
uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const& filter_tree_root) const;
filter_node_t* filter_tree_root) const;
void do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& reference_index_ids,
filter_node_t const* const& filter_tree_root,
filter_node_t* filter_tree_root,
const std::string& reference_field_name) const;
void refresh_schemas(const std::vector<field>& new_fields, const std::vector<field>& del_fields);

View File

@ -1661,11 +1661,9 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
ids = out;
}
void Index::do_filtering(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const root) const {
void Index::do_filtering(filter_node_t* const root) const {
// auto begin = std::chrono::high_resolution_clock::now();
/**/ const filter a_filter = root->filter_exp;
const filter a_filter = root->filter_exp;
bool is_referenced_filter = !a_filter.referenced_collection_name.empty();
if (is_referenced_filter) {
@ -1673,16 +1671,12 @@ void Index::do_filtering(uint32_t*& filter_ids,
auto& cm = CollectionManager::get_instance();
auto collection = cm.get_collection(a_filter.referenced_collection_name);
std::pair<uint32_t, uint32_t*> documents;
auto op = collection->get_reference_filter_ids(a_filter.field_name,
cm.get_collection_with_id(collection_id)->get_name(),
documents);
root->match_index_ids);
if (!op.ok()) {
return;
}
filter_ids_length = documents.first;
filter_ids = documents.second;
return;
}
@ -1695,17 +1689,9 @@ void Index::do_filtering(uint32_t*& filter_ids,
std::sort(result_ids.begin(), result_ids.end());
if (filter_ids_length == 0) {
filter_ids = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), filter_ids);
filter_ids_length = result_ids.size();
} else {
uint32_t* filtered_results = nullptr;
filter_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, &result_ids[0],
result_ids.size(), &filtered_results);
delete[] filter_ids;
filter_ids = filtered_results;
}
root->match_index_ids.second = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), root->match_index_ids.second);
root->match_index_ids.first = result_ids.size();
return;
}
@ -2005,8 +1991,8 @@ void Index::do_filtering(uint32_t*& filter_ids,
result_ids_len = to_include_ids_len;
}
filter_ids = result_ids;
filter_ids_length = result_ids_len;
root->match_index_ids.first = result_ids_len;
root->match_index_ids.second = result_ids;
/*long long int timeMillis =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now()
@ -2015,38 +2001,131 @@ void Index::do_filtering(uint32_t*& filter_ids,
LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/
}
void Index::recursive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
const filter_node_t* root,
const bool enable_short_circuit) const {
void Index::get_filter_matches(filter_node_t* const root, std::vector<std::pair<uint32_t, filter_node_t*>>& vec) const {
if (root == nullptr) {
return;
}
if (root->isOperator && root->filter_operator == OR) {
uint32_t* l_filter_ids = nullptr;
uint32_t l_filter_ids_length = 0;
if (root->left != nullptr) {
recursive_filter(l_filter_ids, l_filter_ids_length, root->left);
}
uint32_t* r_filter_ids = nullptr;
uint32_t r_filter_ids_length = 0;
if (root->right != nullptr) {
recursive_filter(r_filter_ids, r_filter_ids_length, root->right);
}
root->match_index_ids.first = ArrayUtils::or_scalar(
l_filter_ids, l_filter_ids_length, r_filter_ids,
r_filter_ids_length, &(root->match_index_ids.second));
delete[] l_filter_ids;
delete[] r_filter_ids;
vec.emplace_back(root->match_index_ids.first, root);
} else if (root->left == nullptr && root->right == nullptr) {
do_filtering(root);
vec.emplace_back(root->match_index_ids.first, root);
} else {
get_filter_matches(root->left, vec);
get_filter_matches(root->right, vec);
}
}
void evaluate_filter_tree(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const root,
bool is_rearranged,
std::vector<std::pair<uint32_t, filter_node_t*>>& vec,
size_t& index) {
if (root == nullptr) {
return;
}
if (root->isOperator) {
if (root->filter_operator == AND) {
uint32_t* l_filter_ids = nullptr;
uint32_t l_filter_ids_length = 0;
if (root->left != nullptr) {
evaluate_filter_tree(l_filter_ids, l_filter_ids_length, root->left, is_rearranged, vec, index);
}
uint32_t* r_filter_ids = nullptr;
uint32_t r_filter_ids_length = 0;
if (root->right != nullptr) {
evaluate_filter_tree(r_filter_ids, r_filter_ids_length, root->right, is_rearranged, vec, index);
}
root->match_index_ids.first = ArrayUtils::and_scalar(
l_filter_ids, l_filter_ids_length, r_filter_ids,
r_filter_ids_length, &(root->match_index_ids.second));
filter_ids_length = root->match_index_ids.first;
filter_ids = root->match_index_ids.second;
} else {
filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first;
filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second;
index++;
}
} else if (root->left == nullptr && root->right == nullptr) {
filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first;
filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second;
index++;
} else {
// malformed
}
}
void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const {
std::vector<std::pair<uint32_t, filter_node_t*>> vec;
get_filter_matches(root, vec);
bool should_rearrange = vec.size() > 2;
if (should_rearrange) {
std::sort(vec.begin(), vec.end(),
[](const std::pair<uint32_t, filter_node_t*>& lhs, const std::pair<uint32_t, filter_node_t*>& rhs) {
return lhs.first < rhs.first;
});
}
size_t index = 0;
evaluate_filter_tree(filter_ids, filter_ids_length, root, should_rearrange, vec, index);
}
void Index::recursive_filter(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t* const root,
const bool enable_short_circuit) const {
if (root == nullptr) {
return;
}
uint32_t* l_filter_ids = nullptr;
uint32_t l_filter_ids_length = 0;
if (root->left != nullptr) {
recursive_filter(l_filter_ids, l_filter_ids_length, root->left,
enable_short_circuit);
}
uint32_t* r_filter_ids = nullptr;
uint32_t r_filter_ids_length = 0;
if (root->right != nullptr) {
recursive_filter(r_filter_ids, r_filter_ids_length, root->right,
enable_short_circuit);
}
if (root->isOperator) {
uint32_t* filtered_results = nullptr;
if (root->filter_operator == AND) {
filter_ids_length = ArrayUtils::and_scalar(
l_filter_ids, l_filter_ids_length, r_filter_ids,
r_filter_ids_length, &filtered_results);
l_filter_ids, l_filter_ids_length, r_filter_ids,
r_filter_ids_length, &filtered_results);
} else {
filter_ids_length = ArrayUtils::or_scalar(
l_filter_ids, l_filter_ids_length, r_filter_ids,
r_filter_ids_length, &filtered_results);
l_filter_ids, l_filter_ids_length, r_filter_ids,
r_filter_ids_length, &filtered_results);
}
delete[] l_filter_ids;
@ -2054,7 +2133,10 @@ void Index::recursive_filter(uint32_t*& filter_ids,
filter_ids = filtered_results;
} else if (root->left == nullptr && root->right == nullptr) {
do_filtering(filter_ids, filter_ids_length, root);
do_filtering(root);
filter_ids_length = root->match_index_ids.first;
filter_ids = root->match_index_ids.second;
root->match_index_ids.second = nullptr;
} else {
// malformed
}
@ -2062,13 +2144,13 @@ void Index::recursive_filter(uint32_t*& filter_ids,
void Index::do_filtering_with_lock(uint32_t*& filter_ids,
uint32_t& filter_ids_length,
filter_node_t const* const& filter_tree_root) const {
filter_node_t* filter_tree_root) const {
std::shared_lock lock(mutex);
recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false);
}
void Index::do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& reference_index_ids,
filter_node_t const* const& filter_tree_root,
filter_node_t* filter_tree_root,
const std::string& reference_field_name) const {
std::shared_lock lock(mutex);
recursive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false);
@ -2077,7 +2159,7 @@ void Index::do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& ref
vector.reserve(reference_index_ids.first);
for (uint32_t i = 0; i < reference_index_ids.first; i++) {
auto filtered_doc_id = *(reference_index_ids.second + i);
auto filtered_doc_id = reference_index_ids.second[i];
// Extract the sequence_id from the reference field.
vector.push_back(sort_index.at(reference_field_name)->at(filtered_doc_id));
@ -2550,7 +2632,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name
void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
filter_node_t const* const& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
filter_node_t* filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,

View File

@ -2536,3 +2536,59 @@ TEST_F(CollectionFilteringTest, FilteringAfterUpsertOnArrayWithSymbolsToIndex) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionFilteringTest, ComplexFilterQuery) {
nlohmann::json schema_json =
R"({
"name": "ComplexFilterQueryCollection",
"fields": [
{"name": "name", "type": "string"},
{"name": "age", "type": "int32"},
{"name": "years", "type": "int32[]"},
{"name": "rating", "type": "float"}
]
})"_json;
auto op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(op.ok());
auto coll = op.get();
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::string json_line;
while (std::getline(infile, json_line)) {
auto add_op = coll->add(json_line);
ASSERT_TRUE(add_op.ok());
}
infile.close();
std::vector<sort_by> sort_fields_desc = {sort_by("rating", "DESC")};
nlohmann::json results = coll->search("Jeremy", {"name"}, "(rating:>=0 && years:>2000) && age:>50",
{}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(0, results["hits"].size());
results = coll->search("Jeremy", {"name"}, "(age:>50 || rating:>5) && years:<2000",
{}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(2, results["hits"].size());
std::vector<std::string> ids = {"4", "3"};
for (size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
results = coll->search("Jeremy", {"name"}, "(age:<50 && rating:10) || (years:>2000 && rating:<5)",
{}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(1, results["hits"].size());
ids = {"0"};
for (size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
collectionManager.drop_collection("ComplexFilterQueryCollection");
}