From 331db4f27e7917dd025e3f89218f3d824b374788 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Tue, 27 Jul 2021 19:57:56 +0530 Subject: [PATCH] Add precision option to geo field sorting. --- include/field.h | 10 +++- src/collection.cpp | 70 ++++++++++++++--------- src/index.cpp | 5 ++ test/collection_sorting_test.cpp | 98 +++++++++++++++++++++++++++++++- 4 files changed, 152 insertions(+), 31 deletions(-) diff --git a/include/field.h b/include/field.h index e7200da8..03d82ec8 100644 --- a/include/field.h +++ b/include/field.h @@ -475,6 +475,7 @@ namespace sort_field_const { static const std::string seq_id = "_seq_id"; static const std::string exclude_radius = "exclude_radius"; + static const std::string precision = "precision"; } struct sort_by { @@ -484,14 +485,16 @@ struct sort_by { // geo related fields int64_t geopoint; uint32_t exclude_radius; + uint32_t geo_precision; sort_by(const std::string & name, const std::string & order): - name(name), order(order), geopoint(0), exclude_radius(0) { + name(name), order(order), geopoint(0), exclude_radius(0), geo_precision(0) { } - sort_by(const std::string &name, const std::string &order, int64_t geopoint, uint32_t exclude_radius) : - name(name), order(order), geopoint(geopoint), exclude_radius(exclude_radius) { + sort_by(const std::string &name, const std::string &order, int64_t geopoint, + uint32_t exclude_radius, uint32_t geo_precision) : + name(name), order(order), geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision) { } @@ -500,6 +503,7 @@ struct sort_by { order = other.order; geopoint = other.geopoint; exclude_radius = other.exclude_radius; + geo_precision = other.geo_precision; return *this; } }; diff --git a/src/collection.cpp b/src/collection.cpp index 3373ab30..44590a26 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -753,36 +753,54 @@ Option Collection::search(const std::string & query, const std:: if(geo_parts.size() == 3) { // try to parse the exclude radius option - if(!StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) { - return Option(400, error); - } + bool is_exclude_option = false; - std::vector exclude_parts; - StringUtils::split(geo_parts[2], exclude_parts, ":"); - - if(exclude_parts.size() != 2) { - return Option(400, error); - } - - std::vector exclude_value_parts; - StringUtils::split(exclude_parts[1], exclude_value_parts, " "); - - if(exclude_value_parts.size() != 2) { - return Option(400, error); - } - - if(!StringUtils::is_float(exclude_value_parts[0])) { - return Option(400, error); - } - - if(exclude_value_parts[1] == "km") { - sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1000; - } else if(exclude_value_parts[1] == "mi") { - sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1609.34; + if(StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) { + is_exclude_option = true; + } else if(StringUtils::begins_with(geo_parts[2], sort_field_const::precision)) { + is_exclude_option = false; } else { - return Option(400, "Sort field's exclude radius " + return Option(400, error); + } + + std::vector param_parts; + StringUtils::split(geo_parts[2], param_parts, ":"); + + if(param_parts.size() != 2) { + return Option(400, error); + } + + std::vector param_value_parts; + StringUtils::split(param_parts[1], param_value_parts, " "); + + if(param_value_parts.size() != 2) { + return Option(400, error); + } + + if(!StringUtils::is_float(param_value_parts[0])) { + return Option(400, error); + } + + int32_t value_meters; + + if(param_value_parts[1] == "km") { + value_meters = std::stof(param_value_parts[0]) * 1000; + } else if(param_value_parts[1] == "mi") { + value_meters = std::stof(param_value_parts[0]) * 1609.34; + } else { + return Option(400, "Sort field's parameter " "unit must be either `km` or `mi`."); } + + if(value_meters <= 0) { + return Option(400, "Sort field's parameter must be a positive number."); + } + + if(is_exclude_option) { + sort_field_std.exclude_radius = value_meters; + } else { + sort_field_std.geo_precision = value_meters; + } } double lat = std::stod(geo_parts[0]); diff --git a/src/index.cpp b/src/index.cpp index 6f6f4b7d..a36a7adc 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2145,6 +2145,11 @@ void Index::score_results(const std::vector & sort_fields, const uint16 dist = 0; } + if(sort_fields[i].geo_precision > 0) { + dist = dist + sort_fields[i].geo_precision - 1 - + (dist + sort_fields[i].geo_precision - 1) % sort_fields[i].geo_precision; + } + geopoint_distances[i].emplace(seq_id, dist); } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index d85686a5..5c2283b0 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -756,7 +756,101 @@ TEST_F(CollectionSortingTest, GeoPointSortingWithExcludeRadius) { {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); ASSERT_FALSE(res_op.ok()); - ASSERT_EQ("Sort field's exclude radius unit must be either `km` or `mi`.", res_op.error()); + ASSERT_EQ("Sort field's parameter unit must be either `km` or `mi`.", res_op.error()); + + geo_sort_fields = { sort_by("loc(32.24348, 77.1893, exclude_radius: -10 km)", "ASC") }; + res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Sort field's parameter must be a positive number.", res_op.error()); collectionManager.drop_collection("coll1"); -} \ No newline at end of file +} + +TEST_F(CollectionSortingTest, GeoPointSortingWithPrecision) { + Collection* coll1; + + std::vector fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if (coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Tibetan Colony", "32.24678, 77.19239"}, + {"Civil Hospital", "32.23959, 77.18763"}, + {"Johnson Lodge", "32.24751, 77.18814"}, + + {"Lion King Rock", "32.24493, 77.17038"}, + {"Jai Durga Handloom", "32.25749, 77.17583"}, + {"Panduropa", "32.26059, 77.21798"}, + + {"Police Station", "32.23743, 77.18639"}, + {"Panduropa Post", "32.26263, 77.2196"}, + }; + + for (size_t i = 0; i < records.size(); i++) { + nlohmann::json doc; + + std::vector lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector geo_sort_fields = { + sort_by("loc(32.24348, 77.1893, precision: 0.9 km)", "ASC"), + sort_by("points", "DESC"), + }; + + auto results = coll1->search("*", + {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(8, results["found"].get()); + + std::vector expected_ids = { + "6", "2", "1", "0", "3", "4", "7", "5" + }; + + for (size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_STREQ(expected_ids[i].c_str(), results["hits"][i]["document"]["id"].get().c_str()); + } + + // badly formatted precision + + geo_sort_fields = { sort_by("loc(32.24348, 77.1893, precision 1 km)", "ASC") }; + auto res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Bad syntax for geopoint sorting field `loc`", res_op.error()); + + geo_sort_fields = { sort_by("loc(32.24348, 77.1893, precision: 1 meter)", "ASC") }; + res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Sort field's parameter unit must be either `km` or `mi`.", res_op.error()); + + geo_sort_fields = { sort_by("loc(32.24348, 77.1893, precision: -10 km)", "ASC") }; + res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Sort field's parameter must be a positive number.", res_op.error()); + + collectionManager.drop_collection("coll1"); +}