sorting grouped results on group hit count

This commit is contained in:
krunal1313 2023-03-02 18:05:28 +05:30
parent 7b59484e2f
commit 77ffcd9444
5 changed files with 104 additions and 22 deletions

View File

@ -192,7 +192,7 @@ private:
Option<bool> validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
bool is_wildcard_query) const;
bool is_wildcard_query, bool is_group_by_query = false) const;
Option<bool> persist_collection_meta();
@ -351,7 +351,9 @@ public:
bool facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document,
std::string &value) const;
static void populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs);
static void populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs,
const spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
const std::vector<sort_by>& sort_by_fields);
void batch_index(std::vector<index_record>& index_records, std::vector<std::string>& json_out, size_t &num_indexed,
const bool& return_doc, const bool& return_id);

View File

@ -556,6 +556,7 @@ namespace sort_field_const {
static const std::string text_match = "_text_match";
static const std::string eval = "_eval";
static const std::string seq_id = "_seq_id";
static const std::string group_count = "_group_count";
static const std::string exclude_radius = "exclude_radius";
static const std::string precision = "precision";

View File

@ -644,7 +644,8 @@ void Collection::curate_results(string& actual_query, const string& filter_query
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std,
const bool is_wildcard_query) const {
const bool is_wildcard_query,
const bool is_group_by_query) const {
size_t num_sort_expressions = 0;
@ -819,14 +820,21 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
sort_field_std.name != sort_field_const::seq_id) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
std::string error = "Could not find a field named `" + sort_field_std.name +
"` in the schema for sorting.";
return Option<bool>(404, error);
if(!is_group_by_query) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
std::string error = "Could not find a field named `" + sort_field_std.name +
"` in the schema for sorting.";
return Option<bool>(404, error);
}
}
}
if(sort_field_std.name == sort_field_const::group_count && is_group_by_query == false) {
std::string error = " group_by parameters should not be empty when using sort_by group_count";
return Option<bool>(404, error);
}
StringUtils::toupper(sort_field_std.order);
if(sort_field_std.order != sort_field_const::asc && sort_field_std.order != sort_field_const::desc) {
@ -1292,9 +1300,11 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
std::vector<sort_by>& sort_fields_std = sort_fields_guard.sort_fields_std;
bool is_wildcard_query = (query == "*");
bool is_group_by_query = group_by_fields.size() > 0;
if(curated_sort_by.empty()) {
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std, is_wildcard_query);
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields,
sort_fields_std, is_wildcard_query, is_group_by_query);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
@ -1305,8 +1315,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
return Option<nlohmann::json>(400, "Parameter `sort_by` is malformed.");
}
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, sort_fields_std,
is_wildcard_query);
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields,
sort_fields_std, is_wildcard_query, is_group_by_query);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
@ -1398,8 +1408,8 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
topster.sort();
curated_topster.sort();
populate_result_kvs(&topster, raw_result_kvs);
populate_result_kvs(&curated_topster, override_result_kvs);
populate_result_kvs(&topster, raw_result_kvs, search_params->groups_processed, sort_fields_std);
populate_result_kvs(&curated_topster, override_result_kvs, search_params->groups_processed, sort_fields_std);
// for grouping we have to aggregate group set sizes to a count value
if(group_limit) {
@ -1731,8 +1741,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
if(group_limit) {
group_hits["group_key"] = group_key;
uint64_t distinct_id = index->get_distinct_id(group_by_fields, kv_group[0]->key);
const auto& itr = search_params->groups_processed.find(distinct_id);
const auto& itr = search_params->groups_processed.find(kv_group[0]->distinct_key);
if(itr != search_params->groups_processed.end()) {
group_hits["found"] = itr->second;
@ -2317,7 +2326,9 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s
}
}
void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs) {
void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<KV *>> &result_kvs,
const spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
const std::vector<sort_by>& sort_by_fields) {
if(topster->distinct) {
// we have to pick top-K groups
Topster gtopster(topster->MAX_SIZE);
@ -2340,6 +2351,25 @@ void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<K
);
result_kvs.emplace_back(group_kvs);
}
if(!sort_by_fields.empty() && sort_by_fields[0].name == sort_field_const::group_count) {
std::sort(result_kvs.begin(), result_kvs.end(),
[&](const std::vector<KV*>& g1, const std::vector<KV*>& g2) {
const auto& it1 = groups_processed.find(g1[0]->distinct_key);
const auto& it2 = groups_processed.find(g2[0]->distinct_key);
if(it1 != groups_processed.end() && it2 != groups_processed.end()) {
if(sort_by_fields[0].order == sort_field_const::asc) {
return it1->second < it2->second;
}
else {
return it1->second > it2->second;
}
}
return false;
});
}
} else {
for(uint32_t t = 0; t < topster->size; t++) {
KV* kv = topster->getKV(t);

View File

@ -3799,7 +3799,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional)
// avoiding loop
if (sort_fields.size() > 0) {
if (sort_fields.size() > 0 && sort_fields[0].name != sort_field_const::group_count) {
if (field_values[0] == &text_match_sentinel_value) {
scores[0] = int64_t(max_field_match_score);
match_score_index = 0;
@ -3855,7 +3855,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
}
if(sort_fields.size() > 1) {
if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) {
if (field_values[1] == &text_match_sentinel_value) {
scores[1] = int64_t(max_field_match_score);
match_score_index = 1;
@ -3907,7 +3907,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
}
if(sort_fields.size() > 2) {
if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) {
if (field_values[2] == &text_match_sentinel_value) {
scores[2] = int64_t(max_field_match_score);
match_score_index = 2;
@ -5038,7 +5038,8 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
// avoiding loop
if (sort_fields.size() > 0) {
if (field_values[0] == &text_match_sentinel_value) {
if (field_values[0] == &text_match_sentinel_value
&& sort_fields[0].name != sort_field_const::group_count) {
scores[0] = int64_t(match_score);
match_score_index = 0;
} else if (field_values[0] == &seq_id_sentinel_value) {
@ -5057,7 +5058,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
}
}
if(sort_fields.size() > 1) {
if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) {
if (field_values[1] == &text_match_sentinel_value) {
scores[1] = int64_t(match_score);
match_score_index = 1;
@ -5077,7 +5078,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
}
}
if(sort_fields.size() > 2) {
if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) {
if (field_values[2] == &text_match_sentinel_value) {
scores[2] = int64_t(match_score);
match_score_index = 2;

View File

@ -598,4 +598,52 @@ TEST_F(CollectionGroupingTest, RepeatedFieldNameGroupHitCount) {
ASSERT_EQ(1, res["grouped_hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["found"].get<int32_t>());
}
TEST_F(CollectionGroupingTest, SortingOnGroupCount) {
std::vector<sort_by> sort_fields = {sort_by("_group_count", "DESC")};
auto res = coll_group->search("*", {}, "", {"brand"}, sort_fields, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
{}, {}, {"size"}, 2).get();
ASSERT_EQ(3, res["found"].get<size_t>());
ASSERT_EQ(3, res["grouped_hits"].size());
ASSERT_EQ(10, res["grouped_hits"][0]["group_key"][0].get<size_t>());
ASSERT_EQ(7, res["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_EQ(12, res["grouped_hits"][1]["group_key"][0].get<size_t>());
ASSERT_EQ(3, res["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_EQ(11, res["grouped_hits"][2]["group_key"][0].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][2]["found"].get<int32_t>());
//search in asc order
std::vector<sort_by> sort_fields2 = {sort_by("_group_count", "ASC")};
auto res2 = coll_group->search("*", {}, "", {"brand"}, sort_fields2, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
{}, {}, {"size"}, 2).get();
ASSERT_EQ(3, res2["found"].get<size_t>());
ASSERT_EQ(3, res2["grouped_hits"].size());
ASSERT_EQ(11, res2["grouped_hits"][0]["group_key"][0].get<size_t>());
ASSERT_EQ(2, res2["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_EQ(12, res2["grouped_hits"][1]["group_key"][0].get<size_t>());
ASSERT_EQ(3, res2["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_EQ(10, res2["grouped_hits"][2]["group_key"][0].get<size_t>());
ASSERT_EQ(7, res2["grouped_hits"][2]["found"].get<int32_t>());
}