Merge branch 'v0.24-changes' into v0.24-nested

# Conflicts:
#	src/index.cpp
This commit is contained in:
Kishore Nallan 2022-08-05 14:14:52 +05:30
commit 58bdcc2e6f
11 changed files with 592 additions and 25 deletions

View File

@ -483,6 +483,7 @@ namespace sort_field_const {
static const std::string desc = "DESC";
static const std::string text_match = "_text_match";
static const std::string eval = "_eval";
static const std::string seq_id = "_seq_id";
static const std::string exclude_radius = "exclude_radius";
@ -498,6 +499,12 @@ struct sort_by {
normal,
};
struct eval_t {
std::vector<filter> filters;
uint32_t* ids = nullptr;
uint32_t size = 0;
};
std::string name;
std::string order;
@ -510,6 +517,7 @@ struct sort_by {
uint32_t geo_precision;
missing_values_t missing_values;
eval_t eval;
sort_by(const std::string & name, const std::string & order):
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),

View File

@ -110,6 +110,11 @@ struct override_t {
bool stop_processing = true;
std::string sort_by;
std::string replace_query;
// epoch seconds
int64_t effective_from_ts = -1;
int64_t effective_to_ts = -1;
override_t() = default;
@ -128,9 +133,10 @@ struct override_t {
if(override_json.count("includes") == 0 && override_json.count("excludes") == 0 &&
override_json.count("filter_by") == 0 && override_json.count("sort_by") == 0 &&
override_json.count("remove_matched_tokens") == 0) {
override_json.count("remove_matched_tokens") == 0 &&
override_json.count("replace_query") == 0) {
return Option<bool>(400, "Must contain one of: `includes`, `excludes`, "
"`filter_by`, `sort_by`, `remove_matched_tokens`.");
"`filter_by`, `sort_by`, `remove_matched_tokens`, `replace_query`.");
}
if(override_json.count("includes") != 0) {
@ -242,6 +248,13 @@ struct override_t {
override.sort_by = override_json["sort_by"].get<std::string>();
}
if (override_json.count("replace_query") != 0) {
if(override_json.count("remove_matched_tokens") != 0) {
return Option<bool>(400, "Only one of `replace_query` or `remove_matched_tokens` can be specified.");
}
override.replace_query = override_json["replace_query"].get<std::string>();
}
if(override_json.count("remove_matched_tokens") != 0) {
override.remove_matched_tokens = override_json["remove_matched_tokens"].get<bool>();
} else {
@ -256,6 +269,14 @@ struct override_t {
override.stop_processing = override_json["stop_processing"].get<bool>();
}
if(override_json.count("effective_from_ts") != 0) {
override.effective_from_ts = override_json["effective_from_ts"].get<int64_t>();
}
if(override_json.count("effective_to_ts") != 0) {
override.effective_to_ts = override_json["effective_to_ts"].get<int64_t>();
}
// we have to also detect if it is a dynamic query rule
size_t i = 0;
while(i < override.rule.query.size()) {
@ -308,6 +329,18 @@ struct override_t {
override["sort_by"] = sort_by;
}
if(!replace_query.empty()) {
override["replace_query"] = replace_query;
}
if(effective_from_ts != -1) {
override["effective_from_ts"] = effective_from_ts;
}
if(effective_to_ts != -1) {
override["effective_to_ts"] = effective_to_ts;
}
override["remove_matched_tokens"] = remove_matched_tokens;
override["filter_curated_hits"] = filter_curated_hits;
override["stop_processing"] = stop_processing;
@ -517,6 +550,7 @@ private:
static spp::sparse_hash_map<uint32_t, int64_t> text_match_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> seq_id_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> eval_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> geo_sentinel_value;
static spp::sparse_hash_map<uint32_t, int64_t> str_sentinel_value;
@ -565,7 +599,7 @@ private:
const field& the_field, const std::string& field_name,
const uint32_t *filter_ids, size_t filter_ids_length,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields,
std::vector<sort_by> & sort_fields,
int last_typo,
int max_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,
@ -619,7 +653,7 @@ private:
const uint32_t* filter_ids, size_t filter_ids_length,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
std::vector<sort_by> & sort_fields, std::vector<token_candidates> & token_to_candidates,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster, spp::sparse_hash_set<uint64_t>& groups_processed,
uint32_t** all_result_ids,
@ -788,7 +822,7 @@ public:
void search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
const size_t per_page,
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
@ -878,7 +912,7 @@ public:
uint32_t& filter_ids_length, const std::vector<uint32_t>& curated_ids_sorted) const;
void populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
const std::vector<sort_by>& sort_fields_std,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const;
static void remove_matched_tokens(std::vector<std::string>& tokens, const std::set<std::string>& rule_token_set) ;
@ -1040,8 +1074,10 @@ public:
void compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
size_t filter_index,
int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index) const;
int64_t* scores,
int64_t& match_score_index) const;
void
process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,

View File

@ -36,8 +36,9 @@ private:
int32_t utf8_start_index = 0;
char* normalized_text = nullptr;
// non-deletable singleton
const icu::Normalizer2* nfkd;
// non-deletable singletons
const icu::Normalizer2* nfkd = nullptr;
const icu::Normalizer2* nfkc = nullptr;
icu::Transliterator* transliterator = nullptr;

View File

@ -423,6 +423,15 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
for(const auto& override_kv: overrides) {
const auto& override = override_kv.second;
auto now_epoch = int64_t(std::time(0));
if(override.effective_from_ts != -1 && now_epoch < override.effective_from_ts) {
continue;
}
if(override.effective_to_ts != -1 && now_epoch > override.effective_to_ts) {
continue;
}
// ID-based overrides are applied first as they take precedence over filter-based overrides
if(!override.filter_by.empty()) {
filter_overrides.push_back(&override);
@ -453,7 +462,9 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
}
}
if(override.remove_matched_tokens && override.filter_by.empty()) {
if(!override.replace_query.empty()) {
actual_query = override.replace_query;
} else if(override.remove_matched_tokens && override.filter_by.empty()) {
// don't prematurely remove tokens from query because dynamic filtering will require them
StringUtils::replace_all(query, override.rule.query, "");
StringUtils::trim(query);
@ -495,7 +506,10 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<sort_by> & sort_fields,
std::vector<sort_by>& sort_fields_std) const {
for(const sort_by& _sort_field: sort_fields) {
size_t num_sort_expressions = 0;
for(size_t i = 0; i < sort_fields.size(); i++) {
const sort_by& _sort_field = sort_fields[i];
sort_by sort_field_std(_sort_field.name, _sort_field.order);
if(sort_field_std.name.back() == ')') {
@ -523,6 +537,19 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_field_std.name = actual_field_name;
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
} else if(actual_field_name == sort_field_const::eval) {
const std::string& filter_exp = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() - paran_start -
2);
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", sort_field_std.eval.filters);
if(!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}
sort_field_std.name = actual_field_name;
num_sort_expressions++;
} else {
if(field_it == search_schema.end()) {
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
@ -646,7 +673,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
}
}
if(sort_field_std.name != sort_field_const::text_match) {
if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval) {
const auto field_it = search_schema.find(sort_field_std.name);
if(field_it == search_schema.end() || !field_it.value().sort || !field_it.value().index) {
std::string error = "Could not find a field named `" + sort_field_std.name +
@ -697,6 +724,11 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
return Option<bool>(422, message);
}
if(num_sort_expressions > 1) {
std::string message = "Only one sorting eval expression is allowed.";
return Option<bool>(422, message);
}
return Option<bool>(true);
}

View File

@ -37,6 +37,7 @@
spp::sparse_hash_map<uint32_t, int64_t> Index::text_match_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::seq_id_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::eval_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::geo_sentinel_value;
spp::sparse_hash_map<uint32_t, int64_t> Index::str_sentinel_value;
@ -1305,7 +1306,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array,
const uint32_t* filter_ids, size_t filter_ids_length,
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields,
std::vector<sort_by> & sort_fields,
std::vector<token_candidates> & token_candidates_vec,
std::vector<std::vector<art_leaf*>> & searched_queries,
Topster* topster,
@ -2143,8 +2144,9 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string&
continue;
}
std::vector<sort_by> sort_fields;
search_field(0, window_tokens, nullptr, 0, num_toks_dropped, field_it.value(), field_name,
nullptr, 0, {}, {}, -1, 0, searched_queries, topster, groups_processed,
nullptr, 0, {}, sort_fields, -1, 0, searched_queries, topster, groups_processed,
&result_ids, result_ids_len, field_num_results, 0, group_by_fields,
false, 4, query_hashes, token_order, false, 0, 0, false, -1, 3, 7, 4);
@ -2279,7 +2281,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name
void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::vector<search_field_t>& the_fields,
std::vector<filter>& filters, std::vector<facet>& facets, facet_query_t& facet_query,
const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
const std::vector<uint32_t>& excluded_ids, const std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& excluded_ids, std::vector<sort_by>& sort_fields_std,
const std::vector<uint32_t>& num_typos, Topster* topster, Topster* curated_topster,
const size_t per_page,
const size_t page, const token_ordering token_order, const std::vector<bool>& prefixes,
@ -3108,6 +3110,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
}
std::vector<uint32_t> result_ids;
size_t filter_index = 0;
or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector<or_iterator_t>& its) {
//LOG(INFO) << "seq_id: " << seq_id;
@ -3168,7 +3171,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
int64_t scores[3] = {0};
int64_t match_score_index = -1;
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
max_field_match_score, scores, match_score_index);
size_t query_len = query_tokens.size();
@ -3237,7 +3240,7 @@ void Index::search_across_fields(const std::vector<token_t>& query_tokens,
void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const int* sort_order,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices,
uint32_t seq_id, int64_t max_field_match_score,
uint32_t seq_id, size_t filter_index, int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index) const {
int64_t geopoint_distances[3];
@ -3316,6 +3319,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[0] = -scores[0];
}
}
} else if(field_values[0] == &eval_sentinel_value) {
// Returns iterator to the first element that is >= to value or last if no such element is found.
bool found = false;
if (filter_index == 0 || filter_index < sort_fields[0].eval.size) {
size_t found_index = std::lower_bound(sort_fields[0].eval.ids + filter_index,
sort_fields[0].eval.ids + sort_fields[0].eval.size, seq_id) -
sort_fields[0].eval.ids;
if (found_index != sort_fields[0].eval.size && sort_fields[0].eval.ids[found_index] == seq_id) {
filter_index = found_index + 1;
found = true;
}
filter_index = found_index;
}
scores[0] = int64_t(found);
} else {
auto it = field_values[0]->find(seq_id);
scores[0] = (it == field_values[0]->end()) ? default_score : it->second;
@ -3355,6 +3375,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[1] = -scores[1];
}
}
} else if(field_values[1] == &eval_sentinel_value) {
// Returns iterator to the first element that is >= to value or last if no such element is found.
bool found = false;
if (filter_index == 0 || filter_index < sort_fields[1].eval.size) {
size_t found_index = std::lower_bound(sort_fields[1].eval.ids + filter_index,
sort_fields[1].eval.ids + sort_fields[1].eval.size, seq_id) -
sort_fields[1].eval.ids;
if (found_index != sort_fields[1].eval.size && sort_fields[1].eval.ids[found_index] == seq_id) {
filter_index = found_index + 1;
found = true;
}
filter_index = found_index;
}
scores[1] = int64_t(found);
} else {
auto it = field_values[1]->find(seq_id);
scores[1] = (it == field_values[1]->end()) ? default_score : it->second;
@ -3390,6 +3427,23 @@ void Index::compute_sort_scores(const std::vector<sort_by>& sort_fields, const i
scores[2] = -scores[2];
}
}
} else if(field_values[2] == &eval_sentinel_value) {
// Returns iterator to the first element that is >= to value or last if no such element is found.
bool found = false;
if (filter_index == 0 || filter_index < sort_fields[2].eval.size) {
size_t found_index = std::lower_bound(sort_fields[2].eval.ids + filter_index,
sort_fields[2].eval.ids + sort_fields[2].eval.size, seq_id) -
sort_fields[2].eval.ids;
if (found_index != sort_fields[2].eval.size && sort_fields[2].eval.ids[found_index] == seq_id) {
filter_index = found_index + 1;
found = true;
}
filter_index = found_index;
}
scores[0] = int64_t(found);
} else {
auto it = field_values[2]->find(seq_id);
scores[2] = (it == field_values[2]->end()) ? default_score : it->second;
@ -3597,6 +3651,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
}
bool field_is_array = search_schema.at(the_fields[field_id].name).is_array();
size_t filter_index = 0;
for(size_t i = 0; i < raw_infix_ids_length; i++) {
auto seq_id = raw_infix_ids[i];
@ -3608,7 +3663,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector<se
int64_t scores[3] = {0};
int64_t match_score_index = 0;
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
100, scores, match_score_index);
uint64_t distinct_id = seq_id;
@ -3755,10 +3810,11 @@ void Index::compute_facet_infos(const std::vector<facet>& facets, facet_query_t&
size_t field_num_results = 0;
std::set<uint64> query_hashes;
size_t num_toks_dropped = 0;
std::vector<sort_by> sort_fields;
search_field(0, qtokens, nullptr, 0, num_toks_dropped,
facet_field, facet_field.faceted_name(),
all_result_ids, all_result_ids_len, {}, {}, -1, facet_query_num_typos, searched_queries, topster,
all_result_ids, all_result_ids_len, {}, sort_fields, -1, facet_query_num_typos, searched_queries, topster,
groups_processed, &field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields,
false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, -1, 3, 1000, max_candidates);
@ -3926,6 +3982,8 @@ void Index::search_wildcard(const std::vector<filter>& filters,
search_stop_ms = parent_search_stop_ms;
search_cutoff = parent_search_cutoff;
size_t filter_index = 0;
for(size_t i = 0; i < batch_res_len; i++) {
const uint32_t seq_id = batch_result_ids[i];
int64_t match_score = 0;
@ -3936,7 +3994,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
int64_t scores[3] = {0};
int64_t match_score_index = 0;
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id,
compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index,
100, scores, match_score_index);
uint64_t distinct_id = seq_id;
@ -3988,7 +4046,7 @@ void Index::search_wildcard(const std::vector<filter>& filters,
}
void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
const std::vector<sort_by>& sort_fields_std,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3>& field_values) const {
for (size_t i = 0; i < sort_fields_std.size(); i++) {
sort_order[i] = 1;
@ -4000,6 +4058,9 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
field_values[i] = &text_match_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::seq_id) {
field_values[i] = &seq_id_sentinel_value;
} else if (sort_fields_std[i].name == sort_field_const::eval) {
field_values[i] = &eval_sentinel_value;
do_filtering(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filters, true);
} else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) {
if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) {
geopoint_indices.push_back(i);
@ -4034,7 +4095,7 @@ void Index::search_field(const uint8_t & field_id,
const field& the_field, const std::string& field_name, // to handle faceted index
const uint32_t *filter_ids, size_t filter_ids_length,
const std::vector<uint32_t>& curated_ids,
const std::vector<sort_by> & sort_fields,
std::vector<sort_by> & sort_fields,
const int last_typo,
const int max_typos,
std::vector<std::vector<art_leaf*>> & searched_queries,

View File

@ -16,7 +16,14 @@ Tokenizer::Tokenizer(const std::string& input, bool normalize, bool no_op, const
}
UErrorCode errcode = U_ZERO_ERROR;
nfkd = icu::Normalizer2::getNFKDInstance(errcode);
if(locale == "ko") {
nfkd = icu::Normalizer2::getNFKDInstance(errcode);
}
if(locale == "th") {
nfkc = icu::Normalizer2::getNFKCInstance(errcode);
}
cd = iconv_open("ASCII//TRANSLIT", "UTF-8");
@ -119,6 +126,16 @@ bool Tokenizer::next(std::string &token, size_t& token_index, size_t& start_inde
auto raw_text = unicode_text.tempSubStringBetween(start_pos, end_pos);
transliterator->transliterate(raw_text);
token = raw_text.toUTF8String(word);
} else if(locale == "th") {
UErrorCode errcode = U_ZERO_ERROR;
icu::UnicodeString src = unicode_text.tempSubStringBetween(start_pos, end_pos);
icu::UnicodeString dst;
nfkc->normalize(src, dst, errcode);
if(!U_FAILURE(errcode)) {
token = dst.toUTF8String(word);
} else {
LOG(ERROR) << "Unicode error during parsing: " << errcode;
}
} else {
token = unicode_text.tempSubStringBetween(start_pos, end_pos).toUTF8String(word);
}

View File

@ -187,6 +187,39 @@ TEST_F(CollectionLocaleTest, SearchAgainstThaiText) {
ASSERT_EQ("<mark>พกไฟ</mark>\nเสมอ", results["hits"][0]["highlights"][0]["snippet"].get<std::string>());
}
TEST_F(CollectionLocaleTest, ThaiTextShouldBeNormalizedToNFKC) {
Collection *coll1;
std::vector<field> fields = {field("title", field_types::STRING, false, false, true, "th"),
field("artist", field_types::STRING, false),
field("points", field_types::INT32, false),};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
std::vector<std::vector<std::string>> records = {
{"น้ำมัน", "Dustin Kensrue"},
};
for(size_t i=0; i<records.size(); i++) {
nlohmann::json doc;
doc["id"] = std::to_string(i);
doc["title"] = records[i][0];
doc["artist"] = records[i][1];
doc["points"] = i;
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto results = coll1->search("น้ํามัน",{"title"}, "", {}, {},
{0}, 10, 1, FREQUENCY).get();
ASSERT_EQ(1, results["found"].get<size_t>());
}
TEST_F(CollectionLocaleTest, SearchThaiTextPreSegmentedQuery) {
Collection *coll1;

View File

@ -748,6 +748,14 @@ TEST_F(CollectionManagerTest, ParseSortByClause) {
ASSERT_EQ("_text_match(buckets: 10)", sort_fields[0].name);
ASSERT_EQ("ASC", sort_fields[0].order);
sort_fields.clear();
sort_by_parsed = CollectionManager::parse_sort_by_str("_eval(brand:nike && foo:bar):DESC,points:desc ", sort_fields);
ASSERT_TRUE(sort_by_parsed);
ASSERT_EQ("_eval(brand:nike && foo:bar)", sort_fields[0].name);
ASSERT_EQ("DESC", sort_fields[0].order);
ASSERT_EQ("points", sort_fields[1].name);
ASSERT_EQ("DESC", sort_fields[1].order);
sort_fields.clear();
sort_by_parsed = CollectionManager::parse_sort_by_str("", sort_fields);
ASSERT_TRUE(sort_by_parsed);

View File

@ -271,7 +271,7 @@ TEST_F(CollectionOverrideTest, OverrideJSONValidation) {
parse_op = override_t::parse(include_json2, "", override2);
ASSERT_FALSE(parse_op.ok());
ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`.",
ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`, `replace_query`.",
parse_op.error().c_str());
include_json2["includes"] = nlohmann::json::array();
@ -781,6 +781,138 @@ TEST_F(CollectionOverrideTest, IncludeOverrideWithFilterBy) {
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
}
TEST_F(CollectionOverrideTest, ReplaceQuery) {
Collection *coll1;
std::vector<field> fields = {field("name", field_types::STRING, false),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
nlohmann::json doc1;
doc1["id"] = "0";
doc1["name"] = "Amazing Shoes";
doc1["points"] = 30;
nlohmann::json doc2;
doc2["id"] = "1";
doc2["name"] = "Fast Shoes";
doc2["points"] = 50;
nlohmann::json doc3;
doc3["id"] = "2";
doc3["name"] = "Comfortable Socks";
doc3["points"] = 1;
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
ASSERT_TRUE(coll1->add(doc3.dump()).ok());
std::vector<sort_by> sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") };
nlohmann::json override_json = R"({
"id": "rule-1",
"rule": {
"query": "boots",
"match": "exact"
},
"replace_query": "shoes"
})"_json;
override_t override_rule;
auto op = override_t::parse(override_json, "rule-1", override_rule);
ASSERT_TRUE(op.ok());
coll1->add_override(override_rule);
auto results = coll1->search("boots", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
// don't allow both remove_matched_tokens and replace_query
override_json["remove_matched_tokens"] = true;
op = override_t::parse(override_json, "rule-1", override_rule);
ASSERT_FALSE(op.ok());
ASSERT_EQ("Only one of `replace_query` or `remove_matched_tokens` can be specified.", op.error());
}
TEST_F(CollectionOverrideTest, WindowForRule) {
Collection *coll1;
std::vector<field> fields = {field("name", field_types::STRING, false),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
nlohmann::json doc1;
doc1["id"] = "0";
doc1["name"] = "Amazing Shoes";
doc1["points"] = 30;
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
std::vector<sort_by> sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") };
nlohmann::json override_json = R"({
"id": "rule-1",
"rule": {
"query": "boots",
"match": "exact"
},
"replace_query": "shoes"
})"_json;
override_t override_rule;
auto op = override_t::parse(override_json, "rule-1", override_rule);
ASSERT_TRUE(op.ok());
coll1->add_override(override_rule);
auto results = coll1->search("boots", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(1, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
// rule must not match when window_start is set into the future
override_json["effective_from_ts"] = 35677971263; // year 3100, here we come! ;)
op = override_t::parse(override_json, "rule-1", override_rule);
ASSERT_TRUE(op.ok());
coll1->add_override(override_rule);
results = coll1->search("boots", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(0, results["hits"].size());
// rule must not match when window_end is set into the past
override_json["effective_from_ts"] = -1;
override_json["effective_to_ts"] = 965388863;
op = override_t::parse(override_json, "rule-1", override_rule);
ASSERT_TRUE(op.ok());
coll1->add_override(override_rule);
results = coll1->search("boots", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(0, results["hits"].size());
// resetting both should bring the override back in action
override_json["effective_from_ts"] = 965388863;
override_json["effective_to_ts"] = 35677971263;
op = override_t::parse(override_json, "rule-1", override_rule);
ASSERT_TRUE(op.ok());
coll1->add_override(override_rule);
results = coll1->search("boots", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(1, results["hits"].size());
}
TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) {
auto pinned_hits = "13:1,4:2";

View File

@ -1810,3 +1810,242 @@ TEST_F(CollectionSortingTest, DisallowSortingOnNonIndexedIntegerField) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingWildcard) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "brand", "type": "string" },
{"name": "points", "type": "int32" }
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 5; i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = i;
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto sort_fields = {
sort_by("_eval(brand:nike)", "DESC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// compound query
sort_fields = {
sort_by("_eval(brand:nike && points:0)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"0", "4", "3", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// when no results are found for eval query
sort_fields = {
sort_by("_eval(brand:foobar)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"4", "3", "2", "1", "0"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// bad syntax for eval query
sort_fields = {
sort_by("_eval(brandnike || points:0)", "DESC"),
sort_by("points", "DESC"),
};
auto res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error());
// more bad syntax!
sort_fields = {
sort_by(")", "DESC"),
sort_by("points", "DESC"),
};
res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error());
// don't allow multiple sorting eval expressions
sort_fields = {
sort_by("_eval(brand: nike || points:0)", "DESC"),
sort_by("_eval(brand: nike || points:0)", "DESC"),
};
res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Only one sorting eval expression is allowed.", res_op.error());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSearch) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "brand", "type": "string" },
{"name": "points", "type": "int32" }
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 5; i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["points"] = i;
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto sort_fields = {
sort_by("_eval(brand:nike)", "DESC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// compound query
sort_fields = {
sort_by("_eval(brand:nike && points:0)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"0", "4", "3", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// when no results are found for eval query
sort_fields = {
sort_by("_eval(brand:foobar)", "DESC"),
sort_by("points", "DESC"),
};
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
expected_ids = {"4", "3", "2", "1", "0"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// bad syntax for eval query
sort_fields = {
sort_by("_eval(brandnike || points:0)", "DESC"),
sort_by("points", "DESC"),
};
auto res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error());
// more bad syntax!
sort_fields = {
sort_by(")", "DESC"),
sort_by("points", "DESC"),
};
res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) {
std::string coll_schema = R"(
{
"name": "coll1",
"fields": [
{"name": "title", "type": "string" },
{"name": "brand", "type": "string" },
{"name": "points", "type": "int32" },
{"name": "val", "type": "int32" }
]
}
)";
nlohmann::json schema = nlohmann::json::parse(coll_schema);
Collection* coll1 = collectionManager.create_collection(schema).get();
for(size_t i = 0; i < 5; i++) {
nlohmann::json doc;
doc["title"] = "Title " + std::to_string(i);
doc["val"] = 0;
doc["points"] = i;
doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas";
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}
auto sort_fields = {
sort_by("val", "DESC"),
sort_by("_eval(brand:nike)", "DESC"),
sort_by("points", "DESC"),
};
auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> expected_ids = {"3", "0", "4", "2", "1"};
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
// eval expression as 3rd sorting argument
sort_fields = {
sort_by("val", "DESC"),
sort_by("val", "DESC"),
sort_by("_eval(brand:nike)", "DESC"),
};
results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get();
ASSERT_EQ(5, results["hits"].size());
for(size_t i = 0; i > expected_ids.size(); i++) {
ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get<std::string>());
}
}

View File

@ -241,7 +241,7 @@ TEST(TokenizerTest, ShouldTokenizeLocaleText) {
ASSERT_EQ(4, tokens.size());
ASSERT_EQ("จิ้งจอก", tokens[0]);
ASSERT_EQ("สี", tokens[1]);
ASSERT_EQ("น้ตาล", tokens[2]);
ASSERT_EQ("น้ําตาล", tokens[2]);
ASSERT_EQ("ด่วน", tokens[3]);
tokens.clear();