Remove Index::do_filtering. Using filter_result_t instead.

This commit is contained in:
Harpreet Sangar 2023-04-26 20:36:09 +05:30
parent 2f615fe1ff
commit 5dbfb9df63
2 changed files with 25 additions and 511 deletions

View File

@ -472,31 +472,12 @@ private:
bool field_is_indexed(const std::string& field_name) const;
Option<bool> do_filtering(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name = "",
const uint32_t& context_ids_length = 0,
uint32_t* const& context_ids = nullptr) const;
void aproximate_numerical_match(num_tree_t* const num_tree,
const NUM_COMPARATOR& comparator,
const int64_t& value,
const int64_t& range_end_value,
uint32_t& filter_ids_length) const;
/// Traverses through filter tree to get the filter_result.
///
/// \param filter_tree_root
/// \param filter_result
/// \param collection_name Name of the collection to which current index belongs. Used to find the reference field in other collection.
/// \param context_ids_length Number of docs matching the search query.
/// \param context_ids Array of doc ids matching the search query.
Option<bool> recursive_filter(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string& collection_name = "",
const uint32_t& context_ids_length = 0,
uint32_t* const& context_ids = nullptr) const;
void insert_doc(const int64_t score, art_tree *t, uint32_t seq_id,
const std::unordered_map<std::string, std::vector<uint32_t>> &token_to_offsets) const;

View File

@ -1516,446 +1516,6 @@ bool Index::field_is_indexed(const std::string& field_name) const {
geopoint_index.count(field_name) != 0;
}
Option<bool> Index::do_filtering(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name,
const uint32_t& context_ids_length,
uint32_t* const& context_ids) const {
// auto begin = std::chrono::high_resolution_clock::now();
const filter a_filter = root->filter_exp;
bool is_referenced_filter = !a_filter.referenced_collection_name.empty();
if (is_referenced_filter) {
// Apply filter on referenced collection and get the sequence ids of current collection from the filtered documents.
auto& cm = CollectionManager::get_instance();
auto collection = cm.get_collection(a_filter.referenced_collection_name);
if (collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + a_filter.referenced_collection_name + "` not found.");
}
filter_result_t reference_filter_result;
auto reference_filter_op = collection->get_reference_filter_ids(a_filter.field_name,
reference_filter_result,
collection_name);
if (!reference_filter_op.ok()) {
return Option<bool>(400, "Failed to apply reference filter on `" + a_filter.referenced_collection_name
+ "` collection: " + reference_filter_op.error());
}
if (context_ids_length != 0) {
std::vector<uint32_t> include_indexes;
include_indexes.reserve(std::min(context_ids_length, reference_filter_result.count));
size_t context_index = 0, reference_result_index = 0;
while (context_index < context_ids_length && reference_result_index < reference_filter_result.count) {
if (context_ids[context_index] == reference_filter_result.docs[reference_result_index]) {
include_indexes.push_back(reference_result_index);
context_index++;
reference_result_index++;
} else if (context_ids[context_index] < reference_filter_result.docs[reference_result_index]) {
context_index++;
} else {
reference_result_index++;
}
}
result.count = include_indexes.size();
result.docs = new uint32_t[include_indexes.size()];
auto& result_references = result.reference_filter_results[a_filter.referenced_collection_name];
result_references = new reference_filter_result_t[include_indexes.size()];
for (uint32_t i = 0; i < include_indexes.size(); i++) {
result.docs[i] = reference_filter_result.docs[include_indexes[i]];
result_references[i] = reference_filter_result.reference_filter_results[a_filter.referenced_collection_name][include_indexes[i]];
}
return Option(true);
}
result = std::move(reference_filter_result);
return Option(true);
}
if (a_filter.field_name == "id") {
// we handle `ids` separately
std::vector<uint32> result_ids;
for (const auto& id_str : a_filter.values) {
result_ids.push_back(std::stoul(id_str));
}
std::sort(result_ids.begin(), result_ids.end());
auto result_array = new uint32[result_ids.size()];
std::copy(result_ids.begin(), result_ids.end(), result_array);
if (context_ids_length != 0) {
uint32_t* out = nullptr;
result.count = ArrayUtils::and_scalar(context_ids, context_ids_length,
result_array, result_ids.size(), &out);
delete[] result_array;
result.docs = out;
return Option(true);
}
result.docs = result_array;
result.count = result_ids.size();
return Option(true);
}
if (!field_is_indexed(a_filter.field_name)) {
return Option(true);
}
field f = search_schema.at(a_filter.field_name);
uint32_t* result_ids = nullptr;
size_t result_ids_len = 0;
if (f.is_integer()) {
auto num_tree = numerical_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
int64_t value = (int64_t)std::stol(filter_value);
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = (int64_t)std::stol(next_filter_value);
if (context_ids_length != 0) {
num_tree->range_inclusive_contains(value, range_end_value, context_ids_length, context_ids,
result_ids_len, result_ids);
} else {
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
}
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value, context_ids_length, context_ids, result_ids_len, result_ids);
} else {
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[fi], value,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
}
}
}
} else if (f.is_float()) {
auto num_tree = numerical_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
float value = (float)std::atof(filter_value.c_str());
int64_t float_int64 = float_to_int64_t(value);
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = float_to_int64_t((float) std::atof(next_filter_value.c_str()));
if (context_ids_length != 0) {
num_tree->range_inclusive_contains(float_int64, range_end_value, context_ids_length, context_ids,
result_ids_len, result_ids);
} else {
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
}
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, float_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[fi], float_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
}
}
}
} else if (f.is_bool()) {
auto num_tree = numerical_index.at(a_filter.field_name);
size_t value_index = 0;
for (const std::string& filter_value : a_filter.values) {
int64_t bool_int64 = (filter_value == "1") ? 1 : 0;
if (a_filter.comparators[value_index] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, bool_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
if (context_ids_length != 0) {
num_tree->contains(a_filter.comparators[value_index], bool_int64,
context_ids_length, context_ids, result_ids_len, result_ids);
} else {
num_tree->search(a_filter.comparators[value_index], bool_int64, &result_ids, result_ids_len);
}
}
value_index++;
}
} else if (f.is_geopoint()) {
for (const std::string& filter_value : a_filter.values) {
std::vector<uint32_t> geo_result_ids;
std::vector<std::string> filter_value_parts;
StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points
bool is_polygon = StringUtils::is_float(filter_value_parts.back());
S2Region* query_region;
if (is_polygon) {
const int num_verts = int(filter_value_parts.size()) / 2;
std::vector<S2Point> vertices;
double sum = 0.0;
for (size_t point_index = 0; point_index < size_t(num_verts);
point_index++) {
double lat = std::stod(filter_value_parts[point_index * 2]);
double lon = std::stod(filter_value_parts[point_index * 2 + 1]);
S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint();
vertices.emplace_back(vertex);
}
auto loop = new S2Loop(vertices, S2Debug::DISABLE);
loop->Normalize(); // if loop is not CCW but CW, change to CCW.
S2Error error;
if (loop->FindValidationError(&error)) {
LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error;
delete loop;
continue;
} else {
query_region = loop;
}
} else {
double radius = std::stof(filter_value_parts[2]);
const auto& unit = filter_value_parts[3];
if (unit == "km") {
radius *= 1000;
} else {
// assume "mi" (validated upstream)
radius *= 1609.34;
}
S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius));
double query_lat = std::stod(filter_value_parts[0]);
double query_lng = std::stod(filter_value_parts[1]);
S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint();
query_region = new S2Cap(center, query_radius);
}
S2RegionTermIndexer::Options options;
options.set_index_contains_points_only(true);
S2RegionTermIndexer indexer(options);
for (const auto& term : indexer.GetQueryTerms(*query_region, "")) {
auto geo_index = geopoint_index.at(a_filter.field_name);
const auto& ids_it = geo_index->find(term);
if(ids_it != geo_index->end()) {
geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end());
}
}
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end());
// `geo_result_ids` will contain all IDs that are within approximately within query radius
// we still need to do another round of exact filtering on them
if (context_ids_length != 0) {
uint32_t *out = nullptr;
uint32_t count = ArrayUtils::and_scalar(context_ids, context_ids_length,
&geo_result_ids[0], geo_result_ids.size(), &out);
geo_result_ids = std::vector<uint32_t>(out, out + count);
}
std::vector<uint32_t> exact_geo_result_ids;
if (f.is_single_geopoint()) {
spp::sparse_hash_map<uint32_t, int64_t>* sort_field_index = sort_index.at(f.name);
for (auto result_id : geo_result_ids) {
// no need to check for existence of `result_id` because of indexer based pre-filtering above
int64_t lat_lng = sort_field_index->at(result_id);
S2LatLng s2_lat_lng;
GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng);
if (query_region->Contains(s2_lat_lng.ToPoint())) {
exact_geo_result_ids.push_back(result_id);
}
}
} else {
spp::sparse_hash_map<uint32_t, int64_t*>* geo_field_index = geo_array_index.at(f.name);
for (auto result_id : geo_result_ids) {
int64_t* lat_lngs = geo_field_index->at(result_id);
bool point_found = false;
// any one point should exist
for (size_t li = 0; li < lat_lngs[0]; li++) {
int64_t lat_lng = lat_lngs[li + 1];
S2LatLng s2_lat_lng;
GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng);
if (query_region->Contains(s2_lat_lng.ToPoint())) {
point_found = true;
break;
}
}
if (point_found) {
exact_geo_result_ids.push_back(result_id);
}
}
}
uint32_t* out = nullptr;
result_ids_len = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(),
result_ids, result_ids_len, &out);
delete[] result_ids;
result_ids = out;
delete query_region;
}
} else if (f.is_string()) {
art_tree* t = search_index.at(a_filter.field_name);
uint32_t* or_ids = nullptr;
size_t or_ids_size = 0;
// aggregates IDs across array of filter values and reduces excessive ORing
std::vector<uint32_t> f_id_buff;
for (const std::string& filter_value : a_filter.values) {
std::vector<void*> posting_lists;
// there could be multiple tokens in a filter value, which we have to treat as ANDs
// e.g. country: South Africa
Tokenizer tokenizer(filter_value, true, false, f.locale, symbols_to_index, token_separators);
std::string str_token;
size_t token_index = 0;
std::vector<std::string> str_tokens;
while (tokenizer.next(str_token, token_index)) {
str_tokens.push_back(str_token);
art_leaf* leaf = (art_leaf *) art_search(t, (const unsigned char*) str_token.c_str(),
str_token.length()+1);
if (leaf == nullptr) {
continue;
}
posting_lists.push_back(leaf->values);
}
if (posting_lists.size() != str_tokens.size()) {
continue;
}
if(a_filter.comparators[0] == EQUALS || a_filter.comparators[0] == NOT_EQUALS) {
// needs intersection + exact matching (unlike CONTAINS)
std::vector<uint32_t> result_id_vec;
posting_t::intersect(posting_lists, result_id_vec, context_ids_length, context_ids);
if (result_id_vec.empty()) {
continue;
}
// need to do exact match
uint32_t* exact_str_ids = new uint32_t[result_id_vec.size()];
size_t exact_str_ids_size = 0;
std::unique_ptr<uint32_t[]> exact_str_ids_guard(exact_str_ids);
posting_t::get_exact_matches(posting_lists, f.is_array(), result_id_vec.data(), result_id_vec.size(),
exact_str_ids, exact_str_ids_size);
if (exact_str_ids_size == 0) {
continue;
}
for (size_t ei = 0; ei < exact_str_ids_size; ei++) {
f_id_buff.push_back(exact_str_ids[ei]);
}
} else {
// CONTAINS
size_t before_size = f_id_buff.size();
posting_t::intersect(posting_lists, f_id_buff, context_ids_length, context_ids);
if (f_id_buff.size() == before_size) {
continue;
}
}
if (f_id_buff.size() > 100000 || a_filter.values.size() == 1) {
gfx::timsort(f_id_buff.begin(), f_id_buff.end());
f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end());
uint32_t* out = nullptr;
or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out);
delete[] or_ids;
or_ids = out;
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
}
}
if (!f_id_buff.empty()) {
gfx::timsort(f_id_buff.begin(), f_id_buff.end());
f_id_buff.erase(std::unique( f_id_buff.begin(), f_id_buff.end() ), f_id_buff.end());
uint32_t* out = nullptr;
or_ids_size = ArrayUtils::or_scalar(or_ids, or_ids_size, f_id_buff.data(), f_id_buff.size(), &out);
delete[] or_ids;
or_ids = out;
std::vector<uint32_t>().swap(f_id_buff); // clears out memory
}
result_ids = or_ids;
result_ids_len = or_ids_size;
}
if (a_filter.apply_not_equals) {
auto all_ids = seq_ids->uncompress();
auto all_ids_size = seq_ids->num_ids();
uint32_t* to_include_ids = nullptr;
size_t to_include_ids_len = 0;
to_include_ids_len = ArrayUtils::exclude_scalar(all_ids, all_ids_size, result_ids,
result_ids_len, &to_include_ids);
delete[] all_ids;
delete[] result_ids;
result_ids = to_include_ids;
result_ids_len = to_include_ids_len;
if (context_ids_length != 0) {
uint32_t *out = nullptr;
result.count = ArrayUtils::and_scalar(context_ids, context_ids_length,
result_ids, result_ids_len, &out);
delete[] result_ids;
result.docs = out;
return Option(true);
}
}
result.docs = result_ids;
result.count = result_ids_len;
return Option(true);
/*long long int timeMillis =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now()
- begin).count();
LOG(INFO) << "Time taken for filtering: " << timeMillis << "ms";*/
}
void Index::aproximate_numerical_match(num_tree_t* const num_tree,
const NUM_COMPARATOR& comparator,
const int64_t& value,
@ -2126,54 +1686,19 @@ Option<bool> Index::rearrange_filter_tree(filter_node_t* const root,
return Option(true);
}
Option<bool> Index::recursive_filter(filter_node_t* const root,
filter_result_t& result,
const std::string& collection_name,
const uint32_t& context_ids_length,
uint32_t* const& context_ids) const {
if (root == nullptr) {
return Option(true);
}
if (root->isOperator) {
filter_result_t l_result;
if (root->left != nullptr) {
auto filter_op = recursive_filter(root->left, l_result , collection_name, context_ids_length, context_ids);
if (!filter_op.ok()) {
return filter_op;
}
}
filter_result_t r_result;
if (root->right != nullptr) {
auto filter_op = recursive_filter(root->right, r_result , collection_name, context_ids_length, context_ids);
if (!filter_op.ok()) {
return filter_op;
}
}
if (root->filter_operator == AND) {
filter_result_t::and_filter_results(l_result, r_result, result);
} else {
filter_result_t::or_filter_results(l_result, r_result, result);
}
return Option(true);
}
return do_filtering(root, result, collection_name, context_ids_length, context_ids);
}
Option<bool> Index::do_filtering_with_lock(filter_node_t* const filter_tree_root,
filter_result_t& filter_result,
const std::string& collection_name) const {
std::shared_lock lock(mutex);
auto filter_op = recursive_filter(filter_tree_root, filter_result, collection_name);
if (!filter_op.ok()) {
return filter_op;
auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
return filter_init_op;
}
filter_result.count = filter_result_iterator.to_filter_id_array(filter_result.docs);
return Option(true);
}
@ -2183,16 +1708,20 @@ Option<bool> Index::do_reference_filtering_with_lock(filter_node_t* const filter
const std::string& reference_helper_field_name) const {
std::shared_lock lock(mutex);
filter_result_t reference_filter_result;
auto filter_op = recursive_filter(filter_tree_root, reference_filter_result);
if (!filter_op.ok()) {
return filter_op;
auto filter_result_iterator = filter_result_iterator_t(collection_name, this, filter_tree_root);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
return filter_init_op;
}
uint32_t* reference_docs = nullptr;
uint32_t count = filter_result_iterator.to_filter_id_array(reference_docs);
std::unique_ptr<uint32_t> docs_guard(reference_docs);
// doc id -> reference doc ids
std::map<uint32_t, std::vector<uint32_t>> reference_map;
for (uint32_t i = 0; i < reference_filter_result.count; i++) {
auto reference_doc_id = reference_filter_result.docs[i];
for (uint32_t i = 0; i < count; i++) {
auto reference_doc_id = reference_docs[i];
auto doc_id = sort_index.at(reference_helper_field_name)->at(reference_doc_id);
reference_map[doc_id].push_back(reference_doc_id);
@ -5079,11 +4608,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
field_values[i] = &seq_id_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
filter_result_t result;
recursive_filter(sort_fields_std[i].eval.filter_tree_root, result);
sort_fields_std[i].eval.ids = result.docs;
sort_fields_std[i].eval.size = result.count;
result.docs = nullptr;
auto filter_result_iterator = filter_result_iterator_t("", this, sort_fields_std[i].eval.filter_tree_root);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
return;
}
sort_fields_std[i].eval.size = filter_result_iterator.to_filter_id_array(sort_fields_std[i].eval.ids);
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
geopoint_indices.push_back(i);