From 0da51cb874b8c3d52004dccb37c6ea416cb38838 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Mon, 23 May 2022 13:40:11 +0530 Subject: [PATCH] Support sort_by in overrides. --- include/collection.h | 3 +- include/index.h | 22 +++++++--- src/collection.cpp | 69 ++++++++++++++++++------------ src/index.cpp | 2 - test/collection_override_test.cpp | 70 ++++++++++++++++++++++++++++++- 5 files changed, 130 insertions(+), 36 deletions(-) diff --git a/include/collection.h b/include/collection.h index bb084c2f..0ec47c93 100644 --- a/include/collection.h +++ b/include/collection.h @@ -127,7 +127,8 @@ private: const std::vector& hidden_hits, std::vector>& included_ids, std::vector& excluded_ids, std::vector& filter_overrides, - bool& filter_curated_hits) const; + bool& filter_curated_hits, + std::string& curated_sort_by) const; static Option detect_new_fields(nlohmann::json& document, const DIRTY_VALUES& dirty_values, diff --git a/include/index.h b/include/index.h index 57e2add2..3b5f8a49 100644 --- a/include/index.h +++ b/include/index.h @@ -108,6 +108,8 @@ struct override_t { bool remove_matched_tokens = false; bool filter_curated_hits = false; + std::string sort_by; + override_t() = default; static Option parse(const nlohmann::json& override_json, const std::string& id, override_t& override) { @@ -124,9 +126,10 @@ struct override_t { } if(override_json.count("includes") == 0 && override_json.count("excludes") == 0 && - override_json.count("filter_by") == 0 && override_json.count("remove_matched_tokens") == 0) { + override_json.count("filter_by") == 0 && override_json.count("sort_by") == 0 && + override_json.count("remove_matched_tokens") == 0) { return Option(400, "Must contain one of: `includes`, `excludes`, " - "`filter_by`, `remove_matched_tokens`."); + "`filter_by`, `sort_by`, `remove_matched_tokens`."); } if(override_json.count("includes") != 0) { @@ -228,6 +231,10 @@ struct override_t { override.filter_by = override_json["filter_by"].get(); } + if (override_json.count("sort_by") != 0) { + override.sort_by = override_json["sort_by"].get(); + } + if(override_json.count("remove_matched_tokens") != 0) { override.remove_matched_tokens = override_json["remove_matched_tokens"].get(); } else { @@ -286,6 +293,10 @@ struct override_t { override["filter_by"] = filter_by; } + if(!sort_by.empty()) { + override["sort_by"] = sort_by; + } + override["remove_matched_tokens"] = remove_matched_tokens; override["filter_curated_hits"] = filter_curated_hits; @@ -323,7 +334,6 @@ struct search_args { size_t all_result_ids_len; bool exhaustive_search; size_t concurrency; - const std::vector& filter_overrides; size_t search_cutoff_ms; size_t min_len_1typo; size_t min_len_2typo; @@ -351,7 +361,7 @@ struct search_args { const std::vector& prefixes, size_t drop_tokens_threshold, size_t typo_tokens_threshold, const std::vector& group_by_fields, size_t group_limit, const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search, - size_t concurrency, const std::vector& dynamic_overrides, size_t search_cutoff_ms, + size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos, const bool filter_curated_hits, const bool split_join_tokens) : @@ -364,7 +374,7 @@ struct search_args { group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field), prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0), exhaustive_search(exhaustive_search), concurrency(concurrency), - filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms), + search_cutoff_ms(search_cutoff_ms), min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates), infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix), facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits), @@ -769,7 +779,7 @@ public: tsl::htrie_map& qtoken_set, std::vector>& raw_result_kvs, std::vector>& override_result_kvs, const size_t typo_tokens_threshold, const size_t group_limit, - const std::vector& group_by_fields, const std::vector& filter_overrides, + const std::vector& group_by_fields, const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search, size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, diff --git a/src/collection.cpp b/src/collection.cpp index e22bdde1..0b3a0894 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -419,7 +419,8 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo std::vector>& included_ids, std::vector& excluded_ids, std::vector& filter_overrides, - bool& filter_curated_hits) const { + bool& filter_curated_hits, + std::string& curated_sort_by) const { std::set excluded_set; @@ -486,6 +487,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo } filter_curated_hits = override.filter_curated_hits; + curated_sort_by = override.sort_by; } } } @@ -915,29 +917,6 @@ Option Collection::search(const std::string & raw_query, const s } } - // validate sort fields and standardize - - std::vector sort_fields_std; - auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std); - if(!sort_validation_op.ok()) { - return Option(sort_validation_op.code(), sort_validation_op.error()); - } - - // apply bucketing on text match score - int match_score_index = -1; - for(size_t i = 0; i < sort_fields_std.size(); i++) { - if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) { - match_score_index = i; - - if(sort_fields_std[i].text_match_buckets > 1) { - // we will disable prioritize exact match because it's incompatible with bucketing - prioritize_exact_match = false; - } - - break; - } - } - // check for valid pagination if(page < 1) { std::string message = "Page must be an integer of value greater than 0."; @@ -994,8 +973,9 @@ Option Collection::search(const std::string & raw_query, const s std::vector filter_overrides; std::string query = raw_query; bool filter_curated_hits = false; + std::string curated_sort_by; curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits, - included_ids, excluded_ids, filter_overrides, filter_curated_hits); + included_ids, excluded_ids, filter_overrides, filter_curated_hits, curated_sort_by); if(filter_curated_hits_option == 0 || filter_curated_hits_option == 1) { // When query param has explicit value set, override level configuration takes lower precedence. @@ -1025,6 +1005,43 @@ Option Collection::search(const std::string & raw_query, const s } */ + // validate sort fields and standardize + + std::vector sort_fields_std; + + if(curated_sort_by.empty()) { + auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std); + if(!sort_validation_op.ok()) { + return Option(sort_validation_op.code(), sort_validation_op.error()); + } + } else { + std::vector curated_sort_fields; + bool parsed_sort_by = CollectionManager::parse_sort_by_str(curated_sort_by, curated_sort_fields); + if(!parsed_sort_by) { + return Option(400, "Parameter `sort_by` is malformed."); + } + + auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, sort_fields_std); + if(!sort_validation_op.ok()) { + return Option(sort_validation_op.code(), sort_validation_op.error()); + } + } + + // apply bucketing on text match score + int match_score_index = -1; + for(size_t i = 0; i < sort_fields_std.size(); i++) { + if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) { + match_score_index = i; + + if(sort_fields_std[i].text_match_buckets > 1) { + // we will disable prioritize exact match because it's incompatible with bucketing + prioritize_exact_match = false; + } + + break; + } + } + //LOG(INFO) << "Num indices used for querying: " << indices.size(); std::vector field_query_tokens; std::vector q_tokens; // used for auxillary highlighting @@ -1081,7 +1098,7 @@ Option Collection::search(const std::string & raw_query, const s drop_tokens_threshold, typo_tokens_threshold, group_by_fields, group_limit, default_sorting_field, prioritize_exact_match, - exhaustive_search, 4, filter_overrides, + exhaustive_search, 4, search_stop_millis, min_len_1typo, min_len_2typo, max_candidates, infixes, max_extra_prefix, max_extra_suffix, facet_query_num_typos, diff --git a/src/index.cpp b/src/index.cpp index 65e6d9e9..63fd7fb9 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1841,7 +1841,6 @@ void Index::run_search(search_args* search_params) { search_params->raw_result_kvs, search_params->override_result_kvs, search_params->typo_tokens_threshold, search_params->group_limit, search_params->group_by_fields, - search_params->filter_overrides, search_params->default_sorting_field, search_params->prioritize_exact_match, search_params->exhaustive_search, @@ -2271,7 +2270,6 @@ void Index::search(std::vector& field_query_tokens, const std::v std::vector>& raw_result_kvs, std::vector>& override_result_kvs, const size_t typo_tokens_threshold, const size_t group_limit, const std::vector& group_by_fields, - const std::vector& filter_overrides, const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search, size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector& infixes, const size_t max_extra_prefix, diff --git a/test/collection_override_test.cpp b/test/collection_override_test.cpp index df3c131f..90716096 100644 --- a/test/collection_override_test.cpp +++ b/test/collection_override_test.cpp @@ -271,7 +271,8 @@ TEST_F(CollectionOverrideTest, OverrideJSONValidation) { parse_op = override_t::parse(include_json2, "", override2); ASSERT_FALSE(parse_op.ok()); - ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `remove_matched_tokens`.", parse_op.error().c_str()); + ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`.", + parse_op.error().c_str()); include_json2["includes"] = nlohmann::json::array(); include_json2["includes"][0] = 100; @@ -2062,6 +2063,73 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithJustRemoveTokens) { collectionManager.drop_collection("coll1"); } +TEST_F(CollectionOverrideTest, StaticSorting) { + Collection *coll1; + + std::vector fields = {field("name", field_types::STRING, false), + field("price", field_types::FLOAT, true), + field("points", field_types::INT32, false)}; + + coll1 = collectionManager.get_collection("coll1").get(); + if(coll1 == nullptr) { + coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get(); + } + + nlohmann::json doc1; + doc1["id"] = "0"; + doc1["name"] = "Amazing Shoes"; + doc1["price"] = 399.99; + doc1["points"] = 3; + + nlohmann::json doc2; + doc2["id"] = "1"; + doc2["name"] = "Track Shoes"; + doc2["price"] = 49.99; + doc2["points"] = 5; + + ASSERT_TRUE(coll1->add(doc1.dump()).ok()); + ASSERT_TRUE(coll1->add(doc2.dump()).ok()); + + std::vector sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") }; + + nlohmann::json override_json_contains = { + {"id", "static-sort"}, + { + "rule", { + {"query", "shoes"}, + {"match", override_t::MATCH_CONTAINS} + } + }, + {"remove_matched_tokens", true}, + {"sort_by", "price:desc"} + }; + + override_t override_contains; + auto op = override_t::parse(override_json_contains, "static-sort", override_contains); + ASSERT_TRUE(op.ok()); + + // without override kicking in + auto results = coll1->search("shoes", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("1", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("0", results["hits"][1]["document"]["id"].get()); + + // now add override + coll1->add_override(override_contains); + + results = coll1->search("shoes", {"name"}, "", + {}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get(); + + // with override we will sort on price + ASSERT_EQ(2, results["hits"].size()); + ASSERT_EQ("0", results["hits"][0]["document"]["id"].get()); + ASSERT_EQ("1", results["hits"][1]["document"]["id"].get()); + + collectionManager.drop_collection("coll1"); +} + TEST_F(CollectionOverrideTest, DynamicFilteringWithPartialTokenMatch) { // when query tokens do not match placeholder field value exactly, don't do filtering Collection* coll1;