From db3c0144635a3498cdbe9579ba9b3ccd9c34784c Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 4 Aug 2022 14:47:02 +0530 Subject: [PATCH 1/5] Implement optional filtering via sort_by clause. --- include/field.h | 8 ++ include/index.h | 13 +- src/collection.cpp | 25 +++- src/index.cpp | 81 +++++++++-- test/collection_manager_test.cpp | 8 ++ test/collection_sorting_test.cpp | 239 +++++++++++++++++++++++++++++++ 6 files changed, 357 insertions(+), 17 deletions(-) diff --git a/include/field.h b/include/field.h index c18179af..2f558129 100644 --- a/include/field.h +++ b/include/field.h @@ -442,6 +442,7 @@ namespace sort_field_const { static const std::string desc = "DESC"; static const std::string text_match = "_text_match"; + static const std::string eval = "_eval"; static const std::string seq_id = "_seq_id"; static const std::string exclude_radius = "exclude_radius"; @@ -457,6 +458,12 @@ struct sort_by { normal, }; + struct eval_t { + std::vector filters; + uint32_t* ids = nullptr; + uint32_t size = 0; + }; + std::string name; std::string order; @@ -469,6 +476,7 @@ struct sort_by { uint32_t geo_precision; missing_values_t missing_values; + eval_t eval; sort_by(const std::string & name, const std::string & order): name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0), diff --git a/include/index.h b/include/index.h index 896393c6..26e94e84 100644 --- a/include/index.h +++ b/include/index.h @@ -517,6 +517,7 @@ private: static spp::sparse_hash_map text_match_sentinel_value; static spp::sparse_hash_map seq_id_sentinel_value; + static spp::sparse_hash_map eval_sentinel_value; static spp::sparse_hash_map geo_sentinel_value; static spp::sparse_hash_map str_sentinel_value; @@ -565,7 +566,7 @@ private: const field& the_field, const std::string& field_name, const uint32_t *filter_ids, size_t filter_ids_length, const std::vector& curated_ids, - const std::vector & sort_fields, + std::vector & sort_fields, int last_typo, int max_typos, std::vector> & searched_queries, @@ -619,7 +620,7 @@ private: const uint32_t* filter_ids, size_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::vector& curated_ids, - const std::vector & sort_fields, std::vector & token_to_candidates, + std::vector & sort_fields, std::vector & token_to_candidates, std::vector> & searched_queries, Topster* topster, spp::sparse_hash_set& groups_processed, uint32_t** all_result_ids, @@ -788,7 +789,7 @@ public: void search(std::vector& field_query_tokens, const std::vector& the_fields, std::vector& filters, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, - const std::vector& excluded_ids, const std::vector& sort_fields_std, + const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, const size_t per_page, const size_t page, const token_ordering token_order, const std::vector& prefixes, @@ -878,7 +879,7 @@ public: uint32_t& filter_ids_length, const std::vector& curated_ids_sorted) const; void populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, - const std::vector& sort_fields_std, + std::vector& sort_fields_std, std::array*, 3>& field_values) const; static void remove_matched_tokens(std::vector& tokens, const std::set& rule_token_set) ; @@ -1040,8 +1041,10 @@ public: void compute_sort_scores(const std::vector& sort_fields, const int* sort_order, std::array*, 3> field_values, const std::vector& geopoint_indices, uint32_t seq_id, + size_t filter_index, int64_t max_field_match_score, - int64_t* scores, int64_t& match_score_index) const; + int64_t* scores, + int64_t& match_score_index) const; void process_curated_ids(const std::vector>& included_ids, diff --git a/src/collection.cpp b/src/collection.cpp index 03c585cc..00de1d2d 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -523,7 +523,10 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo Option Collection::validate_and_standardize_sort_fields(const std::vector & sort_fields, std::vector& sort_fields_std) const { - for(const sort_by& _sort_field: sort_fields) { + size_t num_sort_expressions = 0; + + for(size_t i = 0; i < sort_fields.size(); i++) { + const sort_by& _sort_field = sort_fields[i]; sort_by sort_field_std(_sort_field.name, _sort_field.order); if(sort_field_std.name.back() == ')') { @@ -551,6 +554,19 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< sort_field_std.name = actual_field_name; sort_field_std.text_match_buckets = std::stoll(match_parts[1]); + } else if(actual_field_name == sort_field_const::eval) { + const std::string& filter_exp = sort_field_std.name.substr(paran_start + 1, + sort_field_std.name.size() - paran_start - + 2); + Option parse_filter_op = filter::parse_filter_query(filter_exp, search_schema, + store, "", sort_field_std.eval.filters); + if(!parse_filter_op.ok()) { + return Option(parse_filter_op.code(), "Error parsing eval expression in sort_by clause."); + } + + sort_field_std.name = actual_field_name; + num_sort_expressions++; + } else { if(field_it == search_schema.end()) { std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting."; @@ -674,7 +690,7 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< } } - if(sort_field_std.name != sort_field_const::text_match) { + if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval) { const auto field_it = search_schema.find(sort_field_std.name); if(field_it == search_schema.end() || !field_it->second.sort || !field_it->second.index) { std::string error = "Could not find a field named `" + sort_field_std.name + @@ -725,6 +741,11 @@ Option Collection::validate_and_standardize_sort_fields(const std::vector< return Option(422, message); } + if(num_sort_expressions > 1) { + std::string message = "Only one sorting eval expression is allowed."; + return Option(422, message); + } + return Option(true); } diff --git a/src/index.cpp b/src/index.cpp index 5e2413a5..24850e07 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -37,6 +37,7 @@ spp::sparse_hash_map Index::text_match_sentinel_value; spp::sparse_hash_map Index::seq_id_sentinel_value; +spp::sparse_hash_map Index::eval_sentinel_value; spp::sparse_hash_map Index::geo_sentinel_value; spp::sparse_hash_map Index::str_sentinel_value; @@ -1306,7 +1307,7 @@ void Index::search_candidates(const uint8_t & field_id, bool field_is_array, const uint32_t* filter_ids, size_t filter_ids_length, const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, const std::vector& curated_ids, - const std::vector & sort_fields, + std::vector & sort_fields, std::vector & token_candidates_vec, std::vector> & searched_queries, Topster* topster, @@ -2144,8 +2145,9 @@ bool Index::check_for_overrides(const token_ordering& token_order, const string& continue; } + std::vector sort_fields; search_field(0, window_tokens, nullptr, 0, num_toks_dropped, field_it->second, field_name, - nullptr, 0, {}, {}, -1, 0, searched_queries, topster, groups_processed, + nullptr, 0, {}, sort_fields, -1, 0, searched_queries, topster, groups_processed, &result_ids, result_ids_len, field_num_results, 0, group_by_fields, false, 4, query_hashes, token_order, false, 0, 0, false, -1, 3, 7, 4); @@ -2280,7 +2282,7 @@ void Index::search_infix(const std::string& query, const std::string& field_name void Index::search(std::vector& field_query_tokens, const std::vector& the_fields, std::vector& filters, std::vector& facets, facet_query_t& facet_query, const std::vector>& included_ids, - const std::vector& excluded_ids, const std::vector& sort_fields_std, + const std::vector& excluded_ids, std::vector& sort_fields_std, const std::vector& num_typos, Topster* topster, Topster* curated_topster, const size_t per_page, const size_t page, const token_ordering token_order, const std::vector& prefixes, @@ -3109,6 +3111,7 @@ void Index::search_across_fields(const std::vector& query_tokens, } std::vector result_ids; + size_t filter_index = 0; or_iterator_t::intersect(token_its, istate, [&](uint32_t seq_id, const std::vector& its) { //LOG(INFO) << "seq_id: " << seq_id; @@ -3169,7 +3172,7 @@ void Index::search_across_fields(const std::vector& query_tokens, int64_t scores[3] = {0}; int64_t match_score_index = -1; - compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, + compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index, max_field_match_score, scores, match_score_index); size_t query_len = query_tokens.size(); @@ -3238,7 +3241,7 @@ void Index::search_across_fields(const std::vector& query_tokens, void Index::compute_sort_scores(const std::vector& sort_fields, const int* sort_order, std::array*, 3> field_values, const std::vector& geopoint_indices, - uint32_t seq_id, int64_t max_field_match_score, + uint32_t seq_id, size_t filter_index, int64_t max_field_match_score, int64_t* scores, int64_t& match_score_index) const { int64_t geopoint_distances[3]; @@ -3317,6 +3320,23 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[0] = -scores[0]; } } + } else if(field_values[0] == &eval_sentinel_value) { + // Returns iterator to the first element that is >= to value or last if no such element is found. + bool found = false; + if (filter_index == 0 || filter_index < sort_fields[0].eval.size) { + size_t found_index = std::lower_bound(sort_fields[0].eval.ids + filter_index, + sort_fields[0].eval.ids + sort_fields[0].eval.size, seq_id) - + sort_fields[0].eval.ids; + + if (found_index != sort_fields[0].eval.size && sort_fields[0].eval.ids[found_index] == seq_id) { + filter_index = found_index + 1; + found = true; + } + + filter_index = found_index; + } + + scores[0] = int64_t(found); } else { auto it = field_values[0]->find(seq_id); scores[0] = (it == field_values[0]->end()) ? default_score : it->second; @@ -3356,6 +3376,23 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[1] = -scores[1]; } } + } else if(field_values[1] == &eval_sentinel_value) { + // Returns iterator to the first element that is >= to value or last if no such element is found. + bool found = false; + if (filter_index == 0 || filter_index < sort_fields[1].eval.size) { + size_t found_index = std::lower_bound(sort_fields[1].eval.ids + filter_index, + sort_fields[1].eval.ids + sort_fields[1].eval.size, seq_id) - + sort_fields[1].eval.ids; + + if (found_index != sort_fields[1].eval.size && sort_fields[1].eval.ids[found_index] == seq_id) { + filter_index = found_index + 1; + found = true; + } + + filter_index = found_index; + } + + scores[1] = int64_t(found); } else { auto it = field_values[1]->find(seq_id); scores[1] = (it == field_values[1]->end()) ? default_score : it->second; @@ -3391,6 +3428,23 @@ void Index::compute_sort_scores(const std::vector& sort_fields, const i scores[2] = -scores[2]; } } + } else if(field_values[2] == &eval_sentinel_value) { + // Returns iterator to the first element that is >= to value or last if no such element is found. + bool found = false; + if (filter_index == 0 || filter_index < sort_fields[2].eval.size) { + size_t found_index = std::lower_bound(sort_fields[2].eval.ids + filter_index, + sort_fields[2].eval.ids + sort_fields[2].eval.size, seq_id) - + sort_fields[2].eval.ids; + + if (found_index != sort_fields[2].eval.size && sort_fields[2].eval.ids[found_index] == seq_id) { + filter_index = found_index + 1; + found = true; + } + + filter_index = found_index; + } + + scores[0] = int64_t(found); } else { auto it = field_values[2]->find(seq_id); scores[2] = (it == field_values[2]->end()) ? default_score : it->second; @@ -3598,6 +3652,7 @@ void Index::do_infix_search(const size_t num_search_fields, const std::vector& facets, facet_query_t& size_t field_num_results = 0; std::set query_hashes; size_t num_toks_dropped = 0; + std::vector sort_fields; search_field(0, qtokens, nullptr, 0, num_toks_dropped, facet_field, facet_field.faceted_name(), - all_result_ids, all_result_ids_len, {}, {}, -1, facet_query_num_typos, searched_queries, topster, + all_result_ids, all_result_ids_len, {}, sort_fields, -1, facet_query_num_typos, searched_queries, topster, groups_processed, &field_result_ids, field_result_ids_len, field_num_results, 0, group_by_fields, false, 4, query_hashes, MAX_SCORE, true, 0, 1, false, -1, 3, 1000, max_candidates); @@ -3927,6 +3983,8 @@ void Index::search_wildcard(const std::vector& filters, search_stop_ms = parent_search_stop_ms; search_cutoff = parent_search_cutoff; + size_t filter_index = 0; + for(size_t i = 0; i < batch_res_len; i++) { const uint32_t seq_id = batch_result_ids[i]; int64_t match_score = 0; @@ -3937,7 +3995,7 @@ void Index::search_wildcard(const std::vector& filters, int64_t scores[3] = {0}; int64_t match_score_index = 0; - compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, + compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices, seq_id, filter_index, 100, scores, match_score_index); uint64_t distinct_id = seq_id; @@ -3989,7 +4047,7 @@ void Index::search_wildcard(const std::vector& filters, } void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint_indices, - const std::vector& sort_fields_std, + std::vector& sort_fields_std, std::array*, 3>& field_values) const { for (size_t i = 0; i < sort_fields_std.size(); i++) { sort_order[i] = 1; @@ -4001,6 +4059,9 @@ void Index::populate_sort_mapping(int* sort_order, std::vector& geopoint field_values[i] = &text_match_sentinel_value; } else if (sort_fields_std[i].name == sort_field_const::seq_id) { field_values[i] = &seq_id_sentinel_value; + } else if (sort_fields_std[i].name == sort_field_const::eval) { + field_values[i] = &eval_sentinel_value; + do_filtering(sort_fields_std[i].eval.ids, sort_fields_std[i].eval.size, sort_fields_std[i].eval.filters, true); } else if (search_schema.count(sort_fields_std[i].name) != 0 && search_schema.at(sort_fields_std[i].name).sort) { if (search_schema.at(sort_fields_std[i].name).type == field_types::GEOPOINT_ARRAY) { geopoint_indices.push_back(i); @@ -4035,7 +4096,7 @@ void Index::search_field(const uint8_t & field_id, const field& the_field, const std::string& field_name, // to handle faceted index const uint32_t *filter_ids, size_t filter_ids_length, const std::vector& curated_ids, - const std::vector & sort_fields, + std::vector & sort_fields, const int last_typo, const int max_typos, std::vector> & searched_queries, diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index 2b40068f..d92b29f9 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -748,6 +748,14 @@ TEST_F(CollectionManagerTest, ParseSortByClause) { ASSERT_EQ("_text_match(buckets: 10)", sort_fields[0].name); ASSERT_EQ("ASC", sort_fields[0].order); + sort_fields.clear(); + sort_by_parsed = CollectionManager::parse_sort_by_str("_eval(brand:nike && foo:bar):DESC,points:desc ", sort_fields); + ASSERT_TRUE(sort_by_parsed); + ASSERT_EQ("_eval(brand:nike && foo:bar)", sort_fields[0].name); + ASSERT_EQ("DESC", sort_fields[0].order); + ASSERT_EQ("points", sort_fields[1].name); + ASSERT_EQ("DESC", sort_fields[1].order); + sort_fields.clear(); sort_by_parsed = CollectionManager::parse_sort_by_str("", sort_fields); ASSERT_TRUE(sort_by_parsed); diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 134d5a2a..4a7202ef 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -1810,3 +1810,242 @@ TEST_F(CollectionSortingTest, DisallowSortingOnNonIndexedIntegerField) { collectionManager.drop_collection("coll1"); } + +TEST_F(CollectionSortingTest, OptionalFilteringViaSortingWildcard) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "brand", "type": "string" }, + {"name": "points", "type": "int32" } + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + for(size_t i = 0; i < 5; i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = i; + + doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + auto sort_fields = { + sort_by("_eval(brand:nike)", "DESC"), + sort_by("points", "DESC"), + }; + + auto results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"3", "0", "4", "2", "1"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // compound query + sort_fields = { + sort_by("_eval(brand:nike && points:0)", "DESC"), + sort_by("points", "DESC"), + }; + + results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + ASSERT_EQ(5, results["hits"].size()); + + expected_ids = {"0", "4", "3", "2", "1"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // when no results are found for eval query + sort_fields = { + sort_by("_eval(brand:foobar)", "DESC"), + sort_by("points", "DESC"), + }; + + results = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(5, results["hits"].size()); + expected_ids = {"4", "3", "2", "1", "0"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // bad syntax for eval query + sort_fields = { + sort_by("_eval(brandnike || points:0)", "DESC"), + sort_by("points", "DESC"), + }; + + auto res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error()); + + // more bad syntax! + sort_fields = { + sort_by(")", "DESC"), + sort_by("points", "DESC"), + }; + + res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error()); + + // don't allow multiple sorting eval expressions + sort_fields = { + sort_by("_eval(brand: nike || points:0)", "DESC"), + sort_by("_eval(brand: nike || points:0)", "DESC"), + }; + + res_op = coll1->search("*", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Only one sorting eval expression is allowed.", res_op.error()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSearch) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "brand", "type": "string" }, + {"name": "points", "type": "int32" } + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + for(size_t i = 0; i < 5; i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["points"] = i; + + doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + auto sort_fields = { + sort_by("_eval(brand:nike)", "DESC"), + sort_by("points", "DESC"), + }; + + auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"3", "0", "4", "2", "1"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // compound query + sort_fields = { + sort_by("_eval(brand:nike && points:0)", "DESC"), + sort_by("points", "DESC"), + }; + + results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + ASSERT_EQ(5, results["hits"].size()); + + expected_ids = {"0", "4", "3", "2", "1"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // when no results are found for eval query + sort_fields = { + sort_by("_eval(brand:foobar)", "DESC"), + sort_by("points", "DESC"), + }; + + results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(5, results["hits"].size()); + expected_ids = {"4", "3", "2", "1", "0"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // bad syntax for eval query + sort_fields = { + sort_by("_eval(brandnike || points:0)", "DESC"), + sort_by("points", "DESC"), + }; + + auto res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Error parsing eval expression in sort_by clause.", res_op.error()); + + // more bad syntax! + sort_fields = { + sort_by(")", "DESC"), + sort_by("points", "DESC"), + }; + + res_op = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10); + ASSERT_FALSE(res_op.ok()); + ASSERT_EQ("Could not find a field named `)` in the schema for sorting.", res_op.error()); + + collectionManager.drop_collection("coll1"); +} + +TEST_F(CollectionSortingTest, OptionalFilteringViaSortingSecondThirdParams) { + std::string coll_schema = R"( + { + "name": "coll1", + "fields": [ + {"name": "title", "type": "string" }, + {"name": "brand", "type": "string" }, + {"name": "points", "type": "int32" }, + {"name": "val", "type": "int32" } + ] + } + )"; + + nlohmann::json schema = nlohmann::json::parse(coll_schema); + Collection* coll1 = collectionManager.create_collection(schema).get(); + + for(size_t i = 0; i < 5; i++) { + nlohmann::json doc; + doc["title"] = "Title " + std::to_string(i); + doc["val"] = 0; + doc["points"] = i; + doc["brand"] = (i == 0 || i == 3) ? "Nike" : "Adidas"; + ASSERT_TRUE(coll1->add(doc.dump()).ok()); + } + + auto sort_fields = { + sort_by("val", "DESC"), + sort_by("_eval(brand:nike)", "DESC"), + sort_by("points", "DESC"), + }; + + auto results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + + ASSERT_EQ(5, results["hits"].size()); + std::vector expected_ids = {"3", "0", "4", "2", "1"}; + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } + + // eval expression as 3rd sorting argument + sort_fields = { + sort_by("val", "DESC"), + sort_by("val", "DESC"), + sort_by("_eval(brand:nike)", "DESC"), + }; + + results = coll1->search("title", {"title"}, "", {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 10).get(); + ASSERT_EQ(5, results["hits"].size()); + for(size_t i = 0; i > expected_ids.size(); i++) { + ASSERT_EQ(expected_ids[i], results["hits"][i]["document"]["id"].get()); + } +} From 6f38ee4270fba39db0e76f9ba210c982e652d7e7 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 4 Aug 2022 16:49:06 +0530 Subject: [PATCH 2/5] Support replacement of query in overrides. --- include/index.h | 17 +++++++-- src/collection.cpp | 4 ++- test/collection_override_test.cpp | 60 +++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/include/index.h b/include/index.h index 26e94e84..f3733db1 100644 --- a/include/index.h +++ b/include/index.h @@ -110,6 +110,7 @@ struct override_t { bool stop_processing = true; std::string sort_by; + std::string replace_query; override_t() = default; @@ -128,9 +129,10 @@ struct override_t { if(override_json.count("includes") == 0 && override_json.count("excludes") == 0 && override_json.count("filter_by") == 0 && override_json.count("sort_by") == 0 && - override_json.count("remove_matched_tokens") == 0) { + override_json.count("remove_matched_tokens") == 0 && + override_json.count("replace_query") == 0) { return Option(400, "Must contain one of: `includes`, `excludes`, " - "`filter_by`, `sort_by`, `remove_matched_tokens`."); + "`filter_by`, `sort_by`, `remove_matched_tokens`, `replace_query`."); } if(override_json.count("includes") != 0) { @@ -242,6 +244,13 @@ struct override_t { override.sort_by = override_json["sort_by"].get(); } + if (override_json.count("replace_query") != 0) { + if(override_json.count("remove_matched_tokens") != 0) { + return Option(400, "Only one of `replace_query` or `remove_matched_tokens` can be specified."); + } + override.replace_query = override_json["replace_query"].get(); + } + if(override_json.count("remove_matched_tokens") != 0) { override.remove_matched_tokens = override_json["remove_matched_tokens"].get(); } else { @@ -308,6 +317,10 @@ struct override_t { override["sort_by"] = sort_by; } + if(!replace_query.empty()) { + override["replace_query"] = replace_query; + } + override["remove_matched_tokens"] = remove_matched_tokens; override["filter_curated_hits"] = filter_curated_hits; override["stop_processing"] = stop_processing; diff --git a/src/collection.cpp b/src/collection.cpp index 00de1d2d..c57707e4 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -481,7 +481,9 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo } } - if(override.remove_matched_tokens && override.filter_by.empty()) { + if(!override.replace_query.empty()) { + actual_query = override.replace_query; + } else if(override.remove_matched_tokens && override.filter_by.empty()) { // don't prematurely remove tokens from query because dynamic filtering will require them StringUtils::replace_all(query, override.rule.query, ""); StringUtils::trim(query); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 27100308..5834ccea 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -781,6 +781,66 @@ TEST_F(CollectionOverrideTest, IncludeOverrideWithFilterBy) { ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); } +TEST_F(CollectionOverrideTest, ReplaceQuery) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["points"] = 30; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Fast Shoes"; + doc2["points"] = 50; + + nlohmann::json doc3; + doc3["id"] = "2"; + doc3["name"] = "Comfortable Socks"; + doc3["points"] = 1; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + ASSERT_TRUE(coll1->add(doc3.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json = R"({ + "id": "rule-1", + "rule": { + "query": "boots", + "match": "exact" + }, + "replace_query": "shoes" + })"_json; + + override_t override_rule; + auto op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + auto results = coll1->search("boots", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + + // don't allow both remove_matched_tokens and replace_query + override_json["remove_matched_tokens"] = true; + op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_FALSE(op.ok()); + ASSERT_EQ("Only one of `replace_query` or `remove_matched_tokens` can be specified.", op.error()); +} + TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) { auto pinned_hits = "13:1,4:2"; From 743abd461c58092122928293b57ea3e99ff89ac6 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 4 Aug 2022 17:09:55 +0530 Subject: [PATCH 3/5] Allow override to be active within specific time window. --- include/index.h | 20 +++++++++ src/collection.cpp | 9 ++++ test/collection_override_test.cpp | 74 ++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/include/index.h b/include/index.h index f3733db1..c5d71091 100644 --- a/include/index.h +++ b/include/index.h @@ -112,6 +112,10 @@ struct override_t { std::string sort_by; std::string replace_query; + // epoch seconds + int64_t window_start_ts = -1; + int64_t window_end_ts = -1; + override_t() = default; static Option parse(const nlohmann::json& override_json, const std::string& id, override_t& override) { @@ -265,6 +269,14 @@ struct override_t { override.stop_processing = override_json["stop_processing"].get(); } + if(override_json.count("window_start_ts") != 0) { + override.window_start_ts = override_json["window_start_ts"].get(); + } + + if(override_json.count("window_end_ts") != 0) { + override.window_end_ts = override_json["window_end_ts"].get(); + } + // we have to also detect if it is a dynamic query rule size_t i = 0; while(i < override.rule.query.size()) { @@ -321,6 +333,14 @@ struct override_t { override["replace_query"] = replace_query; } + if(window_start_ts != -1) { + override["window_start_ts"] = window_start_ts; + } + + if(window_end_ts != -1) { + override["window_end_ts"] = window_end_ts; + } + override["remove_matched_tokens"] = remove_matched_tokens; override["filter_curated_hits"] = filter_curated_hits; override["stop_processing"] = stop_processing; diff --git a/src/collection.cpp b/src/collection.cpp index c57707e4..a3b735af 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -451,6 +451,15 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo for(const auto& override_kv: overrides) { const auto& override = override_kv.second; + auto now_epoch = int64_t(std::time(0)); + if(override.window_start_ts != -1 && now_epoch < override.window_start_ts) { + continue; + } + + if(override.window_end_ts != -1 && now_epoch > override.window_end_ts) { + continue; + } + // ID-based overrides are applied first as they take precedence over filter-based overrides if(!override.filter_by.empty()) { filter_overrides.push_back(&override); diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 5834ccea..781948f4 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -271,7 +271,7 @@ TEST_F(CollectionOverrideTest, OverrideJSONValidation) { parse_op = override_t::parse(include_json2, "", override2); ASSERT_FALSE(parse_op.ok()); - ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`.", + ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`, `replace_query`.", parse_op.error().c_str()); include_json2["includes"] = nlohmann::json::array(); @@ -841,6 +841,78 @@ TEST_F(CollectionOverrideTest, ReplaceQuery) { ASSERT_EQ("Only one of `replace_query` or `remove_matched_tokens` can be specified.", op.error()); } +TEST_F(CollectionOverrideTest, WindowForRule) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["points"] = 30; + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json = R"({ + "id": "rule-1", + "rule": { + "query": "boots", + "match": "exact" + }, + "replace_query": "shoes" + })"_json; + + override_t override_rule; + auto op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + auto results = coll1->search("boots", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(1, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + + // rule must not match when window_start is set into the future + override_json["window_start_ts"] = 35677971263; // year 3100, here we come! ;) + op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + results = coll1->search("boots", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + ASSERT_EQ(0, results["hits"].size()); + + // rule must not match when window_end is set into the past + override_json["window_start_ts"] = -1; + override_json["window_end_ts"] = 965388863; + op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + results = coll1->search("boots", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + ASSERT_EQ(0, results["hits"].size()); + + // resetting both should bring the override back in action + override_json["window_start_ts"] = 965388863; + override_json["window_end_ts"] = 35677971263; + op = override_t::parse(override_json, "rule-1", override_rule); + ASSERT_TRUE(op.ok()); + coll1->add_override(override_rule); + + results = coll1->search("boots", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + ASSERT_EQ(1, results["hits"].size()); +} + TEST_F(CollectionOverrideTest, PinnedAndHiddenHits) { auto pinned_hits = "13:1,4:2"; From 6729b72b1a0a2980c6501c803dcb160ec86c151d Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Thu, 4 Aug 2022 19:52:14 +0530 Subject: [PATCH 4/5] Normalize thai text via nfkc. --- include/tokenizer.h | 5 +++-- src/tokenizer.cpp | 19 ++++++++++++++++++- test/collection_locale_test.cpp | 33 +++++++++++++++++++++++++++++++++ test/tokenizer_test.cpp | 2 +- 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/include/tokenizer.h b/include/tokenizer.h index 35195637..36f3af90 100644 --- a/include/tokenizer.h +++ b/include/tokenizer.h @@ -36,8 +36,9 @@ private: int32_t utf8_start_index = 0; char* normalized_text = nullptr; - // non-deletable singleton - const icu::Normalizer2* nfkd; + // non-deletable singletons + const icu::Normalizer2* nfkd = nullptr; + const icu::Normalizer2* nfkc = nullptr; icu::Transliterator* transliterator = nullptr; diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 3c9cb3d0..41aab66e 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -16,7 +16,14 @@ Tokenizer::Tokenizer(const std::string& input, bool normalize, bool no_op, const } UErrorCode errcode = U_ZERO_ERROR; - nfkd = icu::Normalizer2::getNFKDInstance(errcode); + + if(locale == "ko") { + nfkd = icu::Normalizer2::getNFKDInstance(errcode); + } + + if(locale == "th") { + nfkc = icu::Normalizer2::getNFKCInstance(errcode); + } cd = iconv_open("ASCII//TRANSLIT", "UTF-8"); @@ -119,6 +126,16 @@ bool Tokenizer::next(std::string &token, size_t& token_index, size_t& start_inde auto raw_text = unicode_text.tempSubStringBetween(start_pos, end_pos); transliterator->transliterate(raw_text); token = raw_text.toUTF8String(word); + } else if(locale == "th") { + UErrorCode errcode = U_ZERO_ERROR; + icu::UnicodeString src = unicode_text.tempSubStringBetween(start_pos, end_pos); + icu::UnicodeString dst; + nfkc->normalize(src, dst, errcode); + if(!U_FAILURE(errcode)) { + token = dst.toUTF8String(word); + } else { + LOG(ERROR) << "Unicode error during parsing: " << errcode; + } } else { token = unicode_text.tempSubStringBetween(start_pos, end_pos).toUTF8String(word); } diff --git a/test/collection_locale_test.cpp b/test/collection_locale_test.cpp index e76c5d2c..4940f2b9 100644 --- a/test/collection_locale_test.cpp +++ b/test/collection_locale_test.cpp @@ -187,6 +187,39 @@ TEST_F(CollectionLocaleTest, SearchAgainstThaiText) { ASSERT_EQ("พกไฟ\nเสมอ", results["hits"][0]["highlights"][0]["snippet"].get()); } +TEST_F(CollectionLocaleTest, ThaiTextShouldBeNormalizedToNFKC) { + Collection *coll1; + + std::vector fields = {field("title", field_types::STRING, false, false, true, "th"), + field("artist", field_types::STRING, false), + field("points", field_types::INT32, false),}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + std::vector> records = { + {"น้ำมัน", "Dustin Kensrue"}, + }; + + for(size_t i=0; iadd(doc.dump()).ok()); + } + + auto results = coll1->search("น้ํามัน",{"title"}, "", {}, {}, + {0}, 10, 1, FREQUENCY).get(); + + ASSERT_EQ(1, results["found"].get()); +} + TEST_F(CollectionLocaleTest, SearchThaiTextPreSegmentedQuery) { Collection *coll1; diff --git a/test/tokenizer_test.cpp b/test/tokenizer_test.cpp index 1fa43f78..258a7a17 100644 --- a/test/tokenizer_test.cpp +++ b/test/tokenizer_test.cpp @@ -241,7 +241,7 @@ TEST(TokenizerTest, ShouldTokenizeLocaleText) { ASSERT_EQ(4, tokens.size()); ASSERT_EQ("จิ้งจอก", tokens[0]); ASSERT_EQ("สี", tokens[1]); - ASSERT_EQ("น้ำตาล", tokens[2]); + ASSERT_EQ("น้ําตาล", tokens[2]); ASSERT_EQ("ด่วน", tokens[3]); tokens.clear(); From 460abfa69e39171194da4899327c68f361771601 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Fri, 5 Aug 2022 10:58:18 +0530 Subject: [PATCH 5/5] Tweak override windowing naming scheme. --- include/index.h | 20 ++++++++++---------- src/collection.cpp | 4 ++-- test/collection_override_test.cpp | 10 +++++----- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/index.h b/include/index.h index c5d71091..b9a2ca6b 100644 --- a/include/index.h +++ b/include/index.h @@ -113,8 +113,8 @@ struct override_t { std::string replace_query; // epoch seconds - int64_t window_start_ts = -1; - int64_t window_end_ts = -1; + int64_t effective_from_ts = -1; + int64_t effective_to_ts = -1; override_t() = default; @@ -269,12 +269,12 @@ struct override_t { override.stop_processing = override_json["stop_processing"].get(); } - if(override_json.count("window_start_ts") != 0) { - override.window_start_ts = override_json["window_start_ts"].get(); + if(override_json.count("effective_from_ts") != 0) { + override.effective_from_ts = override_json["effective_from_ts"].get(); } - if(override_json.count("window_end_ts") != 0) { - override.window_end_ts = override_json["window_end_ts"].get(); + if(override_json.count("effective_to_ts") != 0) { + override.effective_to_ts = override_json["effective_to_ts"].get(); } // we have to also detect if it is a dynamic query rule @@ -333,12 +333,12 @@ struct override_t { override["replace_query"] = replace_query; } - if(window_start_ts != -1) { - override["window_start_ts"] = window_start_ts; + if(effective_from_ts != -1) { + override["effective_from_ts"] = effective_from_ts; } - if(window_end_ts != -1) { - override["window_end_ts"] = window_end_ts; + if(effective_to_ts != -1) { + override["effective_to_ts"] = effective_to_ts; } override["remove_matched_tokens"] = remove_matched_tokens; diff --git a/src/collection.cpp b/src/collection.cpp index a3b735af..7bf0561b 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -452,11 +452,11 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo const auto& override = override_kv.second; auto now_epoch = int64_t(std::time(0)); - if(override.window_start_ts != -1 && now_epoch < override.window_start_ts) { + if(override.effective_from_ts != -1 && now_epoch < override.effective_from_ts) { continue; } - if(override.window_end_ts != -1 && now_epoch > override.window_end_ts) { + if(override.effective_to_ts != -1 && now_epoch > override.effective_to_ts) { continue; } diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index 781948f4..7e1cef25 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -881,7 +881,7 @@ TEST_F(CollectionOverrideTest, WindowForRule) { ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); // rule must not match when window_start is set into the future - override_json["window_start_ts"] = 35677971263; // year 3100, here we come! ;) + override_json["effective_from_ts"] = 35677971263; // year 3100, here we come! ;) op = override_t::parse(override_json, "rule-1", override_rule); ASSERT_TRUE(op.ok()); coll1->add_override(override_rule); @@ -891,8 +891,8 @@ TEST_F(CollectionOverrideTest, WindowForRule) { ASSERT_EQ(0, results["hits"].size()); // rule must not match when window_end is set into the past - override_json["window_start_ts"] = -1; - override_json["window_end_ts"] = 965388863; + override_json["effective_from_ts"] = -1; + override_json["effective_to_ts"] = 965388863; op = override_t::parse(override_json, "rule-1", override_rule); ASSERT_TRUE(op.ok()); coll1->add_override(override_rule); @@ -902,8 +902,8 @@ TEST_F(CollectionOverrideTest, WindowForRule) { ASSERT_EQ(0, results["hits"].size()); // resetting both should bring the override back in action - override_json["window_start_ts"] = 965388863; - override_json["window_end_ts"] = 35677971263; + override_json["effective_from_ts"] = 965388863; + override_json["effective_to_ts"] = 35677971263; op = override_t::parse(override_json, "rule-1", override_rule); ASSERT_TRUE(op.ok()); coll1->add_override(override_rule);