Implement optional filtering via sort_by clause.

This commit is contained in:
Kishore Nallan 2022-08-04 14:47:02 +05:30
parent 4c6eac9840
commit db3c014463
6 changed files with 357 additions and 17 deletions

View File

@ -442,6 +442,7 @@ namespace sort_field_const {
static const std::string desc = "DESC";
static const std::string text_match = "_text_match";
static const std::string eval = "_eval";
static const std::string seq_id = "_seq_id";
static const std::string exclude_radius = "exclude_radius";
@ -457,6 +458,12 @@ struct sort_by {
normal,
};
struct eval_t {
std::vector<filter> filters;
uint32_t* ids = nullptr;
uint32_t size = 0;
};
std::string name;
std::string order;
@ -469,6 +476,7 @@ struct sort_by {
uint32_t geo_precision;
missing_values_t missing_values;
eval_t eval;
sort_by(const std::string & name, const std::string & order):
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),

View File

@ -517,6 +517,7 @@ private:
static spp::sparse_hash_map<uint32_t, int64_t> text_match_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> seq_id_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> eval_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
@ -565,7 +566,7 @@ private:
const field& the_field, const std::string& field_name,
const uint32_t *filter_ids, size_t filter_ids_length,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields,
std::vector<sort_by> & sort_fields,
int last_typo,
int max_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,
@ -619,7 +620,7 @@ private:
const uint32_t* filter_ids, size_t filter_ids_length,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
uint32_t** all_result_ids,
@ -788,7 +789,7 @@ public:
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
const size_t per_page,
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
@ -878,7 +879,7 @@ public:
uint32_t& filter_ids_length, const std::vector<uint32_t>& curated_ids_sorted) const;
void populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
const std::vector<sort_by>& sort_fields_std,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const;
static void remove_matched_tokens(std::vector<std::string>& tokens, const std::set<std::string>& rule_token_set) ;
@ -1040,8 +1041,10 @@ public:
void compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
size_t filter_index,
int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index) const;
int64_t* scores,
int64_t& match_score_index) const;
void
process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,

View File

@ -523,7 +523,10 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std) const {
for(const sort_by& _sort_field: sort_fields) {
size_t num_sort_expressions = 0;
for(size_t i = 0; i < sort_fields.size(); i++) {
const sort_by& _sort_field = sort_fields[i];
sort_by sort_field_std(_sort_field.name, _sort_field.order);
if(sort_field_std.name.back() == ')') {
@ -551,6 +554,19 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_field_std.name = actual_field_name;
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
} else if(actual_field_name == sort_field_const::eval) {
const std::string& filter_exp = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() - paran_start -
2);
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", sort_field_std.eval.filters);
if(!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}
sort_field_std.name = actual_field_name;
num_sort_expressions++;
} else {
if(field_it == search_schema.end()) {
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
@ -674,7 +690,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
}
}
if(sort_field_std.name != sort_field_const::text_match) {
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it->second.sort || !field_it->second.index) {
std::string error = "Could not find a field named `" + sort_field_std.name +
@ -725,6 +741,11 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
return Option<bool>(422, message);
}
if(num_sort_expressions > 1) {
std::string message = "Only one sorting eval expression is allowed.";
return Option<bool>(422, message);
}
return Option<bool>(true);
}

View File

@ -37,6 +37,7 @@
spp::sparse_hash_map<uint32_t, int64_t> Index::text_match_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::seq_id_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
@ -1306,7 +1307,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
const uint32_t* filter_ids, size_t filter_ids_length,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields,
std::vector<sort_by> & sort_fields,
std::vector<token_candidates> & token_candidates_vec,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster,
@ -2144,8 +2145,9 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string&
continue;
}
std::vector<sort_by> sort_fields;
search_field(0, window_tokens, nullptr, 0, num_toks_dropped, field_it->second, field_name,
nullptr, 0, {}, {}, -1, 0, searched_queries, topster, groups_processed,
nullptr, 0, {}, sort_fields, -1, 0, searched_queries, topster, groups_processed,
&result_ids, result_ids_len, field_num_results, 0, group_by_fields,
false, 4, query_hashes, token_order, false, 0, 0, false, -1, 3, 7, 4);
@ -2280,7 +2282,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name
void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
const size_t per_page,
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
@ -3109,6 +3111,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
}
std::vector<uint32_t> result_ids;
size_t filter_index = 0;
or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector<or_iterator_t>& its) {
//LOG(INFO) << "seq_id: " << seq_id;
@ -3169,7 +3172,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
int64_t scores[3] = {0};
int64_t match_score_index = -1;
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
max_field_match_score, scores, match_score_index);
size_t query_len = query_tokens.size();
@ -3238,7 +3241,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
uint32_t seq_id, int64_t max_field_match_score,
uint32_t seq_id, size_t filter_index, int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index) const {
int64_t geopoint_distances[3];
@ -3317,6 +3320,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[0] = -scores[0];
}
}
} else if(field_values[0] == &eval_sentinel_value) {
// Returns iterator to the first element that is >= to value or last if no such element is found.
bool found = false;
if (filter_index == 0 || filter_index < sort_fields[0].eval.size) {
size_t found_index = std::lower_bound(sort_fields[0].eval.ids + filter_index,
sort_fields[0].eval.ids + sort_fields[0].eval.size, seq_id) -
sort_fields[0].eval.ids;
if (found_index != sort_fields[0].eval.size && sort_fields[0].eval.ids[found_index] == seq_id) {
filter_index = found_index + 1;
found = true;
}
filter_index = found_index;
}
scores[0] = int64_t(found);
} else {
auto it = field_values[0]->find(seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
@ -3356,6 +3376,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[1] = -scores[1];
}
}
} else if(field_values[1] == &eval_sentinel_value) {
// Returns iterator to the first element that is >= to value or last if no such element is found.
bool found = false;
if (filter_index == 0 || filter_index < sort_fields[1].eval.size) {
size_t found_index = std::lower_bound(sort_fields[1].eval.ids + filter_index,
sort_fields[1].eval.ids + sort_fields[1].eval.size, seq_id) -
sort_fields[1].eval.ids;
if (found_index != sort_fields[1].eval.size && sort_fields[1].eval.ids[found_index] == seq_id) {
filter_index = found_index + 1;
found = true;
}
filter_index = found_index;
}
scores[1] = int64_t(found);
} else {
auto it = field_values[1]->find(seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
@ -3391,6 +3428,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[2] = -scores[2];
}
}
} else if(field_values[2] == &eval_sentinel_value) {
// Returns iterator to the first element that is >= to value or last if no such element is found.
bool found = false;
if (filter_index == 0 || filter_index < sort_fields[2].eval.size) {
size_t found_index = std::lower_bound(sort_fields[2].eval.ids + filter_index,
sort_fields[2].eval.ids + sort_fields[2].eval.size, seq_id) -
sort_fields[2].eval.ids;
if (found_index != sort_fields[2].eval.size && sort_fields[2].eval.ids[found_index] == seq_id) {
filter_index = found_index + 1;
found = true;
}
filter_index = found_index;
}
scores[0] = int64_t(found);
} else {
auto it = field_values[2]->find(seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
@ -3598,6 +3652,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
}
bool field_is_array = search_schema.at(the_fields[field_id].name).is_array();
size_t filter_index = 0;
for(size_t i = 0; i < raw_infix_ids_length; i++) {
auto seq_id = raw_infix_ids[i];
@ -3609,7 +3664,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
int64_t scores[3] = {0};
int64_t match_score_index = 0;
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
100, scores, match_score_index);
uint64_t distinct_id = seq_id;
@ -3756,10 +3811,11 @@ void Index::compute_facet_infos(const std::vector<facet>& facets, facet_query_t&
size_t field_num_results = 0;
std::set<uint64> query_hashes;
size_t num_toks_dropped = 0;
std::vector<sort_by> sort_fields;
search_field(0, qtokens, nullptr, 0, num_toks_dropped,
facet_field, facet_field.faceted_name(),
all_result_ids, all_result_ids_len, {}, {}, -1, facet_query_num_typos, searched_queries, topster,
all_result_ids, all_result_ids_len, {}, sort_fields, -1, facet_query_num_typos, searched_queries, topster,
groups_processed, &field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields,
false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, -1, 3, 1000, max_candidates);
@ -3927,6 +3983,8 @@ void Index::search_wildcard(const std::vector<filter>& filters,
search_stop_ms = parent_search_stop_ms;
search_cutoff = parent_search_cutoff;
size_t filter_index = 0;
for(size_t i = 0; i < batch_res_len; i++) {
const uint32_t seq_id = batch_result_ids[i];
int64_t match_score = 0;
@ -3937,7 +3995,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
int64_t scores[3] = {0};
int64_t match_score_index = 0;
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
100, scores, match_score_index);
uint64_t distinct_id = seq_id;
@ -3989,7 +4047,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
}
void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
const std::vector<sort_by>& sort_fields_std,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const {
for (size_t i = 0; i < sort_fields_std.size(); i++) {
sort_order[i] = 1;
@ -4001,6 +4059,9 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
field_values[i] = &text_match_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::seq_id) {
field_values[i] = &seq_id_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
do_filtering(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filters, true);
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
geopoint_indices.push_back(i);
@ -4035,7 +4096,7 @@ void Index::search_field(const uint8_t & field_id,
const field& the_field, const std::string& field_name, // to handle faceted index
const uint32_t *filter_ids, size_t filter_ids_length,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields,
std::vector<sort_by> & sort_fields,
const int last_typo,
const int max_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,

View File

@ -748,6 +748,14 @@ TEST_F(CollectionManagerTest, ParseSortByClause) {
ASSERT_EQ("_text_match(buckets: 10)", sort_fields[0].name);
ASSERT_EQ("ASC", sort_fields[0].order);
sort_fields.clear();
sort_by_parsed = CollectionManager::parse_sort_by_str("_eval(brand:nike && foo:bar):DESC,points:desc ", sort_fields);
ASSERT_TRUE(sort_by_parsed);
ASSERT_EQ("_eval(brand:nike && foo:bar)", sort_fields[0].name);
ASSERT_EQ("DESC", sort_fields[0].order);
ASSERT_EQ("points", sort_fields[1].name);
ASSERT_EQ("DESC", sort_fields[1].order);
sort_fields.clear();
sort_by_parsed = CollectionManager::parse_sort_by_str("", sort_fields);
ASSERT_TRUE(sort_by_parsed);

View File

@ -1810,3 +1810,242 @@ TEST_F(CollectionSortingTest, DisallowSortingOnNonIndexedIntegerField) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingWildcard) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "brand", "type": "string" },
{"name": "points", "type": "int32" }
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 5; i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = i;
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto sort_fields = {
sort_by("_eval(brand:nike)", "DESC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// compound query
sort_fields = {
sort_by("_eval(brand:nike && points:0)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"0", "4", "3", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// when no results are found for eval query
sort_fields = {
sort_by("_eval(brand:foobar)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"4", "3", "2", "1", "0"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// bad syntax for eval query
sort_fields = {
sort_by("_eval(brandnike || points:0)", "DESC"),
sort_by("points", "DESC"),
};
auto res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error());
// more bad syntax!
sort_fields = {
sort_by(")", "DESC"),
sort_by("points", "DESC"),
};
res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error());
// don't allow multiple sorting eval expressions
sort_fields = {
sort_by("_eval(brand: nike || points:0)", "DESC"),
sort_by("_eval(brand: nike || points:0)", "DESC"),
};
res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Only one sorting eval expression is allowed.", res_op.error());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSearch) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "brand", "type": "string" },
{"name": "points", "type": "int32" }
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 5; i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = i;
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto sort_fields = {
sort_by("_eval(brand:nike)", "DESC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// compound query
sort_fields = {
sort_by("_eval(brand:nike && points:0)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"0", "4", "3", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// when no results are found for eval query
sort_fields = {
sort_by("_eval(brand:foobar)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"4", "3", "2", "1", "0"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// bad syntax for eval query
sort_fields = {
sort_by("_eval(brandnike || points:0)", "DESC"),
sort_by("points", "DESC"),
};
auto res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error());
// more bad syntax!
sort_fields = {
sort_by(")", "DESC"),
sort_by("points", "DESC"),
};
res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "brand", "type": "string" },
{"name": "points", "type": "int32" },
{"name": "val", "type": "int32" }
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 5; i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["val"] = 0;
doc["points"] = i;
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto sort_fields = {
sort_by("val", "DESC"),
sort_by("_eval(brand:nike)", "DESC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// eval expression as 3rd sorting argument
sort_fields = {
sort_by("val", "DESC"),
sort_by("val", "DESC"),
sort_by("_eval(brand:nike)", "DESC"),
};
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
}