mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 14:12:27 +08:00
Add Index::rearranging_recursive_filter
.
This commit is contained in:
parent
6c5662bc95
commit
5c5f43195c
@ -536,6 +536,7 @@ struct filter_node_t {
|
||||
bool isOperator;
|
||||
filter_node_t* left;
|
||||
filter_node_t* right;
|
||||
std::pair<uint32_t, uint32_t*> match_index_ids;
|
||||
|
||||
filter_node_t(filter filter_exp)
|
||||
: filter_exp(std::move(filter_exp)),
|
||||
@ -552,6 +553,7 @@ struct filter_node_t {
|
||||
right(right) {}
|
||||
|
||||
~filter_node_t() {
|
||||
delete[] match_index_ids.second;
|
||||
delete left;
|
||||
delete right;
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ struct search_args {
|
||||
std::vector<query_tokens_t> field_query_tokens;
|
||||
std::vector<search_field_t> search_fields;
|
||||
const text_match_type_t match_type;
|
||||
const filter_node_t* filter_tree_root;
|
||||
filter_node_t* filter_tree_root;
|
||||
std::vector<facet>& facets;
|
||||
std::vector<std::pair<uint32_t, uint32_t>>& included_ids;
|
||||
std::vector<uint32_t> excluded_ids;
|
||||
@ -484,14 +484,16 @@ private:
|
||||
uint32_t*& ids,
|
||||
size_t& ids_len) const;
|
||||
|
||||
void do_filtering(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const root) const;
|
||||
void do_filtering(filter_node_t* const root) const;
|
||||
|
||||
void rearranging_recursive_filter (uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const;
|
||||
|
||||
void recursive_filter(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const root,
|
||||
const bool enable_short_circuit) const;
|
||||
filter_node_t* const root,
|
||||
const bool enable_short_circuit = false) const;
|
||||
|
||||
void get_filter_matches(filter_node_t* const root, std::vector<std::pair<uint32_t, filter_node_t*>>& vec) const;
|
||||
|
||||
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
|
||||
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;
|
||||
@ -653,7 +655,7 @@ public:
|
||||
|
||||
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
|
||||
const text_match_type_t match_type,
|
||||
filter_node_t const* const& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
filter_node_t* filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
|
||||
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
|
||||
@ -713,10 +715,10 @@ public:
|
||||
void do_filtering_with_lock(
|
||||
uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const& filter_tree_root) const;
|
||||
filter_node_t* filter_tree_root) const;
|
||||
|
||||
void do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& reference_index_ids,
|
||||
filter_node_t const* const& filter_tree_root,
|
||||
filter_node_t* filter_tree_root,
|
||||
const std::string& reference_field_name) const;
|
||||
|
||||
void refresh_schemas(const std::vector<field>& new_fields, const std::vector<field>& del_fields);
|
||||
|
156
src/index.cpp
156
src/index.cpp
@ -1661,11 +1661,9 @@ void Index::numeric_not_equals_filter(num_tree_t* const num_tree,
|
||||
ids = out;
|
||||
}
|
||||
|
||||
void Index::do_filtering(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const root) const {
|
||||
void Index::do_filtering(filter_node_t* const root) const {
|
||||
// auto begin = std::chrono::high_resolution_clock::now();
|
||||
/**/ const filter a_filter = root->filter_exp;
|
||||
const filter a_filter = root->filter_exp;
|
||||
|
||||
bool is_referenced_filter = !a_filter.referenced_collection_name.empty();
|
||||
if (is_referenced_filter) {
|
||||
@ -1673,16 +1671,12 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto collection = cm.get_collection(a_filter.referenced_collection_name);
|
||||
|
||||
std::pair<uint32_t, uint32_t*> documents;
|
||||
auto op = collection->get_reference_filter_ids(a_filter.field_name,
|
||||
cm.get_collection_with_id(collection_id)->get_name(),
|
||||
documents);
|
||||
root->match_index_ids);
|
||||
if (!op.ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
filter_ids_length = documents.first;
|
||||
filter_ids = documents.second;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1695,17 +1689,9 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
|
||||
std::sort(result_ids.begin(), result_ids.end());
|
||||
|
||||
if (filter_ids_length == 0) {
|
||||
filter_ids = new uint32[result_ids.size()];
|
||||
std::copy(result_ids.begin(), result_ids.end(), filter_ids);
|
||||
filter_ids_length = result_ids.size();
|
||||
} else {
|
||||
uint32_t* filtered_results = nullptr;
|
||||
filter_ids_length = ArrayUtils::and_scalar(filter_ids, filter_ids_length, &result_ids[0],
|
||||
result_ids.size(), &filtered_results);
|
||||
delete[] filter_ids;
|
||||
filter_ids = filtered_results;
|
||||
}
|
||||
root->match_index_ids.second = new uint32[result_ids.size()];
|
||||
std::copy(result_ids.begin(), result_ids.end(), root->match_index_ids.second);
|
||||
root->match_index_ids.first = result_ids.size();
|
||||
|
||||
return;
|
||||
}
|
||||
@ -2005,8 +1991,8 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
result_ids_len = to_include_ids_len;
|
||||
}
|
||||
|
||||
filter_ids = result_ids;
|
||||
filter_ids_length = result_ids_len;
|
||||
root->match_index_ids.first = result_ids_len;
|
||||
root->match_index_ids.second = result_ids;
|
||||
|
||||
/*long long int timeMillis =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now()
|
||||
@ -2015,38 +2001,131 @@ void Index::do_filtering(uint32_t*& filter_ids,
|
||||
LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/
|
||||
}
|
||||
|
||||
void Index::recursive_filter(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
const filter_node_t* root,
|
||||
const bool enable_short_circuit) const {
|
||||
void Index::get_filter_matches(filter_node_t* const root, std::vector<std::pair<uint32_t, filter_node_t*>>& vec) const {
|
||||
if (root == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (root->isOperator && root->filter_operator == OR) {
|
||||
uint32_t* l_filter_ids = nullptr;
|
||||
uint32_t l_filter_ids_length = 0;
|
||||
if (root->left != nullptr) {
|
||||
recursive_filter(l_filter_ids, l_filter_ids_length, root->left);
|
||||
}
|
||||
|
||||
uint32_t* r_filter_ids = nullptr;
|
||||
uint32_t r_filter_ids_length = 0;
|
||||
if (root->right != nullptr) {
|
||||
recursive_filter(r_filter_ids, r_filter_ids_length, root->right);
|
||||
}
|
||||
|
||||
root->match_index_ids.first = ArrayUtils::or_scalar(
|
||||
l_filter_ids, l_filter_ids_length, r_filter_ids,
|
||||
r_filter_ids_length, &(root->match_index_ids.second));
|
||||
|
||||
delete[] l_filter_ids;
|
||||
delete[] r_filter_ids;
|
||||
|
||||
vec.emplace_back(root->match_index_ids.first, root);
|
||||
} else if (root->left == nullptr && root->right == nullptr) {
|
||||
do_filtering(root);
|
||||
vec.emplace_back(root->match_index_ids.first, root);
|
||||
} else {
|
||||
get_filter_matches(root->left, vec);
|
||||
get_filter_matches(root->right, vec);
|
||||
}
|
||||
}
|
||||
|
||||
void evaluate_filter_tree(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t* const root,
|
||||
bool is_rearranged,
|
||||
std::vector<std::pair<uint32_t, filter_node_t*>>& vec,
|
||||
size_t& index) {
|
||||
if (root == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (root->isOperator) {
|
||||
if (root->filter_operator == AND) {
|
||||
|
||||
uint32_t* l_filter_ids = nullptr;
|
||||
uint32_t l_filter_ids_length = 0;
|
||||
if (root->left != nullptr) {
|
||||
evaluate_filter_tree(l_filter_ids, l_filter_ids_length, root->left, is_rearranged, vec, index);
|
||||
}
|
||||
|
||||
uint32_t* r_filter_ids = nullptr;
|
||||
uint32_t r_filter_ids_length = 0;
|
||||
if (root->right != nullptr) {
|
||||
evaluate_filter_tree(r_filter_ids, r_filter_ids_length, root->right, is_rearranged, vec, index);
|
||||
}
|
||||
|
||||
root->match_index_ids.first = ArrayUtils::and_scalar(
|
||||
l_filter_ids, l_filter_ids_length, r_filter_ids,
|
||||
r_filter_ids_length, &(root->match_index_ids.second));
|
||||
|
||||
filter_ids_length = root->match_index_ids.first;
|
||||
filter_ids = root->match_index_ids.second;
|
||||
} else {
|
||||
filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first;
|
||||
filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second;
|
||||
index++;
|
||||
}
|
||||
} else if (root->left == nullptr && root->right == nullptr) {
|
||||
filter_ids_length = is_rearranged ? vec[index].first : root->match_index_ids.first;
|
||||
filter_ids = is_rearranged ? vec[index].second->match_index_ids.second : root->match_index_ids.second;
|
||||
index++;
|
||||
} else {
|
||||
// malformed
|
||||
}
|
||||
}
|
||||
|
||||
void Index::rearranging_recursive_filter(uint32_t*& filter_ids, uint32_t& filter_ids_length, filter_node_t* const root) const {
|
||||
std::vector<std::pair<uint32_t, filter_node_t*>> vec;
|
||||
get_filter_matches(root, vec);
|
||||
|
||||
bool should_rearrange = vec.size() > 2;
|
||||
if (should_rearrange) {
|
||||
std::sort(vec.begin(), vec.end(),
|
||||
[](const std::pair<uint32_t, filter_node_t*>& lhs, const std::pair<uint32_t, filter_node_t*>& rhs) {
|
||||
return lhs.first < rhs.first;
|
||||
});
|
||||
}
|
||||
|
||||
size_t index = 0;
|
||||
evaluate_filter_tree(filter_ids, filter_ids_length, root, should_rearrange, vec, index);
|
||||
}
|
||||
|
||||
void Index::recursive_filter(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t* const root,
|
||||
const bool enable_short_circuit) const {
|
||||
if (root == nullptr) {
|
||||
return;
|
||||
}
|
||||
uint32_t* l_filter_ids = nullptr;
|
||||
uint32_t l_filter_ids_length = 0;
|
||||
if (root->left != nullptr) {
|
||||
recursive_filter(l_filter_ids, l_filter_ids_length, root->left,
|
||||
enable_short_circuit);
|
||||
}
|
||||
|
||||
uint32_t* r_filter_ids = nullptr;
|
||||
uint32_t r_filter_ids_length = 0;
|
||||
if (root->right != nullptr) {
|
||||
recursive_filter(r_filter_ids, r_filter_ids_length, root->right,
|
||||
enable_short_circuit);
|
||||
}
|
||||
|
||||
if (root->isOperator) {
|
||||
uint32_t* filtered_results = nullptr;
|
||||
if (root->filter_operator == AND) {
|
||||
filter_ids_length = ArrayUtils::and_scalar(
|
||||
l_filter_ids, l_filter_ids_length, r_filter_ids,
|
||||
r_filter_ids_length, &filtered_results);
|
||||
l_filter_ids, l_filter_ids_length, r_filter_ids,
|
||||
r_filter_ids_length, &filtered_results);
|
||||
} else {
|
||||
filter_ids_length = ArrayUtils::or_scalar(
|
||||
l_filter_ids, l_filter_ids_length, r_filter_ids,
|
||||
r_filter_ids_length, &filtered_results);
|
||||
l_filter_ids, l_filter_ids_length, r_filter_ids,
|
||||
r_filter_ids_length, &filtered_results);
|
||||
}
|
||||
|
||||
delete[] l_filter_ids;
|
||||
@ -2054,7 +2133,10 @@ void Index::recursive_filter(uint32_t*& filter_ids,
|
||||
|
||||
filter_ids = filtered_results;
|
||||
} else if (root->left == nullptr && root->right == nullptr) {
|
||||
do_filtering(filter_ids, filter_ids_length, root);
|
||||
do_filtering(root);
|
||||
filter_ids_length = root->match_index_ids.first;
|
||||
filter_ids = root->match_index_ids.second;
|
||||
root->match_index_ids.second = nullptr;
|
||||
} else {
|
||||
// malformed
|
||||
}
|
||||
@ -2062,13 +2144,13 @@ void Index::recursive_filter(uint32_t*& filter_ids,
|
||||
|
||||
void Index::do_filtering_with_lock(uint32_t*& filter_ids,
|
||||
uint32_t& filter_ids_length,
|
||||
filter_node_t const* const& filter_tree_root) const {
|
||||
filter_node_t* filter_tree_root) const {
|
||||
std::shared_lock lock(mutex);
|
||||
recursive_filter(filter_ids, filter_ids_length, filter_tree_root, false);
|
||||
}
|
||||
|
||||
void Index::do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& reference_index_ids,
|
||||
filter_node_t const* const& filter_tree_root,
|
||||
filter_node_t* filter_tree_root,
|
||||
const std::string& reference_field_name) const {
|
||||
std::shared_lock lock(mutex);
|
||||
recursive_filter(reference_index_ids.second, reference_index_ids.first, filter_tree_root, false);
|
||||
@ -2077,7 +2159,7 @@ void Index::do_reference_filtering_with_lock(std::pair<uint32_t, uint32_t*>& ref
|
||||
vector.reserve(reference_index_ids.first);
|
||||
|
||||
for (uint32_t i = 0; i < reference_index_ids.first; i++) {
|
||||
auto filtered_doc_id = *(reference_index_ids.second + i);
|
||||
auto filtered_doc_id = reference_index_ids.second[i];
|
||||
|
||||
// Extract the sequence_id from the reference field.
|
||||
vector.push_back(sort_index.at(reference_field_name)->at(filtered_doc_id));
|
||||
@ -2550,7 +2632,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name
|
||||
|
||||
void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
|
||||
const text_match_type_t match_type,
|
||||
filter_node_t const* const& filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
filter_node_t* filter_tree_root, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
|
||||
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
|
||||
|
@ -2536,3 +2536,59 @@ TEST_F(CollectionFilteringTest, FilteringAfterUpsertOnArrayWithSymbolsToIndex) {
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionFilteringTest, ComplexFilterQuery) {
|
||||
nlohmann::json schema_json =
|
||||
R"({
|
||||
"name": "ComplexFilterQueryCollection",
|
||||
"fields": [
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "age", "type": "int32"},
|
||||
{"name": "years", "type": "int32[]"},
|
||||
{"name": "rating", "type": "float"}
|
||||
]
|
||||
})"_json;
|
||||
|
||||
auto op = collectionManager.create_collection(schema_json);
|
||||
ASSERT_TRUE(op.ok());
|
||||
auto coll = op.get();
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
|
||||
std::string json_line;
|
||||
while (std::getline(infile, json_line)) {
|
||||
auto add_op = coll->add(json_line);
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
infile.close();
|
||||
|
||||
std::vector<sort_by> sort_fields_desc = {sort_by("rating", "DESC")};
|
||||
nlohmann::json results = coll->search("Jeremy", {"name"}, "(rating:>=0 && years:>2000) && age:>50",
|
||||
{}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(0, results["hits"].size());
|
||||
|
||||
results = coll->search("Jeremy", {"name"}, "(age:>50 || rating:>5) && years:<2000",
|
||||
{}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"4", "3"};
|
||||
for (size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll->search("Jeremy", {"name"}, "(age:<50 && rating:10) || (years:>2000 && rating:<5)",
|
||||
{}, sort_fields_desc, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
|
||||
ids = {"0"};
|
||||
for (size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
collectionManager.drop_collection("ComplexFilterQueryCollection");
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user