mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 05:08:43 +08:00
Implement exclude_radius option for geopoint sorting.
This commit is contained in:
parent
da3de68129
commit
7c4aff5268
@ -476,28 +476,36 @@ namespace sort_field_const {
|
||||
static const std::string order = "order";
|
||||
static const std::string asc = "ASC";
|
||||
static const std::string desc = "DESC";
|
||||
|
||||
static const std::string text_match = "_text_match";
|
||||
static const std::string seq_id = "_seq_id";
|
||||
|
||||
static const std::string exclude_radius = "exclude_radius";
|
||||
}
|
||||
|
||||
struct sort_by {
|
||||
std::string name;
|
||||
std::string order;
|
||||
|
||||
// geo related fields
|
||||
int64_t geopoint;
|
||||
uint32_t exclude_radius;
|
||||
|
||||
sort_by(const std::string & name, const std::string & order): name(name), order(order), geopoint(0) {
|
||||
sort_by(const std::string & name, const std::string & order):
|
||||
name(name), order(order), geopoint(0), exclude_radius(0) {
|
||||
|
||||
}
|
||||
|
||||
sort_by(const std::string &name, const std::string &order, int64_t geopoint) :
|
||||
name(name), order(order), geopoint(geopoint) {
|
||||
sort_by(const std::string &name, const std::string &order, int64_t geopoint, uint32_t exclude_radius) :
|
||||
name(name), order(order), geopoint(geopoint), exclude_radius(exclude_radius) {
|
||||
|
||||
}
|
||||
|
||||
sort_by& operator=(sort_by other) {
|
||||
sort_by& operator=(const sort_by& other) {
|
||||
name = other.name;
|
||||
order = other.order;
|
||||
geopoint = other.geopoint;
|
||||
exclude_radius = other.exclude_radius;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
@ -719,23 +719,58 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
|
||||
}
|
||||
|
||||
const std::string& geo_coordstr = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2);
|
||||
std::vector<std::string> geo_coords;
|
||||
StringUtils::split(geo_coordstr, geo_coords, ",");
|
||||
|
||||
if(geo_coords.size() != 2) {
|
||||
std::string error = "Geopoint sorting field `" + actual_field_name +
|
||||
"` must be in the `field(24.56,10.45):ASC` format.";
|
||||
return Option<nlohmann::json>(404, error);
|
||||
// e.g. geopoint_field(lat1, lng1, exclude_radius: 10 miles)
|
||||
|
||||
std::vector<std::string> geo_parts;
|
||||
StringUtils::split(geo_coordstr, geo_parts, ",");
|
||||
|
||||
std::string error = "Bad syntax for geopoint sorting field `" + actual_field_name + "`";
|
||||
|
||||
if(geo_parts.size() != 2 && geo_parts.size() != 3) {
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
if(!StringUtils::is_float(geo_coords[0]) || !StringUtils::is_float(geo_coords[1])) {
|
||||
std::string error = "Geopoint sorting field `" + actual_field_name +
|
||||
"` must be in the `field(24.56,10.45):ASC` format.";
|
||||
return Option<nlohmann::json>(404, error);
|
||||
if(!StringUtils::is_float(geo_parts[0]) || !StringUtils::is_float(geo_parts[1])) {
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
double lat = std::stod(geo_coords[0]);
|
||||
double lng = std::stod(geo_coords[1]);
|
||||
if(geo_parts.size() == 3) {
|
||||
// try to parse the exclude radius option
|
||||
if(!StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) {
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
std::vector<std::string> exclude_parts;
|
||||
StringUtils::split(geo_parts[2], exclude_parts, ":");
|
||||
|
||||
if(exclude_parts.size() != 2) {
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
std::vector<std::string> exclude_value_parts;
|
||||
StringUtils::split(exclude_parts[1], exclude_value_parts, " ");
|
||||
|
||||
if(exclude_value_parts.size() != 2) {
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
if(!StringUtils::is_float(exclude_value_parts[0])) {
|
||||
return Option<nlohmann::json>(400, error);
|
||||
}
|
||||
|
||||
if(exclude_value_parts[1] == "km") {
|
||||
sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1000;
|
||||
} else if(exclude_value_parts[1] == "mi") {
|
||||
sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1609.34;
|
||||
} else {
|
||||
return Option<nlohmann::json>(400, "Sort field's exclude radius "
|
||||
"unit must be either `km` or `mi`.");
|
||||
}
|
||||
}
|
||||
|
||||
double lat = std::stod(geo_parts[0]);
|
||||
double lng = std::stod(geo_parts[1]);
|
||||
int64_t lat_lng = GeoPoint::pack_lat_lng(lat, lng);
|
||||
sort_field_std.name = actual_field_name;
|
||||
sort_field_std.geopoint = lat_lng;
|
||||
|
@ -2069,6 +2069,10 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
dist = GeoPoint::distance(s2_lat_lng, reference_lat_lng);
|
||||
}
|
||||
|
||||
if(dist < sort_fields[i].exclude_radius) {
|
||||
dist = 0;
|
||||
}
|
||||
|
||||
geopoint_distances[i].emplace(seq_id, dist);
|
||||
}
|
||||
|
||||
|
@ -524,7 +524,7 @@ TEST_F(CollectionSortingTest, NegativeInt64Value) {
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, GeoPointFiltering) {
|
||||
TEST_F(CollectionSortingTest, GeoPointSorting) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
@ -537,16 +537,16 @@ TEST_F(CollectionSortingTest, GeoPointFiltering) {
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> records = {
|
||||
{"Palais Garnier", "48.872576479306765, 2.332291112241466"},
|
||||
{"Sacre Coeur", "48.888286721920934, 2.342340862419206"},
|
||||
{"Arc de Triomphe", "48.87538726829884, 2.296113163780903"},
|
||||
{"Place de la Concorde", "48.86536119187326, 2.321850747347093"},
|
||||
{"Louvre Musuem", "48.86065813197502, 2.3381285349616725"},
|
||||
{"Les Invalides", "48.856648379569904, 2.3118555692631357"},
|
||||
{"Eiffel Tower", "48.85821022164442, 2.294239067890161"},
|
||||
{"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"},
|
||||
{"Musee Grevin", "48.872370541246816, 2.3431536410008906"},
|
||||
{"Pantheon", "48.84620987789056, 2.345152755563131"},
|
||||
{"Palais Garnier", "48.872576479306765, 2.332291112241466"},
|
||||
{"Sacre Coeur", "48.888286721920934, 2.342340862419206"},
|
||||
{"Arc de Triomphe", "48.87538726829884, 2.296113163780903"},
|
||||
{"Place de la Concorde", "48.86536119187326, 2.321850747347093"},
|
||||
{"Louvre Musuem", "48.86065813197502, 2.3381285349616725"},
|
||||
{"Les Invalides", "48.856648379569904, 2.3118555692631357"},
|
||||
{"Eiffel Tower", "48.85821022164442, 2.294239067890161"},
|
||||
{"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"},
|
||||
{"Musee Grevin", "48.872370541246816, 2.3431536410008906"},
|
||||
{"Pantheon", "48.84620987789056, 2.345152755563131"},
|
||||
};
|
||||
|
||||
for(size_t i=0; i<records.size(); i++) {
|
||||
@ -610,7 +610,7 @@ TEST_F(CollectionSortingTest, GeoPointFiltering) {
|
||||
{}, bad_geo_sort_fields, {0}, 10, 1, FREQUENCY);
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_STREQ("Geopoint sorting field `loc` must be in the `field(24.56,10.45):ASC` format.", res_op.error().c_str());
|
||||
ASSERT_STREQ("Bad syntax for geopoint sorting field `loc`", res_op.error().c_str());
|
||||
|
||||
bad_geo_sort_fields = {
|
||||
sort_by("loc(x, y)", "ASC")
|
||||
@ -621,7 +621,7 @@ TEST_F(CollectionSortingTest, GeoPointFiltering) {
|
||||
{}, bad_geo_sort_fields, {0}, 10, 1, FREQUENCY);
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_STREQ("Geopoint sorting field `loc` must be in the `field(24.56,10.45):ASC` format.", res_op.error().c_str());
|
||||
ASSERT_STREQ("Bad syntax for geopoint sorting field `loc`", res_op.error().c_str());
|
||||
|
||||
bad_geo_sort_fields = {
|
||||
sort_by("loc(", "ASC")
|
||||
@ -656,5 +656,103 @@ TEST_F(CollectionSortingTest, GeoPointFiltering) {
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_STREQ("Could not find a field named `l` in the schema for sorting.", res_op.error().c_str());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, GeoPointSortingWithExcludeRadius) {
|
||||
Collection* coll1;
|
||||
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("loc", field_types::GEOPOINT, false),
|
||||
field("points", field_types::INT32, false),};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if (coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> records = {
|
||||
{"Tibetan Colony", "32.24678, 77.19239"},
|
||||
{"Civil Hospital", "32.23959, 77.18763"},
|
||||
{"Johnson Lodge", "32.24751, 77.18814"},
|
||||
|
||||
{"Lion King Rock", "32.24493, 77.17038"},
|
||||
{"Jai Durga Handloom", "32.25749, 77.17583"},
|
||||
{"Panduropa", "32.26059, 77.21798"},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < records.size(); i++) {
|
||||
nlohmann::json doc;
|
||||
|
||||
std::vector<std::string> lat_lng;
|
||||
StringUtils::split(records[i][1], lat_lng, ", ");
|
||||
|
||||
double lat = std::stod(lat_lng[0]);
|
||||
double lng = std::stod(lat_lng[1]);
|
||||
|
||||
doc["id"] = std::to_string(i);
|
||||
doc["title"] = records[i][0];
|
||||
doc["loc"] = {lat, lng};
|
||||
doc["points"] = i;
|
||||
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
std::vector<sort_by> geo_sort_fields = {
|
||||
sort_by("loc(32.24348, 77.1893, exclude_radius: 1 km)", "ASC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("*",
|
||||
{}, "loc: (32.24348, 77.1893, 20 km)",
|
||||
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(6, results["found"].get<size_t>());
|
||||
|
||||
std::vector<std::string> expected_ids = {
|
||||
"2", "1", "0", "3", "4", "5"
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < expected_ids.size(); i++) {
|
||||
ASSERT_STREQ(expected_ids[i].c_str(), results["hits"][i]["document"]["id"].get<std::string>().c_str());
|
||||
}
|
||||
|
||||
// without exclusion filter
|
||||
|
||||
geo_sort_fields = {
|
||||
sort_by("loc(32.24348, 77.1893)", "ASC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
results = coll1->search("*",
|
||||
{}, "loc: (32.24348, 77.1893, 20 km)",
|
||||
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY).get();
|
||||
|
||||
ASSERT_EQ(6, results["found"].get<size_t>());
|
||||
|
||||
expected_ids = {
|
||||
"1", "2", "0", "3", "4", "5"
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < expected_ids.size(); i++) {
|
||||
ASSERT_STREQ(expected_ids[i].c_str(), results["hits"][i]["document"]["id"].get<std::string>().c_str());
|
||||
}
|
||||
|
||||
// badly formatted exclusion filter
|
||||
|
||||
geo_sort_fields = { sort_by("loc(32.24348, 77.1893, exclude_radius 1 km)", "ASC") };
|
||||
auto res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)",
|
||||
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Bad syntax for geopoint sorting field `loc`", res_op.error());
|
||||
|
||||
geo_sort_fields = { sort_by("loc(32.24348, 77.1893, exclude_radius: 1 meter)", "ASC") };
|
||||
res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)",
|
||||
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
|
||||
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Sort field's exclude radius unit must be either `km` or `mi`.", res_op.error());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user