Fix nested reference sorting. (#2035)
Some checks failed
tests / test (push) Has been cancelled

* Fix nested reference sorting.

* Refactor `Index::get_ref_seq_id`.

* Refactor.
This commit is contained in:
Harpreet Sangar 2024-10-28 11:57:04 +05:30 committed by GitHub
parent 4dff771b18
commit e15d7d94f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1112 additions and 1097 deletions

View File

@ -591,6 +591,10 @@ private:
Option<int64_t> get_geo_distance(const std::string& geo_field_name, const uint32_t& seq_id,
const S2LatLng& reference_lat_lng, const bool& round_distance = false) const;
Option<uint32_t> get_ref_seq_id_helper(const sort_by& sort_field, const uint32_t& seq_id, std::string& prev_coll_name,
std::map<std::string, reference_filter_result_t> const*& references,
std::string& ref_coll_name) const;
public:
// for limiting number of results on multiple candidates / query rewrites
enum {TYPO_TOKENS_THRESHOLD = 1};
@ -1018,10 +1022,6 @@ public:
bool enable_typos_for_numerical_tokens,
bool enable_typos_for_alpha_numerical_tokens) const;
Option<bool> ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id,
bool& reference_found, const std::map<basic_string<char>, reference_filter_result_t>& references,
const std::string& collection_name) const;
Option<bool> compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
@ -1070,9 +1070,9 @@ public:
const std::map<basic_string<char>, reference_filter_result_t>& references,
const S2LatLng& reference_lat_lng, const bool& round_distance = false) const;
Option<uint32_t> get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id, std::string& prev_coll_name,
std::map<std::string, reference_filter_result_t> const*& references,
std::string& ref_coll_name) const;
Option<uint32_t> get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id,
const std::map<std::string, reference_filter_result_t>& references,
std::string& ref_collection_name) const;
void get_top_k_result_ids(const std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<uint32_t>& result_ids) const;

View File

@ -4927,194 +4927,6 @@ Option<bool> Index::search_across_fields(const std::vector<token_t>& query_token
return Option<bool>(true);
}
Option<bool> Index::ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id,
bool& reference_found, const std::map<basic_string<char>, reference_filter_result_t>& references,
const std::string& collection_name) const {
auto const& ref_collection_name = sort_field.reference_collection_name;
auto const& multiple_references_error_message = "Multiple references found to sort by on `" +
ref_collection_name + "." + sort_field.name + "`.";
if (sort_field.is_nested_join_sort_by()) {
// Get the reference doc_id by following through all the nested join collections.
ref_seq_id = seq_id;
std::string prev_coll_name = collection_name;
for (const auto &coll_name: sort_field.nested_join_collection_names) {
// Joined on ref collection
if (references.count(coll_name) > 0) {
auto const& count = references.at(coll_name).count;
if (count == 0) {
reference_found = false;
break;
} else if (count == 1) {
ref_seq_id = references.at(coll_name).docs[0];
} else {
return Option<bool>(400, multiple_references_error_message);
}
} else {
auto& cm = CollectionManager::get_instance();
auto ref_collection = cm.get_collection(coll_name);
if (ref_collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + coll_name +
"` in `sort_by` not found.");
}
// Current collection has a reference.
if (ref_collection->is_referenced_in(prev_coll_name)) {
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(prev_coll_name);
if (!get_reference_field_op.ok()) {
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
}
auto const& field_name = get_reference_field_op.get();
auto prev_coll = cm.get_collection(prev_coll_name);
if (prev_coll == nullptr) {
return Option<bool>(400, "Referenced collection `" + prev_coll_name +
"` in `sort_by` not found.");
}
auto sort_index_op = prev_coll->get_sort_index_value_with_lock(field_name, ref_seq_id);
if (!sort_index_op.ok()) {
if (sort_index_op.code() == 400) {
return Option<bool>(400, sort_index_op.error());
}
reference_found = false;
break;
} else {
ref_seq_id = sort_index_op.get();
}
}
// Joined collection has a reference
else {
std::string joined_coll_having_reference;
for (const auto &reference: references) {
if (ref_collection->is_referenced_in(reference.first)) {
joined_coll_having_reference = reference.first;
break;
}
}
if (joined_coll_having_reference.empty()) {
reference_found = false;
return Option<bool>(true);
}
auto joined_collection = cm.get_collection(joined_coll_having_reference);
if (joined_collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
"` in `sort_by` not found.");
}
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
if (!reference_field_name_op.ok()) {
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
}
auto const& reference_field_name = reference_field_name_op.get();
auto const& reference = references.at(joined_coll_having_reference);
auto const& count = reference.count;
if (count == 0) {
reference_found = false;
break;
} else if (count == 1) {
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
reference.docs[0]);
if (!op.ok()) {
return Option<bool>(op.code(), op.error());
}
ref_seq_id = op.get();
} else {
return Option<bool>(400, multiple_references_error_message);
}
}
}
prev_coll_name = coll_name;
}
} else if (references.count(ref_collection_name) > 0) { // Joined on ref collection
auto const& count = references.at(ref_collection_name).count;
if (count == 0) {
reference_found = false;
} else if (count == 1) {
ref_seq_id = references.at(ref_collection_name).docs[0];
} else {
return Option<bool>(400, multiple_references_error_message);
}
} else {
auto& cm = CollectionManager::get_instance();
auto ref_collection = cm.get_collection(ref_collection_name);
if (ref_collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + ref_collection_name +
"` in `sort_by` not found.");
}
// Current collection has a reference.
if (ref_collection->is_referenced_in(collection_name)) {
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name);
if (!get_reference_field_op.ok()) {
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
}
auto const& field_name = get_reference_field_op.get();
auto const reference_helper_field_name = field_name + fields::REFERENCE_HELPER_FIELD_SUFFIX;
if (sort_index.count(reference_helper_field_name) == 0) {
return Option<bool>(400, "Could not find `" + reference_helper_field_name + "` in sort_index.");
} else if (sort_index.at(reference_helper_field_name)->count(seq_id) == 0) {
reference_found = false;
} else {
ref_seq_id = sort_index.at(reference_helper_field_name)->at(seq_id);
}
}
// Joined collection has a reference
else {
std::string joined_coll_having_reference;
for (const auto &reference: references) {
if (ref_collection->is_referenced_in(reference.first)) {
joined_coll_having_reference = reference.first;
break;
}
}
if (joined_coll_having_reference.empty()) {
reference_found = false;
return Option<bool>(true);
}
auto joined_collection = cm.get_collection(joined_coll_having_reference);
if (joined_collection == nullptr) {
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
"` in `sort_by` not found.");
}
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
if (!reference_field_name_op.ok()) {
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
}
auto const& reference_field_name = reference_field_name_op.get();
auto const& reference = references.at(joined_coll_having_reference);
auto const& count = reference.count;
if (count == 0) {
reference_found = false;
} else if (count == 1) {
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
reference.docs[0]);
if (!op.ok()) {
return Option<bool>(op.code(), op.error());
}
ref_seq_id = op.get();
} else {
return Option<bool>(400, multiple_references_error_message);
}
}
}
return Option<bool>(true);
}
Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
@ -5166,10 +4978,17 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// In case of reference sort_by, we need to get the sort score of the reference doc id.
if (is_reference_sort) {
auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[i], seq_id, ref_seq_id, reference_found,
references, collection_name);
if (!ref_compute_op.ok()) {
return ref_compute_op;
std::string ref_collection_name;
auto get_ref_seq_id_op = get_ref_seq_id(sort_fields[i], seq_id, references, ref_collection_name);
if (!get_ref_seq_id_op.ok()) {
return Option<bool>(get_ref_seq_id_op.code(), "Error while sorting on `" + sort_fields[i].reference_collection_name
+ "." + sort_fields[i].name + ": " + get_ref_seq_id_op.error());
}
if (get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
reference_found = false;
} else {
ref_seq_id = get_ref_seq_id_op.get();
}
}
@ -7816,12 +7635,39 @@ std::string multiple_references_message(const std::string& coll_name, const uint
"` document references multiple documents of `" + ref_coll_name + "` collection.";
}
Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id, std::string& coll_name,
std::map<std::string, reference_filter_result_t> const*& references,
std::string& ref_coll_name) const {
Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id,
const std::map<std::string, reference_filter_result_t>& references,
std::string& ref_collection_name) const {
auto collection_name = get_collection_name_with_lock();
ref_collection_name = sort_field.reference_collection_name;
auto const* references_ptr = &(references);
auto ref_seq_id = seq_id;
if (sort_field.is_nested_join_sort_by()) {
// Get the reference doc_id by following through all the nested join collections.
for (size_t i = 0; i < sort_field.nested_join_collection_names.size() - 1; i++) {
ref_collection_name = sort_field.nested_join_collection_names[i];
auto get_ref_seq_id_op = get_ref_seq_id_helper(sort_field, ref_seq_id, collection_name, references_ptr,
ref_collection_name);
if (!get_ref_seq_id_op.ok() || get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
return get_ref_seq_id_op;
} else {
ref_seq_id = get_ref_seq_id_op.get();
}
}
ref_collection_name = sort_field.nested_join_collection_names.back();
}
return get_ref_seq_id_helper(sort_field, ref_seq_id, collection_name, references_ptr, ref_collection_name);
}
Option<uint32_t> Index::get_ref_seq_id_helper(const sort_by& sort_field, const uint32_t& seq_id, std::string& coll_name,
std::map<std::string, reference_filter_result_t> const*& references,
std::string& ref_coll_name) const {
uint32_t ref_seq_id = reference_helper_sentinel_value;
if (references->count(ref_coll_name) > 0) { // Joined on ref collection
if (references != nullptr && references->count(ref_coll_name) > 0) { // Joined on ref collection
auto& ref_result = references->at(ref_coll_name);
auto const& count = ref_result.count;
if (count == 1) {
@ -7869,7 +7715,7 @@ Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t
}
}
// Joined collection has a reference
else {
else if (references != nullptr) {
std::string joined_coll_having_reference;
for (const auto &reference: *references) {
if (ref_collection->is_referenced_in(reference.first)) {
@ -7924,30 +7770,8 @@ Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t
Option<int64_t> Index::get_referenced_geo_distance(const sort_by& sort_field, uint32_t seq_id,
const std::map<basic_string<char>, reference_filter_result_t>& references,
const S2LatLng& reference_lat_lng, const bool& round_distance) const {
auto collection_name = get_collection_name_with_lock();
auto ref_collection_name = sort_field.reference_collection_name;
auto const* references_ptr = &(references);
if (sort_field.is_nested_join_sort_by()) {
// Get the reference doc_id by following through all the nested join collections.
for (size_t i = 0; i < sort_field.nested_join_collection_names.size() - 1; i++) {
ref_collection_name = sort_field.nested_join_collection_names[i];
auto get_ref_seq_id_op = get_ref_seq_id(sort_field, seq_id, collection_name, references_ptr,
ref_collection_name);
if (!get_ref_seq_id_op.ok()) {
return Option<int64_t>(400, get_ref_seq_id_op.error());
} else if (get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
return Option<int64_t>(0);
} else {
seq_id = get_ref_seq_id_op.get();
}
}
ref_collection_name = sort_field.nested_join_collection_names.back();
}
auto get_ref_seq_id_op = get_ref_seq_id(sort_field, seq_id, collection_name, references_ptr,
ref_collection_name);
std::string ref_collection_name;
auto get_ref_seq_id_op = get_ref_seq_id(sort_field, seq_id, references, ref_collection_name);
if (!get_ref_seq_id_op.ok()) {
return Option<int64_t>(400, get_ref_seq_id_op.error());
} else if (get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.

File diff suppressed because it is too large Load Diff