mirror of
https://github.com/typesense/typesense.git
synced 2025-05-17 12:12:35 +08:00
Range operator for numerical filtering.
This commit is contained in:
parent
488a579ced
commit
0ad8c48115
@ -132,7 +132,8 @@ enum NUM_COMPARATOR {
|
||||
EQUALS,
|
||||
CONTAINS,
|
||||
GREATER_THAN,
|
||||
GREATER_THAN_EQUALS
|
||||
GREATER_THAN_EQUALS,
|
||||
RANGE_INCLUSIVE
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -102,6 +102,26 @@ struct filter {
|
||||
std::vector<std::string> values;
|
||||
std::vector<NUM_COMPARATOR> comparators;
|
||||
|
||||
static const std::string RANGE_OPERATOR() {
|
||||
return "..";
|
||||
}
|
||||
|
||||
static Option<bool> validate_numerical_filter_value(field _field, const std::string& raw_value) {
|
||||
if(_field.is_int32() && !StringUtils::is_int32_t(raw_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int32.");
|
||||
}
|
||||
|
||||
else if(_field.is_int64() && !StringUtils::is_int64_t(raw_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int64.");
|
||||
}
|
||||
|
||||
else if(_field.is_float() && !StringUtils::is_float(raw_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not a float.");
|
||||
}
|
||||
|
||||
return Option<bool>(true);
|
||||
}
|
||||
|
||||
static Option<NUM_COMPARATOR> extract_num_comparator(std::string & comp_and_value) {
|
||||
auto num_comparator = EQUALS;
|
||||
|
||||
@ -126,6 +146,10 @@ struct filter {
|
||||
num_comparator = GREATER_THAN;
|
||||
}
|
||||
|
||||
else if(comp_and_value.find("..") != std::string::npos) {
|
||||
num_comparator = RANGE_INCLUSIVE;
|
||||
}
|
||||
|
||||
else {
|
||||
return Option<NUM_COMPARATOR>(400, "Numerical field has an invalid comparator.");
|
||||
}
|
||||
|
@ -25,6 +25,32 @@ public:
|
||||
int64map[value]->append(id);
|
||||
}
|
||||
|
||||
void range_inclusive_search(int64_t start, int64_t end, uint32_t** ids, size_t& ids_len) {
|
||||
if(int64map.empty()) {
|
||||
return ;
|
||||
}
|
||||
|
||||
auto it_start = int64map.lower_bound(start); // iter values will be >= start
|
||||
|
||||
std::vector<uint32_t> consolidated_ids;
|
||||
while(it_start != int64map.end() && it_start->first <= end) {
|
||||
for(size_t i = 0; i < it_start->second->getLength(); i++) {
|
||||
consolidated_ids.push_back(it_start->second->at(i));
|
||||
}
|
||||
|
||||
it_start++;
|
||||
}
|
||||
|
||||
std::sort(consolidated_ids.begin(), consolidated_ids.end());
|
||||
|
||||
uint32_t *out = nullptr;
|
||||
ids_len = ArrayUtils::or_scalar(&consolidated_ids[0], consolidated_ids.size(),
|
||||
*ids, ids_len, &out);
|
||||
|
||||
delete [] *ids;
|
||||
*ids = out;
|
||||
}
|
||||
|
||||
void search(NUM_COMPARATOR comparator, int64_t value, uint32_t** ids, size_t& ids_len) {
|
||||
if(int64map.empty()) {
|
||||
return ;
|
||||
|
@ -1800,26 +1800,28 @@ Option<bool> Collection::parse_filter_query(const std::string& simple_filter_que
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
|
||||
}
|
||||
|
||||
if(_field.is_int32()) {
|
||||
if(!StringUtils::is_int32_t(filter_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int32.");
|
||||
}
|
||||
}
|
||||
if(op_comparator.get() == RANGE_INCLUSIVE) {
|
||||
// split the value around range operator to extract bounds
|
||||
std::vector<std::string> range_values;
|
||||
StringUtils::split(filter_value, range_values, filter::RANGE_OPERATOR());
|
||||
for(const std::string& range_value: range_values) {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
|
||||
if(!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
|
||||
else if(_field.is_int64()) {
|
||||
if(!StringUtils::is_int64_t(filter_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int64.");
|
||||
f.values.push_back(range_value);
|
||||
f.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
}
|
||||
|
||||
else if(_field.is_float()) {
|
||||
if(!StringUtils::is_float(filter_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not a float.");
|
||||
} else {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, filter_value);
|
||||
if(!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
}
|
||||
|
||||
f.values.push_back(filter_value);
|
||||
f.comparators.push_back(op_comparator.get());
|
||||
f.values.push_back(filter_value);
|
||||
f.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
@ -1828,25 +1830,28 @@ Option<bool> Collection::parse_filter_query(const std::string& simple_filter_que
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: " + op_comparator.error());
|
||||
}
|
||||
|
||||
if(_field.is_int32()) {
|
||||
if(!StringUtils::is_int32_t(raw_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int32.");
|
||||
}
|
||||
}
|
||||
if(op_comparator.get() == RANGE_INCLUSIVE) {
|
||||
// split the value around range operator to extract bounds
|
||||
std::vector<std::string> range_values;
|
||||
StringUtils::split(raw_value, range_values, filter::RANGE_OPERATOR());
|
||||
|
||||
else if(_field.is_int64()) {
|
||||
if(!StringUtils::is_int64_t(raw_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not an int64.");
|
||||
}
|
||||
}
|
||||
f.field_name = field_name;
|
||||
for(const std::string& range_value: range_values) {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, range_value);
|
||||
if(!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
|
||||
else if(_field.is_float()) {
|
||||
if(!StringUtils::is_float(raw_value)) {
|
||||
return Option<bool>(400, "Error with filter field `" + _field.name + "`: Not a float.");
|
||||
f.values.push_back(range_value);
|
||||
f.comparators.push_back(op_comparator.get());
|
||||
}
|
||||
} else {
|
||||
auto validate_op = filter::validate_numerical_filter_value(_field, raw_value);
|
||||
if(!validate_op.ok()) {
|
||||
return validate_op;
|
||||
}
|
||||
f = {field_name, {raw_value}, {op_comparator.get()}};
|
||||
}
|
||||
|
||||
f = {field_name, {raw_value}, {op_comparator.get()}};
|
||||
}
|
||||
} else if(_field.is_bool()) {
|
||||
if(raw_value[0] == '[' && raw_value[raw_value.size() - 1] == ']') {
|
||||
|
@ -977,29 +977,36 @@ Option<uint32_t> Index::do_filtering(uint32_t** filter_ids_out, const std::vecto
|
||||
if(f.is_integer()) {
|
||||
auto num_tree = numerical_index.at(a_filter.field_name);
|
||||
|
||||
size_t value_index = 0;
|
||||
for(const std::string & filter_value: a_filter.values) {
|
||||
if(f.type == field_types::INT32 || f.type == field_types::INT32_ARRAY) {
|
||||
// check for comparator again
|
||||
int32_t value = (int32_t) std::stoi(filter_value);
|
||||
num_tree->search(a_filter.comparators[value_index], value, &result_ids, result_ids_len);
|
||||
} else { // int64
|
||||
int64_t value = (int64_t) std::stol(filter_value);
|
||||
num_tree->search(a_filter.comparators[value_index], value, &result_ids, result_ids_len);
|
||||
}
|
||||
for(size_t fi=0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string & filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t) std::stol(filter_value);
|
||||
|
||||
value_index++;
|
||||
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = (int64_t) std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &result_ids, result_ids_len);
|
||||
fi++;
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &result_ids, result_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
} else if(f.is_float()) {
|
||||
auto num_tree = numerical_index.at(a_filter.field_name);
|
||||
|
||||
size_t value_index = 0;
|
||||
for(const std::string & filter_value: a_filter.values) {
|
||||
for(size_t fi=0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string & filter_value = a_filter.values[fi];
|
||||
float value = (float) std::atof(filter_value.c_str());
|
||||
int64_t float_int64 = float_to_in64_t(value);
|
||||
num_tree->search(a_filter.comparators[value_index], float_int64, &result_ids, result_ids_len);
|
||||
value_index++;
|
||||
|
||||
if(a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi+1];
|
||||
int64_t range_end_value = float_to_in64_t((float) std::atof(next_filter_value.c_str()));
|
||||
num_tree->range_inclusive_search(float_int64, range_end_value, &result_ids, result_ids_len);
|
||||
fi++;
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], float_int64, &result_ids, result_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
} else if(f.is_bool()) {
|
||||
|
@ -323,6 +323,7 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
|
||||
std::vector<field> fields = {
|
||||
field("name", field_types::STRING, false),
|
||||
field("rating", field_types::FLOAT, false),
|
||||
field("age", field_types::INT32, false),
|
||||
field("years", field_types::INT32_ARRAY, false),
|
||||
field("timestamps", field_types::INT64_ARRAY, false),
|
||||
@ -468,6 +469,46 @@ TEST_F(CollectionFilteringTest, FilterOnNumericFields) {
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// range based filter
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age: 21..32", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
ids = {"4", "0", "2"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age: 0 .. 100", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age: [21..24, 40..65]", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(4, results["hits"].size());
|
||||
|
||||
ids = {"3", "1", "0", "2"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating: 7.812 .. 9.999", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(2, results["hits"].size());
|
||||
|
||||
ids = {"1", "2"};
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "rating: [7.812 .. 9.999, 1.05 .. 1.09]", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
// when filters don't match any record, no results should be returned
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "timestamps:>1591091288061", facets, sort_fields, 0, 10, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(0, results["hits"].size());
|
||||
|
Loading…
x
Reference in New Issue
Block a user