throw error when range faceted field has sorting disabled (#1726)

* range facets with sort disabled

* update test names

* throw error when range faceted field has sort disabled
This commit is contained in:
Krunal Gandhi 2024-05-15 11:38:57 +00:00 committed by GitHub
parent be57f68e64
commit 6075362709
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 172 additions and 94 deletions

View File

@ -6191,13 +6191,14 @@ bool Collection::get_enable_nested_fields() {
Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector<facet>& facets) const {
const std::regex base_pattern(".+\\(.*\\)");
const std::regex range_pattern("[[0-9]*[a-z A-Z]+[0-9]*:\\[([+-]?([0-9]*[.])?[0-9]*)\\,\\s*([+-]?([0-9]*[.])?[0-9]*)\\]");
const std::regex range_pattern(
"[[0-9]*[a-z A-Z]+[0-9]*:\\[([+-]?([0-9]*[.])?[0-9]*)\\,\\s*([+-]?([0-9]*[.])?[0-9]*)\\]");
const std::string _alpha = "_alpha";
if ((facet_field.find(":") != std::string::npos)
&& (facet_field.find("sort_by") == std::string::npos)) { //range based facet
if((facet_field.find(":") != std::string::npos)
&& (facet_field.find("sort_by") == std::string::npos)) { //range based facet
if (!std::regex_match(facet_field, base_pattern)) {
if(!std::regex_match(facet_field, base_pattern)) {
std::string error = "Facet range value is not valid.";
return Option<bool>(400, error);
}
@ -6211,7 +6212,7 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
}
if((field_name.find("sort") == std::string::npos)
&& (facet_field.find("sort") != std::string::npos)) {
&& (facet_field.find("sort") != std::string::npos)) {
//sort keyword is found in facet string but not in facet field
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
@ -6219,11 +6220,15 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
const field& a_field = search_schema.at(field_name);
if(!a_field.is_integer() && !a_field.is_float()){
if(!a_field.is_integer() && !a_field.is_float()) {
std::string error = "Range facet is restricted to only integer and float fields.";
return Option<bool>(400, error);
}
if(!a_field.sort) {
return Option<bool>(400, "Range facets require sort enabled for the field.");
}
facet a_facet(field_name, facets.size());
//starting after "(" and excluding ")"
@ -6232,32 +6237,28 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
//split the ranges
std::vector<std::string> result;
startpos = 0;
int index=0;
int index = 0;
int commaFound = 0, rangeFound = 0;
bool range_open=false;
while(index < range_string.size()){
if(range_string[index] == ']'){
if(range_open == true){
bool range_open = false;
while(index < range_string.size()) {
if(range_string[index] == ']') {
if(range_open == true) {
std::string range = range_string.substr(startpos, index + 1 - startpos);
range=StringUtils::trim(range);
range = StringUtils::trim(range);
result.emplace_back(range);
rangeFound++;
range_open=false;
}
else{
range_open = false;
} else {
result.clear();
break;
}
}
else if(range_string[index] == ',' && range_open == false){
startpos = index+1;
} else if(range_string[index] == ',' && range_open == false) {
startpos = index + 1;
commaFound++;
}
else if(range_string[index] == '['){
if((commaFound == rangeFound) && range_open==false){
range_open=true;
}
else{
} else if(range_string[index] == '[') {
if((commaFound == rangeFound) && range_open == false) {
range_open = true;
} else {
result.clear();
break;
}
@ -6266,7 +6267,7 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
index++;
}
if((result.empty()) || (range_open==true)){
if((result.empty()) || (range_open == true)) {
std::string error = "Error splitting the facet range values.";
return Option<bool>(400, error);
}
@ -6275,9 +6276,9 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
auto& range_map = a_facet.facet_range_map;
range_map.clear();
for(const auto& range : result){
for(const auto& range: result) {
//validate each range syntax
if(!std::regex_match(range, range_pattern)){
if(!std::regex_match(range, range_pattern)) {
std::string error = "Facet range value is not valid.";
return Option<bool>(400, error);
}
@ -6339,27 +6340,27 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
//sort the range values so that we can check continuity
sort(tupVec.begin(), tupVec.end());
for(const auto& tup : tupVec){
for(const auto& tup: tupVec) {
const auto& lower_range = std::get<0>(tup);
const auto& upper_range = std::get<1>(tup);
const std::string& range_val = std::get<2>(tup);
//check if ranges are continous or not
if((!range_map.empty()) && (range_map.find(lower_range)== range_map.end())){
if((!range_map.empty()) && (range_map.find(lower_range) == range_map.end())) {
std::string error = "Ranges in range facet syntax should be continous.";
return Option<bool>(400, error);
}
range_map[upper_range] = range_specs_t{range_val, lower_range};
range_map[upper_range] = range_specs_t{range_val, lower_range};
}
a_facet.is_range_query = true;
facets.emplace_back(std::move(a_facet));
} else if (facet_field.find('*') != std::string::npos) { // Wildcard
if (facet_field[facet_field.size() - 1] != '*') {
return Option<bool>(404, "Only prefix matching with a wildcard is allowed.");
}
} else if(facet_field.find('*') != std::string::npos) { // Wildcard
if(facet_field[facet_field.size() - 1] != '*') {
return Option<bool>(404, "Only prefix matching with a wildcard is allowed.");
}
// Trim * from the end.
auto prefix = facet_field.substr(0, facet_field.size() - 1);
@ -6372,80 +6373,80 @@ Option<bool> Collection::parse_facet(const std::string& facet_field, std::vector
}
// Collect the fields that match the prefix and are marked as facet.
for (auto field = pair.first; field != pair.second; field++) {
if (field->facet) {
for(auto field = pair.first; field != pair.second; field++) {
if(field->facet) {
facets.emplace_back(facet(field->name, facets.size()));
facets.back().is_wildcard_match = true;
}
}
} else {
// normal facet
std::string order = "";
bool sort_alpha = false;
std::string sort_field = "";
std::string facet_field_copy = facet_field;
auto pos = facet_field_copy.find("(");
if(pos != std::string::npos) {
facet_field_copy = facet_field_copy.substr(0, pos);
}
// normal facet
std::string order = "";
bool sort_alpha = false;
std::string sort_field = "";
std::string facet_field_copy = facet_field;
auto pos = facet_field_copy.find("(");
if(pos != std::string::npos) {
facet_field_copy = facet_field_copy.substr(0, pos);
}
if (search_schema.count(facet_field_copy) == 0 || !search_schema.at(facet_field_copy).facet) {
std::string error = "Could not find a facet field named `" + facet_field_copy + "` in the schema.";
return Option<bool>(404, error);
}
if(search_schema.count(facet_field_copy) == 0 || !search_schema.at(facet_field_copy).facet) {
std::string error = "Could not find a facet field named `" + facet_field_copy + "` in the schema.";
return Option<bool>(404, error);
}
if (facet_field.find("sort_by") != std::string::npos) { //sort params are supplied with facet
std::vector<std::string> tokens;
StringUtils::split(facet_field, tokens, ":");
if(facet_field.find("sort_by") != std::string::npos) { //sort params are supplied with facet
std::vector<std::string> tokens;
StringUtils::split(facet_field, tokens, ":");
if(tokens.size() != 3) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
if(tokens.size() != 3) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
//remove possible whitespaces
for(auto i=0; i < 3; ++i) {
StringUtils::trim(tokens[i]);
}
//remove possible whitespaces
for(auto i = 0; i < 3; ++i) {
StringUtils::trim(tokens[i]);
}
if(tokens[1] == _alpha) {
const field &a_field = search_schema.at(facet_field_copy);
if (!a_field.is_string()) {
std::string error = "Facet field should be string type to apply alpha sort.";
return Option<bool>(400, error);
}
sort_alpha = true;
} else { //sort_field based sort
sort_field = tokens[1];
if(tokens[1] == _alpha) {
const field& a_field = search_schema.at(facet_field_copy);
if(!a_field.is_string()) {
std::string error = "Facet field should be string type to apply alpha sort.";
return Option<bool>(400, error);
}
sort_alpha = true;
} else { //sort_field based sort
sort_field = tokens[1];
if (search_schema.count(sort_field) == 0 || !search_schema.at(sort_field).facet) {
std::string error = "Could not find a facet field named `" + sort_field + "` in the schema.";
return Option<bool>(404, error);
}
if(search_schema.count(sort_field) == 0 || !search_schema.at(sort_field).facet) {
std::string error = "Could not find a facet field named `" + sort_field + "` in the schema.";
return Option<bool>(404, error);
}
const field &a_field = search_schema.at(sort_field);
if (a_field.is_string()) {
std::string error = "Sort field should be non string type to apply sort.";
return Option<bool>(400, error);
}
}
const field& a_field = search_schema.at(sort_field);
if(a_field.is_string()) {
std::string error = "Sort field should be non string type to apply sort.";
return Option<bool>(400, error);
}
}
if (tokens[2].find("asc") != std::string::npos) {
order = "asc";
} else if (tokens[2].find("desc") != std::string::npos) {
order = "desc";
} else {
std::string error = "Invalid sort param.";
return Option<bool>(400, error);
}
} else if (facet_field != facet_field_copy) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
if(tokens[2].find("asc") != std::string::npos) {
order = "asc";
} else if(tokens[2].find("desc") != std::string::npos) {
order = "desc";
} else {
std::string error = "Invalid sort param.";
return Option<bool>(400, error);
}
} else if(facet_field != facet_field_copy) {
std::string error = "Invalid sort format.";
return Option<bool>(400, error);
}
facets.emplace_back(facet(facet_field_copy, facets.size(), {}, false, sort_alpha,
order, sort_field));
}
facets.emplace_back(facet(facet_field_copy, facets.size(), {}, false, sort_alpha,
order, sort_field));
}
return Option<bool>(true);
}

View File

@ -5780,6 +5780,7 @@ void Index::compute_facet_infos(const std::vector<facet>& facets, facet_query_t&
bool facet_value_index_exists = facet_index_v4->has_value_index(facet_field.name);
//as we use sort index for range facets with hash based index, sort index should be present
if(facet_index_type == exhaustive) {
facet_infos[findex].use_value_index = false;
}

View File

@ -3137,3 +3137,36 @@ TEST_F(CollectionFacetingTest, FacetingWithCoercedString) {
ASSERT_EQ(3, results["facet_counts"][0]["counts"].size());
ASSERT_EQ(1, results["facet_counts"][0]["counts"][0]["count"]);
}
TEST_F(CollectionFacetingTest, RangeFacetsWithSortDisabled) {
std::vector<field> fields = {field("name", field_types::STRING, false, false, true, "", 1),
field("brand", field_types::STRING, true, false, true, "", 0),
field("price", field_types::FLOAT, true, false, true, "", 0)};
Collection* coll2 = collectionManager.create_collection(
"coll2", 1, fields, "", 0, "",
{},{}).get();
nlohmann::json doc;
doc["name"] = "keyboard";
doc["id"] = "pd-1";
doc["brand"] = "Logitech";
doc["price"] = 49.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
doc["name"] = "mouse";
doc["id"] = "pd-2";
doc["brand"] = "Logitech";
doc["price"] = 29.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
auto results = coll2->search("*", {}, "brand:=Logitech",
{"price(Low:[0, 30], Medium:[30, 75], High:[75, ])"}, {}, {2},
10, 1, FREQUENCY, {true});
//if no facet index is provided then it uses hash index
//hash index requires sort enabled for field for range faceting
ASSERT_FALSE(results.ok());
ASSERT_EQ("Range facets require sort enabled for the field.", results.error());
}

View File

@ -2948,3 +2948,46 @@ TEST_F(CollectionOptimizedFacetingTest, RangeFacetRangeLabelWithSpace) {
ASSERT_EQ(1, (int) results["facet_counts"][0]["counts"][0]["count"]);
ASSERT_EQ("small tvs with display size", results["facet_counts"][0]["counts"][0]["value"]);
}
TEST_F(CollectionOptimizedFacetingTest, RangeFacetsWithSortDisabled) {
std::vector<field> fields = {field("name", field_types::STRING, false, false, true, "", 1),
field("brand", field_types::STRING, true, false, true, "", -1),
field("price", field_types::FLOAT, true, false, true, "", -1)};
Collection* coll2 = collectionManager.create_collection(
"coll2", 1, fields, "", 0, "",
{},{}).get();
nlohmann::json doc;
doc["name"] = "keyboard";
doc["id"] = "pd-1";
doc["brand"] = "Logitech";
doc["price"] = 49.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
doc["name"] = "mouse";
doc["id"] = "pd-2";
doc["brand"] = "Logitech";
doc["price"] = 29.99;
ASSERT_TRUE(coll2->add(doc.dump()).ok());
auto results = coll2->search("*", {},
"brand:=Logitech", {"price(Low:[0, 30], Medium:[30, 75], High:[75, ])"},
{}, {2}, 10,
1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 10, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000,
true, false, true, "", true,
6000*1000, 4, 7, fallback, 4, {off}, INT16_MAX, INT16_MAX,
2, 2, false, "", true, 0, max_score, 100, 0, 0, "top_values").get();
//when value index is forced it works
ASSERT_EQ(2, results["facet_counts"][0]["counts"].size());
ASSERT_EQ(1, results["facet_counts"][0]["counts"][0]["count"]);
ASSERT_EQ("Low", results["facet_counts"][0]["counts"][0]["value"]);
ASSERT_EQ(1, results["facet_counts"][0]["counts"][1]["count"]);
ASSERT_EQ("Medium", results["facet_counts"][0]["counts"][1]["value"]);
}