From 12debac7197cb9ef30df7d0cc4e5d52d8d0b651b Mon Sep 17 00:00:00 2001 From: Harpreet Sangar Date: Wed, 26 Apr 2023 20:36:09 +0530 Subject: [PATCH] Remove `Index::do_filtering`. Using `filter_result_t` instead. --- include/index.h | 19 -- src/index.cpp | 517 +++--------------------------------------------- 2 files changed, 25 insertions(+), 511 deletions(-) diff --git a/include/index.h b/include/index.h index 41e08080..dd4bf224 100644 --- a/include/index.h +++ b/include/index.h @@ -475,31 +475,12 @@ private: bool field_is_indexed(const std::string& field_name) const; - Option do_filtering(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name = "", - const uint32_t& context_ids_length = 0, - uint32_t* const& context_ids = nullptr) const; - void aproximate_numerical_match(num_tree_t* const num_tree, const NUM_COMPARATOR& comparator, const int64_t& value, const int64_t& range_end_value, uint32_t& filter_ids_length) const; - /// Traverses through filter tree to get the filter_result. - /// - /// \param filter_tree_root - /// \param filter_result - /// \param collection_name Name of the collection to which current index belongs. Used to find the reference field in other collection. - /// \param context_ids_length Number of docs matching the search query. - /// \param context_ids Array of doc ids matching the search query. - Option recursive_filter(filter_node_t* const filter_tree_root, - filter_result_t& filter_result, - const std::string& collection_name = "", - const uint32_t& context_ids_length = 0, - uint32_t* const& context_ids = nullptr) const; - void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id, const std::unordered_map> &token_to_offsets) const; diff --git a/src/index.cpp b/src/index.cpp index 759e3f30..a4937412 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1538,446 +1538,6 @@ bool Index::field_is_indexed(const std::string& field_name) const { geopoint_index.count(field_name) != 0; } -Option Index::do_filtering(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name, - const uint32_t& context_ids_length, - uint32_t* const& context_ids) const { - // auto begin = std::chrono::high_resolution_clock::now(); - const filter a_filter = root->filter_exp; - - bool is_referenced_filter = !a_filter.referenced_collection_name.empty(); - if (is_referenced_filter) { - // Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents. - auto& cm = CollectionManager::get_instance(); - auto collection = cm.get_collection(a_filter.referenced_collection_name); - if (collection == nullptr) { - return Option(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found."); - } - - filter_result_t reference_filter_result; - auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name, - reference_filter_result, - collection_name); - if (!reference_filter_op.ok()) { - return Option(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name - + "` collection: " + reference_filter_op.error()); - } - - if (context_ids_length != 0) { - std::vector include_indexes; - include_indexes.reserve(std::min(context_ids_length, reference_filter_result.count)); - - size_t context_index = 0, reference_result_index = 0; - while (context_index < context_ids_length && reference_result_index < reference_filter_result.count) { - if (context_ids[context_index] == reference_filter_result.docs[reference_result_index]) { - include_indexes.push_back(reference_result_index); - context_index++; - reference_result_index++; - } else if (context_ids[context_index] < reference_filter_result.docs[reference_result_index]) { - context_index++; - } else { - reference_result_index++; - } - } - - result.count = include_indexes.size(); - result.docs = new uint32_t[include_indexes.size()]; - auto& result_references = result.reference_filter_results[a_filter.referenced_collection_name]; - result_references = new reference_filter_result_t[include_indexes.size()]; - - for (uint32_t i = 0; i < include_indexes.size(); i++) { - result.docs[i] = reference_filter_result.docs[include_indexes[i]]; - result_references[i] = reference_filter_result.reference_filter_results[a_filter.referenced_collection_name][include_indexes[i]]; - } - - return Option(true); - } - - result = std::move(reference_filter_result); - return Option(true); - } - - if (a_filter.field_name == "id") { - // we handle `ids` separately - std::vector result_ids; - for (const auto& id_str : a_filter.values) { - result_ids.push_back(std::stoul(id_str)); - } - - std::sort(result_ids.begin(), result_ids.end()); - - auto result_array = new uint32[result_ids.size()]; - std::copy(result_ids.begin(), result_ids.end(), result_array); - - if (context_ids_length != 0) { - uint32_t* out = nullptr; - result.count = ArrayUtils::and_scalar(context_ids, context_ids_length, - result_array, result_ids.size(), &out); - - delete[] result_array; - - result.docs = out; - return Option(true); - } - - result.docs = result_array; - result.count = result_ids.size(); - return Option(true); - } - - if (!field_is_indexed(a_filter.field_name)) { - return Option(true); - } - - field f = search_schema.at(a_filter.field_name); - - uint32_t* result_ids = nullptr; - size_t result_ids_len = 0; - - if (f.is_integer()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - int64_t value = (int64_t)std::stol(filter_value); - - if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi + 1]; - auto const range_end_value = (int64_t)std::stol(next_filter_value); - - if (context_ids_length != 0) { - num_tree->range_inclusive_contains(value, range_end_value, context_ids_length, context_ids, - result_ids_len, result_ids); - } else { - num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len); - } - - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[fi], value, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len); - } - } - } - } else if (f.is_float()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - float value = (float)std::atof(filter_value.c_str()); - int64_t float_int64 = float_to_int64_t(value); - - if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi+1]; - int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str())); - - if (context_ids_length != 0) { - num_tree->range_inclusive_contains(float_int64, range_end_value, context_ids_length, context_ids, - result_ids_len, result_ids); - } else { - num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len); - } - - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, float_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[fi], float_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len); - } - } - } - } else if (f.is_bool()) { - auto num_tree = numerical_index.at(a_filter.field_name); - - size_t value_index = 0; - for (const std::string& filter_value : a_filter.values) { - int64_t bool_int64 = (filter_value == "1") ? 1 : 0; - if (a_filter.comparators[value_index] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, bool_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - if (context_ids_length != 0) { - num_tree->contains(a_filter.comparators[value_index], bool_int64, - context_ids_length, context_ids, result_ids_len, result_ids); - } else { - num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len); - } - } - - value_index++; - } - } else if (f.is_geopoint()) { - for (const std::string& filter_value : a_filter.values) { - std::vector geo_result_ids; - - std::vector filter_value_parts; - StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points - - bool is_polygon = StringUtils::is_float(filter_value_parts.back()); - S2Region* query_region; - - if (is_polygon) { - const int num_verts = int(filter_value_parts.size()) / 2; - std::vector vertices; - double sum = 0.0; - - for (size_t point_index = 0; point_index < size_t(num_verts); - point_index++) { - double lat = std::stod(filter_value_parts[point_index * 2]); - double lon = std::stod(filter_value_parts[point_index * 2 + 1]); - S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint(); - vertices.emplace_back(vertex); - } - - auto loop = new S2Loop(vertices, S2Debug::DISABLE); - loop->Normalize(); // if loop is not CCW but CW, change to CCW. - - S2Error error; - if (loop->FindValidationError(&error)) { - LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error; - delete loop; - continue; - } else { - query_region = loop; - } - } else { - double radius = std::stof(filter_value_parts[2]); - const auto& unit = filter_value_parts[3]; - - if (unit == "km") { - radius *= 1000; - } else { - // assume "mi" (validated upstream) - radius *= 1609.34; - } - - S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius)); - double query_lat = std::stod(filter_value_parts[0]); - double query_lng = std::stod(filter_value_parts[1]); - S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint(); - query_region = new S2Cap(center, query_radius); - } - - S2RegionTermIndexer::Options options; - options.set_index_contains_points_only(true); - S2RegionTermIndexer indexer(options); - - for (const auto& term : indexer.GetQueryTerms(*query_region, "")) { - auto geo_index = geopoint_index.at(a_filter.field_name); - const auto& ids_it = geo_index->find(term); - if(ids_it != geo_index->end()) { - geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end()); - } - } - - gfx::timsort(geo_result_ids.begin(), geo_result_ids.end()); - geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end()); - - // `geo_result_ids` will contain all IDs that are within approximately within query radius - // we still need to do another round of exact filtering on them - - if (context_ids_length != 0) { - uint32_t *out = nullptr; - uint32_t count = ArrayUtils::and_scalar(context_ids, context_ids_length, - &geo_result_ids[0], geo_result_ids.size(), &out); - - geo_result_ids = std::vector(out, out + count); - } - - std::vector exact_geo_result_ids; - - if (f.is_single_geopoint()) { - spp::sparse_hash_map* sort_field_index = sort_index.at(f.name); - - for (auto result_id : geo_result_ids) { - // no need to check for existence of `result_id` because of indexer based pre-filtering above - int64_t lat_lng = sort_field_index->at(result_id); - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); - if (query_region->Contains(s2_lat_lng.ToPoint())) { - exact_geo_result_ids.push_back(result_id); - } - } - } else { - spp::sparse_hash_map* geo_field_index = geo_array_index.at(f.name); - - for (auto result_id : geo_result_ids) { - int64_t* lat_lngs = geo_field_index->at(result_id); - - bool point_found = false; - - // any one point should exist - for (size_t li = 0; li < lat_lngs[0]; li++) { - int64_t lat_lng = lat_lngs[li + 1]; - S2LatLng s2_lat_lng; - GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng); - if (query_region->Contains(s2_lat_lng.ToPoint())) { - point_found = true; - break; - } - } - - if (point_found) { - exact_geo_result_ids.push_back(result_id); - } - } - } - - uint32_t* out = nullptr; - result_ids_len = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(), - result_ids, result_ids_len, &out); - - delete[] result_ids; - result_ids = out; - - delete query_region; - } - } else if (f.is_string()) { - art_tree* t = search_index.at(a_filter.field_name); - - uint32_t* or_ids = nullptr; - size_t or_ids_size = 0; - - // aggregates IDs across array of filter values and reduces excessive ORing - std::vector f_id_buff; - - for (const std::string& filter_value : a_filter.values) { - std::vector posting_lists; - - // there could be multiple tokens in a filter value, which we have to treat as ANDs - // e.g. country: South Africa - Tokenizer tokenizer(filter_value, true, false, f.locale, symbols_to_index, token_separators); - - std::string str_token; - size_t token_index = 0; - std::vector str_tokens; - - while (tokenizer.next(str_token, token_index)) { - str_tokens.push_back(str_token); - - art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(), - str_token.length()+1); - if (leaf == nullptr) { - continue; - } - - posting_lists.push_back(leaf->values); - } - - if (posting_lists.size() != str_tokens.size()) { - continue; - } - - if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) { - // needs intersection + exact matching (unlike CONTAINS) - std::vector result_id_vec; - posting_t::intersect(posting_lists, result_id_vec, context_ids_length, context_ids); - - if (result_id_vec.empty()) { - continue; - } - - // need to do exact match - uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()]; - size_t exact_str_ids_size = 0; - std::unique_ptr exact_str_ids_guard(exact_str_ids); - - posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(), - exact_str_ids, exact_str_ids_size); - - if (exact_str_ids_size == 0) { - continue; - } - - for (size_t ei = 0; ei < exact_str_ids_size; ei++) { - f_id_buff.push_back(exact_str_ids[ei]); - } - } else { - // CONTAINS - size_t before_size = f_id_buff.size(); - posting_t::intersect(posting_lists, f_id_buff, context_ids_length, context_ids); - if (f_id_buff.size() == before_size) { - continue; - } - } - - if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) { - gfx::timsort(f_id_buff.begin(), f_id_buff.end()); - f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); - - uint32_t* out = nullptr; - or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); - delete[] or_ids; - or_ids = out; - std::vector().swap(f_id_buff); // clears out memory - } - } - - if (!f_id_buff.empty()) { - gfx::timsort(f_id_buff.begin(), f_id_buff.end()); - f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end()); - - uint32_t* out = nullptr; - or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out); - delete[] or_ids; - or_ids = out; - std::vector().swap(f_id_buff); // clears out memory - } - - result_ids = or_ids; - result_ids_len = or_ids_size; - } - - if (a_filter.apply_not_equals) { - auto all_ids = seq_ids->uncompress(); - auto all_ids_size = seq_ids->num_ids(); - - uint32_t* to_include_ids = nullptr; - size_t to_include_ids_len = 0; - - to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, result_ids, - result_ids_len, &to_include_ids); - - delete[] all_ids; - delete[] result_ids; - - result_ids = to_include_ids; - result_ids_len = to_include_ids_len; - - if (context_ids_length != 0) { - uint32_t *out = nullptr; - result.count = ArrayUtils::and_scalar(context_ids, context_ids_length, - result_ids, result_ids_len, &out); - - delete[] result_ids; - - result.docs = out; - return Option(true); - } - } - - result.docs = result_ids; - result.count = result_ids_len; - - return Option(true); - /*long long int timeMillis = - std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - - begin).count(); - - LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/ -} - void Index::aproximate_numerical_match(num_tree_t* const num_tree, const NUM_COMPARATOR& comparator, const int64_t& value, @@ -2148,54 +1708,19 @@ Option Index::rearrange_filter_tree(filter_node_t* const root, return Option(true); } -Option Index::recursive_filter(filter_node_t* const root, - filter_result_t& result, - const std::string& collection_name, - const uint32_t& context_ids_length, - uint32_t* const& context_ids) const { - if (root == nullptr) { - return Option(true); - } - - if (root->isOperator) { - filter_result_t l_result; - if (root->left != nullptr) { - auto filter_op = recursive_filter(root->left, l_result , collection_name, context_ids_length, context_ids); - if (!filter_op.ok()) { - return filter_op; - } - } - - filter_result_t r_result; - if (root->right != nullptr) { - auto filter_op = recursive_filter(root->right, r_result , collection_name, context_ids_length, context_ids); - if (!filter_op.ok()) { - return filter_op; - } - } - - if (root->filter_operator == AND) { - filter_result_t::and_filter_results(l_result, r_result, result); - } else { - filter_result_t::or_filter_results(l_result, r_result, result); - } - - return Option(true); - } - - return do_filtering(root, result, collection_name, context_ids_length, context_ids); -} - Option Index::do_filtering_with_lock(filter_node_t* const filter_tree_root, filter_result_t& filter_result, const std::string& collection_name) const { std::shared_lock lock(mutex); - auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } + filter_result.count = filter_result_iterator.to_filter_id_array(filter_result.docs); + return Option(true); } @@ -2205,16 +1730,20 @@ Option Index::do_reference_filtering_with_lock(filter_node_t* const filter const std::string& reference_helper_field_name) const { std::shared_lock lock(mutex); - filter_result_t reference_filter_result; - auto filter_op = recursive_filter(filter_tree_root, reference_filter_result); - if (!filter_op.ok()) { - return filter_op; + auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return filter_init_op; } + uint32_t* reference_docs = nullptr; + uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs); + std::unique_ptr docs_guard(reference_docs); + // doc id -> reference doc ids std::map> reference_map; - for (uint32_t i = 0; i < reference_filter_result.count; i++) { - auto reference_doc_id = reference_filter_result.docs[i]; + for (uint32_t i = 0; i < count; i++) { + auto reference_doc_id = reference_docs[i]; auto doc_id = sort_index.at(reference_helper_field_name)->at(reference_doc_id); reference_map[doc_id].push_back(reference_doc_id); @@ -5105,11 +4634,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &seq_id_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::eval) { field_values[i] = &eval_sentinel_value; - filter_result_t result; - recursive_filter(sort_fields_std[i].eval.filter_tree_root, result); - sort_fields_std[i].eval.ids = result.docs; - sort_fields_std[i].eval.size = result.count; - result.docs = nullptr; + + auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root); + auto filter_init_op = filter_result_iterator.init_status(); + if (!filter_init_op.ok()) { + return; + } + + sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids); + } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i);