Range operator for numerical filtering.

This commit is contained in:
Jason Bosco 2021-01-29 13:57:52 -08:00
parent 488a579ced
commit 0ad8c48115
6 changed files with 151 additions and 47 deletions

View File

@ -132,7 +132,8 @@ enum NUM_COMPARATOR {
EQUALS,
CONTAINS,
GREATER_THAN,
GREATER_THAN_EQUALS
GREATER_THAN_EQUALS,
RANGE_INCLUSIVE
};
/**

View File

@ -102,6 +102,26 @@ struct filter {
std::vector<std::string> values;
std::vector<NUM_COMPARATOR> comparators;
static const std::string RANGE_OPERATOR() {
return "..";
}
static Option<bool> validate_numerical_filter_value(field _field, const std::string& raw_value) {
if(_field.is_int32() && !StringUtils::is_int32_t(raw_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int32.");
}
else if(_field.is_int64() && !StringUtils::is_int64_t(raw_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int64.");
}
else if(_field.is_float() && !StringUtils::is_float(raw_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not a float.");
}
return Option<bool>(true);
}
static Option<NUM_COMPARATOR> extract_num_comparator(std::string & comp_and_value) {
auto num_comparator = EQUALS;
@ -126,6 +146,10 @@ struct filter {
num_comparator = GREATER_THAN;
}
else if(comp_and_value.find("..") != std::string::npos) {
num_comparator = RANGE_INCLUSIVE;
}
else {
return Option<NUM_COMPARATOR>(400, "Numerical field has an invalid comparator.");
}

View File

@ -25,6 +25,32 @@ public:
int64map[value]->append(id);
}
void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len) {
if(int64map.empty()) {
return ;
}
auto it_start = int64map.lower_bound(start); // iter values will be >= start
std::vector<uint32_t> consolidated_ids;
while(it_start != int64map.end() && it_start->first <= end) {
for(size_t i = 0; i < it_start->second->getLength(); i++) {
consolidated_ids.push_back(it_start->second->at(i));
}
it_start++;
}
std::sort(consolidated_ids.begin(), consolidated_ids.end());
uint32_t *out = nullptr;
ids_len = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
*ids, ids_len, &out);
delete [] *ids;
*ids = out;
}
void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len) {
if(int64map.empty()) {
return ;

View File

@ -1800,26 +1800,28 @@ Option<bool> Collection::parse_filter_query(const std::string& simple_filter_que
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
}
if(_field.is_int32()) {
if(!StringUtils::is_int32_t(filter_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int32.");
}
}
if(op_comparator.get() == RANGE_INCLUSIVE) {
// split the value around range operator to extract bounds
std::vector<std::string> range_values;
StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR());
for(const std::string& range_value: range_values) {
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
if(!validate_op.ok()) {
return validate_op;
}
else if(_field.is_int64()) {
if(!StringUtils::is_int64_t(filter_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int64.");
f.values.push_back(range_value);
f.comparators.push_back(op_comparator.get());
}
}
else if(_field.is_float()) {
if(!StringUtils::is_float(filter_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not a float.");
} else {
auto validate_op = filter::validate_numerical_filter_value(_field, filter_value);
if(!validate_op.ok()) {
return validate_op;
}
}
f.values.push_back(filter_value);
f.comparators.push_back(op_comparator.get());
f.values.push_back(filter_value);
f.comparators.push_back(op_comparator.get());
}
}
} else {
@ -1828,25 +1830,28 @@ Option<bool> Collection::parse_filter_query(const std::string& simple_filter_que
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
}
if(_field.is_int32()) {
if(!StringUtils::is_int32_t(raw_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int32.");
}
}
if(op_comparator.get() == RANGE_INCLUSIVE) {
// split the value around range operator to extract bounds
std::vector<std::string> range_values;
StringUtils::split(raw_value, range_values, filter::RANGE_OPERATOR());
else if(_field.is_int64()) {
if(!StringUtils::is_int64_t(raw_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int64.");
}
}
f.field_name = field_name;
for(const std::string& range_value: range_values) {
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
if(!validate_op.ok()) {
return validate_op;
}
else if(_field.is_float()) {
if(!StringUtils::is_float(raw_value)) {
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not a float.");
f.values.push_back(range_value);
f.comparators.push_back(op_comparator.get());
}
} else {
auto validate_op = filter::validate_numerical_filter_value(_field, raw_value);
if(!validate_op.ok()) {
return validate_op;
}
f = {field_name, {raw_value}, {op_comparator.get()}};
}
f = {field_name, {raw_value}, {op_comparator.get()}};
}
} else if(_field.is_bool()) {
if(raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') {

View File

@ -977,29 +977,36 @@ Option<uint32_t> Index::do_filtering(uint32_t** filter_ids_out, const std::vecto
if(f.is_integer()) {
auto num_tree = numerical_index.at(a_filter.field_name);
size_t value_index = 0;
for(const std::string & filter_value: a_filter.values) {
if(f.type == field_types::INT32 || f.type == field_types::INT32_ARRAY) {
// check for comparator again
int32_t value = (int32_t) std::stoi(filter_value);
num_tree->search(a_filter.comparators[value_index], value, &result_ids, result_ids_len);
} else { // int64
int64_t value = (int64_t) std::stol(filter_value);
num_tree->search(a_filter.comparators[value_index], value, &result_ids, result_ids_len);
}
for(size_t fi=0; fi < a_filter.values.size(); fi++) {
const std::string & filter_value = a_filter.values[fi];
int64_t value = (int64_t) std::stol(filter_value);
value_index++;
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = (int64_t) std::stol(next_filter_value);
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
fi++;
} else {
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
}
}
} else if(f.is_float()) {
auto num_tree = numerical_index.at(a_filter.field_name);
size_t value_index = 0;
for(const std::string & filter_value: a_filter.values) {
for(size_t fi=0; fi < a_filter.values.size(); fi++) {
const std::string & filter_value = a_filter.values[fi];
float value = (float) std::atof(filter_value.c_str());
int64_t float_int64 = float_to_in64_t(value);
num_tree->search(a_filter.comparators[value_index], float_int64, &result_ids, result_ids_len);
value_index++;
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi+1];
int64_t range_end_value = float_to_in64_t((float) std::atof(next_filter_value.c_str()));
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
fi++;
} else {
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
}
}
} else if(f.is_bool()) {

View File

@ -323,6 +323,7 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::vector<field> fields = {
field("name", field_types::STRING, false),
field("rating", field_types::FLOAT, false),
field("age", field_types::INT32, false),
field("years", field_types::INT32_ARRAY, false),
field("timestamps", field_types::INT64_ARRAY, false),
@ -468,6 +469,46 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// range based filter
results = coll_array_fields->search("Jeremy", query_fields, "age: 21..32", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(3, results["hits"].size());
ids = {"4", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
results = coll_array_fields->search("Jeremy", query_fields, "age: 0 .. 100", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(5, results["hits"].size());
results = coll_array_fields->search("Jeremy", query_fields, "age: [21..24, 40..65]", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(4, results["hits"].size());
ids = {"3", "1", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
results = coll_array_fields->search("Jeremy", query_fields, "rating: 7.812 .. 9.999", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(2, results["hits"].size());
ids = {"1", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(3, results["hits"].size());
// when filters don't match any record, no results should be returned
results = coll_array_fields->search("Jeremy", query_fields, "timestamps:>1591091288061", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
ASSERT_EQ(0, results["hits"].size());