diff --git a/include/field.h b/include/field.h index 59522e3a..fc1d2b30 100644 --- a/include/field.h +++ b/include/field.h @@ -476,28 +476,36 @@ namespace sort_field_const { static const std::string order = "order"; static const std::string asc = "ASC"; static const std::string desc = "DESC"; + static const std::string text_match = "_text_match"; static const std::string seq_id = "_seq_id"; + + static const std::string exclude_radius = "exclude_radius"; } struct sort_by { std::string name; std::string order; + + // geo related fields int64_t geopoint; + uint32_t exclude_radius; - sort_by(const std::string & name, const std::string & order): name(name), order(order), geopoint(0) { + sort_by(const std::string & name, const std::string & order): + name(name), order(order), geopoint(0), exclude_radius(0) { } - sort_by(const std::string &name, const std::string &order, int64_t geopoint) : - name(name), order(order), geopoint(geopoint) { + sort_by(const std::string &name, const std::string &order, int64_t geopoint, uint32_t exclude_radius) : + name(name), order(order), geopoint(geopoint), exclude_radius(exclude_radius) { } - sort_by& operator=(sort_by other) { + sort_by& operator=(const sort_by& other) { name = other.name; order = other.order; geopoint = other.geopoint; + exclude_radius = other.exclude_radius; return *this; } }; diff --git a/src/collection.cpp b/src/collection.cpp index 186353ef..a9760d92 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -719,23 +719,58 @@ Option Collection::search(const std::string & query, const std:: } const std::string& geo_coordstr = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2); - std::vector geo_coords; - StringUtils::split(geo_coordstr, geo_coords, ","); - if(geo_coords.size() != 2) { - std::string error = "Geopoint sorting field `" + actual_field_name + - "` must be in the `field(24.56,10.45):ASC` format."; - return Option(404, error); + // e.g. geopoint_field(lat1, lng1, exclude_radius: 10 miles) + + std::vector geo_parts; + StringUtils::split(geo_coordstr, geo_parts, ","); + + std::string error = "Bad syntax for geopoint sorting field `" + actual_field_name + "`"; + + if(geo_parts.size() != 2 && geo_parts.size() != 3) { + return Option(400, error); } - if(!StringUtils::is_float(geo_coords[0]) || !StringUtils::is_float(geo_coords[1])) { - std::string error = "Geopoint sorting field `" + actual_field_name + - "` must be in the `field(24.56,10.45):ASC` format."; - return Option(404, error); + if(!StringUtils::is_float(geo_parts[0]) || !StringUtils::is_float(geo_parts[1])) { + return Option(400, error); } - double lat = std::stod(geo_coords[0]); - double lng = std::stod(geo_coords[1]); + if(geo_parts.size() == 3) { + // try to parse the exclude radius option + if(!StringUtils::begins_with(geo_parts[2], sort_field_const::exclude_radius)) { + return Option(400, error); + } + + std::vector exclude_parts; + StringUtils::split(geo_parts[2], exclude_parts, ":"); + + if(exclude_parts.size() != 2) { + return Option(400, error); + } + + std::vector exclude_value_parts; + StringUtils::split(exclude_parts[1], exclude_value_parts, " "); + + if(exclude_value_parts.size() != 2) { + return Option(400, error); + } + + if(!StringUtils::is_float(exclude_value_parts[0])) { + return Option(400, error); + } + + if(exclude_value_parts[1] == "km") { + sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1000; + } else if(exclude_value_parts[1] == "mi") { + sort_field_std.exclude_radius = std::stof(exclude_value_parts[0]) * 1609.34; + } else { + return Option(400, "Sort field's exclude radius " + "unit must be either `km` or `mi`."); + } + } + + double lat = std::stod(geo_parts[0]); + double lng = std::stod(geo_parts[1]); int64_t lat_lng = GeoPoint::pack_lat_lng(lat, lng); sort_field_std.name = actual_field_name; sort_field_std.geopoint = lat_lng; diff --git a/src/index.cpp b/src/index.cpp index bfeed79c..d123614d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2069,6 +2069,10 @@ void Index::score_results(const std::vector & sort_fields, const uint16 dist = GeoPoint::distance(s2_lat_lng, reference_lat_lng); } + if(dist < sort_fields[i].exclude_radius) { + dist = 0; + } + geopoint_distances[i].emplace(seq_id, dist); } diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index d516dfae..40a71ba7 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -524,7 +524,7 @@ TEST_F(CollectionSortingTest, NegativeInt64Value) { collectionManager.drop_collection("coll1"); } -TEST_F(CollectionSortingTest, GeoPointFiltering) { +TEST_F(CollectionSortingTest, GeoPointSorting) { Collection *coll1; std::vector fields = {field("title", field_types::STRING, false), @@ -537,16 +537,16 @@ TEST_F(CollectionSortingTest, GeoPointFiltering) { } std::vector> records = { - {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, - {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, - {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, - {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, - {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, - {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, - {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, - {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, - {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, - {"Pantheon", "48.84620987789056, 2.345152755563131"}, + {"Palais Garnier", "48.872576479306765, 2.332291112241466"}, + {"Sacre Coeur", "48.888286721920934, 2.342340862419206"}, + {"Arc de Triomphe", "48.87538726829884, 2.296113163780903"}, + {"Place de la Concorde", "48.86536119187326, 2.321850747347093"}, + {"Louvre Musuem", "48.86065813197502, 2.3381285349616725"}, + {"Les Invalides", "48.856648379569904, 2.3118555692631357"}, + {"Eiffel Tower", "48.85821022164442, 2.294239067890161"}, + {"Notre-Dame de Paris", "48.852455825574495, 2.35071182406452"}, + {"Musee Grevin", "48.872370541246816, 2.3431536410008906"}, + {"Pantheon", "48.84620987789056, 2.345152755563131"}, }; for(size_t i=0; i fields = {field("title", field_types::STRING, false), + field("loc", field_types::GEOPOINT, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if (coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"Tibetan Colony", "32.24678, 77.19239"}, + {"Civil Hospital", "32.23959, 77.18763"}, + {"Johnson Lodge", "32.24751, 77.18814"}, + + {"Lion King Rock", "32.24493, 77.17038"}, + {"Jai Durga Handloom", "32.25749, 77.17583"}, + {"Panduropa", "32.26059, 77.21798"}, + }; + + for (size_t i = 0; i < records.size(); i++) { + nlohmann::json doc; + + std::vector lat_lng; + StringUtils::split(records[i][1], lat_lng, ", "); + + double lat = std::stod(lat_lng[0]); + double lng = std::stod(lat_lng[1]); + + doc["id"] = std::to_string(i); + doc["title"] = records[i][0]; + doc["loc"] = {lat, lng}; + doc["points"] = i; + + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + std::vector geo_sort_fields = { + sort_by("loc(32.24348, 77.1893, exclude_radius: 1 km)", "ASC"), + sort_by("points", "DESC"), + }; + + auto results = coll1->search("*", + {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(6, results["found"].get()); + + std::vector expected_ids = { + "2", "1", "0", "3", "4", "5" + }; + + for (size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_STREQ(expected_ids[i].c_str(), results["hits"][i]["document"]["id"].get().c_str()); + } + + // without exclusion filter + + geo_sort_fields = { + sort_by("loc(32.24348, 77.1893)", "ASC"), + sort_by("points", "DESC"), + }; + + results = coll1->search("*", + {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(6, results["found"].get()); + + expected_ids = { + "1", "2", "0", "3", "4", "5" + }; + + for (size_t i = 0; i < expected_ids.size(); i++) { + ASSERT_STREQ(expected_ids[i].c_str(), results["hits"][i]["document"]["id"].get().c_str()); + } + + // badly formatted exclusion filter + + geo_sort_fields = { sort_by("loc(32.24348, 77.1893, exclude_radius 1 km)", "ASC") }; + auto res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Bad syntax for geopoint sorting field `loc`", res_op.error()); + + geo_sort_fields = { sort_by("loc(32.24348, 77.1893, exclude_radius: 1 meter)", "ASC") }; + res_op = coll1->search("*", {}, "loc: (32.24348, 77.1893, 20 km)", + {}, geo_sort_fields, {0}, 10, 1, FREQUENCY); + + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Sort field's exclude radius unit must be either `km` or `mi`.", res_op.error()); + collectionManager.drop_collection("coll1"); } \ No newline at end of file