Merge pull request #915 from krunal1313/group_hits_count

Group hits count
This commit is contained in:
Kishore Nallan 2023-03-01 18:51:01 +05:30 committed by GitHub
commit 543fc2ca7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 46 deletions

View File

@ -133,7 +133,7 @@ struct search_args {
const enable_t split_join_tokens;
tsl::htrie_map<char, token_leaf> qtoken_set;
spp::sparse_hash_set<uint64_t> groups_processed;
spp::sparse_hash_map<uint64_t, uint32_t> groups_processed;
std::vector<std::vector<art_leaf*>> searched_queries;
Topster* topster;
Topster* curated_topster;
@ -406,7 +406,7 @@ private:
int last_typo,
int max_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
Topster* topster, spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t** all_result_ids, size_t & all_result_ids_len,
size_t& field_num_results,
size_t group_limit,
@ -433,7 +433,7 @@ private:
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
Topster* topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
const size_t typo_tokens_threshold,
const size_t group_limit,
@ -459,7 +459,7 @@ private:
const std::vector<uint32_t>& curated_ids,
std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
Topster* topster, spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t** all_result_ids,
size_t & all_result_ids_len,
size_t& field_num_results,
@ -603,7 +603,7 @@ public:
void score_results(const std::vector<sort_by> &sort_fields, const uint16_t &query_index, const uint8_t &field_id,
bool field_is_array, const uint32_t total_cost,
Topster *topster, const std::vector<art_leaf *> &query_suggestion,
spp::sparse_hash_set<uint64_t> &groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
const uint32_t seq_id, const int sort_order[3],
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
@ -660,7 +660,7 @@ public:
const size_t per_page,
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
const size_t drop_tokens_threshold, size_t& all_result_ids_len,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<std::vector<KV*>>& override_result_kvs,
@ -729,7 +729,7 @@ public:
void search_wildcard(filter_node_t const* const& filter_tree_root,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
const std::vector<sort_by>& sort_fields, Topster* topster, Topster* curated_topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries, const size_t group_limit,
const std::vector<std::string>& group_by_fields, const std::set<uint32_t>& curated_ids,
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
@ -781,7 +781,7 @@ public:
const std::vector<size_t>& geopoint_indices,
const std::vector<uint32_t>& curated_ids_sorted,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
spp::sparse_hash_set<uint64_t>& groups_processed) const;
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed) const;
void do_synonym_search(const std::vector<search_field_t>& the_fields,
const text_match_type_t match_type,
@ -801,7 +801,7 @@ public:
Topster* actual_topster,
std::vector<std::vector<token_t>>& q_pos_synonyms,
int syn_orig_num_tokens,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
const uint32_t* filter_ids, uint32_t filter_ids_length,
@ -827,7 +827,7 @@ public:
const std::vector<uint32_t>& num_typos,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
Topster* topster, spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,
@ -863,7 +863,7 @@ public:
const text_match_type_t match_type,
const std::vector<sort_by>& sort_fields,
Topster* topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
const size_t group_limit,
@ -889,7 +889,7 @@ public:
const size_t min_typo, const std::vector<uint32_t>& num_typos,
Topster* topster, Topster* curated_topster, const token_ordering& token_order,
const std::vector<bool>& prefixes, const size_t drop_tokens_threshold,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
const size_t typo_tokens_threshold, const size_t group_limit,
const std::vector<std::string>& group_by_fields, bool prioritize_exact_match,

View File

@ -99,6 +99,8 @@ struct Topster {
std::unordered_map<uint64_t, KV*> kv_map;
spp::sparse_hash_set<uint64_t> group_doc_seq_ids;
spp::sparse_hash_map<uint64_t, Topster*> group_kv_map;
size_t distinct;
@ -144,23 +146,33 @@ struct Topster {
(*b)->array_index = a_index;
}
bool add(KV* kv) {
int add(KV* kv) {
/*LOG(INFO) << "kv_map size: " << kv_map.size() << " -- kvs[0]: " << kvs[0]->scores[kvs[0]->match_score_index];
for(auto& mkv: kv_map) {
LOG(INFO) << "kv key: " << mkv.first << " => " << mkv.second->scores[mkv.second->match_score_index];
}*/
int ret = 1;
bool less_than_min_heap = (size >= MAX_SIZE) && is_smaller(kv, kvs[0]);
size_t heap_op_index = 0;
if(!distinct && less_than_min_heap) {
// for non-distinct, if incoming value is smaller than min-heap ignore
return false;
return 0;
}
bool SIFT_DOWN = true;
if(distinct) {
const auto& doc_seq_id_exists =
(group_doc_seq_ids.find(kv->key) != group_doc_seq_ids.end());
if(doc_seq_id_exists) {
ret = 2;
}
group_doc_seq_ids.emplace(kv->key);
// Grouping cannot be a streaming operation, so aggregate the KVs associated with every group.
auto kvs_it = group_kv_map.find(kv->distinct_key);
if(kvs_it != group_kv_map.end()) {
@ -171,7 +183,7 @@ struct Topster {
group_kv_map.insert({kv->distinct_key, g_topster});
}
return true;
return ret;
} else { // not distinct
//LOG(INFO) << "Searching for key: " << kv->key;
@ -193,7 +205,7 @@ struct Topster {
bool smaller_than_existing = is_smaller(kv, existing_kv);
if(smaller_than_existing) {
return false;
return 0;
}
SIFT_DOWN = true;
@ -256,7 +268,7 @@ struct Topster {
}
}
return true;
return ret;
}
static bool is_greater(const struct KV* i, const struct KV* j) {

View File

@ -1239,7 +1239,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
std::vector<std::vector<KV*>> override_result_kvs;
size_t total_found = 0;
spp::sparse_hash_set<uint64_t> groups_processed; // used to calculate total_found for grouped query
std::vector<uint32_t> excluded_ids;
std::vector<std::pair<uint32_t, uint32_t>> included_ids; // ID -> position
@ -1735,6 +1734,13 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query,
if(group_limit) {
group_hits["group_key"] = group_key;
uint64_t distinct_id = index->get_distinct_id(group_by_fields, kv_group[0]->key);
const auto& itr = search_params->groups_processed.find(distinct_id);
if(itr != search_params->groups_processed.end()) {
group_hits["found"] = itr->second;
}
result["grouped_hits"].push_back(group_hits);
}
}

View File

@ -1404,7 +1404,7 @@ void Index::search_all_candidates(const size_t num_search_fields,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
Topster* topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
const size_t typo_tokens_threshold,
const size_t group_limit,
@ -1481,7 +1481,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
std::vector<token_candidates> & token_candidates_vec,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t** all_result_ids, size_t & all_result_ids_len,
size_t& field_num_results,
const size_t typo_tokens_threshold,
@ -2334,7 +2334,7 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string&
std::vector<facet> facets;
std::vector<std::vector<art_leaf*>> searched_queries;
Topster* topster = nullptr;
spp::sparse_hash_set<uint64_t> groups_processed;
spp::sparse_hash_map<uint64_t, uint32_t> groups_processed;
uint32_t* result_ids = nullptr;
size_t result_ids_len = 0;
size_t field_num_results = 0;
@ -2491,7 +2491,7 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
const size_t per_page,
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
const size_t drop_tokens_threshold, size_t& all_result_ids_len,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<std::vector<KV*>>& override_result_kvs,
@ -2578,7 +2578,6 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
groups_processed.emplace(distinct_id);
}
int64_t scores[3] = {0};
@ -2587,7 +2586,11 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
result_ids.push_back(seq_id);
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
topster->add(&kv);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
groups_processed[distinct_id]++;
}
if (result_ids.size() == page * per_page) {
break;
@ -2670,7 +2673,6 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
groups_processed.emplace(distinct_id);
}
auto vec_dist_score = (field_vector_index->distance_type == cosine) ? std::abs(dist_label.first) :
@ -2683,7 +2685,11 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
//LOG(INFO) << "SEQ_ID: " << seq_id << ", score: " << dist_label.first;
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
topster->add(&kv);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
groups_processed[distinct_id]++;
}
nearest_ids.push_back(seq_id);
}
@ -3099,7 +3105,7 @@ void Index::fuzzy_search_fields(const std::vector<search_field_t>& the_fields,
const std::vector<uint32_t>& num_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
Topster* topster, spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t*& all_result_ids, size_t & all_result_ids_len,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,
@ -3495,7 +3501,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
const text_match_type_t match_type,
const std::vector<sort_by>& sort_fields,
Topster* topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
tsl::htrie_map<char, token_leaf>& qtoken_set,
const size_t group_limit,
@ -3644,7 +3650,6 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
groups_processed.emplace(distinct_id);
}
int64_t scores[3] = {0};
@ -3698,7 +3703,11 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
if(match_score_index != -1) {
kv.scores[match_score_index] = aggregated_score;
}
topster->add(&kv);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
groups_processed[distinct_id]++;
}
result_ids.push_back(seq_id);
});
@ -4069,7 +4078,7 @@ void Index::do_synonym_search(const std::vector<search_field_t>& the_fields,
Topster* actual_topster,
std::vector<std::vector<token_t>>& q_pos_synonyms,
int syn_orig_num_tokens,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
const uint32_t* filter_ids, const uint32_t filter_ids_length,
@ -4106,7 +4115,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
const std::vector<size_t>& geopoint_indices,
const std::vector<uint32_t>& curated_ids_sorted,
uint32_t*& all_result_ids, size_t& all_result_ids_len,
spp::sparse_hash_set<uint64_t>& groups_processed) const {
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed) const {
for(size_t field_id = 0; field_id < num_search_fields; field_id++) {
auto& field_name = the_fields[field_id].name;
@ -4163,11 +4172,13 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
groups_processed.emplace(distinct_id);
}
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
actual_topster->add(&kv);
int ret = actual_topster->add(&kv);
if(group_limit != 0 && ret < 2) {
groups_processed[distinct_id]++;
}
if(((i + 1) % (1 << 12)) == 0) {
BREAK_CIRCUIT_BREAKER
@ -4301,7 +4312,7 @@ void Index::compute_facet_infos(const std::vector<facet>& facets, facet_query_t&
std::vector<std::vector<art_leaf*>> searched_queries;
Topster* topster = nullptr;
spp::sparse_hash_set<uint64_t> groups_processed;
spp::sparse_hash_map<uint64_t, uint32_t> groups_processed;
uint32_t* field_result_ids = nullptr;
size_t field_result_ids_len = 0;
size_t field_num_results = 0;
@ -4419,7 +4430,7 @@ void Index::curate_filtered_ids(filter_node_t const* const& filter_tree_root, co
void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
const std::map<size_t, std::map<size_t, uint32_t>>& included_ids_map,
const std::vector<sort_by>& sort_fields, Topster* topster, Topster* curated_topster,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
std::vector<std::vector<art_leaf*>>& searched_queries, const size_t group_limit,
const std::vector<std::string>& group_by_fields, const std::set<uint32_t>& curated_ids,
const std::vector<uint32_t>& curated_ids_sorted, const uint32_t* exclude_token_ids,
@ -4439,7 +4450,7 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
const size_t window_size = (num_threads == 0) ? 0 :
(filter_ids_length + num_threads - 1) / num_threads; // rounds up
spp::sparse_hash_set<uint64_t> tgroups_processed[num_threads];
spp::sparse_hash_map<uint64_t, uint64_t> tgroups_processed[num_threads];
Topster* topsters[num_threads];
std::vector<posting_list_t::iterator_t> plists;
@ -4498,11 +4509,14 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
tgroups_processed[thread_id].emplace(distinct_id);
}
KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores);
topsters[thread_id]->add(&kv);
int ret = topsters[thread_id]->add(&kv);
if(group_limit != 0 && ret < 2) {
tgroups_processed[thread_id][distinct_id]++;
}
if(check_for_circuit_break && ((i + 1) % (1 << 15)) == 0) {
// check only once every 2^15 docs to reduce overhead
@ -4525,7 +4539,10 @@ void Index::search_wildcard(filter_node_t const* const& filter_tree_root,
search_cutoff = parent_search_cutoff;
for(size_t thread_id = 0; thread_id < num_processed; thread_id++) {
groups_processed.insert(tgroups_processed[thread_id].begin(), tgroups_processed[thread_id].end());
//groups_processed.insert(tgroups_processed[thread_id].begin(), tgroups_processed[thread_id].end());
for(const auto& it : tgroups_processed[thread_id]) {
groups_processed[it.first]+= it.second;
}
aggregate_topster(topster, topsters[thread_id]);
delete topsters[thread_id];
}
@ -4586,7 +4603,7 @@ void Index::search_field(const uint8_t & field_id,
const int last_typo,
const int max_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
Topster* topster, spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
uint32_t** all_result_ids, size_t & all_result_ids_len, size_t& field_num_results,
const size_t group_limit, const std::vector<std::string>& group_by_fields,
bool prioritize_exact_match,
@ -4886,7 +4903,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
const uint8_t & field_id, const bool field_is_array, const uint32_t total_cost,
Topster* topster,
const std::vector<art_leaf *> &query_suggestion,
spp::sparse_hash_set<uint64_t>& groups_processed,
spp::sparse_hash_map<uint64_t, uint32_t>& groups_processed,
const uint32_t seq_id, const int sort_order[3],
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
@ -5084,12 +5101,14 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
if(group_limit != 0) {
distinct_id = get_distinct_id(group_by_fields, seq_id);
groups_processed.emplace(distinct_id);
}
//LOG(INFO) << "Seq id: " << seq_id << ", match_score: " << match_score;
KV kv(query_index, seq_id, distinct_id, match_score_index, scores);
topster->add(&kv);
int ret = topster->add(&kv);
if(group_limit != 0 && ret < 2) {
groups_processed[distinct_id]++;
}
//long long int timeNanos = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - begin).count();
//LOG(INFO) << "Time taken for results iteration: " << timeNanos << "ms";

View File

@ -73,17 +73,20 @@ TEST_F(CollectionGroupingTest, GroupingBasics) {
ASSERT_EQ(3, res["grouped_hits"].size());
ASSERT_EQ(11, res["grouped_hits"][0]["group_key"][0].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][0]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ(11, res["grouped_hits"][0]["hits"][0]["document"]["size"].get<size_t>());
ASSERT_STREQ("5", res["grouped_hits"][0]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.3, res["grouped_hits"][0]["hits"][1]["document"]["rating"].get<float>());
ASSERT_STREQ("1", res["grouped_hits"][0]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(7, res["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][1]["hits"][0]["document"]["rating"].get<float>());
ASSERT_STREQ("4", res["grouped_hits"][1]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.6, res["grouped_hits"][1]["hits"][1]["document"]["rating"].get<float>());
ASSERT_STREQ("3", res["grouped_hits"][1]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(3, res["grouped_hits"][2]["found"].get<int32_t>());
ASSERT_FLOAT_EQ(4.6, res["grouped_hits"][2]["hits"][0]["document"]["rating"].get<float>());
ASSERT_STREQ("2", res["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][2]["hits"][1]["document"]["rating"].get<float>());
@ -117,22 +120,26 @@ TEST_F(CollectionGroupingTest, GroupingBasics) {
ASSERT_EQ(7, res["grouped_hits"].size());
ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][0]["group_key"][0].get<float>());
ASSERT_EQ(1, res["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_EQ(12, res["grouped_hits"][0]["hits"][0]["document"]["size"].get<uint32_t>());
ASSERT_STREQ("8", res["grouped_hits"][0]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][0]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ(4, res["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_EQ(12, res["grouped_hits"][1]["hits"][0]["document"]["size"].get<uint32_t>());
ASSERT_STREQ("6", res["grouped_hits"][1]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.3, res["grouped_hits"][1]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ(11, res["grouped_hits"][1]["hits"][1]["document"]["size"].get<uint32_t>());
ASSERT_STREQ("1", res["grouped_hits"][1]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.3, res["grouped_hits"][1]["hits"][1]["document"]["rating"].get<float>());
ASSERT_EQ(1, res["grouped_hits"][5]["found"].get<int32_t>());
ASSERT_EQ(10, res["grouped_hits"][5]["hits"][0]["document"]["size"].get<uint32_t>());
ASSERT_STREQ("9", res["grouped_hits"][5]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.1, res["grouped_hits"][5]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ(1, res["grouped_hits"][6]["found"].get<int32_t>());
ASSERT_EQ(10, res["grouped_hits"][6]["hits"][0]["document"]["size"].get<uint32_t>());
ASSERT_STREQ("0", res["grouped_hits"][6]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_FLOAT_EQ(4.5, res["grouped_hits"][6]["hits"][0]["document"]["rating"].get<float>());
@ -164,6 +171,7 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) {
ASSERT_EQ(10, res["found"].get<size_t>());
ASSERT_EQ(10, res["grouped_hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_EQ(11, res["grouped_hits"][0]["group_key"][0].get<size_t>());
ASSERT_STREQ("Beta", res["grouped_hits"][0]["group_key"][1].get<std::string>().c_str());
@ -176,10 +184,12 @@ TEST_F(CollectionGroupingTest, GroupingCompoundKey) {
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][0]["hits"][0]["document"]["rating"].get<float>());
ASSERT_STREQ("5", res["grouped_hits"][0]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(1, res["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_EQ(1, res["grouped_hits"][1]["hits"].size());
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][1]["hits"][0]["document"]["rating"].get<float>());
ASSERT_STREQ("4", res["grouped_hits"][1]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(2, res["grouped_hits"][2]["found"].get<int32_t>());
ASSERT_EQ(2, res["grouped_hits"][2]["hits"].size());
ASSERT_FLOAT_EQ(4.6, res["grouped_hits"][2]["hits"][0]["document"]["rating"].get<float>());
ASSERT_STREQ("3", res["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
@ -306,16 +316,19 @@ TEST_F(CollectionGroupingTest, GroupingWithMultiFieldRelevance) {
ASSERT_EQ(3, results["found"].get<size_t>());
ASSERT_EQ(3, results["grouped_hits"].size());
ASSERT_EQ(3, results["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_STREQ("pop", results["grouped_hits"][0]["group_key"][0].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][0]["hits"].size());
ASSERT_STREQ("1", results["grouped_hits"][0]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("4", results["grouped_hits"][0]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_STREQ("rock", results["grouped_hits"][1]["group_key"][0].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][1]["hits"].size());
ASSERT_STREQ("5", results["grouped_hits"][1]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_STREQ("0", results["grouped_hits"][1]["hits"][1]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][2]["found"].get<int32_t>());
ASSERT_STREQ("country", results["grouped_hits"][2]["group_key"][0].get<std::string>().c_str());
ASSERT_EQ(2, results["grouped_hits"][2]["hits"].size());
ASSERT_STREQ("3", results["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
@ -339,11 +352,20 @@ TEST_F(CollectionGroupingTest, GroupingWithGropLimitOfOne) {
for(auto i=0; i<5; i++) {
ASSERT_EQ(1, res["grouped_hits"][i]["hits"].size());
}
ASSERT_EQ(3, res["grouped_hits"][0]["found"].get<int32_t>());
ASSERT_STREQ("5", res["grouped_hits"][0]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(4, res["grouped_hits"][1]["found"].get<int32_t>());
ASSERT_STREQ("3", res["grouped_hits"][1]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(2, res["grouped_hits"][2]["found"].get<int32_t>());
ASSERT_STREQ("8", res["grouped_hits"][2]["hits"][0]["document"]["id"].get<std::string>().c_str());
ASSERT_EQ(2, res["grouped_hits"][3]["found"].get<int32_t>());
ASSERT_STREQ("10", res["grouped_hits"][3]["hits"][0]["document"]["id"].get<std::string>().c_str()); // unbranded
ASSERT_EQ(1, res["grouped_hits"][4]["found"].get<int32_t>());
ASSERT_STREQ("9", res["grouped_hits"][4]["hits"][0]["document"]["id"].get<std::string>().c_str());
// facet counts should each be 1, including unbranded
@ -545,3 +567,35 @@ TEST_F(CollectionGroupingTest, UseHighestValueInGroupForOrdering) {
ASSERT_STREQ("249", res["grouped_hits"][0]["group_key"][0].get<std::string>().c_str());
ASSERT_EQ(2, res["grouped_hits"][0]["hits"].size());
}
TEST_F(CollectionGroupingTest, RepeatedFieldNameGroupHitCount) {
std::vector<field> fields = {
field("title", field_types::STRING, false),
field("brand", field_types::STRING, true, true),
field("colors", field_types::STRING, true, false),
};
Collection* coll2 = collectionManager.get_collection("coll2").get();
if(coll2 == nullptr) {
coll2 = collectionManager.create_collection("coll2", 1, fields).get();
}
nlohmann::json doc;
doc["id"] = "0";
doc["title"] = "foobar";
doc["brand"] = "Omega";
doc["colors"] = "foo";
ASSERT_TRUE(coll2->add(doc.dump()).ok());
auto res = coll2->search("f", {"title", "colors"}, "", {}, {}, {0}, 10, 1, FREQUENCY,
{true}, 10,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
{}, {}, {"brand"}, 2).get();
ASSERT_EQ(1, res["grouped_hits"].size());
ASSERT_EQ(1, res["grouped_hits"][0]["found"].get<int32_t>());
}