fix pinned hits with grouping and filter (#1572)

* fix pinned_hits_with_grouping

* remove repeated group_limit check
This commit is contained in:
Krunal Gandhi 2024-02-22 13:17:42 +00:00 committed by GitHub
parent 2c21e1306b
commit d1079c633c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 137 additions and 9 deletions

View File

@ -761,9 +761,9 @@ struct facet {
return false;
}
explicit facet(const std::string& field_name, std::map<int64_t, range_specs_t> facet_range = {},
explicit facet(const std::string& field_name, uint32_t orig_index, std::map<int64_t, range_specs_t> facet_range = {},
bool is_range_q = false, bool sort_by_alpha=false, const std::string& order="",
const std::string& sort_by_field="", uint32_t orig_index = 0)
const std::string& sort_by_field="")
: field_name(field_name), facet_range_map(facet_range),
is_range_query(is_range_q), is_sort_by_alpha(sort_by_alpha), sort_order(order),
sort_field(sort_by_field), orig_index(orig_index) {

View File

@ -6059,7 +6059,7 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
return Option<bool>(400, error);
}
facet a_facet(field_name);
facet a_facet(field_name, facets.size());
//starting after "(" and excluding ")"
auto range_string = std::string(facet_field.begin() + startpos + 1, facet_field.end() - 1);
@ -6209,7 +6209,7 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
// Collect the fields that match the prefix and are marked as facet.
for (auto field = pair.first; field != pair.second; field++) {
if (field->facet) {
facets.emplace_back(facet(field->name));
facets.emplace_back(facet(field->name, facets.size()));
facets.back().is_wildcard_match = true;
}
}
@ -6278,7 +6278,7 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
return Option<bool>(400, error);
}
facets.emplace_back(facet(facet_field_copy, {}, false, sort_alpha,
facets.emplace_back(facet(facet_field_copy, facets.size(), {}, false, sort_alpha,
order, sort_field));
}

View File

@ -3661,16 +3661,16 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
if(facet_infos[i].use_value_index) {
#endif
// value based faceting on a single thread
value_facets.emplace_back(this_facet.field_name, this_facet.facet_range_map,
value_facets.emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.facet_range_map,
this_facet.is_range_query, this_facet.is_sort_by_alpha,
this_facet.sort_order, this_facet.sort_field, i);
this_facet.sort_order, this_facet.sort_field);
continue;
}
for(size_t j = 0; j < num_threads; j++) {
facet_batches[j].emplace_back(this_facet.field_name, this_facet.facet_range_map,
facet_batches[j].emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.facet_range_map,
this_facet.is_range_query, this_facet.is_sort_by_alpha,
this_facet.sort_order, this_facet.sort_field, i);
this_facet.sort_order, this_facet.sort_field);
}
}
@ -3785,6 +3785,22 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
all_result_ids_len += curated_topster->size;
if(!included_ids_map.empty() && group_limit != 0) {
for (auto &acc_facet: facets) {
for (auto &facet_kv: acc_facet.result_map) {
facet_kv.second.count = acc_facet.hash_groups[facet_kv.first].size();
if (estimate_facets) {
facet_kv.second.count = size_t(double(facet_kv.second.count) * (100.0f / facet_sample_percent));
}
}
if (estimate_facets) {
acc_facet.sampled = true;
}
}
}
delete [] all_result_ids;
//LOG(INFO) << "all_result_ids_len " << all_result_ids_len << " for index " << name;

View File

@ -1045,4 +1045,116 @@ TEST_F(CollectionGroupingTest, GroupByMultipleFacetFields) {
ASSERT_EQ(1, (int) res["facet_counts"][1]["counts"][2]["count"]);
ASSERT_STREQ("red", res["facet_counts"][1]["counts"][2]["value"].get<std::string>().c_str());
}
TEST_F(CollectionGroupingTest, GroupByMultipleFacetFieldsWithFilter) {
auto res = coll_group->search("*", {}, "size:>10", {"colors", "brand"}, {}, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
{}, {}, {"size"}, 2).get();
ASSERT_EQ(5, res["found_docs"].get<size_t>());
ASSERT_EQ(2, res["found"].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"].size());
ASSERT_EQ(11, res["grouped_hits"][0]["group_key"][0].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][0]["found"].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][0]["hits"].size());
ASSERT_EQ("5", res["grouped_hits"][0]["hits"][0]["document"]["id"]);
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][0]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ("1", res["grouped_hits"][0]["hits"][1]["document"]["id"]);
ASSERT_FLOAT_EQ(4.3, res["grouped_hits"][0]["hits"][1]["document"]["rating"].get<float>());
ASSERT_EQ(12, res["grouped_hits"][1]["group_key"][0].get<size_t>());
ASSERT_EQ(3, res["grouped_hits"][1]["found"].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][1]["hits"].size());
ASSERT_EQ("2", res["grouped_hits"][1]["hits"][0]["document"]["id"]);
ASSERT_FLOAT_EQ(4.6, res["grouped_hits"][1]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ("8", res["grouped_hits"][1]["hits"][1]["document"]["id"]);
ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][1]["hits"][1]["document"]["rating"].get<float>());
ASSERT_STREQ("colors", res["facet_counts"][0]["field_name"].get<std::string>().c_str());
ASSERT_EQ(2, (int) res["facet_counts"][0]["counts"][0]["count"]);
ASSERT_STREQ("blue", res["facet_counts"][0]["counts"][0]["value"].get<std::string>().c_str());
ASSERT_EQ(2, (int) res["facet_counts"][0]["counts"][1]["count"]);
ASSERT_STREQ("white", res["facet_counts"][0]["counts"][1]["value"].get<std::string>().c_str());
ASSERT_EQ(1, (int) res["facet_counts"][0]["counts"][2]["count"]);
ASSERT_STREQ("red", res["facet_counts"][0]["counts"][2]["value"].get<std::string>().c_str());
ASSERT_STREQ("brand", res["facet_counts"][1]["field_name"].get<std::string>().c_str());
ASSERT_EQ(2, (int) res["facet_counts"][1]["counts"][0]["count"]);
ASSERT_STREQ("Beta", res["facet_counts"][1]["counts"][0]["value"].get<std::string>().c_str());
ASSERT_EQ(2, (int) res["facet_counts"][1]["counts"][1]["count"]);
ASSERT_STREQ("Omega", res["facet_counts"][1]["counts"][1]["value"].get<std::string>().c_str());
ASSERT_EQ(1, (int) res["facet_counts"][1]["counts"][2]["count"]);
ASSERT_STREQ("Xorp", res["facet_counts"][1]["counts"][2]["value"].get<std::string>().c_str());
}
TEST_F(CollectionGroupingTest, GroupByMultipleFacetFieldsWithPinning) {
auto res = coll_group->search("*", {}, "size:>10", {"colors", "brand"}, {}, {0}, 50, 1, FREQUENCY,
{false}, Index::DROP_TOKENS_THRESHOLD,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 5,
"", 10,
{"3:1,4:2"}, {}, {"size"}, 2).get();
ASSERT_EQ(5, res["found_docs"].get<size_t>());
ASSERT_EQ(4, res["found"].get<size_t>());
ASSERT_EQ(4, res["grouped_hits"].size());
ASSERT_EQ(10, res["grouped_hits"][0]["group_key"][0].get<size_t>());
ASSERT_EQ(1, res["grouped_hits"][0]["hits"].size());
ASSERT_EQ("3", res["grouped_hits"][0]["hits"][0]["document"]["id"]);
ASSERT_FLOAT_EQ(4.6, res["grouped_hits"][0]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ(10, res["grouped_hits"][1]["group_key"][0].get<size_t>());
ASSERT_EQ(1, res["grouped_hits"][1]["hits"].size());
ASSERT_EQ("4", res["grouped_hits"][1]["hits"][0]["document"]["id"]);
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][1]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ(11, res["grouped_hits"][2]["group_key"][0].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][2]["found"].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][2]["hits"].size());
ASSERT_EQ("5", res["grouped_hits"][2]["hits"][0]["document"]["id"]);
ASSERT_FLOAT_EQ(4.8, res["grouped_hits"][2]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ("1", res["grouped_hits"][2]["hits"][1]["document"]["id"]);
ASSERT_FLOAT_EQ(4.3, res["grouped_hits"][2]["hits"][1]["document"]["rating"].get<float>());
ASSERT_EQ(12, res["grouped_hits"][3]["group_key"][0].get<size_t>());
ASSERT_EQ(3, res["grouped_hits"][3]["found"].get<size_t>());
ASSERT_EQ(2, res["grouped_hits"][3]["hits"].size());
ASSERT_EQ("2", res["grouped_hits"][3]["hits"][0]["document"]["id"]);
ASSERT_FLOAT_EQ(4.6, res["grouped_hits"][3]["hits"][0]["document"]["rating"].get<float>());
ASSERT_EQ("8", res["grouped_hits"][3]["hits"][1]["document"]["id"]);
ASSERT_FLOAT_EQ(4.4, res["grouped_hits"][3]["hits"][1]["document"]["rating"].get<float>());
ASSERT_STREQ("colors", res["facet_counts"][0]["field_name"].get<std::string>().c_str());
ASSERT_EQ(3, (int) res["facet_counts"][0]["counts"][0]["count"]);
ASSERT_STREQ("blue", res["facet_counts"][0]["counts"][0]["value"].get<std::string>().c_str());
ASSERT_EQ(3, (int) res["facet_counts"][0]["counts"][1]["count"]);
ASSERT_STREQ("white", res["facet_counts"][0]["counts"][1]["value"].get<std::string>().c_str());
ASSERT_EQ(1, (int) res["facet_counts"][0]["counts"][2]["count"]);
ASSERT_STREQ("red", res["facet_counts"][0]["counts"][2]["value"].get<std::string>().c_str());
ASSERT_STREQ("brand", res["facet_counts"][1]["field_name"].get<std::string>().c_str());
ASSERT_EQ(3, (int) res["facet_counts"][1]["counts"][0]["count"]);
ASSERT_STREQ("Beta", res["facet_counts"][1]["counts"][0]["value"].get<std::string>().c_str());
ASSERT_EQ(3, (int) res["facet_counts"][1]["counts"][1]["count"]);
ASSERT_STREQ("Omega", res["facet_counts"][1]["counts"][1]["value"].get<std::string>().c_str());
ASSERT_EQ(1, (int) res["facet_counts"][1]["counts"][2]["count"]);
ASSERT_STREQ("Xorp", res["facet_counts"][1]["counts"][2]["value"].get<std::string>().c_str());
}