making sorting more generic

This commit is contained in:
krunal1313 2023-03-02 18:54:12 +05:30
parent 77ffcd9444
commit de5a7d08a2
2 changed files with 30 additions and 27 deletions

View File

@ -2333,10 +2333,32 @@ void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<K
// we have to pick top-K groups
Topster gtopster(topster->MAX_SIZE);
int group_count_index = -1;
int group_sort_order = 1;
for(int i = 0; i < sort_by_fields.size(); ++i) {
if(sort_by_fields[i].name == sort_field_const::group_count) {
group_count_index = i;
if(sort_by_fields[i].order == sort_field_const::asc) {
group_sort_order *= -1;
}
break;
}
}
for(auto& group_topster: topster->group_kv_map) {
group_topster.second->sort();
if(group_topster.second->size != 0) {
KV* kv_head = group_topster.second->getKV(0);
if(group_count_index >= 0) {
const auto& itr = groups_processed.find(kv_head->distinct_key);
if(itr != groups_processed.end()) {
kv_head->scores[0] = itr->second * group_sort_order;
}
}
gtopster.add(kv_head);
}
}
@ -2351,25 +2373,6 @@ void Collection::populate_result_kvs(Topster *topster, std::vector<std::vector<K
);
result_kvs.emplace_back(group_kvs);
}
if(!sort_by_fields.empty() && sort_by_fields[0].name == sort_field_const::group_count) {
std::sort(result_kvs.begin(), result_kvs.end(),
[&](const std::vector<KV*>& g1, const std::vector<KV*>& g2) {
const auto& it1 = groups_processed.find(g1[0]->distinct_key);
const auto& it2 = groups_processed.find(g2[0]->distinct_key);
if(it1 != groups_processed.end() && it2 != groups_processed.end()) {
if(sort_by_fields[0].order == sort_field_const::asc) {
return it1->second < it2->second;
}
else {
return it1->second > it2->second;
}
}
return false;
});
}
} else {
for(uint32_t t = 0; t < topster->size; t++) {
KV* kv = topster->getKV(t);

View File

@ -3799,7 +3799,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
const int64_t default_score = INT64_MIN; // to handle field that doesn't exist in document (e.g. optional)
// avoiding loop
if (sort_fields.size() > 0 && sort_fields[0].name != sort_field_const::group_count) {
if (sort_fields.size() > 0) {
if (field_values[0] == &text_match_sentinel_value) {
scores[0] = int64_t(max_field_match_score);
match_score_index = 0;
@ -3855,7 +3855,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
}
if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) {
if(sort_fields.size() > 1) {
if (field_values[1] == &text_match_sentinel_value) {
scores[1] = int64_t(max_field_match_score);
match_score_index = 1;
@ -3907,7 +3907,7 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
}
}
if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) {
if(sort_fields.size() > 2) {
if (field_values[2] == &text_match_sentinel_value) {
scores[2] = int64_t(max_field_match_score);
match_score_index = 2;
@ -4569,7 +4569,8 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
if (sort_fields_std[i].name == sort_field_const::text_match) {
field_values[i] = &text_match_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::seq_id) {
} else if (sort_fields_std[i].name == sort_field_const::seq_id ||
sort_fields_std[i].name == sort_field_const::group_count) {
field_values[i] = &seq_id_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
@ -5038,8 +5039,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
// avoiding loop
if (sort_fields.size() > 0) {
if (field_values[0] == &text_match_sentinel_value
&& sort_fields[0].name != sort_field_const::group_count) {
if (field_values[0] == &text_match_sentinel_value) {
scores[0] = int64_t(match_score);
match_score_index = 0;
} else if (field_values[0] == &seq_id_sentinel_value) {
@ -5058,7 +5058,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
}
}
if(sort_fields.size() > 1 && sort_fields[1].name != sort_field_const::group_count) {
if(sort_fields.size() > 1) {
if (field_values[1] == &text_match_sentinel_value) {
scores[1] = int64_t(match_score);
match_score_index = 1;
@ -5078,7 +5078,7 @@ void Index::score_results(const std::vector<sort_by> & sort_fields, const uint16
}
}
if(sort_fields.size() > 2 && sort_fields[2].name != sort_field_const::group_count) {
if(sort_fields.size() > 2) {
if (field_values[2] == &text_match_sentinel_value) {
scores[2] = int64_t(match_score);
match_score_index = 2;