support random sort (#1918)

* support random sort

* add empty sort_fields check

* generate random score instead of shuffling

* refactor sort_random_t

* add seed value check
This commit is contained in:
Krunal Gandhi 2024-08-29 12:50:33 +00:00 committed by GitHub
parent e62e7e8316
commit 3a24ba9f84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 2 deletions

View File

@ -407,6 +407,8 @@ namespace sort_field_const {
static const std::string vector_distance = "_vector_distance";
static const std::string vector_query = "_vector_query";
static const std::string random_order = "_rand";
}
namespace ref_include {
@ -449,6 +451,29 @@ struct sort_vector_query_t {
hnsw_index_t* vector_index;
};
struct sort_random_t {
bool is_enabled = false;
mutable std::mt19937 rng;
mutable std::uniform_int_distribution<uint32_t> distrib;
sort_random_t() : distrib(0, UINT32_MAX) {};
sort_random_t& operator=(const sort_random_t& other) {
rng = other.rng;
distrib = other.distrib;
is_enabled = other.is_enabled;
}
void initialize(uint32_t seed) {
rng.seed(seed);
is_enabled = true;
}
uint32_t generate_random() const {
return distrib(rng);
}
};
struct sort_by {
enum missing_values_t {
first,
@ -483,6 +508,8 @@ struct sort_by {
std::vector<std::string> nested_join_collection_names;
sort_vector_query_t vector_query;
sort_random_t random_sort;
sort_by(const std::string & name, const std::string & order):
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),
missing_values(normal) {
@ -518,6 +545,7 @@ struct sort_by {
reference_collection_name = other.reference_collection_name;
nested_join_collection_names = other.nested_join_collection_names;
vector_query = other.vector_query;
random_sort = other.random_sort;
}
sort_by& operator=(const sort_by& other) {

View File

@ -1465,7 +1465,21 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
}
sort_field_std.name = actual_field_name;
} else if(actual_field_name == sort_field_const::random_order) {
const std::string &random_sort_str = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() -
paran_start -2);
uint32_t seed = time(nullptr);
if (!random_sort_str.empty()) {
if(random_sort_str[0] == '-' || !StringUtils::is_uint32_t(random_sort_str)) {
return Option<bool>(400, "Only positive integer seed value is allowed.");
}
seed = static_cast<uint32_t>(std::stoul(random_sort_str));
}
sort_field_std.random_sort.initialize(seed);
sort_field_std.name = actual_field_name;
} else {
if(field_it == search_schema.end()) {
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
@ -1592,7 +1606,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance &&
sort_field_std.name != sort_field_const::vector_query) {
sort_field_std.name != sort_field_const::vector_query && sort_field_std.name != sort_field_const::random_order) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
std::string error = "Could not find a field named `" + sort_field_std.name +

View File

@ -4954,6 +4954,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
if (sort_fields.size() > 0) {
auto reference_found = true;
auto const& is_reference_sort = !sort_fields[0].reference_collection_name.empty();
auto is_random_sort = sort_fields[0].random_sort.is_enabled;
// In case of reference sort_by, we need to get the sort score of the reference doc id.
if (is_reference_sort) {
@ -5047,7 +5048,9 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// do nothing
}
} else {
if (!is_reference_sort || reference_found) {
if(is_random_sort) {
scores[0] = sort_fields[0].random_sort.generate_random();
} else if (!is_reference_sort || reference_found) {
auto it = field_values[0]->find(is_reference_sort ? ref_seq_id : seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
} else {

View File

@ -2699,4 +2699,93 @@ TEST_F(CollectionSortingTest, TestVectorQueryDistanceThresholdSorting) {
ASSERT_EQ(0.07853113859891891, res["hits"][0]["vector_distance"].get<float>());
ASSERT_EQ("Cell Phone", res["hits"][1]["document"]["product_name"]);
ASSERT_EQ(0.08472149819135666, res["hits"][1]["vector_distance"].get<float>());
}
TEST_F(CollectionSortingTest, TestSortByRandomOrder) {
auto schema_json = R"({
"name": "digital_products",
"fields":[
{
"name": "product_name","type": "string"
}]
})"_json;
auto coll_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(coll_op.ok());
auto coll = coll_op.get();
std::vector<std::string> products = {"Samsung Smartphone", "Vivo SmartPhone", "Oneplus Smartphone", "Pixel Smartphone", "Moto Smartphone"};
nlohmann::json doc;
for (auto product: products) {
doc["product_name"] = product;
ASSERT_TRUE(coll->add(doc.dump()).ok());
}
sort_fields = {
sort_by("_rand(5)", "asc"),
};
auto results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get();
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"]);
ASSERT_EQ("4", results["hits"][1]["document"]["id"]);
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
ASSERT_EQ("3", results["hits"][3]["document"]["id"]);
ASSERT_EQ("2", results["hits"][4]["document"]["id"]);
sort_fields = {
sort_by("_rand(8)", "asc"),
};
results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get();
ASSERT_EQ(5, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"]);
ASSERT_EQ("3", results["hits"][1]["document"]["id"]);
ASSERT_EQ("4", results["hits"][2]["document"]["id"]);
ASSERT_EQ("0", results["hits"][3]["document"]["id"]);
ASSERT_EQ("2", results["hits"][4]["document"]["id"]);
//without seed value it takes current time as seed
sort_fields = {
sort_by("_rand()", "asc"),
};
results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get();
ASSERT_EQ(5, results["hits"].size());
//negative seed value is not allowed
sort_fields = {
sort_by("_rand(-1)", "asc"),
};
auto results_op = coll->search("*", {}, "", {}, sort_fields, {0});
ASSERT_EQ("Only positive integer seed value is allowed.", results_op.error());
sort_fields = {
sort_by("_rand(sadkjkj)", "asc"),
};
results_op = coll->search("*", {}, "", {}, sort_fields, {0});
ASSERT_EQ("Only positive integer seed value is allowed.", results_op.error());
//typos
sort_fields = {
sort_by("rand()", "asc"),
};
results_op = coll->search("*", {}, "", {}, sort_fields, {0});
ASSERT_EQ("Could not find a field named `rand` in the schema for sorting.", results_op.error());
sort_fields = {
sort_by("_random()", "asc"),
};
results_op = coll->search("*", {}, "", {}, sort_fields, {0});
ASSERT_EQ("Could not find a field named `_random` in the schema for sorting.", results_op.error());
}