diff --git a/include/string_utils.h b/include/string_utils.h index 62005f7e..1b12da8b 100644 --- a/include/string_utils.h +++ b/include/string_utils.h @@ -151,8 +151,18 @@ struct StringUtils { } char * p ; - strtoull(s.c_str(), &p, 10); - return (*p == 0); + unsigned long long ull = strtoull(s.c_str(), &p, 10); + return (*p == 0) && ull <= std::numeric_limits::max(); + } + + static bool is_uint32_t(const std::string &s) { + if(s.empty()) { + return false; + } + + char * p ; + unsigned long ul = strtoul(s.c_str(), &p, 10); + return (*p == 0) && ul <= std::numeric_limits::max(); } static void toupper(std::string& str) { diff --git a/src/collection.cpp b/src/collection.cpp index bdfd1801..2c12f690 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -394,6 +394,10 @@ Option Collection::search(const std::string & query, const std:: return Option(400, "No search fields specified for the query."); } + if(group_limit == 0 || group_limit >= 100) { + return Option(400, "Value of `group_limit` is invalid."); + } + std::vector excluded_ids; std::map> include_ids; // position => list of IDs populate_overrides(query, pinned_hits, hidden_hits, include_ids, excluded_ids); diff --git a/src/core_api.cpp b/src/core_api.cpp index 7d5294d3..06952a80 100644 --- a/src/core_api.cpp +++ b/src/core_api.cpp @@ -301,42 +301,42 @@ bool get_search(http_req & req, http_res & res) { } } - if(!StringUtils::is_uint64_t(req.params[DROP_TOKENS_THRESHOLD])) { + if(!StringUtils::is_uint32_t(req.params[DROP_TOKENS_THRESHOLD])) { res.set_400("Parameter `" + std::string(DROP_TOKENS_THRESHOLD) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[TYPO_TOKENS_THRESHOLD])) { + if(!StringUtils::is_uint32_t(req.params[TYPO_TOKENS_THRESHOLD])) { res.set_400("Parameter `" + std::string(TYPO_TOKENS_THRESHOLD) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[NUM_TYPOS])) { + if(!StringUtils::is_uint32_t(req.params[NUM_TYPOS])) { res.set_400("Parameter `" + std::string(NUM_TYPOS) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[PER_PAGE])) { + if(!StringUtils::is_uint32_t(req.params[PER_PAGE])) { res.set_400("Parameter `" + std::string(PER_PAGE) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[PAGE])) { + if(!StringUtils::is_uint32_t(req.params[PAGE])) { res.set_400("Parameter `" + std::string(PAGE) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[MAX_FACET_VALUES])) { + if(!StringUtils::is_uint32_t(req.params[MAX_FACET_VALUES])) { res.set_400("Parameter `" + std::string(MAX_FACET_VALUES) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[SNIPPET_THRESHOLD])) { + if(!StringUtils::is_uint32_t(req.params[SNIPPET_THRESHOLD])) { res.set_400("Parameter `" + std::string(SNIPPET_THRESHOLD) + "` must be an unsigned integer."); return false; } - if(!StringUtils::is_uint64_t(req.params[GROUP_LIMIT])) { + if(!StringUtils::is_uint32_t(req.params[GROUP_LIMIT])) { res.set_400("Parameter `" + std::string(GROUP_LIMIT) + "` must be an unsigned integer."); return false; } @@ -441,19 +441,19 @@ bool get_search(http_req & req, http_res & res) { Option result_op = collection->search(req.params[QUERY], search_fields, filter_str, facet_fields, sort_fields, std::stoi(req.params[NUM_TYPOS]), - static_cast(std::stoi(req.params[PER_PAGE])), - static_cast(std::stoi(req.params[PAGE])), + static_cast(std::stol(req.params[PER_PAGE])), + static_cast(std::stol(req.params[PAGE])), token_order, prefix, drop_tokens_threshold, include_fields, exclude_fields, - static_cast(std::stoi(req.params[MAX_FACET_VALUES])), + static_cast(std::stol(req.params[MAX_FACET_VALUES])), req.params[FACET_QUERY], - static_cast(std::stoi(req.params[SNIPPET_THRESHOLD])), + static_cast(std::stol(req.params[SNIPPET_THRESHOLD])), req.params[HIGHLIGHT_FULL_FIELDS], typo_tokens_threshold, pinned_hits, hidden_hits, group_by_fields, - static_cast(std::stoi(req.params[GROUP_LIMIT])) + static_cast(std::stol(req.params[GROUP_LIMIT])) ); uint64_t timeMillis = std::chrono::duration_cast( diff --git a/test/collection_grouping_test.cpp b/test/collection_grouping_test.cpp index 1c565bd5..63118a4e 100644 --- a/test/collection_grouping_test.cpp +++ b/test/collection_grouping_test.cpp @@ -222,6 +222,27 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) { ASSERT_EQ(1, (int) res["facet_counts"][0]["counts"][3]["count"]); ASSERT_STREQ("Zeta", res["facet_counts"][0]["counts"][3]["value"].get().c_str()); + + // respect min and max grouping limit (greater than 0 and less than 99) + auto res_op = coll_group->search("*", {}, "", {"brand"}, {}, 0, 50, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "brand: omeg", 30, + "", 10, + {}, {}, {"rating"}, 100); + + ASSERT_FALSE(res_op.ok()); + ASSERT_STREQ("Value of `group_limit` is invalid.", res_op.error().c_str()); + + res_op = coll_group->search("*", {}, "", {"brand"}, {}, 0, 50, 1, FREQUENCY, + false, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "brand: omeg", 30, + "", 10, + {}, {}, {"rating"}, 0); + + ASSERT_FALSE(res_op.ok()); + ASSERT_STREQ("Value of `group_limit` is invalid.", res_op.error().c_str()); } TEST_F(CollectionGroupingTest, GroupingWithGropLimitOfOne) { diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp index 4f8c2400..a2d64147 100644 --- a/test/string_utils_test.cpp +++ b/test/string_utils_test.cpp @@ -54,3 +54,8 @@ TEST(StringUtilsTest, HMAC) { std::string digest1 = StringUtils::hmac("KeyVal", "{\"filter_by\": \"user_id:1080\"}"); ASSERT_STREQ("IvjqWNZ5M5ElcvbMoXj45BxkQrZG4ZKEaNQoRioCx2s=", digest1.c_str()); } + +TEST(StringUtilsTest, UInt32Validation) { + std::string big_num = "99999999999999999999999999999999"; + ASSERT_FALSE(StringUtils::is_uint32_t(big_num)); +}