From 0d030f08fea6904c7bf756cfb4a0c418f855d08b Mon Sep 17 00:00:00 2001 From: Krunal Gandhi Date: Tue, 3 Sep 2024 07:31:46 +0000 Subject: [PATCH] Faceting on top k (#1878) * do faceting on top_k results * make top_k faceting on single thread * remove logging to tsv file * Revert "remove logging to tsv file" This reverts commit 42bd4fdc4607d3cb5000080ac8aeba21b602e279. * add validation checks & aggregate facets * add tests * refactor code * refactor facet parsing to single pass --------- Co-authored-by: Kishore Nallan --- include/field.h | 6 +- include/index.h | 2 + src/collection.cpp | 487 ++++++++++---------- src/index.cpp | 59 ++- test/collection_faceting_test.cpp | 157 ++++++- test/collection_optimized_faceting_test.cpp | 257 ++++++++++- 6 files changed, 711 insertions(+), 257 deletions(-) diff --git a/include/field.h b/include/field.h index 5e58b216..0ed7c684 100644 --- a/include/field.h +++ b/include/field.h @@ -676,6 +676,8 @@ struct facet { uint32_t orig_index; + bool is_top_k = false; + bool get_range(int64_t key, std::pair& range_pair) { if(facet_range_map.empty()) { LOG (ERROR) << "Facet range is not defined!!!"; @@ -696,12 +698,12 @@ struct facet { return false; } - explicit facet(const std::string& field_name, uint32_t orig_index, std::map facet_range = {}, + explicit facet(const std::string& field_name, uint32_t orig_index, bool is_top_k = false, std::map facet_range = {}, bool is_range_q = false, bool sort_by_alpha=false, const std::string& order="", const std::string& sort_by_field="") : field_name(field_name), facet_range_map(facet_range), is_range_query(is_range_q), is_sort_by_alpha(sort_by_alpha), sort_order(order), - sort_field(sort_by_field), orig_index(orig_index) { + sort_field(sort_by_field), orig_index(orig_index), is_top_k(is_top_k) { } }; diff --git a/include/index.h b/include/index.h index 522e0e7d..36ea97db 100644 --- a/include/index.h +++ b/include/index.h @@ -1066,6 +1066,8 @@ public: float get_distance(const string& geo_field_name, const uint32_t& seq_id, const S2LatLng& reference_lat_lng, const std::string& unit) const; + + void get_top_k_result_ids(const std::vector>& raw_result_kvs, std::vector& result_ids) const; }; template diff --git a/src/collection.cpp b/src/collection.cpp index c3b247c4..ecfe3a07 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -2105,9 +2105,6 @@ Option Collection::search(std::string raw_query, auto drop_tokens_param = drop_tokens_param_op.get(); - std::vector> raw_result_kvs; - std::vector> override_result_kvs; - size_t total = 0; std::vector excluded_ids; @@ -2316,15 +2313,8 @@ Option Collection::search(std::string raw_query, return Option(search_op.code(), search_op.error()); } - // for grouping we have to re-aggregate - Topster& topster = *search_params->topster; - Topster& curated_topster = *search_params->curated_topster; - - topster.sort(); - curated_topster.sort(); - - populate_result_kvs(&topster, raw_result_kvs, search_params->groups_processed, sort_fields_std); - populate_result_kvs(&curated_topster, override_result_kvs, search_params->groups_processed, sort_fields_std); + auto& raw_result_kvs = search_params->raw_result_kvs; + auto& override_result_kvs = search_params->override_result_kvs; // for grouping we have to aggregate group set sizes to a count value if(group_limit) { @@ -6032,161 +6022,262 @@ bool Collection::get_enable_nested_fields() { } Option Collection::parse_facet(const std::string& facet_field, std::vector& facets) const { - const std::regex base_pattern(".+\\(.*\\)"); - const std::regex range_pattern( - "[[:print:]]+:\\[([+-]?([[:digit:]]*[.])?[[:digit:]]*)\\,\\s*([+-]?([[:digit:]]*[.])?[[:digit:]]*)\\]"); const std::string _alpha = "_alpha"; + bool top_k = false; + std::string facet_field_name, param_str; + bool paran_open = false; //for ( + bool brace_open = false; //for [ + std::string order = ""; + bool sort_alpha = false; + std::string sort_field = ""; + bool colon_found = false; + bool top_k_found = false; + bool sort_found = false; + unsigned facet_param_count = 0; + unsigned commaCount = 0; + bool is_wildcard = false; - if((facet_field.find(":") != std::string::npos) - && (facet_field.find("sort_by") == std::string::npos)) { //range based facet + std::vector> tupVec; - if(!std::regex_match(facet_field, base_pattern)) { - std::string error = "Facet range value is not valid."; - return Option(400, error); + for(int i = 0; i < facet_field.size(); ) { + if(facet_field[i] == '(') { + //facet field name complete, check validity + if(search_schema.count(facet_field_name) == 0 || !search_schema.at(facet_field_name).facet) { + std::string error = "Could not find a facet field named `" + facet_field_name + "` in the schema."; + return Option(404, error); + } + + paran_open = true; + i++; + continue; + } else if(facet_field[i] == '*') { + if(i == facet_field.size() - 1) { + auto prefix = facet_field.substr(0, facet_field.size() - 1); + auto pair = search_schema.equal_prefix_range(prefix); + + if(pair.first == pair.second) { + // not found + std::string error = "Could not find a facet field for `" + facet_field + "` in the schema."; + return Option(404, error); + } + + // Collect the fields that match the prefix and are marked as facet. + for(auto field = pair.first; field != pair.second; field++) { + if(field->facet) { + facets.emplace_back(facet(field->name, facets.size())); + facets.back().is_wildcard_match = true; + } + } + i++; + is_wildcard = true; + continue; + } else { + return Option(404, "Only prefix matching with a wildcard is allowed."); + } + } else if(facet_field[i] == ')') { + if(paran_open == true && (facet_param_count == commaCount + 1)) { + if(!colon_found && !top_k_found) { + return Option(400, "Invalid facet param `" + param_str + "`."); + } + + paran_open = false; + commaCount = facet_param_count; + break; + } else { + return Option(400, "Invalid facet format."); + } + } else if(facet_field[i] == ':') { + if(paran_open == false || facet_param_count != commaCount) { + return Option(400, "Invalid facet format."); + } + colon_found = true; + StringUtils::trim(param_str); + + if(param_str == "sort_by") { //sort_by params + sort_found = true; + for(i; facet_field.size(); i++) { + if(facet_field[i] == ',' || facet_field[i] == ')') { + break; + } else { + param_str+=facet_field[i]; + } + } + + std::vector tokens; + StringUtils::split(param_str, tokens, ":"); + + if(tokens.size() != 3) { + std::string error = "Invalid sort format."; + return Option(400, error); + } + + if(tokens[1] == _alpha) { + const field& a_field = search_schema.at(facet_field_name); + if(!a_field.is_string()) { + std::string error = "Facet field should be string type to apply alpha sort."; + return Option(400, error); + } + sort_alpha = true; + } else { //sort_field based sort + sort_field = tokens[1]; + + if(search_schema.count(sort_field) == 0 || !search_schema.at(sort_field).facet) { + std::string error = "Could not find a facet field named `" + sort_field + "` in the schema."; + return Option(404, error); + } + + const field& a_field = search_schema.at(sort_field); + if(a_field.is_string()) { + std::string error = "Sort field should be non string type to apply sort."; + return Option(400, error); + } + } + + if(tokens[2] == "asc") { + order = "asc"; + } else if(tokens[2] == "desc") { + order = "desc"; + } else { + std::string error = "Invalid sort param."; + return Option(400, error); + } + facet_param_count++; + } else if(param_str == "top_k") { //top_k param + top_k_found = true; + param_str.clear(); + i++; //skip : + for(i; i < facet_field.size(); i++) { + if(facet_field[i] == ',' || facet_field[i] == ')') { + break; + } + param_str+=facet_field[i]; + } + + if(param_str.empty() || (param_str != "true" && param_str != "false")) { + return Option(400, "top_k string format is invalid."); + } + + if(param_str == "true") { + top_k = true; + } + facet_param_count++; + } else if((i + 1) < facet_field.size() && facet_field[i+1] == '[') { //range params + const field& a_field = search_schema.at(facet_field_name); + if(tupVec.empty()) { + if(!a_field.is_integer() && !a_field.is_float()) { + std::string error = "Range facet is restricted to only integer and float fields."; + return Option(400, error); + } + + if(!a_field.sort) { + return Option(400, "Range facets require sort enabled for the field."); + } + } + auto range_val = param_str; + StringUtils::trim(range_val); + if(range_val.empty()) { + return Option(400, "Facet range value is not valid."); + } + + std::string lower, upper; + int64_t lower_range, upper_range; + + brace_open = true; + auto commaFound = 0; + i+=2; //skip : and [ + param_str.clear(); + while(i < facet_field.size()) { + if(facet_field[i]== ',') { + if(commaFound == 1) { + return Option(400, "Error splitting the facet range values."); + } + + lower = param_str; + StringUtils::trim(lower); + param_str.clear(); + commaFound++; + } else if(facet_field[i] == ']') { + brace_open = false; + upper = param_str; + StringUtils::trim(upper); + i++; //skip ] and break loop + break; + } else if(facet_field[i] == ')') { + return Option(400, "Error splitting the facet range values."); + } else { + param_str += facet_field[i]; + } + i++; + } + + if(lower.empty()) { + lower_range = INT64_MIN; + } else if(a_field.is_integer() && StringUtils::is_int64_t(lower)) { + lower_range = std::stoll(lower); + } else if(a_field.is_float() && StringUtils::is_float(lower)) { + float val = std::stof(lower); + lower_range = Index::float_to_int64_t(val); + } else { + return Option(400, "Facet range value is not valid."); + } + + if(upper.empty()) { + upper_range = INT64_MAX; + } else if(a_field.is_integer() && StringUtils::is_int64_t(upper)) { + upper_range = std::stoll(upper); + } else if(a_field.is_float() && StringUtils::is_float(upper)) { + float val = std::stof(upper); + upper_range = Index::float_to_int64_t(val); + } else { + return Option(400, "Facet range value is not valid."); + } + + tupVec.emplace_back(lower_range, upper_range, range_val); + facet_param_count++; + } else { + return Option(400, "Invalid facet param `" + param_str + "`."); + } + + continue; + } else if(facet_field[i] == ',') { + param_str.clear(); + commaCount++; + i++; + continue; } - auto startpos = facet_field.find("("); - auto field_name = facet_field.substr(0, startpos); + if(!paran_open) { + facet_field_name+=facet_field[i]; + } else { + param_str+=facet_field[i]; + } + i++; + } - if(search_schema.count(field_name) == 0) { - std::string error = "Could not find a facet field named `" + field_name + "` in the schema."; + if(paran_open || brace_open || facet_param_count != commaCount) { + return Option(400, "Invalid facet format."); + } + + if(facet_param_count == 0 && !is_wildcard) { + //facets with params will be validated while parsing + // for normal facets need to perform check + if(search_schema.count(facet_field_name) == 0 || !search_schema.at(facet_field_name).facet) { + std::string error = "Could not find a facet field named `" + facet_field_name + "` in the schema."; return Option(404, error); } + } - if((field_name.find("sort") == std::string::npos) - && (facet_field.find("sort") != std::string::npos)) { - //sort keyword is found in facet string but not in facet field - std::string error = "Invalid sort format."; - return Option(400, error); - } - - const field& a_field = search_schema.at(field_name); - - if(!a_field.is_integer() && !a_field.is_float()) { - std::string error = "Range facet is restricted to only integer and float fields."; - return Option(400, error); - } - - if(!a_field.sort) { - return Option(400, "Range facets require sort enabled for the field."); - } - - facet a_facet(field_name, facets.size()); - - //starting after "(" and excluding ")" - auto range_string = std::string(facet_field.begin() + startpos + 1, facet_field.end() - 1); - - //split the ranges - std::vector result; - startpos = 0; - int index = 0; - int commaFound = 0, rangeFound = 0; - bool range_open = false; - while(index < range_string.size()) { - if(range_string[index] == ']') { - if(range_open == true) { - std::string range = range_string.substr(startpos, index + 1 - startpos); - range = StringUtils::trim(range); - result.emplace_back(range); - rangeFound++; - range_open = false; - } else { - result.clear(); - break; - } - } else if(range_string[index] == ',' && range_open == false) { - startpos = index + 1; - commaFound++; - } else if(range_string[index] == '[') { - if((commaFound == rangeFound) && range_open == false) { - range_open = true; - } else { - result.clear(); - break; - } - } - - index++; - } - - if((result.empty()) || (range_open == true)) { - std::string error = "Error splitting the facet range values."; - return Option(400, error); - } - - std::vector> tupVec; - - auto& range_map = a_facet.facet_range_map; - range_map.clear(); - for(const auto& range: result) { - //validate each range syntax - if(!std::regex_match(range, range_pattern)) { - std::string error = "Facet range value is not valid."; - return Option(400, error); - } - auto pos1 = range.find(":"); - std::string range_val = range.substr(0, pos1); - - auto pos2 = range.find(","); - auto pos3 = range.find("]"); - - int64_t lower_range, upper_range; - - if(a_field.is_integer()) { - auto start = pos1 + 2; - auto end = pos2 - start; - auto lower_range_str = range.substr(start, end); - StringUtils::trim(lower_range_str); - if(lower_range_str.empty()) { - lower_range = INT64_MIN; - } else { - lower_range = std::stoll(lower_range_str); - } - - start = pos2 + 1; - end = pos3 - start; - auto upper_range_str = range.substr(start, end); - StringUtils::trim(upper_range_str); - if(upper_range_str.empty()) { - upper_range = INT64_MAX; - } else { - upper_range = std::stoll(upper_range_str); - } - } else { - auto start = pos1 + 2; - auto end = pos2 - start; - auto lower_range_str = range.substr(start, end); - StringUtils::trim(lower_range_str); - if(lower_range_str.empty()) { - lower_range = INT64_MIN; - } else { - float val = std::stof(lower_range_str); - lower_range = Index::float_to_int64_t(val); - } - - start = pos2 + 1; - end = pos3 - start; - auto upper_range_str = range.substr(start, end); - StringUtils::trim(upper_range_str); - if(upper_range_str.empty()) { - upper_range = INT64_MAX; - } else { - float val = std::stof(upper_range_str); - upper_range = Index::float_to_int64_t(val); - } - } - - tupVec.emplace_back(lower_range, upper_range, range_val); - } - - //sort the range values so that we can check continuity + if(!tupVec.empty()) { //add range facets sort(tupVec.begin(), tupVec.end()); - for(const auto& tup: tupVec) { + facet a_facet(facet_field_name, facets.size()); + auto& range_map = a_facet.facet_range_map; + for(const auto& tup: tupVec) { const auto& lower_range = std::get<0>(tup); const auto& upper_range = std::get<1>(tup); const std::string& range_val = std::get<2>(tup); + //check if ranges are continous or not if((!range_map.empty()) && (range_map.find(lower_range) == range_map.end())) { std::string error = "Ranges in range facet syntax should be continous."; @@ -6195,98 +6286,12 @@ Option Collection::parse_facet(const std::string& facet_field, std::vector range_map[upper_range] = range_specs_t{range_val, lower_range}; } - a_facet.is_range_query = true; + a_facet.is_top_k = top_k; facets.emplace_back(std::move(a_facet)); - } else if(facet_field.find('*') != std::string::npos) { // Wildcard - if(facet_field[facet_field.size() - 1] != '*') { - return Option(404, "Only prefix matching with a wildcard is allowed."); - } - - // Trim * from the end. - auto prefix = facet_field.substr(0, facet_field.size() - 1); - auto pair = search_schema.equal_prefix_range(prefix); - - if(pair.first == pair.second) { - // not found - std::string error = "Could not find a facet field for `" + facet_field + "` in the schema."; - return Option(404, error); - } - - // Collect the fields that match the prefix and are marked as facet. - for(auto field = pair.first; field != pair.second; field++) { - if(field->facet) { - facets.emplace_back(facet(field->name, facets.size())); - facets.back().is_wildcard_match = true; - } - } - } else { - // normal facet - std::string order = ""; - bool sort_alpha = false; - std::string sort_field = ""; - std::string facet_field_copy = facet_field; - auto pos = facet_field_copy.find("("); - if(pos != std::string::npos) { - facet_field_copy = facet_field_copy.substr(0, pos); - } - - if(search_schema.count(facet_field_copy) == 0 || !search_schema.at(facet_field_copy).facet) { - std::string error = "Could not find a facet field named `" + facet_field_copy + "` in the schema."; - return Option(404, error); - } - - if(facet_field.find("sort_by") != std::string::npos) { //sort params are supplied with facet - std::vector tokens; - StringUtils::split(facet_field, tokens, ":"); - - if(tokens.size() != 3) { - std::string error = "Invalid sort format."; - return Option(400, error); - } - - //remove possible whitespaces - for(auto i = 0; i < 3; ++i) { - StringUtils::trim(tokens[i]); - } - - if(tokens[1] == _alpha) { - const field& a_field = search_schema.at(facet_field_copy); - if(!a_field.is_string()) { - std::string error = "Facet field should be string type to apply alpha sort."; - return Option(400, error); - } - sort_alpha = true; - } else { //sort_field based sort - sort_field = tokens[1]; - - if(search_schema.count(sort_field) == 0 || !search_schema.at(sort_field).facet) { - std::string error = "Could not find a facet field named `" + sort_field + "` in the schema."; - return Option(404, error); - } - - const field& a_field = search_schema.at(sort_field); - if(a_field.is_string()) { - std::string error = "Sort field should be non string type to apply sort."; - return Option(400, error); - } - } - - if(tokens[2].find("asc") != std::string::npos) { - order = "asc"; - } else if(tokens[2].find("desc") != std::string::npos) { - order = "desc"; - } else { - std::string error = "Invalid sort param."; - return Option(400, error); - } - } else if(facet_field != facet_field_copy) { - std::string error = "Invalid sort format."; - return Option(400, error); - } - - facets.emplace_back(facet(facet_field_copy, facets.size(), {}, false, sort_alpha, + } else if(!is_wildcard) { //add other facet types, wildcard facets are already added while parsing + facets.emplace_back(facet(facet_field_name, facets.size(), top_k, {}, false, sort_alpha, order, sort_field)); } diff --git a/src/index.cpp b/src/index.cpp index b72fbcf5..c587ab9a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3682,6 +3682,14 @@ Option Index::search(std::vector& field_query_tokens, cons process_search_results: + topster->sort(); + curated_topster->sort(); + + Collection::populate_result_kvs(topster, raw_result_kvs, groups_processed, sort_fields_std); + Collection::populate_result_kvs(curated_topster, override_result_kvs, groups_processed, sort_fields_std); + std::vector top_k_result_ids, top_k_curated_result_ids; + std::vector top_k_facets; + delete [] exclude_token_ids; delete [] excluded_result_ids; @@ -3705,15 +3713,22 @@ Option Index::search(std::vector& field_query_tokens, cons std::vector> facet_batches(num_threads); std::vector> value_facets(concurrency); + size_t num_value_facets = 0; for(size_t i = 0; i < facets.size(); i++) { const auto& this_facet = facets[i]; + //process facets separately which has top_k set to true + if(this_facet.is_top_k) { + top_k_facets.emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.is_top_k, this_facet.facet_range_map, + this_facet.is_range_query, this_facet.is_sort_by_alpha, this_facet.sort_order, this_facet.sort_field); + continue; + } if(facet_infos[i].use_value_index) { // value based faceting on a single thread value_facets[num_value_facets % num_threads].emplace_back(this_facet.field_name, this_facet.orig_index, - this_facet.facet_range_map, + this_facet.is_top_k, this_facet.facet_range_map, this_facet.is_range_query, this_facet.is_sort_by_alpha, this_facet.sort_order, this_facet.sort_field); num_value_facets++; @@ -3721,9 +3736,9 @@ Option Index::search(std::vector& field_query_tokens, cons } for(size_t j = 0; j < num_threads; j++) { - facet_batches[j].emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.facet_range_map, - this_facet.is_range_query, this_facet.is_sort_by_alpha, - this_facet.sort_order, this_facet.sort_field); + facet_batches[j].emplace_back(this_facet.field_name, this_facet.orig_index, this_facet.is_top_k, + this_facet.facet_range_map, this_facet.is_range_query, + this_facet.is_sort_by_alpha, this_facet.sort_order, this_facet.sort_field); } } @@ -3736,6 +3751,15 @@ Option Index::search(std::vector& field_query_tokens, cons //auto beginF = std::chrono::high_resolution_clock::now(); + if(top_k_facets.size() > 0) { + get_top_k_result_ids(raw_result_kvs, top_k_result_ids); + + do_facets(top_k_facets, facet_query, estimate_facets, facet_sample_percent, + facet_infos, group_limit, group_by_fields, group_missing_values, top_k_result_ids.data(), + top_k_result_ids.size(), max_facet_values, is_wildcard_no_filter_query, + facet_index_types); + } + for(size_t thread_id = 0; thread_id < num_threads && result_index < all_result_ids_len; thread_id++) { size_t batch_res_len = window_size; @@ -3857,6 +3881,14 @@ Option Index::search(std::vector& field_query_tokens, cons included_ids_vec.size(), max_facet_values, is_wildcard_no_filter_query, facet_index_types); + if(top_k_facets.size() > 0) { + get_top_k_result_ids(override_result_kvs, top_k_curated_result_ids); + do_facets(top_k_facets, facet_query, estimate_facets, facet_sample_percent, + facet_infos, group_limit, group_by_fields, group_missing_values, top_k_curated_result_ids.data(), + top_k_curated_result_ids.size(), max_facet_values, is_wildcard_no_filter_query, + facet_index_types); + } + all_result_ids_len += curated_topster->size; if(!included_ids_map.empty() && group_limit != 0) { @@ -3875,6 +3907,14 @@ Option Index::search(std::vector& field_query_tokens, cons } } + //copy top_k facets data + if(!top_k_facets.empty()) { + for(auto& this_facet : top_k_facets) { + auto& acc_facet = facets[this_facet.orig_index]; + aggregate_facet(group_limit, this_facet, acc_facet); + } + } + delete [] all_result_ids; //LOG(INFO) << "all_result_ids_len " << all_result_ids_len << " for index " << name; @@ -7919,6 +7959,17 @@ float Index::get_distance(const string& geo_field_name, const uint32_t& seq_id, return std::round(dist * 1000.0) / 1000.0; } +void Index::get_top_k_result_ids(const std::vector>& raw_result_kvs, + std::vector& result_ids) const{ + + for(const auto& group_kv : raw_result_kvs) { + for(const auto& kv : group_kv) { + result_ids.push_back(kv->key); + } + } + + std::sort(result_ids.begin(), result_ids.end()); +} /* // https://stackoverflow.com/questions/924171/geo-fencing-point-inside-outside-polygon // NOTE: polygon and point should have been transformed with `transform_for_180th_meridian` diff --git a/test/collection_faceting_test.cpp b/test/collection_faceting_test.cpp index f9f41515..a9bbbb12 100644 --- a/test/collection_faceting_test.cpp +++ b/test/collection_faceting_test.cpp @@ -1392,7 +1392,7 @@ TEST_F(CollectionFacetingTest, FacetParseTest){ TEST_F(CollectionFacetingTest, RangeFacetTest) { std::vector fields = {field("place", field_types::STRING, false), - field("state", field_types::STRING, false), + field("state", field_types::STRING, true), field("visitors", field_types::INT32, true), field("rating", field_types::FLOAT, true), field("trackingFrom", field_types::INT32, true),}; @@ -1694,7 +1694,7 @@ TEST_F(CollectionFacetingTest, RangeFacetTypo) { spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", true); - ASSERT_STREQ("Error splitting the facet range values.", results2.error().c_str()); + ASSERT_STREQ("Invalid facet param `VeryBusy`.", results2.error().c_str()); auto results3 = coll1->search("TamilNadu", {"state"}, "", {"visitors(Busy:[0, 200000] VeryBusy:[200000, 500000])"}, //missing ',' between ranges @@ -1704,7 +1704,7 @@ TEST_F(CollectionFacetingTest, RangeFacetTypo) { spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", true); - ASSERT_STREQ("Error splitting the facet range values.", results3.error().c_str()); + ASSERT_STREQ("Invalid facet format.", results3.error().c_str()); auto results4 = coll1->search("TamilNadu", {"state"}, "", {"visitors(Busy:[0 200000], VeryBusy:[200000, 500000])"}, //missing ',' between first ranges values @@ -1724,7 +1724,7 @@ TEST_F(CollectionFacetingTest, RangeFacetTypo) { spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, "", "", {}, 1000, true, false, true, "", true); - ASSERT_STREQ("Facet range value is not valid.", results5.error().c_str()); + ASSERT_STREQ("Error splitting the facet range values.", results5.error().c_str()); collectionManager.drop_collection("coll1"); } @@ -2887,7 +2887,7 @@ TEST_F(CollectionFacetingTest, FacetSortValidation) { {}, {2}); ASSERT_EQ(400, search_op.code()); - ASSERT_EQ("Invalid sort format.", search_op.error()); + ASSERT_EQ("Invalid facet param `sort`.", search_op.error()); //invalid param search_op = coll1->search("*", {}, "", {"phone(sort_by:_alpha:foo)"}, @@ -3286,3 +3286,150 @@ TEST_F(CollectionFacetingTest, FacetSearchIndexTypeValidation) { ASSERT_TRUE(res_op.ok()); } + +TEST_F(CollectionFacetingTest, TopKFaceting) { + std::vector fields = {field("name", field_types::STRING, true, false, true, "", 1), + field("price", field_types::FLOAT, true, false, true, "", 0)}; + + Collection* coll2 = collectionManager.create_collection( + "coll2", 1, fields, "", 0, "", + {},{}).get(); + + nlohmann::json doc; + for(int i=0; i < 500; ++i) { + doc["name"] = "jeans"; + doc["price"] = 49.99; + ASSERT_TRUE(coll2->add(doc.dump()).ok()); + + doc["name"] = "narrow jeans"; + doc["price"] = 29.99; + ASSERT_TRUE(coll2->add(doc.dump()).ok()); + } + + //normal facet + auto results = coll2->search("jeans", {"name"}, "", + {"name"}, {}, {2}, + 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ("name", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(2, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][0]["count"]); + ASSERT_EQ("narrow jeans", results["facet_counts"][0]["counts"][1]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][1]["count"]); + + //facet with top_k + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ("name", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]); + ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]); + + //some are facets with top-K + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:true)", "price"}, {}, {2}, + 10, 1, FREQUENCY, {true}).get(); + + ASSERT_EQ(2, results["facet_counts"].size()); + + ASSERT_EQ("name", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]); + ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]); + + ASSERT_EQ("price", results["facet_counts"][1]["field_name"]); + ASSERT_EQ(2, results["facet_counts"][1]["counts"].size()); + ASSERT_EQ("49.99", results["facet_counts"][1]["counts"][0]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][0]["count"]); + ASSERT_EQ("29.99", results["facet_counts"][1]["counts"][1]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][1]["count"]); +} + +TEST_F(CollectionFacetingTest, TopKFacetValidation) { + std::vector fields = {field("name", field_types::STRING, true, false, true, "", 1), + field("price", field_types::FLOAT, true, false, true, "", 1)}; + + Collection* coll2 = collectionManager.create_collection( + "coll2", 1, fields, "", 0, "", + {},{}).get(); + + //'=' separator instead of ":" + auto results = coll2->search("jeans", {"name"}, "", + {"name(top_k=true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet format.", results.error()); + + //typo in top_k + results = coll2->search("jeans", {"name"}, "", + {"name(top-k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet param `top-k`.", results.error()); + + results = coll2->search("jeans", {"name"}, "", + {"name(topk:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet param `topk`.", results.error()); + + //value should be boolean + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:10)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_FALSE(results.ok()); + ASSERT_EQ("top_k string format is invalid.", results.error()); + + //correct val + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_TRUE(results.ok()); + + //with sort params + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false, sort_by:_alpha:desc)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_TRUE(results.ok()); + + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false, sort_by:price:desc)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_TRUE(results.ok()); + + //with range facets + results = coll2->search("jeans", {"name"}, "", + {"price(top_k:false, economic:[0, 30], Luxury:[30, 50])"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_TRUE(results.ok()); + + + results = coll2->search("jeans", {"name"}, "", + {"price(economic:[0, 30], top_k:true, Luxury:[30, 50])"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_TRUE(results.ok()); + + results = coll2->search("jeans", {"name"}, "", + {"price(economic:[0, 30], Luxury:[30, 50], top_k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_TRUE(results.ok()); + + //missing , seperator + results = coll2->search("jeans", {"name"}, "", + {"price(economic:[0, 30], Luxury:[30, 50] top_k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet format.", results.error()); + + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false sort_by:_alpha:desc)"}, {}, {2}, + 10, 1, FREQUENCY, {true}); + ASSERT_FALSE(results.ok()); + ASSERT_EQ("top_k string format is invalid.", results.error()); +} diff --git a/test/collection_optimized_faceting_test.cpp b/test/collection_optimized_faceting_test.cpp index af5bbbf8..168932a5 100644 --- a/test/collection_optimized_faceting_test.cpp +++ b/test/collection_optimized_faceting_test.cpp @@ -1090,7 +1090,7 @@ TEST_F(CollectionOptimizedFacetingTest, FacetParseTest){ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTest) { std::vector fields = {field("place", field_types::STRING, false), - field("state", field_types::STRING, false), + field("state", field_types::STRING, true), field("visitors", field_types::INT32, true), field("trackingFrom", field_types::INT32, true),}; Collection* coll1 = collectionManager.create_collection( @@ -1376,7 +1376,7 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTypo) { 4UL, 7UL, fallback, 4UL, {off}, 32767UL, 32767UL, 2UL, 2UL, false, "", true, 0UL, max_score, 100UL, 0UL, 4294967295UL, "top_values"); - ASSERT_STREQ("Error splitting the facet range values.", results2.error().c_str()); + ASSERT_STREQ("Invalid facet param `VeryBusy`.", results2.error().c_str()); auto results3 = coll1->search("TamilNadu", {"state"}, "", {"visitors(Busy:[0, 200000] VeryBusy:[200000, 500000])"}, //missing ',' between ranges @@ -1389,7 +1389,7 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTypo) { 4UL, 7UL, fallback, 4UL, {off}, 32767UL, 32767UL, 2UL, 2UL, false, "", true, 0UL, max_score, 100UL, 0UL, 4294967295UL, "top_values"); - ASSERT_STREQ("Error splitting the facet range values.", results3.error().c_str()); + ASSERT_STREQ("Invalid facet format.", results3.error().c_str()); auto results4 = coll1->search("TamilNadu", {"state"}, "", {"visitors(Busy:[0 200000], VeryBusy:[200000, 500000])"}, //missing ',' between first ranges values @@ -1415,7 +1415,7 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetTypo) { 4UL, 7UL, fallback, 4UL, {off}, 32767UL, 32767UL, 2UL, 2UL, false, "", true, 0UL, max_score, 100UL, 0UL, 4294967295UL, "top_values"); - ASSERT_STREQ("Facet range value is not valid.", results5.error().c_str()); + ASSERT_STREQ("Error splitting the facet range values.", results5.error().c_str()); collectionManager.drop_collection("coll1"); } @@ -2478,7 +2478,7 @@ TEST_F(CollectionOptimizedFacetingTest, FacetSortValidation) { {}, {2}); ASSERT_EQ(400, search_op.code()); - ASSERT_EQ("Invalid sort format.", search_op.error()); + ASSERT_EQ("Invalid facet param `sort`.", search_op.error()); //invalid param search_op = coll1->search("*", {}, "", {"phone(sort_by:_alpha:foo)"}, @@ -3073,3 +3073,250 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetsWithSortDisabled) { ASSERT_EQ(1, results["facet_counts"][0]["counts"][1]["count"]); ASSERT_EQ("Medium", results["facet_counts"][0]["counts"][1]["value"]); } + +TEST_F(CollectionOptimizedFacetingTest, TopKFaceting) { + std::vector fields = {field("name", field_types::STRING, true, false, true, "", 1), + field("price", field_types::FLOAT, true, false, true, "", 0)}; + + Collection* coll2 = collectionManager.create_collection( + "coll2", 1, fields, "", 0, "", + {},{}).get(); + + nlohmann::json doc; + for(int i=0; i < 500; ++i) { + doc["name"] = "jeans"; + doc["price"] = 49.99; + ASSERT_TRUE(coll2->add(doc.dump()).ok()); + + doc["name"] = "narrow jeans"; + doc["price"] = 29.99; + ASSERT_TRUE(coll2->add(doc.dump()).ok()); + } + + //normal facet + auto results = coll2->search("jeans", {"name"}, + "", {"name"}, + {}, {2}, 10, + 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ("name", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(2, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][0]["count"]); + ASSERT_EQ("narrow jeans", results["facet_counts"][0]["counts"][1]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][0]["counts"][1]["count"]); + + //facet with top_k + results = coll2->search("jeans", {"name"}, + "", {"name(top_k:true)"}, + {}, {2}, 10, + 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get(); + + ASSERT_EQ(1, results["facet_counts"].size()); + ASSERT_EQ("name", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]); + ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]); + + //some are facets with top-K + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:true)", "price"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get(); + + ASSERT_EQ(2, results["facet_counts"].size()); + + ASSERT_EQ("name", results["facet_counts"][0]["field_name"]); + ASSERT_EQ(1, results["facet_counts"][0]["counts"].size()); + ASSERT_EQ("jeans", results["facet_counts"][0]["counts"][0]["value"]); + ASSERT_EQ(250, (int) results["facet_counts"][0]["counts"][0]["count"]); + + ASSERT_EQ("price", results["facet_counts"][1]["field_name"]); + ASSERT_EQ(2, results["facet_counts"][1]["counts"].size()); + ASSERT_EQ("49.99", results["facet_counts"][1]["counts"][0]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][0]["count"]); + ASSERT_EQ("29.99", results["facet_counts"][1]["counts"][1]["value"]); + ASSERT_EQ(500, (int) results["facet_counts"][1]["counts"][1]["count"]); +} + +TEST_F(CollectionOptimizedFacetingTest, TopKFacetValidation) { + std::vector fields = {field("name", field_types::STRING, true, false, true, "", 1), + field("price", field_types::FLOAT, true, false, true, "", 1)}; + + Collection* coll2 = collectionManager.create_collection( + "coll2", 1, fields, "", 0, "", + {},{}).get(); + + //'=' separator instead of ":" + auto results = coll2->search("jeans", {"name"}, "", + {"name(top_k=true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet format.", results.error()); + + //typo in top_k + results = coll2->search("jeans", {"name"}, "", + {"name(top-k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet param `top-k`.", results.error()); + + results = coll2->search("jeans", {"name"}, "", + {"name(topk:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet param `topk`.", results.error()); + + //value should be boolean + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:10)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_FALSE(results.ok()); + ASSERT_EQ("top_k string format is invalid.", results.error()); + + //correct val + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_TRUE(results.ok()); + + //with sort params + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false, sort_by:_alpha:desc)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_TRUE(results.ok()); + + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false, sort_by:price:desc)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_TRUE(results.ok()); + + //with range facets + results = coll2->search("jeans", {"name"}, "", + {"price(top_k:false, economic:[0, 30], Luxury:[30, 50])"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + ASSERT_TRUE(results.ok()); + + + results = coll2->search("jeans", {"name"}, "", + {"price(economic:[0, 30], top_k:true, Luxury:[30, 50])"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_TRUE(results.ok()); + + results = coll2->search("jeans", {"name"}, "", + {"price(economic:[0, 30], Luxury:[30, 50], top_k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + ASSERT_TRUE(results.ok()); + + //missing , seperator + results = coll2->search("jeans", {"name"}, "", + {"price(economic:[0, 30], Luxury:[30, 50] top_k:true)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + ASSERT_FALSE(results.ok()); + ASSERT_EQ("Invalid facet format.", results.error()); + + results = coll2->search("jeans", {"name"}, "", + {"name(top_k:false sort_by:_alpha:desc)"}, {}, {2}, + 10, 1, FREQUENCY, {true}, + 10, spp::sparse_hash_set(), + spp::sparse_hash_set(), 10, "", 30, 4, "", 10, {}, {}, {}, 0, + "", "", {}, 1000, + true, false, true, "", true, + 6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX, + 2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values"); + + ASSERT_FALSE(results.ok()); + ASSERT_EQ("top_k string format is invalid.", results.error()); +}