Assign default sorting score if reference is not found while sorting by a reference field. (#1770)

This commit is contained in:
Harpreet Sangar 2024-06-04 17:21:39 +05:30 committed by GitHub
parent f47c09fd63
commit 561b01bb51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 183 additions and 36 deletions

View File

@ -4739,6 +4739,8 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// avoiding loop
if (sort_fields.size() > 0) {
auto reference_found = true;
// In case of reference sort_by, we need to get the sort score of the reference doc id.
if (!sort_fields[0].reference_collection_name.empty()) {
auto& sort_field = sort_fields[0];
@ -4750,11 +4752,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// Joined on ref collection
if (references.count(ref_collection_name) > 0) {
if (references.at(ref_collection_name).count == 1) {
auto const& count = references.at(ref_collection_name).count;
if (count == 0) {
reference_found = false;
} else if (count == 1) {
ref_seq_id = references.at(ref_collection_name).docs[0];
} else {
return Option<bool>(400, references.at(ref_collection_name).count > 1 ?
multiple_references_error_message : no_references_error_message);
return Option<bool>(400, multiple_references_error_message);
}
} else {
auto& cm = CollectionManager::get_instance();
@ -4771,11 +4775,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
}
auto const& field_name = get_reference_field_op.get();
if (sort_index.count(field_name) == 0 || sort_index.at(field_name)->count(seq_id) == 0) {
return Option<bool>(400, "Could not find a reference for doc " + std::to_string(seq_id));
if (sort_index.count(field_name) == 0) {
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
reference_found = false;
} else {
ref_seq_id = sort_index.at(field_name)->at(seq_id);
}
ref_seq_id = sort_index.at(field_name)->at(seq_id);
}
// Joined collection has a reference
else {
@ -4806,7 +4812,9 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
auto const& reference = references.at(joined_coll_having_reference);
auto const& count = reference.count;
if (count == 1) {
if (count == 0) {
reference_found = false;
} else if (count == 1) {
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
reference.docs[0]);
if (!op.ok()) {
@ -4815,8 +4823,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
ref_seq_id = op.get();
} else {
return Option<bool>(400, count > 1 ? multiple_references_error_message :
no_references_error_message);
return Option<bool>(400, multiple_references_error_message);
}
}
}
@ -4832,6 +4839,8 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
} else if(field_values[0] == &str_sentinel_value) {
if (sort_fields[0].reference_collection_name.empty()) {
scores[0] = str_sort_index.at(sort_fields[0].name)->rank(seq_id);
} else if (!reference_found) {
scores[0] = adi_tree_t::NOT_FOUND;
} else {
auto& cm = CollectionManager::get_instance();
auto ref_collection = cm.get_collection(sort_fields[0].reference_collection_name);
@ -4896,8 +4905,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// do nothing
}
} else {
auto it = field_values[0]->find(sort_fields[0].reference_collection_name.empty() ? seq_id : ref_seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
if (sort_fields[0].reference_collection_name.empty() || reference_found) {
auto it = field_values[0]->find(sort_fields[0].reference_collection_name.empty() ? seq_id : ref_seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
} else {
scores[0] = default_score;
}
if(scores[0] == INT64_MIN && sort_fields[0].missing_values == sort_by::missing_values_t::first) {
// By default, missing numerical value are always going to be sorted to be at the end
@ -4914,6 +4927,8 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
}
if(sort_fields.size() > 1) {
auto reference_found = true;
// In case of reference sort_by, we need to get the sort score of the reference doc id.
if (!sort_fields[1].reference_collection_name.empty()) {
auto& sort_field = sort_fields[1];
@ -4925,11 +4940,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// Joined on ref collection
if (references.count(ref_collection_name) > 0) {
if (references.at(ref_collection_name).count == 1) {
auto const& count = references.at(ref_collection_name).count;
if (count == 0) {
reference_found = false;
} else if (count == 1) {
ref_seq_id = references.at(ref_collection_name).docs[0];
} else {
return Option<bool>(400, references.at(ref_collection_name).count > 1 ?
multiple_references_error_message : no_references_error_message);
return Option<bool>(400, multiple_references_error_message);
}
} else {
auto& cm = CollectionManager::get_instance();
@ -4946,11 +4963,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
}
auto const& field_name = get_reference_field_op.get();
if (sort_index.count(field_name) == 0 || sort_index.at(field_name)->count(seq_id) == 0) {
return Option<bool>(400, "Could not find a reference for doc " + std::to_string(seq_id));
if (sort_index.count(field_name) == 0) {
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
reference_found = false;
} else {
ref_seq_id = sort_index.at(field_name)->at(seq_id);
}
ref_seq_id = sort_index.at(field_name)->at(seq_id);
}
// Joined collection has a reference
else {
@ -4981,7 +5000,9 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
auto const& reference = references.at(joined_coll_having_reference);
auto const& count = reference.count;
if (count == 1) {
if (count == 0) {
reference_found = false;
} else if (count == 1) {
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
reference.docs[0]);
if (!op.ok()) {
@ -4990,8 +5011,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
ref_seq_id = op.get();
} else {
return Option<bool>(400, count > 1 ? multiple_references_error_message :
no_references_error_message);
return Option<bool>(400, multiple_references_error_message);
}
}
}
@ -5007,6 +5027,8 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
} else if(field_values[1] == &str_sentinel_value) {
if (sort_fields[1].reference_collection_name.empty()) {
scores[1] = str_sort_index.at(sort_fields[1].name)->rank(seq_id);
} else if (!reference_found) {
scores[1] = adi_tree_t::NOT_FOUND;
} else {
auto& cm = CollectionManager::get_instance();
auto ref_collection = cm.get_collection(sort_fields[1].reference_collection_name);
@ -5072,8 +5094,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
}
} else {
auto it = field_values[1]->find(sort_fields[1].reference_collection_name.empty() ? seq_id : ref_seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
if (sort_fields[1].reference_collection_name.empty() || reference_found) {
auto it = field_values[1]->find(sort_fields[1].reference_collection_name.empty() ? seq_id : ref_seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
} else {
scores[1] = default_score;
}
if(scores[1] == INT64_MIN && sort_fields[1].missing_values == sort_by::missing_values_t::first) {
bool is_asc = (sort_order[1] == -1);
scores[1] = is_asc ? (INT64_MIN + 1) : INT64_MAX;
@ -5086,6 +5113,8 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
}
if(sort_fields.size() > 2) {
auto reference_found = true;
// In case of reference sort_by, we need to get the sort score of the reference doc id.
if (!sort_fields[2].reference_collection_name.empty()) {
auto& sort_field = sort_fields[2];
@ -5097,11 +5126,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// Joined on ref collection
if (references.count(ref_collection_name) > 0) {
if (references.at(ref_collection_name).count == 1) {
auto const& count = references.at(ref_collection_name).count;
if (count == 0) {
reference_found = false;
} else if (count == 1) {
ref_seq_id = references.at(ref_collection_name).docs[0];
} else {
return Option<bool>(400, references.at(ref_collection_name).count > 1 ?
multiple_references_error_message : no_references_error_message);
return Option<bool>(400, multiple_references_error_message);
}
} else {
auto& cm = CollectionManager::get_instance();
@ -5118,11 +5149,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
return Option<bool>(get_reference_field_op.code(), get_reference_field_op.error());
}
auto const& field_name = get_reference_field_op.get();
if (sort_index.count(field_name) == 0 || sort_index.at(field_name)->count(seq_id) == 0) {
return Option<bool>(400, "Could not find a reference for doc " + std::to_string(seq_id));
if (sort_index.count(field_name) == 0) {
return Option<bool>(400, "Could not find `" + field_name + "` in sort_index.");
} else if (sort_index.at(field_name)->count(seq_id) == 0) {
reference_found = false;
} else {
ref_seq_id = sort_index.at(field_name)->at(seq_id);
}
ref_seq_id = sort_index.at(field_name)->at(seq_id);
}
// Joined collection has a reference
else {
@ -5153,7 +5186,9 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
auto const& reference = references.at(joined_coll_having_reference);
auto const& count = reference.count;
if (count == 1) {
if (count == 0) {
reference_found = false;
} else if (count == 1) {
auto op = joined_collection->get_sort_index_value_with_lock(reference_field_name,
reference.docs[0]);
if (!op.ok()) {
@ -5162,8 +5197,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
ref_seq_id = op.get();
} else {
return Option<bool>(400, count > 1 ? multiple_references_error_message :
no_references_error_message);
return Option<bool>(400, multiple_references_error_message);
}
}
}
@ -5179,6 +5213,8 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
} else if(field_values[2] == &str_sentinel_value) {
if (sort_fields[2].reference_collection_name.empty()) {
scores[2] = str_sort_index.at(sort_fields[2].name)->rank(seq_id);
} else if (!reference_found) {
scores[2] = adi_tree_t::NOT_FOUND;
} else {
auto& cm = CollectionManager::get_instance();
auto ref_collection = cm.get_collection(sort_fields[2].reference_collection_name);
@ -5243,8 +5279,13 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
// do nothing
}
} else {
auto it = field_values[2]->find(sort_fields[2].reference_collection_name.empty() ? seq_id : ref_seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
if (sort_fields[2].reference_collection_name.empty() || reference_found) {
auto it = field_values[2]->find(sort_fields[2].reference_collection_name.empty() ? seq_id : ref_seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
} else {
scores[2] = default_score;
}
if(scores[2] == INT64_MIN && sort_fields[2].missing_values == sort_by::missing_values_t::first) {
bool is_asc = (sort_order[2] == -1);
scores[2] = is_asc ? (INT64_MIN + 1) : INT64_MAX;

View File

@ -5283,6 +5283,112 @@ TEST_F(CollectionJoinTest, SortByReference) {
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_FALSE(search_op.ok());
ASSERT_EQ("Multiple references found to sort by on `Customers.product_price`.", search_op.error());
schema_json =
R"({
"name": "Ads",
"fields": [
{"name": "id", "type": "string"}
]
})"_json;
documents = {
R"({
"id": "ad_a"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Structures",
"fields": [
{"name": "id", "type": "string"},
{"name": "name", "type": "string", "sort": true}
]
})"_json;
documents = {
R"({
"id": "struct_a",
"name": "foo"
})"_json,
R"({
"id": "struct_b",
"name": "bar"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
schema_json =
R"({
"name": "Candidates",
"fields": [
{"name": "structure", "type": "string", "reference": "Structures.id", "optional": true},
{"name": "ad", "type": "string", "reference": "Ads.id", "optional": true}
]
})"_json;
documents = {
R"({
"structure": "struct_a"
})"_json,
R"({
"ad": "ad_a"
})"_json,
R"({
"structure": "struct_b"
})"_json
};
collection_create_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(collection_create_op.ok());
for (auto const &json: documents) {
auto add_op = collection_create_op.get()->add(json.dump());
if (!add_op.ok()) {
LOG(INFO) << add_op.error();
}
ASSERT_TRUE(add_op.ok());
}
req_params = {
{"collection", "Candidates"},
{"q", "*"},
{"filter_by", "$Ads(id:*) || $Structures(id:*)"},
{"sort_by", "$Structures(name: asc)"}
};
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_TRUE(search_op.ok());
res_obj = nlohmann::json::parse(json_res);
ASSERT_EQ(3, res_obj["found"].get<size_t>());
ASSERT_EQ(3, res_obj["hits"].size());
ASSERT_EQ("2", res_obj["hits"][0]["document"].at("id"));
ASSERT_EQ("bar", res_obj["hits"][0]["document"]["Structures"].at("name"));
ASSERT_EQ(0, res_obj["hits"][0]["document"].count("Ads"));
ASSERT_EQ("0", res_obj["hits"][1]["document"].at("id"));
ASSERT_EQ("foo", res_obj["hits"][1]["document"]["Structures"].at("name"));
ASSERT_EQ(0, res_obj["hits"][1]["document"].count("Ads"));
ASSERT_EQ("1", res_obj["hits"][2]["document"].at("id"));
ASSERT_EQ(0, res_obj["hits"][2]["document"].count("Structures"));
ASSERT_EQ(1, res_obj["hits"][2]["document"].count("Ads"));
}
TEST_F(CollectionJoinTest, FilterByReferenceAlias) {