Support sort_by in overrides.

This commit is contained in:
Kishore Nallan 2022-05-23 13:40:11 +05:30
parent 1d0917dc41
commit 0da51cb874
5 changed files with 130 additions and 36 deletions

View File

@ -127,7 +127,8 @@ private:
const std::vector<std::string>& hidden_hits,
std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
std::vector<uint32_t>& excluded_ids, std::vector<const override_t*>& filter_overrides,
bool& filter_curated_hits) const;
bool& filter_curated_hits,
std::string& curated_sort_by) const;
static Option<bool> detect_new_fields(nlohmann::json& document,
const DIRTY_VALUES& dirty_values,

View File

@ -108,6 +108,8 @@ struct override_t {
bool remove_matched_tokens = false;
bool filter_curated_hits = false;
std::string sort_by;
override_t() = default;
static Option<bool> parse(const nlohmann::json& override_json, const std::string& id, override_t& override) {
@ -124,9 +126,10 @@ struct override_t {
}
if(override_json.count("includes") == 0 && override_json.count("excludes") == 0 &&
override_json.count("filter_by") == 0 && override_json.count("remove_matched_tokens") == 0) {
override_json.count("filter_by") == 0 && override_json.count("sort_by") == 0 &&
override_json.count("remove_matched_tokens") == 0) {
return Option<bool>(400, "Must contain one of: `includes`, `excludes`, "
"`filter_by`, `remove_matched_tokens`.");
"`filter_by`, `sort_by`, `remove_matched_tokens`.");
}
if(override_json.count("includes") != 0) {
@ -228,6 +231,10 @@ struct override_t {
override.filter_by = override_json["filter_by"].get<std::string>();
}
if (override_json.count("sort_by") != 0) {
override.sort_by = override_json["sort_by"].get<std::string>();
}
if(override_json.count("remove_matched_tokens") != 0) {
override.remove_matched_tokens = override_json["remove_matched_tokens"].get<bool>();
} else {
@ -286,6 +293,10 @@ struct override_t {
override["filter_by"] = filter_by;
}
if(!sort_by.empty()) {
override["sort_by"] = sort_by;
}
override["remove_matched_tokens"] = remove_matched_tokens;
override["filter_curated_hits"] = filter_curated_hits;
@ -323,7 +334,6 @@ struct search_args {
size_t all_result_ids_len;
bool exhaustive_search;
size_t concurrency;
const std::vector<const override_t*>& filter_overrides;
size_t search_cutoff_ms;
size_t min_len_1typo;
size_t min_len_2typo;
@ -351,7 +361,7 @@ struct search_args {
const std::vector<bool>& prefixes, size_t drop_tokens_threshold, size_t typo_tokens_threshold,
const std::vector<std::string>& group_by_fields, size_t group_limit,
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
size_t concurrency, const std::vector<const override_t*>& dynamic_overrides, size_t search_cutoff_ms,
size_t concurrency, size_t search_cutoff_ms,
size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector<infix_t>& infixes,
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
const bool filter_curated_hits, const bool split_join_tokens) :
@ -364,7 +374,7 @@ struct search_args {
group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field),
prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0),
exhaustive_search(exhaustive_search), concurrency(concurrency),
filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms),
search_cutoff_ms(search_cutoff_ms),
min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates),
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix),
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits),
@ -769,7 +779,7 @@ public:
tsl::htrie_map<char, token_leaf>& qtoken_set,
std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<std::vector<KV*>>& override_result_kvs,
const size_t typo_tokens_threshold, const size_t group_limit,
const std::vector<std::string>& group_by_fields, const std::vector<const override_t*>& filter_overrides,
const std::vector<std::string>& group_by_fields,
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
size_t max_candidates, const std::vector<infix_t>& infixes, const size_t max_extra_prefix,

View File

@ -419,7 +419,8 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
std::vector<uint32_t>& excluded_ids,
std::vector<const override_t*>& filter_overrides,
bool& filter_curated_hits) const {
bool& filter_curated_hits,
std::string& curated_sort_by) const {
std::set<uint32_t> excluded_set;
@ -486,6 +487,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
}
filter_curated_hits = override.filter_curated_hits;
curated_sort_by = override.sort_by;
}
}
}
@ -915,29 +917,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
}
}
// validate sort fields and standardize
std::vector<sort_by> sort_fields_std;
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
// apply bucketing on text match score
int match_score_index = -1;
for(size_t i = 0; i < sort_fields_std.size(); i++) {
if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) {
match_score_index = i;
if(sort_fields_std[i].text_match_buckets > 1) {
// we will disable prioritize exact match because it's incompatible with bucketing
prioritize_exact_match = false;
}
break;
}
}
// check for valid pagination
if(page < 1) {
std::string message = "Page must be an integer of value greater than 0.";
@ -994,8 +973,9 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
std::vector<const override_t*> filter_overrides;
std::string query = raw_query;
bool filter_curated_hits = false;
std::string curated_sort_by;
curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits,
included_ids, excluded_ids, filter_overrides, filter_curated_hits);
included_ids, excluded_ids, filter_overrides, filter_curated_hits, curated_sort_by);
if(filter_curated_hits_option == 0 || filter_curated_hits_option == 1) {
// When query param has explicit value set, override level configuration takes lower precedence.
@ -1025,6 +1005,43 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
}
*/
// validate sort fields and standardize
std::vector<sort_by> sort_fields_std;
if(curated_sort_by.empty()) {
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
} else {
std::vector<sort_by> curated_sort_fields;
bool parsed_sort_by = CollectionManager::parse_sort_by_str(curated_sort_by, curated_sort_fields);
if(!parsed_sort_by) {
return Option<nlohmann::json>(400, "Parameter `sort_by` is malformed.");
}
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, sort_fields_std);
if(!sort_validation_op.ok()) {
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
}
}
// apply bucketing on text match score
int match_score_index = -1;
for(size_t i = 0; i < sort_fields_std.size(); i++) {
if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) {
match_score_index = i;
if(sort_fields_std[i].text_match_buckets > 1) {
// we will disable prioritize exact match because it's incompatible with bucketing
prioritize_exact_match = false;
}
break;
}
}
//LOG(INFO) << "Num indices used for querying: " << indices.size();
std::vector<query_tokens_t> field_query_tokens;
std::vector<std::string> q_tokens; // used for auxillary highlighting
@ -1081,7 +1098,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
drop_tokens_threshold, typo_tokens_threshold,
group_by_fields, group_limit, default_sorting_field,
prioritize_exact_match,
exhaustive_search, 4, filter_overrides,
exhaustive_search, 4,
search_stop_millis,
min_len_1typo, min_len_2typo, max_candidates, infixes,
max_extra_prefix, max_extra_suffix, facet_query_num_typos,

View File

@ -1841,7 +1841,6 @@ void Index::run_search(search_args* search_params) {
search_params->raw_result_kvs, search_params->override_result_kvs,
search_params->typo_tokens_threshold,
search_params->group_limit, search_params->group_by_fields,
search_params->filter_overrides,
search_params->default_sorting_field,
search_params->prioritize_exact_match,
search_params->exhaustive_search,
@ -2271,7 +2270,6 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<std::vector<KV*>>& override_result_kvs,
const size_t typo_tokens_threshold, const size_t group_limit,
const std::vector<std::string>& group_by_fields,
const std::vector<const override_t*>& filter_overrides,
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
size_t max_candidates, const std::vector<infix_t>& infixes, const size_t max_extra_prefix,

View File

@ -271,7 +271,8 @@ TEST_F(CollectionOverrideTest, OverrideJSONValidation) {
parse_op = override_t::parse(include_json2, "", override2);
ASSERT_FALSE(parse_op.ok());
ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `remove_matched_tokens`.", parse_op.error().c_str());
ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`.",
parse_op.error().c_str());
include_json2["includes"] = nlohmann::json::array();
include_json2["includes"][0] = 100;
@ -2062,6 +2063,73 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithJustRemoveTokens) {
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionOverrideTest, StaticSorting) {
Collection *coll1;
std::vector<field> fields = {field("name", field_types::STRING, false),
field("price", field_types::FLOAT, true),
field("points", field_types::INT32, false)};
coll1 = collectionManager.get_collection("coll1").get();
if(coll1 == nullptr) {
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
}
nlohmann::json doc1;
doc1["id"] = "0";
doc1["name"] = "Amazing Shoes";
doc1["price"] = 399.99;
doc1["points"] = 3;
nlohmann::json doc2;
doc2["id"] = "1";
doc2["name"] = "Track Shoes";
doc2["price"] = 49.99;
doc2["points"] = 5;
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
std::vector<sort_by> sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") };
nlohmann::json override_json_contains = {
{"id", "static-sort"},
{
"rule", {
{"query", "shoes"},
{"match", override_t::MATCH_CONTAINS}
}
},
{"remove_matched_tokens", true},
{"sort_by", "price:desc"}
};
override_t override_contains;
auto op = override_t::parse(override_json_contains, "static-sort", override_contains);
ASSERT_TRUE(op.ok());
// without override kicking in
auto results = coll1->search("shoes", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
// now add override
coll1->add_override(override_contains);
results = coll1->search("shoes", {"name"}, "",
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
// with override we will sort on price
ASSERT_EQ(2, results["hits"].size());
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][1]["document"]["id"].get<std::string>());
collectionManager.drop_collection("coll1");
}
TEST_F(CollectionOverrideTest, DynamicFilteringWithPartialTokenMatch) {
// when query tokens do not match placeholder field value exactly, don't do filtering
Collection* coll1;