Support geo filtering.

This commit is contained in:
Harpreet Sangar 2023-04-21 12:48:46 +05:30
parent 6db0b108a1
commit 1eafc22348

View File

@ -1,5 +1,13 @@
#include <queue>
#include <id_list.h>
#include <s2/s2point.h>
#include <s2/s2latlng.h>
#include <s2/s2region_term_indexer.h>
#include <s2/s2cap.h>
#include <s2/s2earth.h>
#include <s2/s2loop.h>
#include <s2/s2builder.h>
#include <timsort.hpp>
#include "filter_result_iterator.h"
#include "index.h"
#include "posting.h"
@ -306,8 +314,12 @@ void filter_result_iterator_t::doc_matching_string_filter(bool field_is_array) {
break;
} else {
// Keep advancing token iterators till exact match is not found.
for (auto &item: filter_value_tokens) {
item.next();
for (auto &iter: filter_value_tokens) {
if (!iter.valid()) {
break;
}
iter.next();
}
}
}
@ -661,6 +673,133 @@ void filter_result_iterator_t::init() {
return;
}
seq_id = filter_result.docs[result_index];
is_filter_result_initialized = true;
return;
} else if (f.is_geopoint()) {
for (const std::string& filter_value : a_filter.values) {
std::vector<uint32_t> geo_result_ids;
std::vector<std::string> filter_value_parts;
StringUtils::split(filter_value, filter_value_parts, ","); // x, y, 2, km (or) list of points
bool is_polygon = StringUtils::is_float(filter_value_parts.back());
S2Region* query_region;
if (is_polygon) {
const int num_verts = int(filter_value_parts.size()) / 2;
std::vector<S2Point> vertices;
double sum = 0.0;
for (size_t point_index = 0; point_index < size_t(num_verts);
point_index++) {
double lat = std::stod(filter_value_parts[point_index * 2]);
double lon = std::stod(filter_value_parts[point_index * 2 + 1]);
S2Point vertex = S2LatLng::FromDegrees(lat, lon).ToPoint();
vertices.emplace_back(vertex);
}
auto loop = new S2Loop(vertices, S2Debug::DISABLE);
loop->Normalize(); // if loop is not CCW but CW, change to CCW.
S2Error error;
if (loop->FindValidationError(&error)) {
LOG(ERROR) << "Query vertex is bad, skipping. Error: " << error;
delete loop;
continue;
} else {
query_region = loop;
}
} else {
double radius = std::stof(filter_value_parts[2]);
const auto& unit = filter_value_parts[3];
if (unit == "km") {
radius *= 1000;
} else {
// assume "mi" (validated upstream)
radius *= 1609.34;
}
S1Angle query_radius = S1Angle::Radians(S2Earth::MetersToRadians(radius));
double query_lat = std::stod(filter_value_parts[0]);
double query_lng = std::stod(filter_value_parts[1]);
S2Point center = S2LatLng::FromDegrees(query_lat, query_lng).ToPoint();
query_region = new S2Cap(center, query_radius);
}
S2RegionTermIndexer::Options options;
options.set_index_contains_points_only(true);
S2RegionTermIndexer indexer(options);
for (const auto& term : indexer.GetQueryTerms(*query_region, "")) {
auto geo_index = index->geopoint_index.at(a_filter.field_name);
const auto& ids_it = geo_index->find(term);
if(ids_it != geo_index->end()) {
geo_result_ids.insert(geo_result_ids.end(), ids_it->second.begin(), ids_it->second.end());
}
}
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end());
// `geo_result_ids` will contain all IDs that are within approximately within query radius
// we still need to do another round of exact filtering on them
std::vector<uint32_t> exact_geo_result_ids;
if (f.is_single_geopoint()) {
spp::sparse_hash_map<uint32_t, int64_t>* sort_field_index = index->sort_index.at(f.name);
for (auto result_id : geo_result_ids) {
// no need to check for existence of `result_id` because of indexer based pre-filtering above
int64_t lat_lng = sort_field_index->at(result_id);
S2LatLng s2_lat_lng;
GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng);
if (query_region->Contains(s2_lat_lng.ToPoint())) {
exact_geo_result_ids.push_back(result_id);
}
}
} else {
spp::sparse_hash_map<uint32_t, int64_t*>* geo_field_index = index->geo_array_index.at(f.name);
for (auto result_id : geo_result_ids) {
int64_t* lat_lngs = geo_field_index->at(result_id);
bool point_found = false;
// any one point should exist
for (size_t li = 0; li < lat_lngs[0]; li++) {
int64_t lat_lng = lat_lngs[li + 1];
S2LatLng s2_lat_lng;
GeoPoint::unpack_lat_lng(lat_lng, s2_lat_lng);
if (query_region->Contains(s2_lat_lng.ToPoint())) {
point_found = true;
break;
}
}
if (point_found) {
exact_geo_result_ids.push_back(result_id);
}
}
}
uint32_t* out = nullptr;
filter_result.count = ArrayUtils::or_scalar(&exact_geo_result_ids[0], exact_geo_result_ids.size(),
filter_result.docs, filter_result.count, &out);
delete[] filter_result.docs;
filter_result.docs = out;
delete query_region;
}
if (filter_result.count == 0) {
is_valid = false;
return;
}
seq_id = filter_result.docs[result_index];
is_filter_result_initialized = true;
return;