mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 20:52:50 +08:00
Implement optional filtering via sort_by clause.
This commit is contained in:
parent
4c6eac9840
commit
db3c014463
@ -442,6 +442,7 @@ namespace sort_field_const {
|
||||
static const std::string desc = "DESC";
|
||||
|
||||
static const std::string text_match = "_text_match";
|
||||
static const std::string eval = "_eval";
|
||||
static const std::string seq_id = "_seq_id";
|
||||
|
||||
static const std::string exclude_radius = "exclude_radius";
|
||||
@ -457,6 +458,12 @@ struct sort_by {
|
||||
normal,
|
||||
};
|
||||
|
||||
struct eval_t {
|
||||
std::vector<filter> filters;
|
||||
uint32_t* ids = nullptr;
|
||||
uint32_t size = 0;
|
||||
};
|
||||
|
||||
std::string name;
|
||||
std::string order;
|
||||
|
||||
@ -469,6 +476,7 @@ struct sort_by {
|
||||
uint32_t geo_precision;
|
||||
|
||||
missing_values_t missing_values;
|
||||
eval_t eval;
|
||||
|
||||
sort_by(const std::string & name, const std::string & order):
|
||||
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),
|
||||
|
@ -517,6 +517,7 @@ private:
|
||||
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> text_match_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> seq_id_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> eval_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
|
||||
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
|
||||
|
||||
@ -565,7 +566,7 @@ private:
|
||||
const field& the_field, const std::string& field_name,
|
||||
const uint32_t *filter_ids, size_t filter_ids_length,
|
||||
const std::vector<uint32_t>& curated_ids,
|
||||
const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by> & sort_fields,
|
||||
int last_typo,
|
||||
int max_typos,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries,
|
||||
@ -619,7 +620,7 @@ private:
|
||||
const uint32_t* filter_ids, size_t filter_ids_length,
|
||||
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
|
||||
const std::vector<uint32_t>& curated_ids,
|
||||
const std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
|
||||
std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries,
|
||||
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
|
||||
uint32_t** all_result_ids,
|
||||
@ -788,7 +789,7 @@ public:
|
||||
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
|
||||
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
|
||||
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
|
||||
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
|
||||
const size_t per_page,
|
||||
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
|
||||
@ -878,7 +879,7 @@ public:
|
||||
uint32_t& filter_ids_length, const std::vector<uint32_t>& curated_ids_sorted) const;
|
||||
|
||||
void populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
const std::vector<sort_by>& sort_fields_std,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const;
|
||||
|
||||
static void remove_matched_tokens(std::vector<std::string>& tokens, const std::set<std::string>& rule_token_set) ;
|
||||
@ -1040,8 +1041,10 @@ public:
|
||||
void compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
|
||||
size_t filter_index,
|
||||
int64_t max_field_match_score,
|
||||
int64_t* scores, int64_t& match_score_index) const;
|
||||
int64_t* scores,
|
||||
int64_t& match_score_index) const;
|
||||
|
||||
void
|
||||
process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
|
@ -523,7 +523,10 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
|
||||
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by>& sort_fields_std) const {
|
||||
|
||||
for(const sort_by& _sort_field: sort_fields) {
|
||||
size_t num_sort_expressions = 0;
|
||||
|
||||
for(size_t i = 0; i < sort_fields.size(); i++) {
|
||||
const sort_by& _sort_field = sort_fields[i];
|
||||
sort_by sort_field_std(_sort_field.name, _sort_field.order);
|
||||
|
||||
if(sort_field_std.name.back() == ')') {
|
||||
@ -551,6 +554,19 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
sort_field_std.name = actual_field_name;
|
||||
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
|
||||
|
||||
} else if(actual_field_name == sort_field_const::eval) {
|
||||
const std::string& filter_exp = sort_field_std.name.substr(paran_start + 1,
|
||||
sort_field_std.name.size() - paran_start -
|
||||
2);
|
||||
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
|
||||
store, "", sort_field_std.eval.filters);
|
||||
if(!parse_filter_op.ok()) {
|
||||
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
|
||||
}
|
||||
|
||||
sort_field_std.name = actual_field_name;
|
||||
num_sort_expressions++;
|
||||
|
||||
} else {
|
||||
if(field_it == search_schema.end()) {
|
||||
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
|
||||
@ -674,7 +690,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
}
|
||||
}
|
||||
|
||||
if(sort_field_std.name != sort_field_const::text_match) {
|
||||
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval) {
|
||||
const auto field_it = search_schema.find(sort_field_std.name);
|
||||
if(field_it == search_schema.end() || !field_it->second.sort || !field_it->second.index) {
|
||||
std::string error = "Could not find a field named `" + sort_field_std.name +
|
||||
@ -725,6 +741,11 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
|
||||
return Option<bool>(422, message);
|
||||
}
|
||||
|
||||
if(num_sort_expressions > 1) {
|
||||
std::string message = "Only one sorting eval expression is allowed.";
|
||||
return Option<bool>(422, message);
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
|
@ -37,6 +37,7 @@
|
||||
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::text_match_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::seq_id_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
|
||||
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
|
||||
|
||||
@ -1306,7 +1307,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
|
||||
const uint32_t* filter_ids, size_t filter_ids_length,
|
||||
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
|
||||
const std::vector<uint32_t>& curated_ids,
|
||||
const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by> & sort_fields,
|
||||
std::vector<token_candidates> & token_candidates_vec,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries,
|
||||
Topster* topster,
|
||||
@ -2144,8 +2145,9 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string&
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<sort_by> sort_fields;
|
||||
search_field(0, window_tokens, nullptr, 0, num_toks_dropped, field_it->second, field_name,
|
||||
nullptr, 0, {}, {}, -1, 0, searched_queries, topster, groups_processed,
|
||||
nullptr, 0, {}, sort_fields, -1, 0, searched_queries, topster, groups_processed,
|
||||
&result_ids, result_ids_len, field_num_results, 0, group_by_fields,
|
||||
false, 4, query_hashes, token_order, false, 0, 0, false, -1, 3, 7, 4);
|
||||
|
||||
@ -2280,7 +2282,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name
|
||||
void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
|
||||
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
|
||||
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
|
||||
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
|
||||
const size_t per_page,
|
||||
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
|
||||
@ -3109,6 +3111,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
}
|
||||
|
||||
std::vector<uint32_t> result_ids;
|
||||
size_t filter_index = 0;
|
||||
|
||||
or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector<or_iterator_t>& its) {
|
||||
//LOG(INFO) << "seq_id: " << seq_id;
|
||||
@ -3169,7 +3172,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score_index = -1;
|
||||
|
||||
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
|
||||
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
|
||||
max_field_match_score, scores, match_score_index);
|
||||
|
||||
size_t query_len = query_tokens.size();
|
||||
@ -3238,7 +3241,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
|
||||
void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
|
||||
const std::vector<size_t>& geopoint_indices,
|
||||
uint32_t seq_id, int64_t max_field_match_score,
|
||||
uint32_t seq_id, size_t filter_index, int64_t max_field_match_score,
|
||||
int64_t* scores, int64_t& match_score_index) const {
|
||||
|
||||
int64_t geopoint_distances[3];
|
||||
@ -3317,6 +3320,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
scores[0] = -scores[0];
|
||||
}
|
||||
}
|
||||
} else if(field_values[0] == &eval_sentinel_value) {
|
||||
// Returns iterator to the first element that is >= to value or last if no such element is found.
|
||||
bool found = false;
|
||||
if (filter_index == 0 || filter_index < sort_fields[0].eval.size) {
|
||||
size_t found_index = std::lower_bound(sort_fields[0].eval.ids + filter_index,
|
||||
sort_fields[0].eval.ids + sort_fields[0].eval.size, seq_id) -
|
||||
sort_fields[0].eval.ids;
|
||||
|
||||
if (found_index != sort_fields[0].eval.size && sort_fields[0].eval.ids[found_index] == seq_id) {
|
||||
filter_index = found_index + 1;
|
||||
found = true;
|
||||
}
|
||||
|
||||
filter_index = found_index;
|
||||
}
|
||||
|
||||
scores[0] = int64_t(found);
|
||||
} else {
|
||||
auto it = field_values[0]->find(seq_id);
|
||||
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
|
||||
@ -3356,6 +3376,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
scores[1] = -scores[1];
|
||||
}
|
||||
}
|
||||
} else if(field_values[1] == &eval_sentinel_value) {
|
||||
// Returns iterator to the first element that is >= to value or last if no such element is found.
|
||||
bool found = false;
|
||||
if (filter_index == 0 || filter_index < sort_fields[1].eval.size) {
|
||||
size_t found_index = std::lower_bound(sort_fields[1].eval.ids + filter_index,
|
||||
sort_fields[1].eval.ids + sort_fields[1].eval.size, seq_id) -
|
||||
sort_fields[1].eval.ids;
|
||||
|
||||
if (found_index != sort_fields[1].eval.size && sort_fields[1].eval.ids[found_index] == seq_id) {
|
||||
filter_index = found_index + 1;
|
||||
found = true;
|
||||
}
|
||||
|
||||
filter_index = found_index;
|
||||
}
|
||||
|
||||
scores[1] = int64_t(found);
|
||||
} else {
|
||||
auto it = field_values[1]->find(seq_id);
|
||||
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
|
||||
@ -3391,6 +3428,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
|
||||
scores[2] = -scores[2];
|
||||
}
|
||||
}
|
||||
} else if(field_values[2] == &eval_sentinel_value) {
|
||||
// Returns iterator to the first element that is >= to value or last if no such element is found.
|
||||
bool found = false;
|
||||
if (filter_index == 0 || filter_index < sort_fields[2].eval.size) {
|
||||
size_t found_index = std::lower_bound(sort_fields[2].eval.ids + filter_index,
|
||||
sort_fields[2].eval.ids + sort_fields[2].eval.size, seq_id) -
|
||||
sort_fields[2].eval.ids;
|
||||
|
||||
if (found_index != sort_fields[2].eval.size && sort_fields[2].eval.ids[found_index] == seq_id) {
|
||||
filter_index = found_index + 1;
|
||||
found = true;
|
||||
}
|
||||
|
||||
filter_index = found_index;
|
||||
}
|
||||
|
||||
scores[0] = int64_t(found);
|
||||
} else {
|
||||
auto it = field_values[2]->find(seq_id);
|
||||
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
|
||||
@ -3598,6 +3652,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
|
||||
}
|
||||
|
||||
bool field_is_array = search_schema.at(the_fields[field_id].name).is_array();
|
||||
size_t filter_index = 0;
|
||||
|
||||
for(size_t i = 0; i < raw_infix_ids_length; i++) {
|
||||
auto seq_id = raw_infix_ids[i];
|
||||
@ -3609,7 +3664,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score_index = 0;
|
||||
|
||||
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
|
||||
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
|
||||
100, scores, match_score_index);
|
||||
|
||||
uint64_t distinct_id = seq_id;
|
||||
@ -3756,10 +3811,11 @@ void Index::compute_facet_infos(const std::vector<facet>& facets, facet_query_t&
|
||||
size_t field_num_results = 0;
|
||||
std::set<uint64> query_hashes;
|
||||
size_t num_toks_dropped = 0;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
search_field(0, qtokens, nullptr, 0, num_toks_dropped,
|
||||
facet_field, facet_field.faceted_name(),
|
||||
all_result_ids, all_result_ids_len, {}, {}, -1, facet_query_num_typos, searched_queries, topster,
|
||||
all_result_ids, all_result_ids_len, {}, sort_fields, -1, facet_query_num_typos, searched_queries, topster,
|
||||
groups_processed, &field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields,
|
||||
false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, -1, 3, 1000, max_candidates);
|
||||
|
||||
@ -3927,6 +3983,8 @@ void Index::search_wildcard(const std::vector<filter>& filters,
|
||||
search_stop_ms = parent_search_stop_ms;
|
||||
search_cutoff = parent_search_cutoff;
|
||||
|
||||
size_t filter_index = 0;
|
||||
|
||||
for(size_t i = 0; i < batch_res_len; i++) {
|
||||
const uint32_t seq_id = batch_result_ids[i];
|
||||
int64_t match_score = 0;
|
||||
@ -3937,7 +3995,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
|
||||
int64_t scores[3] = {0};
|
||||
int64_t match_score_index = 0;
|
||||
|
||||
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
|
||||
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
|
||||
100, scores, match_score_index);
|
||||
|
||||
uint64_t distinct_id = seq_id;
|
||||
@ -3989,7 +4047,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
|
||||
}
|
||||
|
||||
void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
|
||||
const std::vector<sort_by>& sort_fields_std,
|
||||
std::vector<sort_by>& sort_fields_std,
|
||||
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const {
|
||||
for (size_t i = 0; i < sort_fields_std.size(); i++) {
|
||||
sort_order[i] = 1;
|
||||
@ -4001,6 +4059,9 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
|
||||
field_values[i] = &text_match_sentinel_value;
|
||||
} else if (sort_fields_std[i].name == sort_field_const::seq_id) {
|
||||
field_values[i] = &seq_id_sentinel_value;
|
||||
} else if (sort_fields_std[i].name == sort_field_const::eval) {
|
||||
field_values[i] = &eval_sentinel_value;
|
||||
do_filtering(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filters, true);
|
||||
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
|
||||
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
|
||||
geopoint_indices.push_back(i);
|
||||
@ -4035,7 +4096,7 @@ void Index::search_field(const uint8_t & field_id,
|
||||
const field& the_field, const std::string& field_name, // to handle faceted index
|
||||
const uint32_t *filter_ids, size_t filter_ids_length,
|
||||
const std::vector<uint32_t>& curated_ids,
|
||||
const std::vector<sort_by> & sort_fields,
|
||||
std::vector<sort_by> & sort_fields,
|
||||
const int last_typo,
|
||||
const int max_typos,
|
||||
std::vector<std::vector<art_leaf*>> & searched_queries,
|
||||
|
@ -748,6 +748,14 @@ TEST_F(CollectionManagerTest, ParseSortByClause) {
|
||||
ASSERT_EQ("_text_match(buckets: 10)", sort_fields[0].name);
|
||||
ASSERT_EQ("ASC", sort_fields[0].order);
|
||||
|
||||
sort_fields.clear();
|
||||
sort_by_parsed = CollectionManager::parse_sort_by_str("_eval(brand:nike && foo:bar):DESC,points:desc ", sort_fields);
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
ASSERT_EQ("_eval(brand:nike && foo:bar)", sort_fields[0].name);
|
||||
ASSERT_EQ("DESC", sort_fields[0].order);
|
||||
ASSERT_EQ("points", sort_fields[1].name);
|
||||
ASSERT_EQ("DESC", sort_fields[1].order);
|
||||
|
||||
sort_fields.clear();
|
||||
sort_by_parsed = CollectionManager::parse_sort_by_str("", sort_fields);
|
||||
ASSERT_TRUE(sort_by_parsed);
|
||||
|
@ -1810,3 +1810,242 @@ TEST_F(CollectionSortingTest, DisallowSortingOnNonIndexedIntegerField) {
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingWildcard) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "brand", "type": "string" },
|
||||
{"name": "points", "type": "int32" }
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
for(size_t i = 0; i < 5; i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = i;
|
||||
|
||||
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
auto sort_fields = {
|
||||
sort_by("_eval(brand:nike)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// compound query
|
||||
sort_fields = {
|
||||
sort_by("_eval(brand:nike && points:0)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
expected_ids = {"0", "4", "3", "2", "1"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// when no results are found for eval query
|
||||
sort_fields = {
|
||||
sort_by("_eval(brand:foobar)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
expected_ids = {"4", "3", "2", "1", "0"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// bad syntax for eval query
|
||||
sort_fields = {
|
||||
sort_by("_eval(brandnike || points:0)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
auto res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error());
|
||||
|
||||
// more bad syntax!
|
||||
sort_fields = {
|
||||
sort_by(")", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error());
|
||||
|
||||
// don't allow multiple sorting eval expressions
|
||||
sort_fields = {
|
||||
sort_by("_eval(brand: nike || points:0)", "DESC"),
|
||||
sort_by("_eval(brand: nike || points:0)", "DESC"),
|
||||
};
|
||||
|
||||
res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Only one sorting eval expression is allowed.", res_op.error());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSearch) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "brand", "type": "string" },
|
||||
{"name": "points", "type": "int32" }
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
for(size_t i = 0; i < 5; i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["points"] = i;
|
||||
|
||||
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
auto sort_fields = {
|
||||
sort_by("_eval(brand:nike)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// compound query
|
||||
sort_fields = {
|
||||
sort_by("_eval(brand:nike && points:0)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
expected_ids = {"0", "4", "3", "2", "1"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// when no results are found for eval query
|
||||
sort_fields = {
|
||||
sort_by("_eval(brand:foobar)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
expected_ids = {"4", "3", "2", "1", "0"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// bad syntax for eval query
|
||||
sort_fields = {
|
||||
sort_by("_eval(brandnike || points:0)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
auto res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error());
|
||||
|
||||
// more bad syntax!
|
||||
sort_fields = {
|
||||
sort_by(")", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
|
||||
ASSERT_FALSE(res_op.ok());
|
||||
ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) {
|
||||
std::string coll_schema = R"(
|
||||
{
|
||||
"name": "coll1",
|
||||
"fields": [
|
||||
{"name": "title", "type": "string" },
|
||||
{"name": "brand", "type": "string" },
|
||||
{"name": "points", "type": "int32" },
|
||||
{"name": "val", "type": "int32" }
|
||||
]
|
||||
}
|
||||
)";
|
||||
|
||||
nlohmann::json schema = nlohmann::json::parse(coll_schema);
|
||||
Collection* coll1 = collectionManager.create_collection(schema).get();
|
||||
|
||||
for(size_t i = 0; i < 5; i++) {
|
||||
nlohmann::json doc;
|
||||
doc["title"] = "Title " + std::to_string(i);
|
||||
doc["val"] = 0;
|
||||
doc["points"] = i;
|
||||
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
|
||||
ASSERT_TRUE(coll1->add(doc.dump()).ok());
|
||||
}
|
||||
|
||||
auto sort_fields = {
|
||||
sort_by("val", "DESC"),
|
||||
sort_by("_eval(brand:nike)", "DESC"),
|
||||
sort_by("points", "DESC"),
|
||||
};
|
||||
|
||||
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
|
||||
// eval expression as 3rd sorting argument
|
||||
sort_fields = {
|
||||
sort_by("val", "DESC"),
|
||||
sort_by("val", "DESC"),
|
||||
sort_by("_eval(brand:nike)", "DESC"),
|
||||
};
|
||||
|
||||
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
for(size_t i = 0; i > expected_ids.size(); i++) {
|
||||
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user