diff --git a/include/field.h b/include/field.h index 1ffe7a86..b98a0309 100644 --- a/include/field.h +++ b/include/field.h @@ -407,6 +407,8 @@ namespace sort_field_const { static const std::string vector_distance = "_vector_distance"; static const std::string vector_query = "_vector_query"; + + static const std::string random_order = "_rand"; } namespace ref_include { @@ -449,6 +451,29 @@ struct sort_vector_query_t { hnsw_index_t* vector_index; }; +struct sort_random_t { + bool is_enabled = false; + mutable std::mt19937 rng; + mutable std::uniform_int_distribution distrib; + + sort_random_t() : distrib(0, UINT32_MAX) {}; + + sort_random_t& operator=(const sort_random_t& other) { + rng = other.rng; + distrib = other.distrib; + is_enabled = other.is_enabled; + } + + void initialize(uint32_t seed) { + rng.seed(seed); + is_enabled = true; + } + + uint32_t generate_random() const { + return distrib(rng); + } +}; + struct sort_by { enum missing_values_t { first, @@ -483,6 +508,8 @@ struct sort_by { std::vector nested_join_collection_names; sort_vector_query_t vector_query; + sort_random_t random_sort; + sort_by(const std::string & name, const std::string & order): name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0), missing_values(normal) { @@ -518,6 +545,7 @@ struct sort_by { reference_collection_name = other.reference_collection_name; nested_join_collection_names = other.nested_join_collection_names; vector_query = other.vector_query; + random_sort = other.random_sort; } sort_by& operator=(const sort_by& other) { diff --git a/src/collection.cpp b/src/collection.cpp index 4a7a134f..570a9fb8 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -1465,7 +1465,21 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } sort_field_std.name = actual_field_name; + } else if(actual_field_name == sort_field_const::random_order) { + const std::string &random_sort_str = sort_field_std.name.substr(paran_start + 1, + sort_field_std.name.size() - + paran_start -2); + uint32_t seed = time(nullptr); + if (!random_sort_str.empty()) { + if(random_sort_str[0] == '-' || !StringUtils::is_uint32_t(random_sort_str)) { + return Option(400, "Only positive integer seed value is allowed."); + } + + seed = static_cast(std::stoul(random_sort_str)); + } + sort_field_std.random_sort.initialize(seed); + sort_field_std.name = actual_field_name; } else { if(field_it == search_schema.end()) { std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting."; @@ -1592,7 +1606,7 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance && - sort_field_std.name != sort_field_const::vector_query) { + sort_field_std.name != sort_field_const::vector_query && sort_field_std.name != sort_field_const::random_order) { const auto field_it = search_schema.find(sort_field_std.name); if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { std::string error = "Could not find a field named `" + sort_field_std.name + diff --git a/src/index.cpp b/src/index.cpp index 6af2dbb6..aea8ab93 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4954,6 +4954,7 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, if (sort_fields.size() > 0) { auto reference_found = true; auto const& is_reference_sort = !sort_fields[0].reference_collection_name.empty(); + auto is_random_sort = sort_fields[0].random_sort.is_enabled; // In case of reference sort_by, we need to get the sort score of the reference doc id. if (is_reference_sort) { @@ -5047,7 +5048,9 @@ Option Index::compute_sort_scores(const std::vector& sort_fields, // do nothing } } else { - if (!is_reference_sort || reference_found) { + if(is_random_sort) { + scores[0] = sort_fields[0].random_sort.generate_random(); + } else if (!is_reference_sort || reference_found) { auto it = field_values[0]->find(is_reference_sort ? ref_seq_id : seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; } else { diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 9bf3e5d2..ef3fa5b9 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -2699,4 +2699,93 @@ TEST_F(CollectionSortingTest, TestVectorQueryDistanceThresholdSorting) { ASSERT_EQ(0.07853113859891891, res["hits"][0]["vector_distance"].get()); ASSERT_EQ("Cell Phone", res["hits"][1]["document"]["product_name"]); ASSERT_EQ(0.08472149819135666, res["hits"][1]["vector_distance"].get()); +} + +TEST_F(CollectionSortingTest, TestSortByRandomOrder) { + auto schema_json = R"({ + "name": "digital_products", + "fields":[ + { + "name": "product_name","type": "string" + }] + })"_json; + + + auto coll_op = collectionManager.create_collection(schema_json); + ASSERT_TRUE(coll_op.ok()); + auto coll = coll_op.get(); + + std::vector products = {"Samsung Smartphone", "Vivo SmartPhone", "Oneplus Smartphone", "Pixel Smartphone", "Moto Smartphone"}; + nlohmann::json doc; + for (auto product: products) { + doc["product_name"] = product; + ASSERT_TRUE(coll->add(doc.dump()).ok()); + } + + sort_fields = { + sort_by("_rand(5)", "asc"), + }; + + auto results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get(); + ASSERT_EQ(5, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"]); + ASSERT_EQ("4", results["hits"][1]["document"]["id"]); + ASSERT_EQ("0", results["hits"][2]["document"]["id"]); + ASSERT_EQ("3", results["hits"][3]["document"]["id"]); + ASSERT_EQ("2", results["hits"][4]["document"]["id"]); + + + + sort_fields = { + sort_by("_rand(8)", "asc"), + }; + + results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get(); + + ASSERT_EQ(5, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"]); + ASSERT_EQ("3", results["hits"][1]["document"]["id"]); + ASSERT_EQ("4", results["hits"][2]["document"]["id"]); + ASSERT_EQ("0", results["hits"][3]["document"]["id"]); + ASSERT_EQ("2", results["hits"][4]["document"]["id"]); + + + //without seed value it takes current time as seed + sort_fields = { + sort_by("_rand()", "asc"), + }; + + results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get(); + ASSERT_EQ(5, results["hits"].size()); + + + //negative seed value is not allowed + sort_fields = { + sort_by("_rand(-1)", "asc"), + }; + + auto results_op = coll->search("*", {}, "", {}, sort_fields, {0}); + ASSERT_EQ("Only positive integer seed value is allowed.", results_op.error()); + + sort_fields = { + sort_by("_rand(sadkjkj)", "asc"), + }; + + results_op = coll->search("*", {}, "", {}, sort_fields, {0}); + ASSERT_EQ("Only positive integer seed value is allowed.", results_op.error()); + + //typos + sort_fields = { + sort_by("rand()", "asc"), + }; + + results_op = coll->search("*", {}, "", {}, sort_fields, {0}); + ASSERT_EQ("Could not find a field named `rand` in the schema for sorting.", results_op.error()); + + sort_fields = { + sort_by("_random()", "asc"), + }; + + results_op = coll->search("*", {}, "", {}, sort_fields, {0}); + ASSERT_EQ("Could not find a field named `_random` in the schema for sorting.", results_op.error()); } \ No newline at end of file