diff --git a/include/collection.h b/include/collection.h
index 66bade45..2246a033 100644
--- a/include/collection.h
+++ b/include/collection.h
@@ -373,7 +373,8 @@ public:
const size_t group_limit = 0,
const std::string& highlight_start_tag="",
const std::string& highlight_end_tag="",
- std::vector query_by_weights={});
+ std::vector query_by_weights={},
+ size_t limit_hits=UINT32_MAX);
Option get_filter_ids(const std::string & simple_filter_query,
std::vector>& index_ids);
diff --git a/src/collection.cpp b/src/collection.cpp
index b16a2649..41329ba7 100644
--- a/src/collection.cpp
+++ b/src/collection.cpp
@@ -512,7 +512,8 @@ Option Collection::search(const std::string & query, const std::
const size_t group_limit,
const std::string& highlight_start_tag,
const std::string& highlight_end_tag,
- std::vector query_by_weights) {
+ std::vector query_by_weights,
+ size_t limit_hits) {
if(query != "*" && search_fields.empty()) {
return Option(400, "No search fields specified for the query.");
@@ -743,6 +744,11 @@ Option Collection::search(const std::string & query, const std::
return Option(422, message);
}
+ if((page * per_page) > limit_hits) {
+ std::string message = "Only upto " + std::to_string(limit_hits) + " hits can be fetched.";
+ return Option(422, message);
+ }
+
size_t max_hits;
// ensure that `max_hits` never exceeds number of documents in collection
diff --git a/src/core_api.cpp b/src/core_api.cpp
index ae97670e..04209ef9 100644
--- a/src/core_api.cpp
+++ b/src/core_api.cpp
@@ -277,6 +277,7 @@ bool get_search(http_req & req, http_res & res) {
const char *GROUP_BY = "group_by";
const char *GROUP_LIMIT = "group_limit";
+ const char *LIMIT_HITS = "limit_hits";
const char *PER_PAGE = "per_page";
const char *PAGE = "page";
const char *CALLBACK = "callback";
@@ -328,6 +329,10 @@ bool get_search(http_req & req, http_res & res) {
req.params[FACET_QUERY] = "";
}
+ if(req.params.count(LIMIT_HITS) == 0) {
+ req.params[LIMIT_HITS] = std::to_string(UINT32_MAX);
+ }
+
if(req.params.count(SNIPPET_THRESHOLD) == 0) {
req.params[SNIPPET_THRESHOLD] = "30";
}
@@ -427,6 +432,11 @@ bool get_search(http_req & req, http_res & res) {
return false;
}
+ if(!StringUtils::is_uint32_t(req.params[LIMIT_HITS])) {
+ res.set_400("Parameter `" + std::string(LIMIT_HITS) + "` must be an unsigned integer.");
+ return false;
+ }
+
if(!StringUtils::is_uint32_t(req.params[SNIPPET_THRESHOLD])) {
res.set_400("Parameter `" + std::string(SNIPPET_THRESHOLD) + "` must be an unsigned integer.");
return false;
@@ -531,7 +541,8 @@ bool get_search(http_req & req, http_res & res) {
static_cast(std::stol(req.params[GROUP_LIMIT])),
req.params[HIGHLIGHT_START_TAG],
req.params[HIGHLIGHT_END_TAG],
- query_by_weights
+ query_by_weights,
+ static_cast(std::stol(req.params[LIMIT_HITS]))
);
uint64_t timeMillis = std::chrono::duration_cast(
diff --git a/test/collection_test.cpp b/test/collection_test.cpp
index ff6040d0..c20f1d44 100644
--- a/test/collection_test.cpp
+++ b/test/collection_test.cpp
@@ -2566,6 +2566,16 @@ TEST_F(CollectionTest, WildcardQueryReturnsResultsBasedOnPerPageParam) {
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ(25, results["found"].get());
+
+ // enforce limit_hits
+ res_op = collection->search("*", query_fields, "", facets, sort_fields, 0, 10, 3,
+ FREQUENCY, false, 1000,
+ spp::sparse_hash_set(),
+ spp::sparse_hash_set(), 10, "", 30, 4, "", 40, {}, {}, {}, 0,
+ "", "", {1}, 20);
+
+ ASSERT_FALSE(res_op.ok());
+ ASSERT_STREQ("Only upto 20 hits can be fetched.", res_op.error().c_str());
}
TEST_F(CollectionTest, RemoveIfFound) {