mirror of
https://github.com/typesense/typesense.git
synced 2025-05-19 05:08:43 +08:00
sorting grouped results on group hit count
This commit is contained in:
parent
7b59484e2f
commit
77ffcd9444
@ -192,7 +192,7 @@ private:
|
||||
|
||||
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
bool is_wildcard_query) const;
|
||||
bool is_wildcard_query, bool is_group_by_query = false) const;
|
||||
|
||||
Option<bool> persist_collection_meta();
|
||||
|
||||
@ -351,7 +351,9 @@ public:
|
||||
bool facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document,
|
||||
std::string &value) const;
|
||||
|
||||
static void populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs);
|
||||
static void populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs,
|
||||
const spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
|
||||
const std::vector<sort_by>& sort_by_fields);
|
||||
|
||||
void batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out, size_t &num_indexed,
|
||||
const bool& return_doc, const bool& return_id);
|
||||
|
@ -556,6 +556,7 @@ namespace sort_field_const {
|
||||
static const std::string text_match = "_text_match";
|
||||
static const std::string eval = "_eval";
|
||||
static const std::string seq_id = "_seq_id";
|
||||
static const std::string group_count = "_group_count";
|
||||
|
||||
static const std::string exclude_radius = "exclude_radius";
|
||||
static const std::string precision = "precision";
|
||||
|
@ -644,7 +644,8 @@ void Collection::curate_results(string& actual_query, const string& filter_query
|
||||
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
const bool is_wildcard_query) const {
|
||||
const bool is_wildcard_query,
|
||||
const bool is_group_by_query) const {
|
||||
|
||||
size_t num_sort_expressions = 0;
|
||||
|
||||
@ -819,14 +820,21 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
|
||||
sort_field_std.name != sort_field_const::seq_id) {
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
std::string error = "Could not find a field named `" + sort_field_std.name +
|
||||
"` in the schema for sorting.";
|
||||
return Option<bool>(404, error);
|
||||
if(!is_group_by_query) {
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
|
||||
std::string error = "Could not find a field named `" + sort_field_std.name +
|
||||
"` in the schema for sorting.";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_field_std.name == sort_field_const::group_count && is_group_by_query == false) {
|
||||
std::string error = " group_by parameters should not be empty when using sort_by group_count";
|
||||
return Option<bool>(404, error);
|
||||
}
|
||||
|
||||
StringUtils::toupper(sort_field_std.order);
|
||||
|
||||
if(sort_field_std.order != sort_field_const::asc && sort_field_std.order != sort_field_const::desc) {
|
||||
@ -1292,9 +1300,11 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
std::vector<sort_by>& sort_fields_std = sort_fields_guard.sort_fields_std;
|
||||
|
||||
bool is_wildcard_query = (query == "*");
|
||||
bool is_group_by_query = group_by_fields.size() > 0;
|
||||
|
||||
if(curated_sort_by.empty()) {
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std, is_wildcard_query);
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_group_by_query);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1305,8 +1315,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
return Option<nlohmann::json>(400, "Parameter `sort_by` is malformed.");
|
||||
}
|
||||
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, sort_fields_std,
|
||||
is_wildcard_query);
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
|
||||
sort_fields_std, is_wildcard_query, is_group_by_query);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
@ -1398,8 +1408,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
topster.sort();
|
||||
curated_topster.sort();
|
||||
|
||||
populate_result_kvs(&topster, raw_result_kvs);
|
||||
populate_result_kvs(&curated_topster, override_result_kvs);
|
||||
populate_result_kvs(&topster, raw_result_kvs, search_params->groups_processed, sort_fields_std);
|
||||
populate_result_kvs(&curated_topster, override_result_kvs, search_params->groups_processed, sort_fields_std);
|
||||
|
||||
// for grouping we have to aggregate group set sizes to a count value
|
||||
if(group_limit) {
|
||||
@ -1731,8 +1741,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
|
||||
if(group_limit) {
|
||||
group_hits["group_key"] = group_key;
|
||||
|
||||
uint64_t distinct_id = index->get_distinct_id(group_by_fields, kv_group[0]->key);
|
||||
const auto& itr = search_params->groups_processed.find(distinct_id);
|
||||
const auto& itr = search_params->groups_processed.find(kv_group[0]->distinct_key);
|
||||
|
||||
if(itr != search_params->groups_processed.end()) {
|
||||
group_hits["found"] = itr->second;
|
||||
@ -2317,7 +2326,9 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s
|
||||
}
|
||||
}
|
||||
|
||||
void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs) {
|
||||
void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs,
|
||||
const spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
|
||||
const std::vector<sort_by>& sort_by_fields) {
|
||||
if(topster->distinct) {
|
||||
// we have to pick top-K groups
|
||||
Topster gtopster(topster->MAX_SIZE);
|
||||
@ -2340,6 +2351,25 @@ void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<K
|
||||
);
|
||||
result_kvs.emplace_back(group_kvs);
|
||||
}
|
||||
|
||||
if(!sort_by_fields.empty() && sort_by_fields[0].name == sort_field_const::group_count) {
|
||||
std::sort(result_kvs.begin(), result_kvs.end(),
|
||||
[&](const std::vector<KV*>& g1, const std::vector<KV*>& g2) {
|
||||
const auto& it1 = groups_processed.find(g1[0]->distinct_key);
|
||||
const auto& it2 = groups_processed.find(g2[0]->distinct_key);
|
||||
|
||||
if(it1 != groups_processed.end() && it2 != groups_processed.end()) {
|
||||
if(sort_by_fields[0].order == sort_field_const::asc) {
|
||||
return it1->second < it2->second;
|
||||
}
|
||||
else {
|
||||
return it1->second > it2->second;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
} else {
|
||||
for(uint32_t t = 0; t < topster->size; t++) {
|
||||
KV* kv = topster->getKV(t);
|
||||
|
@ -3799,7 +3799,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional)
|
||||
|
||||
// avoiding loop
|
||||
if (sort_fields.size() > 0) {
|
||||
if (sort_fields.size() > 0 && sort_fields[0].name != sort_field_const::group_count) {
|
||||
if (field_values[0] == &text_match_sentinel_value) {
|
||||
scores[0] = int64_t(max_field_match_score);
|
||||
match_score_index = 0;
|
||||
@ -3855,7 +3855,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_fields.size() > 1) {
|
||||
if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) {
|
||||
if (field_values[1] == &text_match_sentinel_value) {
|
||||
scores[1] = int64_t(max_field_match_score);
|
||||
match_score_index = 1;
|
||||
@ -3907,7 +3907,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_fields.size() > 2) {
|
||||
if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) {
|
||||
if (field_values[2] == &text_match_sentinel_value) {
|
||||
scores[2] = int64_t(max_field_match_score);
|
||||
match_score_index = 2;
|
||||
@ -5038,7 +5038,8 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
|
||||
// avoiding loop
|
||||
if (sort_fields.size() > 0) {
|
||||
if (field_values[0] == &text_match_sentinel_value) {
|
||||
if (field_values[0] == &text_match_sentinel_value
|
||||
&& sort_fields[0].name != sort_field_const::group_count) {
|
||||
scores[0] = int64_t(match_score);
|
||||
match_score_index = 0;
|
||||
} else if (field_values[0] == &seq_id_sentinel_value) {
|
||||
@ -5057,7 +5058,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_fields.size() > 1) {
|
||||
if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) {
|
||||
if (field_values[1] == &text_match_sentinel_value) {
|
||||
scores[1] = int64_t(match_score);
|
||||
match_score_index = 1;
|
||||
@ -5077,7 +5078,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_fields.size() > 2) {
|
||||
if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) {
|
||||
if (field_values[2] == &text_match_sentinel_value) {
|
||||
scores[2] = int64_t(match_score);
|
||||
match_score_index = 2;
|
||||
|
@ -598,4 +598,52 @@ TEST_F(CollectionGroupingTest, RepeatedFieldNameGroupHitCount) {
|
||||
|
||||
ASSERT_EQ(1, res["grouped_hits"].size());
|
||||
ASSERT_EQ(1, res["grouped_hits"][0]["found"].get<int32_t>());
|
||||
}
|
||||
|
||||
TEST_F(CollectionGroupingTest, SortingOnGroupCount) {
|
||||
|
||||
std::vector<sort_by> sort_fields = {sort_by("_group_count", "DESC")};
|
||||
|
||||
auto res = coll_group->search("*", {}, "", {"brand"}, sort_fields, {0}, 50, 1, FREQUENCY,
|
||||
{false}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(3, res["found"].get<size_t>());
|
||||
ASSERT_EQ(3, res["grouped_hits"].size());
|
||||
|
||||
ASSERT_EQ(10, res["grouped_hits"][0]["group_key"][0].get<size_t>());
|
||||
ASSERT_EQ(7, res["grouped_hits"][0]["found"].get<int32_t>());
|
||||
|
||||
ASSERT_EQ(12, res["grouped_hits"][1]["group_key"][0].get<size_t>());
|
||||
ASSERT_EQ(3, res["grouped_hits"][1]["found"].get<int32_t>());
|
||||
|
||||
ASSERT_EQ(11, res["grouped_hits"][2]["group_key"][0].get<size_t>());
|
||||
ASSERT_EQ(2, res["grouped_hits"][2]["found"].get<int32_t>());
|
||||
|
||||
|
||||
//search in asc order
|
||||
|
||||
std::vector<sort_by> sort_fields2 = {sort_by("_group_count", "ASC")};
|
||||
|
||||
auto res2 = coll_group->search("*", {}, "", {"brand"}, sort_fields2, {0}, 50, 1, FREQUENCY,
|
||||
{false}, Index::DROP_TOKENS_THRESHOLD,
|
||||
spp::sparse_hash_set<std::string>(),
|
||||
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
|
||||
"", 10,
|
||||
{}, {}, {"size"}, 2).get();
|
||||
|
||||
ASSERT_EQ(3, res2["found"].get<size_t>());
|
||||
ASSERT_EQ(3, res2["grouped_hits"].size());
|
||||
|
||||
ASSERT_EQ(11, res2["grouped_hits"][0]["group_key"][0].get<size_t>());
|
||||
ASSERT_EQ(2, res2["grouped_hits"][0]["found"].get<int32_t>());
|
||||
|
||||
ASSERT_EQ(12, res2["grouped_hits"][1]["group_key"][0].get<size_t>());
|
||||
ASSERT_EQ(3, res2["grouped_hits"][1]["found"].get<int32_t>());
|
||||
|
||||
ASSERT_EQ(10, res2["grouped_hits"][2]["group_key"][0].get<size_t>());
|
||||
ASSERT_EQ(7, res2["grouped_hits"][2]["found"].get<int32_t>());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user