Faceting on top k (#1878)

* do faceting on top_k results

* make top_k faceting on single thread

* remove logging to tsv file

* Revert "remove logging to tsv file"

This reverts commit 42bd4fdc4607d3cb5000080ac8aeba21b602e279.

* add validation checks & aggregate facets

* add tests

* refactor code

* refactor facet parsing to single pass

---------

Co-authored-by: Kishore Nallan <kishorenc@gmail.com>
This commit is contained in:
Krunal Gandhi 2024-09-03 07:31:46 +00:00 committed by GitHub
parent 478f0ce322
commit 0d030f08fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 711 additions and 257 deletions

View File

@ -676,6 +676,8 @@ struct facet {
uint32_t orig_index;
bool is_top_k = false;
bool get_range(int64_t key, std::pair<int64_t, std::string>& range_pair) {
if(facet_range_map.empty()) {
LOG (ERROR) << "Facet range is not defined!!!";
@ -696,12 +698,12 @@ struct facet {
return false;
}
explicit facet(const std::string& field_name, uint32_t orig_index, std::map<int64_t, range_specs_t> facet_range = {},
explicit facet(const std::string& field_name, uint32_t orig_index, bool is_top_k = false, std::map<int64_t, range_specs_t> facet_range = {},
bool is_range_q = false, bool sort_by_alpha=false, const std::string& order="",
const std::string& sort_by_field="")
: field_name(field_name), facet_range_map(facet_range),
is_range_query(is_range_q), is_sort_by_alpha(sort_by_alpha), sort_order(order),
sort_field(sort_by_field), orig_index(orig_index) {
sort_field(sort_by_field), orig_index(orig_index), is_top_k(is_top_k) {
}
};

View File

@ -1066,6 +1066,8 @@ public:
float get_distance(const string& geo_field_name, const uint32_t& seq_id,
const S2LatLng& reference_lat_lng, const std::string& unit) const;
void get_top_k_result_ids(const std::vector<std::vector<KV*>>& raw_result_kvs, std::vector<uint32_t>& result_ids) const;
};
template<class T>

View File

@ -2105,9 +2105,6 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
auto drop_tokens_param = drop_tokens_param_op.get();
std::vector<std::vector<KV*>> raw_result_kvs;
std::vector<std::vector<KV*>> override_result_kvs;
size_t total = 0;
std::vector<uint32_t> excluded_ids;
@ -2316,15 +2313,8 @@ Option<nlohmann::json> Collection::search(std::string raw_query,
return Option<nlohmann::json>(search_op.code(), search_op.error());
}
// for grouping we have to re-aggregate
Topster& topster = *search_params->topster;
Topster& curated_topster = *search_params->curated_topster;
topster.sort();
curated_topster.sort();
populate_result_kvs(&topster, raw_result_kvs, search_params->groups_processed, sort_fields_std);
populate_result_kvs(&curated_topster, override_result_kvs, search_params->groups_processed, sort_fields_std);
auto& raw_result_kvs = search_params->raw_result_kvs;
auto& override_result_kvs = search_params->override_result_kvs;
// for grouping we have to aggregate group set sizes to a count value
if(group_limit) {
@ -6032,161 +6022,262 @@ bool Collection::get_enable_nested_fields() {
}
Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector<facet>& facets) const {
const std::regex base_pattern(".+\\(.*\\)");
const std::regex range_pattern(
"[[:print:]]+:\\[([+-]?([[:digit:]]*[.])?[[:digit:]]*)\\,\\s*([+-]?([[:digit:]]*[.])?[[:digit:]]*)\\]");
const std::string _alpha = "_alpha";
bool top_k = false;
std::string facet_field_name, param_str;
bool paran_open = false; //for (
bool brace_open = false; //for [
std::string order = "";
bool sort_alpha = false;
std::string sort_field = "";
bool colon_found = false;
bool top_k_found = false;
bool sort_found = false;
unsigned facet_param_count = 0;
unsigned commaCount = 0;
bool is_wildcard = false;
if((facet_field.find(":") != std::string::npos)
&& (facet_field.find("sort_by") == std::string::npos)) { //range based facet
std::vector<std::tuple<int64_t, int64_t, std::string>> tupVec;
if(!std::regex_match(facet_field, base_pattern)) {
std::string error = "Facet range value is not valid.";
return Option<bool>(400, error);
for(int i = 0; i < facet_field.size(); ) {
if(facet_field[i] == '(') {
//facet field name complete, check validity
if(search_schema.count(facet_field_name) == 0 || !search_schema.at(facet_field_name).facet) {
std::string error = "Could not find a facet field named `" + facet_field_name + "` in the schema.";
return Option<bool>(404, error);
}
paran_open = true;
i++;
continue;
} else if(facet_field[i] == '*') {
if(i == facet_field.size() - 1) {
auto prefix = facet_field.substr(0, facet_field.size() - 1);
auto pair = search_schema.equal_prefix_range(prefix);
if(pair.first == pair.second) {
// not found
std::string error = "Could not find a facet field for `" + facet_field + "` in the schema.";
return Option<bool>(404, error);
}
// Collect the fields that match the prefix and are marked as facet.
for(auto field = pair.first; field != pair.second; field++) {
if(field->facet) {
facets.emplace_back(facet(field->name, facets.size()));
facets.back().is_wildcard_match = true;
}
}
i++;
is_wildcard = true;
continue;
} else {
return Option<bool>(404, "Only prefix matching with a wildcard is allowed.");
}
} else if(facet_field[i] == ')') {
if(paran_open == true && (facet_param_count == commaCount + 1)) {
if(!colon_found && !top_k_found) {
return Option<bool>(400, "Invalid facet param `" + param_str + "`.");
}
paran_open = false;
commaCount = facet_param_count;
break;
} else {
return Option<bool>(400, "Invalid facet format.");
}
} else if(facet_field[i] == ':') {
if(paran_open == false || facet_param_count != commaCount) {
return Option<bool>(400, "Invalid facet format.");
}
colon_found = true;
StringUtils::trim(param_str);
if(param_str == "sort_by") { //sort_by params
sort_found = true;
for(i; facet_field.size(); i++) {
if(facet_field[i] == ',' || facet_field[i] == ')') {
break;
} else {
param_str+=facet_field[i];
}
}
std::vector<std::string> tokens;
StringUtils::split(param_str, tokens, ":");
if(tokens.size() != 3) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
if(tokens[1] == _alpha) {
const field& a_field = search_schema.at(facet_field_name);
if(!a_field.is_string()) {
std::string error = "Facet field should be string type to apply alpha sort.";
return Option<bool>(400, error);
}
sort_alpha = true;
} else { //sort_field based sort
sort_field = tokens[1];
if(search_schema.count(sort_field) == 0 || !search_schema.at(sort_field).facet) {
std::string error = "Could not find a facet field named `" + sort_field + "` in the schema.";
return Option<bool>(404, error);
}
const field& a_field = search_schema.at(sort_field);
if(a_field.is_string()) {
std::string error = "Sort field should be non string type to apply sort.";
return Option<bool>(400, error);
}
}
if(tokens[2] == "asc") {
order = "asc";
} else if(tokens[2] == "desc") {
order = "desc";
} else {
std::string error = "Invalid sort param.";
return Option<bool>(400, error);
}
facet_param_count++;
} else if(param_str == "top_k") { //top_k param
top_k_found = true;
param_str.clear();
i++; //skip :
for(i; i < facet_field.size(); i++) {
if(facet_field[i] == ',' || facet_field[i] == ')') {
break;
}
param_str+=facet_field[i];
}
if(param_str.empty() || (param_str != "true" && param_str != "false")) {
return Option<bool>(400, "top_k string format is invalid.");
}
if(param_str == "true") {
top_k = true;
}
facet_param_count++;
} else if((i + 1) < facet_field.size() && facet_field[i+1] == '[') { //range params
const field& a_field = search_schema.at(facet_field_name);
if(tupVec.empty()) {
if(!a_field.is_integer() && !a_field.is_float()) {
std::string error = "Range facet is restricted to only integer and float fields.";
return Option<bool>(400, error);
}
if(!a_field.sort) {
return Option<bool>(400, "Range facets require sort enabled for the field.");
}
}
auto range_val = param_str;
StringUtils::trim(range_val);
if(range_val.empty()) {
return Option<bool>(400, "Facet range value is not valid.");
}
std::string lower, upper;
int64_t lower_range, upper_range;
brace_open = true;
auto commaFound = 0;
i+=2; //skip : and [
param_str.clear();
while(i < facet_field.size()) {
if(facet_field[i]== ',') {
if(commaFound == 1) {
return Option<bool>(400, "Error splitting the facet range values.");
}
lower = param_str;
StringUtils::trim(lower);
param_str.clear();
commaFound++;
} else if(facet_field[i] == ']') {
brace_open = false;
upper = param_str;
StringUtils::trim(upper);
i++; //skip ] and break loop
break;
} else if(facet_field[i] == ')') {
return Option<bool>(400, "Error splitting the facet range values.");
} else {
param_str += facet_field[i];
}
i++;
}
if(lower.empty()) {
lower_range = INT64_MIN;
} else if(a_field.is_integer() && StringUtils::is_int64_t(lower)) {
lower_range = std::stoll(lower);
} else if(a_field.is_float() && StringUtils::is_float(lower)) {
float val = std::stof(lower);
lower_range = Index::float_to_int64_t(val);
} else {
return Option<bool>(400, "Facet range value is not valid.");
}
if(upper.empty()) {
upper_range = INT64_MAX;
} else if(a_field.is_integer() && StringUtils::is_int64_t(upper)) {
upper_range = std::stoll(upper);
} else if(a_field.is_float() && StringUtils::is_float(upper)) {
float val = std::stof(upper);
upper_range = Index::float_to_int64_t(val);
} else {
return Option<bool>(400, "Facet range value is not valid.");
}
tupVec.emplace_back(lower_range, upper_range, range_val);
facet_param_count++;
} else {
return Option<bool>(400, "Invalid facet param `" + param_str + "`.");
}
continue;
} else if(facet_field[i] == ',') {
param_str.clear();
commaCount++;
i++;
continue;
}
auto startpos = facet_field.find("(");
auto field_name = facet_field.substr(0, startpos);
if(!paran_open) {
facet_field_name+=facet_field[i];
} else {
param_str+=facet_field[i];
}
i++;
}
if(search_schema.count(field_name) == 0) {
std::string error = "Could not find a facet field named `" + field_name + "` in the schema.";
if(paran_open || brace_open || facet_param_count != commaCount) {
return Option<bool>(400, "Invalid facet format.");
}
if(facet_param_count == 0 && !is_wildcard) {
//facets with params will be validated while parsing
// for normal facets need to perform check
if(search_schema.count(facet_field_name) == 0 || !search_schema.at(facet_field_name).facet) {
std::string error = "Could not find a facet field named `" + facet_field_name + "` in the schema.";
return Option<bool>(404, error);
}
}
if((field_name.find("sort") == std::string::npos)
&& (facet_field.find("sort") != std::string::npos)) {
//sort keyword is found in facet string but not in facet field
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
const field& a_field = search_schema.at(field_name);
if(!a_field.is_integer() && !a_field.is_float()) {
std::string error = "Range facet is restricted to only integer and float fields.";
return Option<bool>(400, error);
}
if(!a_field.sort) {
return Option<bool>(400, "Range facets require sort enabled for the field.");
}
facet a_facet(field_name, facets.size());
//starting after "(" and excluding ")"
auto range_string = std::string(facet_field.begin() + startpos + 1, facet_field.end() - 1);
//split the ranges
std::vector<std::string> result;
startpos = 0;
int index = 0;
int commaFound = 0, rangeFound = 0;
bool range_open = false;
while(index < range_string.size()) {
if(range_string[index] == ']') {
if(range_open == true) {
std::string range = range_string.substr(startpos, index + 1 - startpos);
range = StringUtils::trim(range);
result.emplace_back(range);
rangeFound++;
range_open = false;
} else {
result.clear();
break;
}
} else if(range_string[index] == ',' && range_open == false) {
startpos = index + 1;
commaFound++;
} else if(range_string[index] == '[') {
if((commaFound == rangeFound) && range_open == false) {
range_open = true;
} else {
result.clear();
break;
}
}
index++;
}
if((result.empty()) || (range_open == true)) {
std::string error = "Error splitting the facet range values.";
return Option<bool>(400, error);
}
std::vector<std::tuple<int64_t, int64_t, std::string>> tupVec;
auto& range_map = a_facet.facet_range_map;
range_map.clear();
for(const auto& range: result) {
//validate each range syntax
if(!std::regex_match(range, range_pattern)) {
std::string error = "Facet range value is not valid.";
return Option<bool>(400, error);
}
auto pos1 = range.find(":");
std::string range_val = range.substr(0, pos1);
auto pos2 = range.find(",");
auto pos3 = range.find("]");
int64_t lower_range, upper_range;
if(a_field.is_integer()) {
auto start = pos1 + 2;
auto end = pos2 - start;
auto lower_range_str = range.substr(start, end);
StringUtils::trim(lower_range_str);
if(lower_range_str.empty()) {
lower_range = INT64_MIN;
} else {
lower_range = std::stoll(lower_range_str);
}
start = pos2 + 1;
end = pos3 - start;
auto upper_range_str = range.substr(start, end);
StringUtils::trim(upper_range_str);
if(upper_range_str.empty()) {
upper_range = INT64_MAX;
} else {
upper_range = std::stoll(upper_range_str);
}
} else {
auto start = pos1 + 2;
auto end = pos2 - start;
auto lower_range_str = range.substr(start, end);
StringUtils::trim(lower_range_str);
if(lower_range_str.empty()) {
lower_range = INT64_MIN;
} else {
float val = std::stof(lower_range_str);
lower_range = Index::float_to_int64_t(val);
}
start = pos2 + 1;
end = pos3 - start;
auto upper_range_str = range.substr(start, end);
StringUtils::trim(upper_range_str);
if(upper_range_str.empty()) {
upper_range = INT64_MAX;
} else {
float val = std::stof(upper_range_str);
upper_range = Index::float_to_int64_t(val);
}
}
tupVec.emplace_back(lower_range, upper_range, range_val);
}
//sort the range values so that we can check continuity
if(!tupVec.empty()) { //add range facets
sort(tupVec.begin(), tupVec.end());
for(const auto& tup: tupVec) {
facet a_facet(facet_field_name, facets.size());
auto& range_map = a_facet.facet_range_map;
for(const auto& tup: tupVec) {
const auto& lower_range = std::get<0>(tup);
const auto& upper_range = std::get<1>(tup);
const std::string& range_val = std::get<2>(tup);
//check if ranges are continous or not
if((!range_map.empty()) && (range_map.find(lower_range) == range_map.end())) {
std::string error = "Ranges in range facet syntax should be continous.";
@ -6195,98 +6286,12 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
range_map[upper_range] = range_specs_t{range_val, lower_range};
}
a_facet.is_range_query = true;
a_facet.is_top_k = top_k;
facets.emplace_back(std::move(a_facet));
} else if(facet_field.find('*') != std::string::npos) { // Wildcard
if(facet_field[facet_field.size() - 1] != '*') {
return Option<bool>(404, "Only prefix matching with a wildcard is allowed.");
}
// Trim * from the end.
auto prefix = facet_field.substr(0, facet_field.size() - 1);
auto pair = search_schema.equal_prefix_range(prefix);
if(pair.first == pair.second) {
// not found
std::string error = "Could not find a facet field for `" + facet_field + "` in the schema.";
return Option<bool>(404, error);
}
// Collect the fields that match the prefix and are marked as facet.
for(auto field = pair.first; field != pair.second; field++) {
if(field->facet) {
facets.emplace_back(facet(field->name, facets.size()));
facets.back().is_wildcard_match = true;
}
}
} else {
// normal facet
std::string order = "";
bool sort_alpha = false;
std::string sort_field = "";
std::string facet_field_copy = facet_field;
auto pos = facet_field_copy.find("(");
if(pos != std::string::npos) {
facet_field_copy = facet_field_copy.substr(0, pos);
}
if(search_schema.count(facet_field_copy) == 0 || !search_schema.at(facet_field_copy).facet) {
std::string error = "Could not find a facet field named `" + facet_field_copy + "` in the schema.";
return Option<bool>(404, error);
}
if(facet_field.find("sort_by") != std::string::npos) { //sort params are supplied with facet
std::vector<std::string> tokens;
StringUtils::split(facet_field, tokens, ":");
if(tokens.size() != 3) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
//remove possible whitespaces
for(auto i = 0; i < 3; ++i) {
StringUtils::trim(tokens[i]);
}
if(tokens[1] == _alpha) {
const field& a_field = search_schema.at(facet_field_copy);
if(!a_field.is_string()) {
std::string error = "Facet field should be string type to apply alpha sort.";
return Option<bool>(400, error);
}
sort_alpha = true;
} else { //sort_field based sort
sort_field = tokens[1];
if(search_schema.count(sort_field) == 0 || !search_schema.at(sort_field).facet) {
std::string error = "Could not find a facet field named `" + sort_field + "` in the schema.";
return Option<bool>(404, error);
}
const field& a_field = search_schema.at(sort_field);
if(a_field.is_string()) {
std::string error = "Sort field should be non string type to apply sort.";
return Option<bool>(400, error);
}
}
if(tokens[2].find("asc") != std::string::npos) {
order = "asc";
} else if(tokens[2].find("desc") != std::string::npos) {
order = "desc";
} else {
std::string error = "Invalid sort param.";
return Option<bool>(400, error);
}
} else if(facet_field != facet_field_copy) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
facets.emplace_back(facet(facet_field_copy, facets.size(), {}, false, sort_alpha,
} else if(!is_wildcard) { //add other facet types, wildcard facets are already added while parsing
facets.emplace_back(facet(facet_field_name, facets.size(), top_k, {}, false, sort_alpha,
order, sort_field));
}

View File

@ -3682,6 +3682,14 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
process_search_results:
topster->sort();
curated_topster->sort();
Collection::populate_result_kvs(topster, raw_result_kvs, groups_processed, sort_fields_std);
Collection::populate_result_kvs(curated_topster, override_result_kvs, groups_processed, sort_fields_std);
std::vector<uint32_t> top_k_result_ids, top_k_curated_result_ids;
std::vector<facet> top_k_facets;
delete [] exclude_token_ids;
delete [] excluded_result_ids;
@ -3705,15 +3713,22 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
std::vector<std::vector<facet>> facet_batches(num_threads);
std::vector<std::vector<facet>> value_facets(concurrency);
size_t num_value_facets = 0;
for(size_t i = 0; i < facets.size(); i++) {
const auto& this_facet = facets[i];
//process facets separately which has top_k set to true
if(this_facet.is_top_k) {
top_k_facets.emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.is_top_k, this_facet.facet_range_map,
this_facet.is_range_query, this_facet.is_sort_by_alpha, this_facet.sort_order, this_facet.sort_field);
continue;
}
if(facet_infos[i].use_value_index) {
// value based faceting on a single thread
value_facets[num_value_facets % num_threads].emplace_back(this_facet.field_name, this_facet.orig_index,
this_facet.facet_range_map,
this_facet.is_top_k, this_facet.facet_range_map,
this_facet.is_range_query, this_facet.is_sort_by_alpha,
this_facet.sort_order, this_facet.sort_field);
num_value_facets++;
@ -3721,9 +3736,9 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
for(size_t j = 0; j < num_threads; j++) {
facet_batches[j].emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.facet_range_map,
this_facet.is_range_query, this_facet.is_sort_by_alpha,
this_facet.sort_order, this_facet.sort_field);
facet_batches[j].emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.is_top_k,
this_facet.facet_range_map, this_facet.is_range_query,
this_facet.is_sort_by_alpha, this_facet.sort_order, this_facet.sort_field);
}
}
@ -3736,6 +3751,15 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
//auto beginF = std::chrono::high_resolution_clock::now();
if(top_k_facets.size() > 0) {
get_top_k_result_ids(raw_result_kvs, top_k_result_ids);
do_facets(top_k_facets, facet_query, estimate_facets, facet_sample_percent,
facet_infos, group_limit, group_by_fields, group_missing_values, top_k_result_ids.data(),
top_k_result_ids.size(), max_facet_values, is_wildcard_no_filter_query,
facet_index_types);
}
for(size_t thread_id = 0; thread_id < num_threads && result_index < all_result_ids_len; thread_id++) {
size_t batch_res_len = window_size;
@ -3857,6 +3881,14 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
included_ids_vec.size(), max_facet_values, is_wildcard_no_filter_query,
facet_index_types);
if(top_k_facets.size() > 0) {
get_top_k_result_ids(override_result_kvs, top_k_curated_result_ids);
do_facets(top_k_facets, facet_query, estimate_facets, facet_sample_percent,
facet_infos, group_limit, group_by_fields, group_missing_values, top_k_curated_result_ids.data(),
top_k_curated_result_ids.size(), max_facet_values, is_wildcard_no_filter_query,
facet_index_types);
}
all_result_ids_len += curated_topster->size;
if(!included_ids_map.empty() && group_limit != 0) {
@ -3875,6 +3907,14 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
}
//copy top_k facets data
if(!top_k_facets.empty()) {
for(auto& this_facet : top_k_facets) {
auto& acc_facet = facets[this_facet.orig_index];
aggregate_facet(group_limit, this_facet, acc_facet);
}
}
delete [] all_result_ids;
//LOG(INFO) << "all_result_ids_len " << all_result_ids_len << " for index " << name;
@ -7919,6 +7959,17 @@ float Index::get_distance(const string& geo_field_name, const uint32_t& seq_id,
return std::round(dist * 1000.0) / 1000.0;
}
void Index::get_top_k_result_ids(const std::vector<std::vector<KV*>>& raw_result_kvs,
std::vector<uint32_t>& result_ids) const{
for(const auto& group_kv : raw_result_kvs) {
for(const auto& kv : group_kv) {
result_ids.push_back(kv->key);
}
}
std::sort(result_ids.begin(), result_ids.end());
}
/*
// https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon
// NOTE: polygon and point should have been transformed with `transform_for_180th_meridian`

View File

@ -1392,7 +1392,7 @@ TEST_F(CollectionFacetingTest, FacetParseTest){
TEST_F(CollectionFacetingTest, RangeFacetTest) {
std::vector<field> fields = {field("place", field_types::STRING, false),
field("state", field_types::STRING, false),
field("state", field_types::STRING, true),
field("visitors", field_types::INT32, true),
field("rating", field_types::FLOAT, true),
field("trackingFrom", field_types::INT32, true),};
@ -1694,7 +1694,7 @@ TEST_F(CollectionFacetingTest, RangeFacetTypo) {
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true);
ASSERT_STREQ("Error splitting the facet range values.", results2.error().c_str());
ASSERT_STREQ("Invalid facet param `VeryBusy`.", results2.error().c_str());
auto results3 = coll1->search("TamilNadu", {"state"},
"", {"visitors(Busy:[0, 200000] VeryBusy:[200000, 500000])"}, //missing ',' between ranges
@ -1704,7 +1704,7 @@ TEST_F(CollectionFacetingTest, RangeFacetTypo) {
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true);
ASSERT_STREQ("Error splitting the facet range values.", results3.error().c_str());
ASSERT_STREQ("Invalid facet format.", results3.error().c_str());
auto results4 = coll1->search("TamilNadu", {"state"},
"", {"visitors(Busy:[0 200000], VeryBusy:[200000, 500000])"}, //missing ',' between first ranges values
@ -1724,7 +1724,7 @@ TEST_F(CollectionFacetingTest, RangeFacetTypo) {
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true);
ASSERT_STREQ("Facet range value is not valid.", results5.error().c_str());
ASSERT_STREQ("Error splitting the facet range values.", results5.error().c_str());
collectionManager.drop_collection("coll1");
}
@ -2887,7 +2887,7 @@ TEST_F(CollectionFacetingTest, FacetSortValidation) {
{}, {2});
ASSERT_EQ(400, search_op.code());
ASSERT_EQ("Invalid sort format.", search_op.error());
ASSERT_EQ("Invalid facet param `sort`.", search_op.error());
//invalid param
search_op = coll1->search("*", {}, "", {"phone(sort_by:_alpha:foo)"},
@ -3286,3 +3286,150 @@ TEST_F(CollectionFacetingTest, FacetSearchIndexTypeValidation) {
ASSERT_TRUE(res_op.ok());
}
TEST_F(CollectionFacetingTest, TopKFaceting) {
std::vector<field> fields = {field("name", field_types::STRING, true, false, true, "", 1),
field("price", field_types::FLOAT, true, false, true, "", 0)};
Collection* coll2 = collectionManager.create_collection(
"coll2", 1, fields, "", 0, "",
{},{}).get();
nlohmann::json doc;
for(int i=0; i < 500; ++i) {
doc["name"] = "jeans";
doc["price"] = 49.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
doc["name"] = "narrow jeans";
doc["price"] = 29.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
}
//normal facet
auto results = coll2->search("jeans", {"name"}, "",
{"name"}, {}, {2},
10, 1, FREQUENCY, {true}).get();
ASSERT_EQ(1, results["facet_counts"].size());
ASSERT_EQ("name", results["facet_counts"][0]["field_name"]);
ASSERT_EQ(2, results["facet_counts"][0]["counts"].size());
ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][0]["count"]);
ASSERT_EQ("narrow jeans", results["facet_counts"][0]["counts"][1]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][1]["count"]);
//facet with top_k
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:true)"}, {}, {2},
10, 1, FREQUENCY, {true}).get();
ASSERT_EQ(1, results["facet_counts"].size());
ASSERT_EQ("name", results["facet_counts"][0]["field_name"]);
ASSERT_EQ(1, results["facet_counts"][0]["counts"].size());
ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]);
//some are facets with top-K
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:true)", "price"}, {}, {2},
10, 1, FREQUENCY, {true}).get();
ASSERT_EQ(2, results["facet_counts"].size());
ASSERT_EQ("name", results["facet_counts"][0]["field_name"]);
ASSERT_EQ(1, results["facet_counts"][0]["counts"].size());
ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]);
ASSERT_EQ("price", results["facet_counts"][1]["field_name"]);
ASSERT_EQ(2, results["facet_counts"][1]["counts"].size());
ASSERT_EQ("49.99", results["facet_counts"][1]["counts"][0]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][0]["count"]);
ASSERT_EQ("29.99", results["facet_counts"][1]["counts"][1]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][1]["count"]);
}
TEST_F(CollectionFacetingTest, TopKFacetValidation) {
std::vector<field> fields = {field("name", field_types::STRING, true, false, true, "", 1),
field("price", field_types::FLOAT, true, false, true, "", 1)};
Collection* coll2 = collectionManager.create_collection(
"coll2", 1, fields, "", 0, "",
{},{}).get();
//'=' separator instead of ":"
auto results = coll2->search("jeans", {"name"}, "",
{"name(top_k=true)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet format.", results.error());
//typo in top_k
results = coll2->search("jeans", {"name"}, "",
{"name(top-k:true)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet param `top-k`.", results.error());
results = coll2->search("jeans", {"name"}, "",
{"name(topk:true)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet param `topk`.", results.error());
//value should be boolean
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:10)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_FALSE(results.ok());
ASSERT_EQ("top_k string format is invalid.", results.error());
//correct val
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_TRUE(results.ok());
//with sort params
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false, sort_by:_alpha:desc)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_TRUE(results.ok());
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false, sort_by:price:desc)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_TRUE(results.ok());
//with range facets
results = coll2->search("jeans", {"name"}, "",
{"price(top_k:false, economic:[0, 30], Luxury:[30, 50])"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_TRUE(results.ok());
results = coll2->search("jeans", {"name"}, "",
{"price(economic:[0, 30], top_k:true, Luxury:[30, 50])"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_TRUE(results.ok());
results = coll2->search("jeans", {"name"}, "",
{"price(economic:[0, 30], Luxury:[30, 50], top_k:true)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_TRUE(results.ok());
//missing , seperator
results = coll2->search("jeans", {"name"}, "",
{"price(economic:[0, 30], Luxury:[30, 50] top_k:true)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet format.", results.error());
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false sort_by:_alpha:desc)"}, {}, {2},
10, 1, FREQUENCY, {true});
ASSERT_FALSE(results.ok());
ASSERT_EQ("top_k string format is invalid.", results.error());
}

View File

@ -1090,7 +1090,7 @@ TEST_F(CollectionOptimizedFacetingTest, FacetParseTest){
TEST_F(CollectionOptimizedFacetingTest, RangeFacetTest) {
std::vector<field> fields = {field("place", field_types::STRING, false),
field("state", field_types::STRING, false),
field("state", field_types::STRING, true),
field("visitors", field_types::INT32, true),
field("trackingFrom", field_types::INT32, true),};
Collection* coll1 = collectionManager.create_collection(
@ -1376,7 +1376,7 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTypo) {
4UL, 7UL, fallback, 4UL, {off}, 32767UL, 32767UL, 2UL, 2UL, false,
"", true, 0UL, max_score, 100UL, 0UL, 4294967295UL, "top_values");
ASSERT_STREQ("Error splitting the facet range values.", results2.error().c_str());
ASSERT_STREQ("Invalid facet param `VeryBusy`.", results2.error().c_str());
auto results3 = coll1->search("TamilNadu", {"state"},
"", {"visitors(Busy:[0, 200000] VeryBusy:[200000, 500000])"}, //missing ',' between ranges
@ -1389,7 +1389,7 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTypo) {
4UL, 7UL, fallback, 4UL, {off}, 32767UL, 32767UL, 2UL, 2UL, false,
"", true, 0UL, max_score, 100UL, 0UL, 4294967295UL, "top_values");
ASSERT_STREQ("Error splitting the facet range values.", results3.error().c_str());
ASSERT_STREQ("Invalid facet format.", results3.error().c_str());
auto results4 = coll1->search("TamilNadu", {"state"},
"", {"visitors(Busy:[0 200000], VeryBusy:[200000, 500000])"}, //missing ',' between first ranges values
@ -1415,7 +1415,7 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTypo) {
4UL, 7UL, fallback, 4UL, {off}, 32767UL, 32767UL, 2UL, 2UL, false,
"", true, 0UL, max_score, 100UL, 0UL, 4294967295UL, "top_values");
ASSERT_STREQ("Facet range value is not valid.", results5.error().c_str());
ASSERT_STREQ("Error splitting the facet range values.", results5.error().c_str());
collectionManager.drop_collection("coll1");
}
@ -2478,7 +2478,7 @@ TEST_F(CollectionOptimizedFacetingTest, FacetSortValidation) {
{}, {2});
ASSERT_EQ(400, search_op.code());
ASSERT_EQ("Invalid sort format.", search_op.error());
ASSERT_EQ("Invalid facet param `sort`.", search_op.error());
//invalid param
search_op = coll1->search("*", {}, "", {"phone(sort_by:_alpha:foo)"},
@ -3073,3 +3073,250 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetsWithSortDisabled) {
ASSERT_EQ(1, results["facet_counts"][0]["counts"][1]["count"]);
ASSERT_EQ("Medium", results["facet_counts"][0]["counts"][1]["value"]);
}
TEST_F(CollectionOptimizedFacetingTest, TopKFaceting) {
std::vector<field> fields = {field("name", field_types::STRING, true, false, true, "", 1),
field("price", field_types::FLOAT, true, false, true, "", 0)};
Collection* coll2 = collectionManager.create_collection(
"coll2", 1, fields, "", 0, "",
{},{}).get();
nlohmann::json doc;
for(int i=0; i < 500; ++i) {
doc["name"] = "jeans";
doc["price"] = 49.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
doc["name"] = "narrow jeans";
doc["price"] = 29.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
}
//normal facet
auto results = coll2->search("jeans", {"name"},
"", {"name"},
{}, {2}, 10,
1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get();
ASSERT_EQ(1, results["facet_counts"].size());
ASSERT_EQ("name", results["facet_counts"][0]["field_name"]);
ASSERT_EQ(2, results["facet_counts"][0]["counts"].size());
ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][0]["count"]);
ASSERT_EQ("narrow jeans", results["facet_counts"][0]["counts"][1]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][1]["count"]);
//facet with top_k
results = coll2->search("jeans", {"name"},
"", {"name(top_k:true)"},
{}, {2}, 10,
1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get();
ASSERT_EQ(1, results["facet_counts"].size());
ASSERT_EQ("name", results["facet_counts"][0]["field_name"]);
ASSERT_EQ(1, results["facet_counts"][0]["counts"].size());
ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]);
//some are facets with top-K
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:true)", "price"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get();
ASSERT_EQ(2, results["facet_counts"].size());
ASSERT_EQ("name", results["facet_counts"][0]["field_name"]);
ASSERT_EQ(1, results["facet_counts"][0]["counts"].size());
ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]);
ASSERT_EQ("price", results["facet_counts"][1]["field_name"]);
ASSERT_EQ(2, results["facet_counts"][1]["counts"].size());
ASSERT_EQ("49.99", results["facet_counts"][1]["counts"][0]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][0]["count"]);
ASSERT_EQ("29.99", results["facet_counts"][1]["counts"][1]["value"]);
ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][1]["count"]);
}
TEST_F(CollectionOptimizedFacetingTest, TopKFacetValidation) {
std::vector<field> fields = {field("name", field_types::STRING, true, false, true, "", 1),
field("price", field_types::FLOAT, true, false, true, "", 1)};
Collection* coll2 = collectionManager.create_collection(
"coll2", 1, fields, "", 0, "",
{},{}).get();
//'=' separator instead of ":"
auto results = coll2->search("jeans", {"name"}, "",
{"name(top_k=true)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet format.", results.error());
//typo in top_k
results = coll2->search("jeans", {"name"}, "",
{"name(top-k:true)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet param `top-k`.", results.error());
results = coll2->search("jeans", {"name"}, "",
{"name(topk:true)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet param `topk`.", results.error());
//value should be boolean
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:10)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_FALSE(results.ok());
ASSERT_EQ("top_k string format is invalid.", results.error());
//correct val
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_TRUE(results.ok());
//with sort params
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false, sort_by:_alpha:desc)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_TRUE(results.ok());
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false, sort_by:price:desc)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_TRUE(results.ok());
//with range facets
results = coll2->search("jeans", {"name"}, "",
{"price(top_k:false, economic:[0, 30], Luxury:[30, 50])"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_TRUE(results.ok());
results = coll2->search("jeans", {"name"}, "",
{"price(economic:[0, 30], top_k:true, Luxury:[30, 50])"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_TRUE(results.ok());
results = coll2->search("jeans", {"name"}, "",
{"price(economic:[0, 30], Luxury:[30, 50], top_k:true)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_TRUE(results.ok());
//missing , seperator
results = coll2->search("jeans", {"name"}, "",
{"price(economic:[0, 30], Luxury:[30, 50] top_k:true)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_FALSE(results.ok());
ASSERT_EQ("Invalid facet format.", results.error());
results = coll2->search("jeans", {"name"}, "",
{"name(top_k:false sort_by:_alpha:desc)"}, {}, {2},
10, 1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values");
ASSERT_FALSE(results.ok());
ASSERT_EQ("top_k string format is invalid.", results.error());
}