From 77ffcd94447ca65df46eb06389ede59cbf15e251 Mon Sep 17 00:00:00 2001 From: krunal1313 Date: Thu, 2 Mar 2023 18:05:28 +0530 Subject: [PATCH] sorting grouped results on group hit count --- include/collection.h | 6 ++-- include/field.h | 1 + src/collection.cpp | 58 +++++++++++++++++++++++-------- src/index.cpp | 13 +++---- test/collection_grouping_test.cpp | 48 +++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 22 deletions(-) diff --git a/include/collection.h b/include/collection.h index 720631c0..7100ad9e 100644 --- a/include/collection.h +++ b/include/collection.h @@ -192,7 +192,7 @@ private: Option validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, - bool is_wildcard_query) const; + bool is_wildcard_query, bool is_group_by_query = false) const; Option persist_collection_meta(); @@ -351,7 +351,9 @@ public: bool facet_value_to_string(const facet &a_facet, const facet_count_t &facet_count, const nlohmann::json &document, std::string &value) const; - static void populate_result_kvs(Topster *topster, std::vector> &result_kvs); + static void populate_result_kvs(Topster *topster, std::vector> &result_kvs, + const spp::sparse_hash_map& groups_processed, + const std::vector& sort_by_fields); void batch_index(std::vector& index_records, std::vector& json_out, size_t &num_indexed, const bool& return_doc, const bool& return_id); diff --git a/include/field.h b/include/field.h index 872eb34a..1a21e50b 100644 --- a/include/field.h +++ b/include/field.h @@ -556,6 +556,7 @@ namespace sort_field_const { static const std::string text_match = "_text_match"; static const std::string eval = "_eval"; static const std::string seq_id = "_seq_id"; + static const std::string group_count = "_group_count"; static const std::string exclude_radius = "exclude_radius"; static const std::string precision = "precision"; diff --git a/src/collection.cpp b/src/collection.cpp index 7fb74f54..a6ab6caf 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -644,7 +644,8 @@ void Collection::curate_results(string& actual_query, const string& filter_query Option Collection::validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std, - const bool is_wildcard_query) const { + const bool is_wildcard_query, + const bool is_group_by_query) const { size_t num_sort_expressions = 0; @@ -819,14 +820,21 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval && sort_field_std.name != sort_field_const::seq_id) { - const auto field_it = search_schema.find(sort_field_std.name); - if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { - std::string error = "Could not find a field named `" + sort_field_std.name + - "` in the schema for sorting."; - return Option(404, error); + if(!is_group_by_query) { + const auto field_it = search_schema.find(sort_field_std.name); + if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) { + std::string error = "Could not find a field named `" + sort_field_std.name + + "` in the schema for sorting."; + return Option(404, error); + } } } + if(sort_field_std.name == sort_field_const::group_count && is_group_by_query == false) { + std::string error = " group_by parameters should not be empty when using sort_by group_count"; + return Option(404, error); + } + StringUtils::toupper(sort_field_std.order); if(sort_field_std.order != sort_field_const::asc && sort_field_std.order != sort_field_const::desc) { @@ -1292,9 +1300,11 @@ Option Collection::search(const std::string & raw_query, std::vector& sort_fields_std = sort_fields_guard.sort_fields_std; bool is_wildcard_query = (query == "*"); + bool is_group_by_query = group_by_fields.size() > 0; if(curated_sort_by.empty()) { - auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std, is_wildcard_query); + auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, + sort_fields_std, is_wildcard_query, is_group_by_query); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1305,8 +1315,8 @@ Option Collection::search(const std::string & raw_query, return Option(400, "Parameter `sort_by` is malformed."); } - auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, sort_fields_std, - is_wildcard_query); + auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, + sort_fields_std, is_wildcard_query, is_group_by_query); if(!sort_validation_op.ok()) { return Option(sort_validation_op.code(), sort_validation_op.error()); } @@ -1398,8 +1408,8 @@ Option Collection::search(const std::string & raw_query, topster.sort(); curated_topster.sort(); - populate_result_kvs(&topster, raw_result_kvs); - populate_result_kvs(&curated_topster, override_result_kvs); + populate_result_kvs(&topster, raw_result_kvs, search_params->groups_processed, sort_fields_std); + populate_result_kvs(&curated_topster, override_result_kvs, search_params->groups_processed, sort_fields_std); // for grouping we have to aggregate group set sizes to a count value if(group_limit) { @@ -1731,8 +1741,7 @@ Option Collection::search(const std::string & raw_query, if(group_limit) { group_hits["group_key"] = group_key; - uint64_t distinct_id = index->get_distinct_id(group_by_fields, kv_group[0]->key); - const auto& itr = search_params->groups_processed.find(distinct_id); + const auto& itr = search_params->groups_processed.find(kv_group[0]->distinct_key); if(itr != search_params->groups_processed.end()) { group_hits["found"] = itr->second; @@ -2317,7 +2326,9 @@ void Collection::parse_search_query(const std::string &query, std::vector> &result_kvs) { +void Collection::populate_result_kvs(Topster *topster, std::vector> &result_kvs, + const spp::sparse_hash_map& groups_processed, + const std::vector& sort_by_fields) { if(topster->distinct) { // we have to pick top-K groups Topster gtopster(topster->MAX_SIZE); @@ -2340,6 +2351,25 @@ void Collection::populate_result_kvs(Topster *topster, std::vector& g1, const std::vector& g2) { + const auto& it1 = groups_processed.find(g1[0]->distinct_key); + const auto& it2 = groups_processed.find(g2[0]->distinct_key); + + if(it1 != groups_processed.end() && it2 != groups_processed.end()) { + if(sort_by_fields[0].order == sort_field_const::asc) { + return it1->second < it2->second; + } + else { + return it1->second > it2->second; + } + } + return false; + }); + } + } else { for(uint32_t t = 0; t < topster->size; t++) { KV* kv = topster->getKV(t); diff --git a/src/index.cpp b/src/index.cpp index 79f51504..fd24b481 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3799,7 +3799,7 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional) // avoiding loop - if (sort_fields.size() > 0) { + if (sort_fields.size() > 0 && sort_fields[0].name != sort_field_const::group_count) { if (field_values[0] == &text_match_sentinel_value) { scores[0] = int64_t(max_field_match_score); match_score_index = 0; @@ -3855,7 +3855,7 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } } - if(sort_fields.size() > 1) { + if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) { if (field_values[1] == &text_match_sentinel_value) { scores[1] = int64_t(max_field_match_score); match_score_index = 1; @@ -3907,7 +3907,7 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i } } - if(sort_fields.size() > 2) { + if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) { if (field_values[2] == &text_match_sentinel_value) { scores[2] = int64_t(max_field_match_score); match_score_index = 2; @@ -5038,7 +5038,8 @@ void Index::score_results(const std::vector & sort_fields, const uint16 // avoiding loop if (sort_fields.size() > 0) { - if (field_values[0] == &text_match_sentinel_value) { + if (field_values[0] == &text_match_sentinel_value + && sort_fields[0].name != sort_field_const::group_count) { scores[0] = int64_t(match_score); match_score_index = 0; } else if (field_values[0] == &seq_id_sentinel_value) { @@ -5057,7 +5058,7 @@ void Index::score_results(const std::vector & sort_fields, const uint16 } } - if(sort_fields.size() > 1) { + if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) { if (field_values[1] == &text_match_sentinel_value) { scores[1] = int64_t(match_score); match_score_index = 1; @@ -5077,7 +5078,7 @@ void Index::score_results(const std::vector & sort_fields, const uint16 } } - if(sort_fields.size() > 2) { + if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) { if (field_values[2] == &text_match_sentinel_value) { scores[2] = int64_t(match_score); match_score_index = 2; diff --git a/test/collection_grouping_test.cpp b/test/collection_grouping_test.cpp index d2c2a7fd..efbb2feb 100644 --- a/test/collection_grouping_test.cpp +++ b/test/collection_grouping_test.cpp @@ -598,4 +598,52 @@ TEST_F(CollectionGroupingTest, RepeatedFieldNameGroupHitCount) { ASSERT_EQ(1, res["grouped_hits"].size()); ASSERT_EQ(1, res["grouped_hits"][0]["found"].get()); +} + +TEST_F(CollectionGroupingTest, SortingOnGroupCount) { + + std::vector sort_fields = {sort_by("_group_count", "DESC")}; + + auto res = coll_group->search("*", {}, "", {"brand"}, sort_fields, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + {}, {}, {"size"}, 2).get(); + + ASSERT_EQ(3, res["found"].get()); + ASSERT_EQ(3, res["grouped_hits"].size()); + + ASSERT_EQ(10, res["grouped_hits"][0]["group_key"][0].get()); + ASSERT_EQ(7, res["grouped_hits"][0]["found"].get()); + + ASSERT_EQ(12, res["grouped_hits"][1]["group_key"][0].get()); + ASSERT_EQ(3, res["grouped_hits"][1]["found"].get()); + + ASSERT_EQ(11, res["grouped_hits"][2]["group_key"][0].get()); + ASSERT_EQ(2, res["grouped_hits"][2]["found"].get()); + + + //search in asc order + + std::vector sort_fields2 = {sort_by("_group_count", "ASC")}; + + auto res2 = coll_group->search("*", {}, "", {"brand"}, sort_fields2, {0}, 50, 1, FREQUENCY, + {false}, Index::DROP_TOKENS_THRESHOLD, + spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 5, + "", 10, + {}, {}, {"size"}, 2).get(); + + ASSERT_EQ(3, res2["found"].get()); + ASSERT_EQ(3, res2["grouped_hits"].size()); + + ASSERT_EQ(11, res2["grouped_hits"][0]["group_key"][0].get()); + ASSERT_EQ(2, res2["grouped_hits"][0]["found"].get()); + + ASSERT_EQ(12, res2["grouped_hits"][1]["group_key"][0].get()); + ASSERT_EQ(3, res2["grouped_hits"][1]["found"].get()); + + ASSERT_EQ(10, res2["grouped_hits"][2]["group_key"][0].get()); + ASSERT_EQ(7, res2["grouped_hits"][2]["found"].get()); } \ No newline at end of file