Add precision option to geo field sorting.

This commit is contained in:
Kishore Nallan 2021-07-27 19:57:56 +05:30
parent 13cb7b9364
commit 331db4f27e
4 changed files with 152 additions and 31 deletions

View File

@ -475,6 +475,7 @@ namespace sort_field_const {
static const std::string seq_id = "_seq_id";
static const std::string exclude_radius = "exclude_radius";
static const std::string precision = "precision";
}
struct sort_by {
@ -484,14 +485,16 @@ struct sort_by {
// geo related fields
int64_t geopoint;
uint32_t exclude_radius;
uint32_t geo_precision;
sort_by(const std::string & name, const std::string & order):
name(name), order(order), geopoint(0), exclude_radius(0) {
name(name), order(order), geopoint(0), exclude_radius(0), geo_precision(0) {
}
sort_by(const std::string &name, const std::string &order, int64_t geopoint, uint32_t exclude_radius) :
name(name), order(order), geopoint(geopoint), exclude_radius(exclude_radius) {
sort_by(const std::string &name, const std::string &order, int64_t geopoint,
uint32_t exclude_radius, uint32_t geo_precision) :
name(name), order(order), geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision) {
}
@ -500,6 +503,7 @@ struct sort_by {
order = other.order;
geopoint = other.geopoint;
exclude_radius = other.exclude_radius;
geo_precision = other.geo_precision;
return *this;
}
};

View File

@ -753,36 +753,54 @@ Option<nlohmann::json> Collection::search(const std::string & query, const std::
if(geo_parts.size() == 3) {
// try to parse the exclude radius option
if(!StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) {
return Option<nlohmann::json>(400, error);
}
bool is_exclude_option = false;
std::vector<std::string> exclude_parts;
StringUtils::split(geo_parts[2], exclude_parts, ":");
if(exclude_parts.size() != 2) {
return Option<nlohmann::json>(400, error);
}
std::vector<std::string> exclude_value_parts;
StringUtils::split(exclude_parts[1], exclude_value_parts, " ");
if(exclude_value_parts.size() != 2) {
return Option<nlohmann::json>(400, error);
}
if(!StringUtils::is_float(exclude_value_parts[0])) {
return Option<nlohmann::json>(400, error);
}
if(exclude_value_parts[1] == "km") {
sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1000;
} else if(exclude_value_parts[1] == "mi") {
sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1609.34;
if(StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) {
is_exclude_option = true;
} else if(StringUtils::begins_with(geo_parts[2], sort_field_const::precision)) {
is_exclude_option = false;
} else {
return Option<nlohmann::json>(400, "Sort field's exclude radius "
return Option<nlohmann::json>(400, error);
}
std::vector<std::string> param_parts;
StringUtils::split(geo_parts[2], param_parts, ":");
if(param_parts.size() != 2) {
return Option<nlohmann::json>(400, error);
}
std::vector<std::string> param_value_parts;
StringUtils::split(param_parts[1], param_value_parts, " ");
if(param_value_parts.size() != 2) {
return Option<nlohmann::json>(400, error);
}
if(!StringUtils::is_float(param_value_parts[0])) {
return Option<nlohmann::json>(400, error);
}
int32_t value_meters;
if(param_value_parts[1] == "km") {
value_meters = std::stof(param_value_parts[0]) * 1000;
} else if(param_value_parts[1] == "mi") {
value_meters = std::stof(param_value_parts[0]) * 1609.34;
} else {
return Option<nlohmann::json>(400, "Sort field's parameter "
"unit must be either `km` or `mi`.");
}
if(value_meters <= 0) {
return Option<nlohmann::json>(400, "Sort field's parameter must be a positive number.");
}
if(is_exclude_option) {
sort_field_std.exclude_radius = value_meters;
} else {
sort_field_std.geo_precision = value_meters;
}
}
double lat = std::stod(geo_parts[0]);

View File

@ -2145,6 +2145,11 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
dist = 0;
}
if(sort_fields[i].geo_precision > 0) {
dist = dist + sort_fields[i].geo_precision - 1 -
(dist + sort_fields[i].geo_precision - 1) % sort_fields[i].geo_precision;
}
geopoint_distances[i].emplace(seq_id, dist);
}

View File

@ -756,7 +756,101 @@ TEST_F(CollectionSortingTest, GeoPointSortingWithExcludeRadius) {
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Sort field's exclude radius unit must be either `km` or `mi`.", res_op.error());
ASSERT_EQ("Sort field's parameter unit must be either `km` or `mi`.", res_op.error());
geo_sort_fields = { sort_by("loc(32.24348, 77.1893, exclude_radius: -10 km)", "ASC") };
res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)",
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Sort field's parameter must be a positive number.", res_op.error());
collectionManager.drop_collection("coll1");
}
}
TEST_F(CollectionSortingTest, GeoPointSortingWithPrecision) {
Collection* coll1;
std::vector<field> fields = {field("title", field_types::STRING, false),
field("loc", field_types::GEOPOINT, false),
field("points", field_types::INT32, false),};
coll1 = collectionManager.get_collection("coll1").get();
if (coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
std::vector<std::vector<std::string>> records = {
{"Tibetan Colony", "32.24678, 77.19239"},
{"Civil Hospital", "32.23959, 77.18763"},
{"Johnson Lodge", "32.24751, 77.18814"},
{"Lion King Rock", "32.24493, 77.17038"},
{"Jai Durga Handloom", "32.25749, 77.17583"},
{"Panduropa", "32.26059, 77.21798"},
{"Police Station", "32.23743, 77.18639"},
{"Panduropa Post", "32.26263, 77.2196"},
};
for (size_t i = 0; i < records.size(); i++) {
nlohmann::json doc;
std::vector<std::string> lat_lng;
StringUtils::split(records[i][1], lat_lng, ", ");
double lat = std::stod(lat_lng[0]);
double lng = std::stod(lat_lng[1]);
doc["id"] = std::to_string(i);
doc["title"] = records[i][0];
doc["loc"] = {lat, lng};
doc["points"] = i;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
std::vector<sort_by> geo_sort_fields = {
sort_by("loc(32.24348, 77.1893, precision: 0.9 km)", "ASC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("*",
{}, "loc: (32.24348, 77.1893, 20 km)",
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY).get();
ASSERT_EQ(8, results["found"].get<size_t>());
std::vector<std::string> expected_ids = {
"6", "2", "1", "0", "3", "4", "7", "5"
};
for (size_t i = 0; i < expected_ids.size(); i++) {
ASSERT_STREQ(expected_ids[i].c_str(), results["hits"][i]["document"]["id"].get<std::string>().c_str());
}
// badly formatted precision
geo_sort_fields = { sort_by("loc(32.24348, 77.1893, precision 1 km)", "ASC") };
auto res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)",
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Bad syntax for geopoint sorting field `loc`", res_op.error());
geo_sort_fields = { sort_by("loc(32.24348, 77.1893, precision: 1 meter)", "ASC") };
res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)",
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Sort field's parameter unit must be either `km` or `mi`.", res_op.error());
geo_sort_fields = { sort_by("loc(32.24348, 77.1893, precision: -10 km)", "ASC") };
res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)",
{}, geo_sort_fields, {0}, 10, 1, FREQUENCY);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Sort field's parameter must be a positive number.", res_op.error());
collectionManager.drop_collection("coll1");
}