mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 14:12:27 +08:00
support random sort (#1918)
* support random sort * add empty sort_fields check * generate random score instead of shuffling * refactor sort_random_t * add seed value check
This commit is contained in:
parent
e62e7e8316
commit
3a24ba9f84
@ -407,6 +407,8 @@ namespace sort_field_const {
|
||||
|
||||
static const std::string vector_distance = "_vector_distance";
|
||||
static const std::string vector_query = "_vector_query";
|
||||
|
||||
static const std::string random_order = "_rand";
|
||||
}
|
||||
|
||||
namespace ref_include {
|
||||
@ -449,6 +451,29 @@ struct sort_vector_query_t {
|
||||
hnsw_index_t* vector_index;
|
||||
};
|
||||
|
||||
struct sort_random_t {
|
||||
bool is_enabled = false;
|
||||
mutable std::mt19937 rng;
|
||||
mutable std::uniform_int_distribution<uint32_t> distrib;
|
||||
|
||||
sort_random_t() : distrib(0, UINT32_MAX) {};
|
||||
|
||||
sort_random_t& operator=(const sort_random_t& other) {
|
||||
rng = other.rng;
|
||||
distrib = other.distrib;
|
||||
is_enabled = other.is_enabled;
|
||||
}
|
||||
|
||||
void initialize(uint32_t seed) {
|
||||
rng.seed(seed);
|
||||
is_enabled = true;
|
||||
}
|
||||
|
||||
uint32_t generate_random() const {
|
||||
return distrib(rng);
|
||||
}
|
||||
};
|
||||
|
||||
struct sort_by {
|
||||
enum missing_values_t {
|
||||
first,
|
||||
@ -483,6 +508,8 @@ struct sort_by {
|
||||
std::vector<std::string> nested_join_collection_names;
|
||||
sort_vector_query_t vector_query;
|
||||
|
||||
sort_random_t random_sort;
|
||||
|
||||
sort_by(const std::string & name, const std::string & order):
|
||||
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),
|
||||
missing_values(normal) {
|
||||
@ -518,6 +545,7 @@ struct sort_by {
|
||||
reference_collection_name = other.reference_collection_name;
|
||||
nested_join_collection_names = other.nested_join_collection_names;
|
||||
vector_query = other.vector_query;
|
||||
random_sort = other.random_sort;
|
||||
}
|
||||
|
||||
sort_by& operator=(const sort_by& other) {
|
||||
|
@ -1465,7 +1465,21 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
|
||||
sort_field_std.name = actual_field_name;
|
||||
} else if(actual_field_name == sort_field_const::random_order) {
|
||||
const std::string &random_sort_str = sort_field_std.name.substr(paran_start + 1,
|
||||
sort_field_std.name.size() -
|
||||
paran_start -2);
|
||||
|
||||
uint32_t seed = time(nullptr);
|
||||
if (!random_sort_str.empty()) {
|
||||
if(random_sort_str[0] == '-' || !StringUtils::is_uint32_t(random_sort_str)) {
|
||||
return Option<bool>(400, "Only positive integer seed value is allowed.");
|
||||
}
|
||||
|
||||
seed = static_cast<uint32_t>(std::stoul(random_sort_str));
|
||||
}
|
||||
sort_field_std.random_sort.initialize(seed);
|
||||
sort_field_std.name = actual_field_name;
|
||||
} else {
|
||||
if(field_it == search_schema.end()) {
|
||||
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
|
||||
@ -1592,7 +1606,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
|
||||
sort_field_std.name != sort_field_const::seq_id && sort_field_std.name != sort_field_const::group_found && sort_field_std.name != sort_field_const::vector_distance &&
|
||||
sort_field_std.name != sort_field_const::vector_query) {
|
||||
sort_field_std.name != sort_field_const::vector_query && sort_field_std.name != sort_field_const::random_order) {
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
std::string error = "Could not find a field named `" + sort_field_std.name +
|
||||
|
@ -4954,6 +4954,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
if (sort_fields.size() > 0) {
|
||||
auto reference_found = true;
|
||||
auto const& is_reference_sort = !sort_fields[0].reference_collection_name.empty();
|
||||
auto is_random_sort = sort_fields[0].random_sort.is_enabled;
|
||||
|
||||
// In case of reference sort_by, we need to get the sort score of the reference doc id.
|
||||
if (is_reference_sort) {
|
||||
@ -5047,7 +5048,9 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
// do nothing
|
||||
}
|
||||
} else {
|
||||
if (!is_reference_sort || reference_found) {
|
||||
if(is_random_sort) {
|
||||
scores[0] = sort_fields[0].random_sort.generate_random();
|
||||
} else if (!is_reference_sort || reference_found) {
|
||||
auto it = field_values[0]->find(is_reference_sort ? ref_seq_id : seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
} else {
|
||||
|
@ -2699,4 +2699,93 @@ TEST_F(CollectionSortingTest, TestVectorQueryDistanceThresholdSorting) {
|
||||
ASSERT_EQ(0.07853113859891891, res["hits"][0]["vector_distance"].get<float>());
|
||||
ASSERT_EQ("Cell Phone", res["hits"][1]["document"]["product_name"]);
|
||||
ASSERT_EQ(0.08472149819135666, res["hits"][1]["vector_distance"].get<float>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, TestSortByRandomOrder) {
|
||||
auto schema_json = R"({
|
||||
"name": "digital_products",
|
||||
"fields":[
|
||||
{
|
||||
"name": "product_name","type": "string"
|
||||
}]
|
||||
})"_json;
|
||||
|
||||
|
||||
auto coll_op = collectionManager.create_collection(schema_json);
|
||||
ASSERT_TRUE(coll_op.ok());
|
||||
auto coll = coll_op.get();
|
||||
|
||||
std::vector<std::string> products = {"Samsung Smartphone", "Vivo SmartPhone", "Oneplus Smartphone", "Pixel Smartphone", "Moto Smartphone"};
|
||||
nlohmann::json doc;
|
||||
for (auto product: products) {
|
||||
doc["product_name"] = product;
|
||||
ASSERT_TRUE(coll->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_rand(5)", "asc"),
|
||||
};
|
||||
|
||||
auto results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("4", results["hits"][1]["document"]["id"]);
|
||||
ASSERT_EQ("0", results["hits"][2]["document"]["id"]);
|
||||
ASSERT_EQ("3", results["hits"][3]["document"]["id"]);
|
||||
ASSERT_EQ("2", results["hits"][4]["document"]["id"]);
|
||||
|
||||
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_rand(8)", "asc"),
|
||||
};
|
||||
|
||||
results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"]);
|
||||
ASSERT_EQ("3", results["hits"][1]["document"]["id"]);
|
||||
ASSERT_EQ("4", results["hits"][2]["document"]["id"]);
|
||||
ASSERT_EQ("0", results["hits"][3]["document"]["id"]);
|
||||
ASSERT_EQ("2", results["hits"][4]["document"]["id"]);
|
||||
|
||||
|
||||
//without seed value it takes current time as seed
|
||||
sort_fields = {
|
||||
sort_by("_rand()", "asc"),
|
||||
};
|
||||
|
||||
results = coll->search("smartphone", {"product_name"}, "", {}, sort_fields, {0}).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
|
||||
//negative seed value is not allowed
|
||||
sort_fields = {
|
||||
sort_by("_rand(-1)", "asc"),
|
||||
};
|
||||
|
||||
auto results_op = coll->search("*", {}, "", {}, sort_fields, {0});
|
||||
ASSERT_EQ("Only positive integer seed value is allowed.", results_op.error());
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_rand(sadkjkj)", "asc"),
|
||||
};
|
||||
|
||||
results_op = coll->search("*", {}, "", {}, sort_fields, {0});
|
||||
ASSERT_EQ("Only positive integer seed value is allowed.", results_op.error());
|
||||
|
||||
//typos
|
||||
sort_fields = {
|
||||
sort_by("rand()", "asc"),
|
||||
};
|
||||
|
||||
results_op = coll->search("*", {}, "", {}, sort_fields, {0});
|
||||
ASSERT_EQ("Could not find a field named `rand` in the schema for sorting.", results_op.error());
|
||||
|
||||
sort_fields = {
|
||||
sort_by("_random()", "asc"),
|
||||
};
|
||||
|
||||
results_op = coll->search("*", {}, "", {}, sort_fields, {0});
|
||||
ASSERT_EQ("Could not find a field named `_random` in the schema for sorting.", results_op.error());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user