mirror of
https://github.com/typesense/typesense.git
synced 2025-05-26 00:36:22 +08:00
Add logic to skip exact geo filtering beyond threshold.
This commit is contained in:
parent
7fb193402d
commit
baffcf2969
@ -25,16 +25,16 @@ struct filter {
|
||||
bool apply_not_equals = false;
|
||||
|
||||
// Would store `Foo` in case of a filter expression like `$Foo(bar := baz)`
|
||||
std::string referenced_collection_name = "";
|
||||
std::string referenced_collection_name;
|
||||
|
||||
std::vector<nlohmann::json> params;
|
||||
|
||||
/// For searching places within a given radius of a given latlong (mi for miles and km for kilometers)
|
||||
static constexpr const char* GEO_FILTER_RADIUS = "radius";
|
||||
static constexpr const char* GEO_FILTER_RADIUS_KEY = "radius";
|
||||
|
||||
/// Radius threshold beyond which exact filtering on geo_result_ids will not be done.
|
||||
static constexpr const char* EXACT_GEO_FILTER_RADIUS = "exact_filter_radius";
|
||||
static constexpr const char* DEFAULT_EXACT_GEO_FILTER_RADIUS = "10km";
|
||||
static constexpr const char* EXACT_GEO_FILTER_RADIUS_KEY = "exact_filter_radius";
|
||||
static constexpr double DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE = 10000;
|
||||
|
||||
static const std::string RANGE_OPERATOR() {
|
||||
return "..";
|
||||
|
@ -143,6 +143,33 @@ Option<bool> filter::parse_geopoint_filter_value(std::string& raw_value,
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> validate_geofilter_distance(std::string& raw_value, const string& format_err_msg,
|
||||
std::string& distance, std::string& unit) {
|
||||
if (raw_value.size() < 2) {
|
||||
return Option<bool>(400, "Unit must be either `km` or `mi`.");
|
||||
}
|
||||
|
||||
unit = raw_value.substr(raw_value.size() - 2, 2);
|
||||
|
||||
if (unit != "km" && unit != "mi") {
|
||||
return Option<bool>(400, "Unit must be either `km` or `mi`.");
|
||||
}
|
||||
|
||||
std::vector<std::string> dist_values;
|
||||
StringUtils::split(raw_value, dist_values, unit);
|
||||
|
||||
if (dist_values.size() != 1) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
}
|
||||
|
||||
if (!StringUtils::is_float(dist_values[0])) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
}
|
||||
|
||||
distance = std::string(dist_values[0]);
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> filter::parse_geopoint_filter_value(string& raw_value, const string& format_err_msg, filter& filter_exp) {
|
||||
// FORMAT:
|
||||
// [ ([48.853, 2.344], radius: 1km, exact_filter_radius: 100km), ([48.8662, 2.3255, 48.8581, 2.3209, 48.8561, 2.3448, 48.8641, 2.3469]) ]
|
||||
@ -185,19 +212,27 @@ Option<bool> filter::parse_geopoint_filter_value(string& raw_value, const string
|
||||
|
||||
if (value_str.empty() || value_str[0] != '[' || value_str.find(']', 1) == std::string::npos) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
} else {
|
||||
std::vector<std::string> filter_values;
|
||||
StringUtils::split(value_str, filter_values, ",");
|
||||
|
||||
if(filter_values.size() < 3) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
auto points_str = value_str.substr(1, value_str.find(']', 1) - 1);
|
||||
std::vector<std::string> geo_points;
|
||||
StringUtils::split(points_str, geo_points, ",");
|
||||
|
||||
bool is_polygon = value_str.back() == ']';
|
||||
for (const auto& geo_point: geo_points) {
|
||||
if(!StringUtils::is_float(geo_point)) {
|
||||
if (!StringUtils::is_float(geo_point) ||
|
||||
(!is_polygon && (geo_point == "nan" || geo_point == "NaN"))) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_polygon = value_str.back() == ']';
|
||||
if (is_polygon) {
|
||||
polygons.push_back(points_str);
|
||||
continue;
|
||||
@ -229,35 +264,34 @@ Option<bool> filter::parse_geopoint_filter_value(string& raw_value, const string
|
||||
continue;
|
||||
}
|
||||
|
||||
if (key_value[0] == GEO_FILTER_RADIUS) {
|
||||
if (key_value[0] == GEO_FILTER_RADIUS_KEY) {
|
||||
is_radius_present = true;
|
||||
auto& value = key_value[1];
|
||||
|
||||
if(value.size() < 2) {
|
||||
return Option<bool>(400, "Unit must be either `km` or `mi`.");
|
||||
std::string distance, unit;
|
||||
auto validate_op = validate_geofilter_distance(key_value[1], format_err_msg, distance, unit);
|
||||
if (!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
|
||||
std::string unit = value.substr(value.size() - 2, 2);
|
||||
|
||||
if(unit != "km" && unit != "mi") {
|
||||
return Option<bool>(400, "Unit must be either `km` or `mi`.");
|
||||
filter_exp.values.push_back(points_str + ", " + distance + ", " + unit);
|
||||
} else if (key_value[0] == EXACT_GEO_FILTER_RADIUS_KEY) {
|
||||
std::string distance, unit;
|
||||
auto validate_op = validate_geofilter_distance(key_value[1], format_err_msg, distance, unit);
|
||||
if (!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
|
||||
std::vector<std::string> dist_values;
|
||||
StringUtils::split(value, dist_values, unit);
|
||||
double exact_under_radius = std::stof(distance);
|
||||
|
||||
if(dist_values.size() != 1) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
if (unit == "km") {
|
||||
exact_under_radius *= 1000;
|
||||
} else {
|
||||
// assume "mi" (validated upstream)
|
||||
exact_under_radius *= 1609.34;
|
||||
}
|
||||
|
||||
if(!StringUtils::is_float(dist_values[0])) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
}
|
||||
|
||||
filter_exp.values.push_back(points_str + ", " + dist_values[0] + ", " + unit);
|
||||
} else if (key_value[0] == EXACT_GEO_FILTER_RADIUS) {
|
||||
nlohmann::json param;
|
||||
param[EXACT_GEO_FILTER_RADIUS] = key_value[1];
|
||||
param[EXACT_GEO_FILTER_RADIUS_KEY] = exact_under_radius;
|
||||
filter_exp.params.push_back(param);
|
||||
}
|
||||
}
|
||||
@ -265,9 +299,11 @@ Option<bool> filter::parse_geopoint_filter_value(string& raw_value, const string
|
||||
if (!is_radius_present) {
|
||||
return Option<bool>(400, format_err_msg);
|
||||
}
|
||||
if (filter_exp.params.empty()) {
|
||||
|
||||
// EXACT_GEO_FILTER_RADIUS_KEY was not present.
|
||||
if (filter_exp.params.size() < filter_exp.values.size()) {
|
||||
nlohmann::json param;
|
||||
param[EXACT_GEO_FILTER_RADIUS] = DEFAULT_EXACT_GEO_FILTER_RADIUS;
|
||||
param[EXACT_GEO_FILTER_RADIUS_KEY] = DEFAULT_EXACT_GEO_FILTER_RADIUS_VALUE;
|
||||
filter_exp.params.push_back(param);
|
||||
}
|
||||
}
|
||||
|
@ -754,7 +754,9 @@ void filter_result_iterator_t::init() {
|
||||
is_filter_result_initialized = true;
|
||||
return;
|
||||
} else if (f.is_geopoint()) {
|
||||
for (const std::string& filter_value : a_filter.values) {
|
||||
for (uint32_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
|
||||
std::vector<uint32_t> geo_result_ids;
|
||||
|
||||
std::vector<std::string> filter_value_parts;
|
||||
@ -763,6 +765,7 @@ void filter_result_iterator_t::init() {
|
||||
bool is_polygon = StringUtils::is_float(filter_value_parts.back());
|
||||
S2Region* query_region;
|
||||
|
||||
double radius;
|
||||
if (is_polygon) {
|
||||
const int num_verts = int(filter_value_parts.size()) / 2;
|
||||
std::vector<S2Point> vertices;
|
||||
@ -788,7 +791,7 @@ void filter_result_iterator_t::init() {
|
||||
query_region = loop;
|
||||
}
|
||||
} else {
|
||||
double radius = std::stof(filter_value_parts[2]);
|
||||
radius = std::stof(filter_value_parts[2]);
|
||||
const auto& unit = filter_value_parts[3];
|
||||
|
||||
if (unit == "km") {
|
||||
@ -820,6 +823,18 @@ void filter_result_iterator_t::init() {
|
||||
gfx::timsort(geo_result_ids.begin(), geo_result_ids.end());
|
||||
geo_result_ids.erase(std::unique( geo_result_ids.begin(), geo_result_ids.end() ), geo_result_ids.end());
|
||||
|
||||
// Skip exact filtering step if query radius is greater than the threshold.
|
||||
if (!is_polygon && fi < a_filter.params.size() &&
|
||||
radius > a_filter.params[fi][filter::EXACT_GEO_FILTER_RADIUS_KEY].get<double>()) {
|
||||
uint32_t* out = nullptr;
|
||||
filter_result.count = ArrayUtils::or_scalar(geo_result_ids.data(), geo_result_ids.size(),
|
||||
filter_result.docs, filter_result.count, &out);
|
||||
|
||||
delete[] filter_result.docs;
|
||||
filter_result.docs = out;
|
||||
continue;
|
||||
}
|
||||
|
||||
// `geo_result_ids` will contain all IDs that are within approximately within query radius
|
||||
// we still need to do another round of exact filtering on them
|
||||
|
||||
|
@ -1049,7 +1049,7 @@ TEST_F(CollectionFilteringTest, ComparatorsOnMultiValuedNumericalField) {
|
||||
collectionManager.drop_collection("coll_array_fields");
|
||||
}
|
||||
|
||||
TEST_F(CollectionFilteringTest, GeoPointFiltering) {
|
||||
TEST_F(CollectionFilteringTest, GeoPointFilteringV1) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
@ -1192,6 +1192,111 @@ TEST_F(CollectionFilteringTest, GeoPointFiltering) {
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionFilteringTest, GeoPointFilteringV2) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("loc", field_types::GEOPOINT, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> records = {
|
||||
{"Palais Garnier", "48.872576479306765, 2.332291112241466"},
|
||||
{"Sacre Coeur", "48.888286721920934, 2.342340862419206"},
|
||||
{"Arc de Triomphe", "48.87538726829884, 2.296113163780903"},
|
||||
{"Place de la Concorde", "48.86536119187326, 2.321850747347093"},
|
||||
{"Louvre Musuem", "48.86065813197502, 2.3381285349616725"},
|
||||
{"Les Invalides", "48.856648379569904, 2.3118555692631357"},
|
||||
{"Eiffel Tower", "48.85821022164442, 2.294239067890161"},
|
||||
{"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"},
|
||||
{"Musee Grevin", "48.872370541246816, 2.3431536410008906"},
|
||||
{"Pantheon", "48.84620987789056, 2.345152755563131"},
|
||||
};
|
||||
|
||||
for(size_t i=0; i<records.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
|
||||
std::vector<std::string> lat_lng;
|
||||
StringUtils::split(records[i][1], lat_lng, ", ");
|
||||
|
||||
double lat = std::stod(lat_lng[0]);
|
||||
double lng = std::stod(lat_lng[1]);
|
||||
|
||||
doc["id"] = std::to_string(i);
|
||||
doc["title"] = records[i][0];
|
||||
doc["loc"] = {lat, lng};
|
||||
doc["points"] = i;
|
||||
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
// pick a location close to only the Sacre Coeur
|
||||
auto results = coll1->search("*",
|
||||
{}, "loc: ([48.90615915923891, 2.3435897727061175], radius: 3 km)",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(1, results["found"].get<size_t>());
|
||||
ASSERT_EQ(1, results["hits"].size());
|
||||
|
||||
ASSERT_STREQ("1", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
results = coll1->search("*", {}, "loc: [([48.90615, 2.34358], radius: 1 km), ([48.8462, 2.34515], radius: 1 km)]",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(2, results["found"].get<size_t>());
|
||||
|
||||
// pick location close to none of the spots
|
||||
results = coll1->search("*",
|
||||
{}, "loc: ([48.910544830985785, 2.337218333651177], radius: 2 km)",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(0, results["found"].get<size_t>());
|
||||
|
||||
// pick a large radius covering all points
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "loc: ([48.910544830985785, 2.337218333651177], radius: 20 km)",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(10, results["found"].get<size_t>());
|
||||
|
||||
// 1 mile radius
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "loc: ([48.85825332869331, 2.303816427653377], radius: 1 mi)",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(3, results["found"].get<size_t>());
|
||||
|
||||
ASSERT_STREQ("6", results["hits"][0]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("5", results["hits"][1]["document"]["id"].get<std::string>().c_str());
|
||||
ASSERT_STREQ("3", results["hits"][2]["document"]["id"].get<std::string>().c_str());
|
||||
|
||||
// when geo query had NaN
|
||||
auto gop = coll1->search("*", {}, "loc: ([NaN, nan], radius: 1 mi)",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY);
|
||||
|
||||
ASSERT_FALSE(gop.ok());
|
||||
ASSERT_EQ("Value of filter field `loc`: must be in the "
|
||||
"`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or "
|
||||
"([56.33, -65.97, 23.82, -127.82]) format.", gop.error());
|
||||
|
||||
// when geo query does not send radius key
|
||||
gop = coll1->search("*", {}, "loc: ([48.85825332869331, 2.303816427653377])",
|
||||
{}, {}, {0}, 10, 1, FREQUENCY);
|
||||
|
||||
ASSERT_FALSE(gop.ok());
|
||||
ASSERT_EQ("Value of filter field `loc`: must be in the "
|
||||
"`([-44.50, 170.29], radius: 0.75 km, exact_filter_radius: 5 km)` or "
|
||||
"([56.33, -65.97, 23.82, -127.82]) format.", gop.error());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionFilteringTest, GeoPointArrayFiltering) {
|
||||
Collection *coll1;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user