mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 22:33:27 +08:00
Fix nested reference sorting. (#2035)
Some checks failed
tests / test (push) Has been cancelled
Some checks failed
tests / test (push) Has been cancelled
* Fix nested reference sorting. * Refactor `Index::get_ref_seq_id`. * Refactor.
This commit is contained in:
parent
4dff771b18
commit
e15d7d94f0
@ -591,6 +591,10 @@ private:
|
||||
Option<int64_t> get_geo_distance(const std::string& geo_field_name, const uint32_t& seq_id,
|
||||
const S2LatLng& reference_lat_lng, const bool& round_distance = false) const;
|
||||
|
||||
Option<uint32_t> get_ref_seq_id_helper(const sort_by& sort_field, const uint32_t& seq_id, std::string& prev_coll_name,
|
||||
std::map<std::string, reference_filter_result_t> const*& references,
|
||||
std::string& ref_coll_name) const;
|
||||
|
||||
public:
|
||||
// for limiting number of results on multiple candidates / query rewrites
|
||||
enum {TYPO_TOKENS_THRESHOLD = 1};
|
||||
@ -1018,10 +1022,6 @@ public:
|
||||
bool enable_typos_for_numerical_tokens,
|
||||
bool enable_typos_for_alpha_numerical_tokens) const;
|
||||
|
||||
Option<bool> ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id,
|
||||
bool& reference_found, const std::map<basic_string<char>, reference_filter_result_t>& references,
|
||||
const std::string& collection_name) const;
|
||||
|
||||
Option<bool> compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
|
||||
@ -1070,9 +1070,9 @@ public:
|
||||
const std::map<basic_string<char>, reference_filter_result_t>& references,
|
||||
const S2LatLng& reference_lat_lng, const bool& round_distance = false) const;
|
||||
|
||||
Option<uint32_t> get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id, std::string& prev_coll_name,
|
||||
std::map<std::string, reference_filter_result_t> const*& references,
|
||||
std::string& ref_coll_name) const;
|
||||
Option<uint32_t> get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id,
|
||||
const std::map<std::string, reference_filter_result_t>& references,
|
||||
std::string& ref_collection_name) const;
|
||||
|
||||
void get_top_k_result_ids(const std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<uint32_t>& result_ids) const;
|
||||
|
||||
|
266
src/index.cpp
266
src/index.cpp
@ -4927,194 +4927,6 @@ Option<bool> Index::search_across_fields(const std::vector<token_t>& query_token
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Index::ref_compute_sort_scores(const sort_by& sort_field, const uint32_t& seq_id, uint32_t& ref_seq_id,
|
||||
bool& reference_found, const std::map<basic_string<char>, reference_filter_result_t>& references,
|
||||
const std::string& collection_name) const {
|
||||
auto const& ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const& multiple_references_error_message = "Multiple references found to sort by on `" +
|
||||
ref_collection_name + "." + sort_field.name + "`.";
|
||||
|
||||
if (sort_field.is_nested_join_sort_by()) {
|
||||
// Get the reference doc_id by following through all the nested join collections.
|
||||
ref_seq_id = seq_id;
|
||||
std::string prev_coll_name = collection_name;
|
||||
for (const auto &coll_name: sort_field.nested_join_collection_names) {
|
||||
// Joined on ref collection
|
||||
if (references.count(coll_name) > 0) {
|
||||
auto const& count = references.at(coll_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
break;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(coll_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(coll_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + coll_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(prev_coll_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(prev_coll_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
|
||||
auto prev_coll = cm.get_collection(prev_coll_name);
|
||||
if (prev_coll == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + prev_coll_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto sort_index_op = prev_coll->get_sort_index_value_with_lock(field_name, ref_seq_id);
|
||||
if (!sort_index_op.ok()) {
|
||||
if (sort_index_op.code() == 400) {
|
||||
return Option<bool>(400, sort_index_op.error());
|
||||
}
|
||||
reference_found = false;
|
||||
break;
|
||||
} else {
|
||||
ref_seq_id = sort_index_op.get();
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
reference_found = false;
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
break;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prev_coll_name = coll_name;
|
||||
}
|
||||
} else if (references.count(ref_collection_name) > 0) { // Joined on ref collection
|
||||
auto const& count = references.at(ref_collection_name).count;
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
ref_seq_id = references.at(ref_collection_name).docs[0];
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
} else {
|
||||
auto& cm = CollectionManager::get_instance();
|
||||
auto ref_collection = cm.get_collection(ref_collection_name);
|
||||
if (ref_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + ref_collection_name +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
// Current collection has a reference.
|
||||
if (ref_collection->is_referenced_in(collection_name)) {
|
||||
auto get_reference_field_op = ref_collection->get_referenced_in_field_with_lock(collection_name);
|
||||
if (!get_reference_field_op.ok()) {
|
||||
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
|
||||
}
|
||||
auto const& field_name = get_reference_field_op.get();
|
||||
auto const reference_helper_field_name = field_name + fields::REFERENCE_HELPER_FIELD_SUFFIX;
|
||||
|
||||
if (sort_index.count(reference_helper_field_name) == 0) {
|
||||
return Option<bool>(400, "Could not find `" + reference_helper_field_name + "` in sort_index.");
|
||||
} else if (sort_index.at(reference_helper_field_name)->count(seq_id) == 0) {
|
||||
reference_found = false;
|
||||
} else {
|
||||
ref_seq_id = sort_index.at(reference_helper_field_name)->at(seq_id);
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
joined_coll_having_reference = reference.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (joined_coll_having_reference.empty()) {
|
||||
reference_found = false;
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
auto joined_collection = cm.get_collection(joined_coll_having_reference);
|
||||
if (joined_collection == nullptr) {
|
||||
return Option<bool>(400, "Referenced collection `" + joined_coll_having_reference +
|
||||
"` in `sort_by` not found.");
|
||||
}
|
||||
|
||||
auto reference_field_name_op = ref_collection->get_referenced_in_field_with_lock(joined_coll_having_reference);
|
||||
if (!reference_field_name_op.ok()) {
|
||||
return Option<bool>(reference_field_name_op.code(), reference_field_name_op.error());
|
||||
}
|
||||
|
||||
auto const& reference_field_name = reference_field_name_op.get();
|
||||
auto const& reference = references.at(joined_coll_having_reference);
|
||||
auto const& count = reference.count;
|
||||
|
||||
if (count == 0) {
|
||||
reference_found = false;
|
||||
} else if (count == 1) {
|
||||
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
|
||||
reference.docs[0]);
|
||||
if (!op.ok()) {
|
||||
return Option<bool>(op.code(), op.error());
|
||||
}
|
||||
|
||||
ref_seq_id = op.get();
|
||||
} else {
|
||||
return Option<bool>(400, multiple_references_error_message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
@ -5166,10 +4978,17 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
|
||||
|
||||
// In case of reference sort_by, we need to get the sort score of the reference doc id.
|
||||
if (is_reference_sort) {
|
||||
auto const& ref_compute_op = ref_compute_sort_scores(sort_fields[i], seq_id, ref_seq_id, reference_found,
|
||||
references, collection_name);
|
||||
if (!ref_compute_op.ok()) {
|
||||
return ref_compute_op;
|
||||
std::string ref_collection_name;
|
||||
auto get_ref_seq_id_op = get_ref_seq_id(sort_fields[i], seq_id, references, ref_collection_name);
|
||||
if (!get_ref_seq_id_op.ok()) {
|
||||
return Option<bool>(get_ref_seq_id_op.code(), "Error while sorting on `" + sort_fields[i].reference_collection_name
|
||||
+ "." + sort_fields[i].name + ": " + get_ref_seq_id_op.error());
|
||||
}
|
||||
|
||||
if (get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
|
||||
reference_found = false;
|
||||
} else {
|
||||
ref_seq_id = get_ref_seq_id_op.get();
|
||||
}
|
||||
}
|
||||
|
||||
@ -7816,12 +7635,39 @@ std::string multiple_references_message(const std::string& coll_name, const uint
|
||||
"` document references multiple documents of `" + ref_coll_name + "` collection.";
|
||||
}
|
||||
|
||||
Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id, std::string& coll_name,
|
||||
std::map<std::string, reference_filter_result_t> const*& references,
|
||||
std::string& ref_coll_name) const {
|
||||
Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t& seq_id,
|
||||
const std::map<std::string, reference_filter_result_t>& references,
|
||||
std::string& ref_collection_name) const {
|
||||
auto collection_name = get_collection_name_with_lock();
|
||||
ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const* references_ptr = &(references);
|
||||
auto ref_seq_id = seq_id;
|
||||
|
||||
if (sort_field.is_nested_join_sort_by()) {
|
||||
// Get the reference doc_id by following through all the nested join collections.
|
||||
for (size_t i = 0; i < sort_field.nested_join_collection_names.size() - 1; i++) {
|
||||
ref_collection_name = sort_field.nested_join_collection_names[i];
|
||||
auto get_ref_seq_id_op = get_ref_seq_id_helper(sort_field, ref_seq_id, collection_name, references_ptr,
|
||||
ref_collection_name);
|
||||
if (!get_ref_seq_id_op.ok() || get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
|
||||
return get_ref_seq_id_op;
|
||||
} else {
|
||||
ref_seq_id = get_ref_seq_id_op.get();
|
||||
}
|
||||
}
|
||||
|
||||
ref_collection_name = sort_field.nested_join_collection_names.back();
|
||||
}
|
||||
|
||||
return get_ref_seq_id_helper(sort_field, ref_seq_id, collection_name, references_ptr, ref_collection_name);
|
||||
}
|
||||
|
||||
Option<uint32_t> Index::get_ref_seq_id_helper(const sort_by& sort_field, const uint32_t& seq_id, std::string& coll_name,
|
||||
std::map<std::string, reference_filter_result_t> const*& references,
|
||||
std::string& ref_coll_name) const {
|
||||
uint32_t ref_seq_id = reference_helper_sentinel_value;
|
||||
|
||||
if (references->count(ref_coll_name) > 0) { // Joined on ref collection
|
||||
if (references != nullptr && references->count(ref_coll_name) > 0) { // Joined on ref collection
|
||||
auto& ref_result = references->at(ref_coll_name);
|
||||
auto const& count = ref_result.count;
|
||||
if (count == 1) {
|
||||
@ -7869,7 +7715,7 @@ Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t
|
||||
}
|
||||
}
|
||||
// Joined collection has a reference
|
||||
else {
|
||||
else if (references != nullptr) {
|
||||
std::string joined_coll_having_reference;
|
||||
for (const auto &reference: *references) {
|
||||
if (ref_collection->is_referenced_in(reference.first)) {
|
||||
@ -7924,30 +7770,8 @@ Option<uint32_t> Index::get_ref_seq_id(const sort_by& sort_field, const uint32_t
|
||||
Option<int64_t> Index::get_referenced_geo_distance(const sort_by& sort_field, uint32_t seq_id,
|
||||
const std::map<basic_string<char>, reference_filter_result_t>& references,
|
||||
const S2LatLng& reference_lat_lng, const bool& round_distance) const {
|
||||
auto collection_name = get_collection_name_with_lock();
|
||||
auto ref_collection_name = sort_field.reference_collection_name;
|
||||
auto const* references_ptr = &(references);
|
||||
|
||||
if (sort_field.is_nested_join_sort_by()) {
|
||||
// Get the reference doc_id by following through all the nested join collections.
|
||||
for (size_t i = 0; i < sort_field.nested_join_collection_names.size() - 1; i++) {
|
||||
ref_collection_name = sort_field.nested_join_collection_names[i];
|
||||
auto get_ref_seq_id_op = get_ref_seq_id(sort_field, seq_id, collection_name, references_ptr,
|
||||
ref_collection_name);
|
||||
if (!get_ref_seq_id_op.ok()) {
|
||||
return Option<int64_t>(400, get_ref_seq_id_op.error());
|
||||
} else if (get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
|
||||
return Option<int64_t>(0);
|
||||
} else {
|
||||
seq_id = get_ref_seq_id_op.get();
|
||||
}
|
||||
}
|
||||
|
||||
ref_collection_name = sort_field.nested_join_collection_names.back();
|
||||
}
|
||||
|
||||
auto get_ref_seq_id_op = get_ref_seq_id(sort_field, seq_id, collection_name, references_ptr,
|
||||
ref_collection_name);
|
||||
std::string ref_collection_name;
|
||||
auto get_ref_seq_id_op = get_ref_seq_id(sort_field, seq_id, references, ref_collection_name);
|
||||
if (!get_ref_seq_id_op.ok()) {
|
||||
return Option<int64_t>(400, get_ref_seq_id_op.error());
|
||||
} else if (get_ref_seq_id_op.get() == reference_helper_sentinel_value) { // No references found.
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user