mirror of
https://github.com/typesense/typesense.git
synced 2025-05-20 05:32:30 +08:00
Support sort_by in overrides.
This commit is contained in:
parent
1d0917dc41
commit
0da51cb874
@ -127,7 +127,8 @@ private:
|
||||
const std::vector<std::string>& hidden_hits,
|
||||
std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
std::vector<uint32_t>& excluded_ids, std::vector<const override_t*>& filter_overrides,
|
||||
bool& filter_curated_hits) const;
|
||||
bool& filter_curated_hits,
|
||||
std::string& curated_sort_by) const;
|
||||
|
||||
static Option<bool> detect_new_fields(nlohmann::json& document,
|
||||
const DIRTY_VALUES& dirty_values,
|
||||
|
@ -108,6 +108,8 @@ struct override_t {
|
||||
bool remove_matched_tokens = false;
|
||||
bool filter_curated_hits = false;
|
||||
|
||||
std::string sort_by;
|
||||
|
||||
override_t() = default;
|
||||
|
||||
static Option<bool> parse(const nlohmann::json& override_json, const std::string& id, override_t& override) {
|
||||
@ -124,9 +126,10 @@ struct override_t {
|
||||
}
|
||||
|
||||
if(override_json.count("includes") == 0 && override_json.count("excludes") == 0 &&
|
||||
override_json.count("filter_by") == 0 && override_json.count("remove_matched_tokens") == 0) {
|
||||
override_json.count("filter_by") == 0 && override_json.count("sort_by") == 0 &&
|
||||
override_json.count("remove_matched_tokens") == 0) {
|
||||
return Option<bool>(400, "Must contain one of: `includes`, `excludes`, "
|
||||
"`filter_by`, `remove_matched_tokens`.");
|
||||
"`filter_by`, `sort_by`, `remove_matched_tokens`.");
|
||||
}
|
||||
|
||||
if(override_json.count("includes") != 0) {
|
||||
@ -228,6 +231,10 @@ struct override_t {
|
||||
override.filter_by = override_json["filter_by"].get<std::string>();
|
||||
}
|
||||
|
||||
if (override_json.count("sort_by") != 0) {
|
||||
override.sort_by = override_json["sort_by"].get<std::string>();
|
||||
}
|
||||
|
||||
if(override_json.count("remove_matched_tokens") != 0) {
|
||||
override.remove_matched_tokens = override_json["remove_matched_tokens"].get<bool>();
|
||||
} else {
|
||||
@ -286,6 +293,10 @@ struct override_t {
|
||||
override["filter_by"] = filter_by;
|
||||
}
|
||||
|
||||
if(!sort_by.empty()) {
|
||||
override["sort_by"] = sort_by;
|
||||
}
|
||||
|
||||
override["remove_matched_tokens"] = remove_matched_tokens;
|
||||
override["filter_curated_hits"] = filter_curated_hits;
|
||||
|
||||
@ -323,7 +334,6 @@ struct search_args {
|
||||
size_t all_result_ids_len;
|
||||
bool exhaustive_search;
|
||||
size_t concurrency;
|
||||
const std::vector<const override_t*>& filter_overrides;
|
||||
size_t search_cutoff_ms;
|
||||
size_t min_len_1typo;
|
||||
size_t min_len_2typo;
|
||||
@ -351,7 +361,7 @@ struct search_args {
|
||||
const std::vector<bool>& prefixes, size_t drop_tokens_threshold, size_t typo_tokens_threshold,
|
||||
const std::vector<std::string>& group_by_fields, size_t group_limit,
|
||||
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
|
||||
size_t concurrency, const std::vector<const override_t*>& dynamic_overrides, size_t search_cutoff_ms,
|
||||
size_t concurrency, size_t search_cutoff_ms,
|
||||
size_t min_len_1typo, size_t min_len_2typo, size_t max_candidates, const std::vector<infix_t>& infixes,
|
||||
const size_t max_extra_prefix, const size_t max_extra_suffix, const size_t facet_query_num_typos,
|
||||
const bool filter_curated_hits, const bool split_join_tokens) :
|
||||
@ -364,7 +374,7 @@ struct search_args {
|
||||
group_by_fields(group_by_fields), group_limit(group_limit), default_sorting_field(default_sorting_field),
|
||||
prioritize_exact_match(prioritize_exact_match), all_result_ids_len(0),
|
||||
exhaustive_search(exhaustive_search), concurrency(concurrency),
|
||||
filter_overrides(dynamic_overrides), search_cutoff_ms(search_cutoff_ms),
|
||||
search_cutoff_ms(search_cutoff_ms),
|
||||
min_len_1typo(min_len_1typo), min_len_2typo(min_len_2typo), max_candidates(max_candidates),
|
||||
infixes(infixes), max_extra_prefix(max_extra_prefix), max_extra_suffix(max_extra_suffix),
|
||||
facet_query_num_typos(facet_query_num_typos), filter_curated_hits(filter_curated_hits),
|
||||
@ -769,7 +779,7 @@ public:
|
||||
tsl::htrie_map<char, token_leaf>& qtoken_set,
|
||||
std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<std::vector<KV*>>& override_result_kvs,
|
||||
const size_t typo_tokens_threshold, const size_t group_limit,
|
||||
const std::vector<std::string>& group_by_fields, const std::vector<const override_t*>& filter_overrides,
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
|
||||
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
|
||||
size_t max_candidates, const std::vector<infix_t>& infixes, const size_t max_extra_prefix,
|
||||
|
@ -419,7 +419,8 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
|
||||
std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
|
||||
std::vector<uint32_t>& excluded_ids,
|
||||
std::vector<const override_t*>& filter_overrides,
|
||||
bool& filter_curated_hits) const {
|
||||
bool& filter_curated_hits,
|
||||
std::string& curated_sort_by) const {
|
||||
|
||||
std::set<uint32_t> excluded_set;
|
||||
|
||||
@ -486,6 +487,7 @@ void Collection::curate_results(string& actual_query, bool enable_overrides, boo
|
||||
}
|
||||
|
||||
filter_curated_hits = override.filter_curated_hits;
|
||||
curated_sort_by = override.sort_by;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -915,29 +917,6 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
|
||||
}
|
||||
}
|
||||
|
||||
// validate sort fields and standardize
|
||||
|
||||
std::vector<sort_by> sort_fields_std;
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
|
||||
// apply bucketing on text match score
|
||||
int match_score_index = -1;
|
||||
for(size_t i = 0; i < sort_fields_std.size(); i++) {
|
||||
if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) {
|
||||
match_score_index = i;
|
||||
|
||||
if(sort_fields_std[i].text_match_buckets > 1) {
|
||||
// we will disable prioritize exact match because it's incompatible with bucketing
|
||||
prioritize_exact_match = false;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// check for valid pagination
|
||||
if(page < 1) {
|
||||
std::string message = "Page must be an integer of value greater than 0.";
|
||||
@ -994,8 +973,9 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
|
||||
std::vector<const override_t*> filter_overrides;
|
||||
std::string query = raw_query;
|
||||
bool filter_curated_hits = false;
|
||||
std::string curated_sort_by;
|
||||
curate_results(query, enable_overrides, pre_segmented_query, pinned_hits, hidden_hits,
|
||||
included_ids, excluded_ids, filter_overrides, filter_curated_hits);
|
||||
included_ids, excluded_ids, filter_overrides, filter_curated_hits, curated_sort_by);
|
||||
|
||||
if(filter_curated_hits_option == 0 || filter_curated_hits_option == 1) {
|
||||
// When query param has explicit value set, override level configuration takes lower precedence.
|
||||
@ -1025,6 +1005,43 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
|
||||
}
|
||||
*/
|
||||
|
||||
// validate sort fields and standardize
|
||||
|
||||
std::vector<sort_by> sort_fields_std;
|
||||
|
||||
if(curated_sort_by.empty()) {
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(sort_fields, sort_fields_std);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
} else {
|
||||
std::vector<sort_by> curated_sort_fields;
|
||||
bool parsed_sort_by = CollectionManager::parse_sort_by_str(curated_sort_by, curated_sort_fields);
|
||||
if(!parsed_sort_by) {
|
||||
return Option<nlohmann::json>(400, "Parameter `sort_by` is malformed.");
|
||||
}
|
||||
|
||||
auto sort_validation_op = validate_and_standardize_sort_fields(curated_sort_fields, sort_fields_std);
|
||||
if(!sort_validation_op.ok()) {
|
||||
return Option<nlohmann::json>(sort_validation_op.code(), sort_validation_op.error());
|
||||
}
|
||||
}
|
||||
|
||||
// apply bucketing on text match score
|
||||
int match_score_index = -1;
|
||||
for(size_t i = 0; i < sort_fields_std.size(); i++) {
|
||||
if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) {
|
||||
match_score_index = i;
|
||||
|
||||
if(sort_fields_std[i].text_match_buckets > 1) {
|
||||
// we will disable prioritize exact match because it's incompatible with bucketing
|
||||
prioritize_exact_match = false;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//LOG(INFO) << "Num indices used for querying: " << indices.size();
|
||||
std::vector<query_tokens_t> field_query_tokens;
|
||||
std::vector<std::string> q_tokens; // used for auxillary highlighting
|
||||
@ -1081,7 +1098,7 @@ Option<nlohmann::json> Collection::search(const std::string & raw_query, const s
|
||||
drop_tokens_threshold, typo_tokens_threshold,
|
||||
group_by_fields, group_limit, default_sorting_field,
|
||||
prioritize_exact_match,
|
||||
exhaustive_search, 4, filter_overrides,
|
||||
exhaustive_search, 4,
|
||||
search_stop_millis,
|
||||
min_len_1typo, min_len_2typo, max_candidates, infixes,
|
||||
max_extra_prefix, max_extra_suffix, facet_query_num_typos,
|
||||
|
@ -1841,7 +1841,6 @@ void Index::run_search(search_args* search_params) {
|
||||
search_params->raw_result_kvs, search_params->override_result_kvs,
|
||||
search_params->typo_tokens_threshold,
|
||||
search_params->group_limit, search_params->group_by_fields,
|
||||
search_params->filter_overrides,
|
||||
search_params->default_sorting_field,
|
||||
search_params->prioritize_exact_match,
|
||||
search_params->exhaustive_search,
|
||||
@ -2271,7 +2270,6 @@ void Index::search(std::vector<query_tokens_t>& field_query_tokens, const std::v
|
||||
std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<std::vector<KV*>>& override_result_kvs,
|
||||
const size_t typo_tokens_threshold, const size_t group_limit,
|
||||
const std::vector<std::string>& group_by_fields,
|
||||
const std::vector<const override_t*>& filter_overrides,
|
||||
const string& default_sorting_field, bool prioritize_exact_match, bool exhaustive_search,
|
||||
size_t concurrency, size_t search_cutoff_ms, size_t min_len_1typo, size_t min_len_2typo,
|
||||
size_t max_candidates, const std::vector<infix_t>& infixes, const size_t max_extra_prefix,
|
||||
|
@ -271,7 +271,8 @@ TEST_F(CollectionOverrideTest, OverrideJSONValidation) {
|
||||
|
||||
parse_op = override_t::parse(include_json2, "", override2);
|
||||
ASSERT_FALSE(parse_op.ok());
|
||||
ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `remove_matched_tokens`.", parse_op.error().c_str());
|
||||
ASSERT_STREQ("Must contain one of: `includes`, `excludes`, `filter_by`, `sort_by`, `remove_matched_tokens`.",
|
||||
parse_op.error().c_str());
|
||||
|
||||
include_json2["includes"] = nlohmann::json::array();
|
||||
include_json2["includes"][0] = 100;
|
||||
@ -2062,6 +2063,73 @@ TEST_F(CollectionOverrideTest, DynamicFilteringWithJustRemoveTokens) {
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, StaticSorting) {
|
||||
Collection *coll1;
|
||||
|
||||
std::vector<field> fields = {field("name", field_types::STRING, false),
|
||||
field("price", field_types::FLOAT, true),
|
||||
field("points", field_types::INT32, false)};
|
||||
|
||||
coll1 = collectionManager.get_collection("coll1").get();
|
||||
if(coll1 == nullptr) {
|
||||
coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();
|
||||
}
|
||||
|
||||
nlohmann::json doc1;
|
||||
doc1["id"] = "0";
|
||||
doc1["name"] = "Amazing Shoes";
|
||||
doc1["price"] = 399.99;
|
||||
doc1["points"] = 3;
|
||||
|
||||
nlohmann::json doc2;
|
||||
doc2["id"] = "1";
|
||||
doc2["name"] = "Track Shoes";
|
||||
doc2["price"] = 49.99;
|
||||
doc2["points"] = 5;
|
||||
|
||||
ASSERT_TRUE(coll1->add(doc1.dump()).ok());
|
||||
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
|
||||
|
||||
std::vector<sort_by> sort_fields = { sort_by("_text_match", "DESC"), sort_by("points", "DESC") };
|
||||
|
||||
nlohmann::json override_json_contains = {
|
||||
{"id", "static-sort"},
|
||||
{
|
||||
"rule", {
|
||||
{"query", "shoes"},
|
||||
{"match", override_t::MATCH_CONTAINS}
|
||||
}
|
||||
},
|
||||
{"remove_matched_tokens", true},
|
||||
{"sort_by", "price:desc"}
|
||||
};
|
||||
|
||||
override_t override_contains;
|
||||
auto op = override_t::parse(override_json_contains, "static-sort", override_contains);
|
||||
ASSERT_TRUE(op.ok());
|
||||
|
||||
// without override kicking in
|
||||
auto results = coll1->search("shoes", {"name"}, "",
|
||||
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
|
||||
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ASSERT_EQ("1", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("0", results["hits"][1]["document"]["id"].get<std::string>());
|
||||
|
||||
// now add override
|
||||
coll1->add_override(override_contains);
|
||||
|
||||
results = coll1->search("shoes", {"name"}, "",
|
||||
{}, sort_fields, {2}, 10, 1, FREQUENCY, {true}, 0).get();
|
||||
|
||||
// with override we will sort on price
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
ASSERT_EQ("0", results["hits"][0]["document"]["id"].get<std::string>());
|
||||
ASSERT_EQ("1", results["hits"][1]["document"]["id"].get<std::string>());
|
||||
|
||||
collectionManager.drop_collection("coll1");
|
||||
}
|
||||
|
||||
TEST_F(CollectionOverrideTest, DynamicFilteringWithPartialTokenMatch) {
|
||||
// when query tokens do not match placeholder field value exactly, don't do filtering
|
||||
Collection* coll1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user