diff --git a/include/collection.h b/include/collection.h index 66bade45..2246a033 100644 --- a/include/collection.h +++ b/include/collection.h @@ -373,7 +373,8 @@ public: const size_t group_limit = 0, const std::string& highlight_start_tag="", const std::string& highlight_end_tag="", - std::vector query_by_weights={}); + std::vector query_by_weights={}, + size_t limit_hits=UINT32_MAX); Option get_filter_ids(const std::string & simple_filter_query, std::vector>& index_ids); diff --git a/src/collection.cpp b/src/collection.cpp index b16a2649..41329ba7 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -512,7 +512,8 @@ Option Collection::search(const std::string & query, const std:: const size_t group_limit, const std::string& highlight_start_tag, const std::string& highlight_end_tag, - std::vector query_by_weights) { + std::vector query_by_weights, + size_t limit_hits) { if(query != "*" && search_fields.empty()) { return Option(400, "No search fields specified for the query."); @@ -743,6 +744,11 @@ Option Collection::search(const std::string & query, const std:: return Option(422, message); } + if((page * per_page) > limit_hits) { + std::string message = "Only upto " + std::to_string(limit_hits) + " hits can be fetched."; + return Option(422, message); + } + size_t max_hits; // ensure that `max_hits` never exceeds number of documents in collection diff --git a/src/core_api.cpp b/src/core_api.cpp index ae97670e..04209ef9 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -277,6 +277,7 @@ bool get_search(http_req & req, http_res & res) { const char *GROUP_BY = "group_by"; const char *GROUP_LIMIT = "group_limit"; + const char *LIMIT_HITS = "limit_hits"; const char *PER_PAGE = "per_page"; const char *PAGE = "page"; const char *CALLBACK = "callback"; @@ -328,6 +329,10 @@ bool get_search(http_req & req, http_res & res) { req.params[FACET_QUERY] = ""; } + if(req.params.count(LIMIT_HITS) == 0) { + req.params[LIMIT_HITS] = std::to_string(UINT32_MAX); + } + if(req.params.count(SNIPPET_THRESHOLD) == 0) { req.params[SNIPPET_THRESHOLD] = "30"; } @@ -427,6 +432,11 @@ bool get_search(http_req & req, http_res & res) { return false; } + if(!StringUtils::is_uint32_t(req.params[LIMIT_HITS])) { + res.set_400("Parameter `" + std::string(LIMIT_HITS) + "` must be an unsigned integer."); + return false; + } + if(!StringUtils::is_uint32_t(req.params[SNIPPET_THRESHOLD])) { res.set_400("Parameter `" + std::string(SNIPPET_THRESHOLD) + "` must be an unsigned integer."); return false; @@ -531,7 +541,8 @@ bool get_search(http_req & req, http_res & res) { static_cast(std::stol(req.params[GROUP_LIMIT])), req.params[HIGHLIGHT_START_TAG], req.params[HIGHLIGHT_END_TAG], - query_by_weights + query_by_weights, + static_cast(std::stol(req.params[LIMIT_HITS])) ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/test/collection_test.cpp b/test/collection_test.cpp index ff6040d0..c20f1d44 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -2566,6 +2566,16 @@ TEST_F(CollectionTest, WildcardQueryReturnsResultsBasedOnPerPageParam) { ASSERT_EQ(5, results["hits"].size()); ASSERT_EQ(25, results["found"].get()); + + // enforce limit_hits + res_op = collection->search("*", query_fields, "", facets, sort_fields, 0, 10, 3, + FREQUENCY, false, 1000, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 40, {}, {}, {}, 0, + "", "", {1}, 20); + + ASSERT_FALSE(res_op.ok()); + ASSERT_STREQ("Only upto 20 hits can be fetched.", res_op.error().c_str()); } TEST_F(CollectionTest, RemoveIfFound) {