mirror of
https://github.com/typesense/typesense.git
synced 2025-05-21 22:33:27 +08:00
Add range_index
property.
This commit is contained in:
parent
1f4643cba0
commit
b4a70682c6
@ -54,6 +54,7 @@ namespace fields {
|
||||
static const std::string from = "from";
|
||||
static const std::string embed_from = "embed_from";
|
||||
static const std::string model_name = "model_name";
|
||||
static const std::string range_index = "range_index";
|
||||
|
||||
// Some models require additional parameters to be passed to the model during indexing/querying
|
||||
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying
|
||||
@ -93,13 +94,17 @@ struct field {
|
||||
|
||||
std::string reference; // Foo.bar (reference to bar field in Foo collection).
|
||||
|
||||
bool range_index;
|
||||
|
||||
field() {}
|
||||
|
||||
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
|
||||
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) :
|
||||
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
|
||||
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) :
|
||||
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) {
|
||||
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
|
||||
embed(embed), range_index(range_index) {
|
||||
|
||||
set_computed_defaults(sort, infix);
|
||||
}
|
||||
|
@ -30,6 +30,7 @@
|
||||
#include "vector_query_ops.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
#include "filter.h"
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
static constexpr size_t ARRAY_FACET_DIM = 4;
|
||||
using facet_map_t = spp::sparse_hash_map<uint32_t, facet_hash_values_t>;
|
||||
@ -305,6 +306,8 @@ private:
|
||||
|
||||
spp::sparse_hash_map<std::string, num_tree_t*> numerical_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, NumericTrie*> range_index;
|
||||
|
||||
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::vector<uint32_t>>*> geopoint_index;
|
||||
|
||||
// geo_array_field => (seq_id => values) used for exact filtering of geo array records
|
||||
|
@ -75,6 +75,23 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::reference] = "";
|
||||
}
|
||||
|
||||
if (field_json.count(fields::range_index) != 0) {
|
||||
if (!field_json.at(fields::range_index).is_boolean()) {
|
||||
return Option<bool>(400, std::string("The `range_index` property of the field `") +
|
||||
field_json[fields::name].get<std::string>() +
|
||||
std::string("` should be a boolean."));
|
||||
}
|
||||
|
||||
auto const& type = field_json["type"];
|
||||
if (type != field_types::INT32 && type != field_types::INT32_ARRAY &&
|
||||
type != field_types::INT64 && type != field_types::INT64_ARRAY &&
|
||||
type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) {
|
||||
return Option<bool>(400, std::string("The `range_index` property is only allowed for the numerical fields`"));
|
||||
}
|
||||
} else {
|
||||
field_json[fields::range_index] = false;
|
||||
}
|
||||
|
||||
if(field_json["name"] == ".*") {
|
||||
if(field_json.count(fields::facet) == 0) {
|
||||
field_json[fields::facet] = false;
|
||||
@ -297,7 +314,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
|
||||
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
|
||||
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
|
||||
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
|
||||
field_json[fields::reference], field_json[fields::embed])
|
||||
field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index])
|
||||
);
|
||||
|
||||
if (!field_json[fields::reference].get<std::string>().empty()) {
|
||||
|
@ -646,27 +646,62 @@ void filter_result_iterator_t::init() {
|
||||
field f = index->search_schema.at(a_filter.field_name);
|
||||
|
||||
if (f.is_integer()) {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
if (f.is_int32() && f.range_index) {
|
||||
auto const& trie = index->range_index.at(a_filter.field_name);
|
||||
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t)std::stol(filter_value);
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
auto const& value = (int32_t)std::stoi(filter_value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const& range_end_value = (int32_t)std::stoi(next_filter_value);
|
||||
trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == EQUALS) {
|
||||
trie->search_equal_to(value, filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
uint32_t* to_exclude_ids = nullptr;
|
||||
uint32_t to_exclude_ids_len = 0;
|
||||
trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len);
|
||||
|
||||
auto all_ids = index->seq_ids->uncompress();
|
||||
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
|
||||
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
|
||||
|
||||
delete[] all_ids;
|
||||
delete[] to_exclude_ids;
|
||||
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
|
||||
trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
|
||||
trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS,
|
||||
filter_result.docs, filter_result.count);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto num_tree = index->numerical_index.at(a_filter.field_name);
|
||||
|
||||
filter_result.count = result_size;
|
||||
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
|
||||
const std::string& filter_value = a_filter.values[fi];
|
||||
int64_t value = (int64_t)std::stol(filter_value);
|
||||
|
||||
size_t result_size = filter_result.count;
|
||||
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
|
||||
const std::string& next_filter_value = a_filter.values[fi + 1];
|
||||
auto const range_end_value = (int64_t)std::stol(next_filter_value);
|
||||
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
|
||||
fi++;
|
||||
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
|
||||
numeric_not_equals_filter(num_tree, value,
|
||||
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
|
||||
filter_result.docs, result_size);
|
||||
} else {
|
||||
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
|
||||
}
|
||||
|
||||
filter_result.count = result_size;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_filter.apply_not_equals) {
|
||||
|
@ -88,6 +88,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
|
||||
} else {
|
||||
num_tree_t* num_tree = new num_tree_t;
|
||||
numerical_index.emplace(a_field.name, num_tree);
|
||||
|
||||
if (a_field.range_index) {
|
||||
auto trie = new NumericTrie();
|
||||
range_index.emplace(a_field.name, trie);
|
||||
}
|
||||
}
|
||||
|
||||
if(a_field.sort) {
|
||||
@ -161,6 +166,13 @@ Index::~Index() {
|
||||
|
||||
numerical_index.clear();
|
||||
|
||||
for(auto & name_tree: range_index) {
|
||||
delete name_tree.second;
|
||||
name_tree.second = nullptr;
|
||||
}
|
||||
|
||||
range_index.clear();
|
||||
|
||||
for(auto & name_map: sort_index) {
|
||||
delete name_map.second;
|
||||
name_map.second = nullptr;
|
||||
@ -737,6 +749,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
if(!afield.is_string()) {
|
||||
if (afield.type == field_types::INT32) {
|
||||
if (afield.range_index) {
|
||||
auto const& trie = range_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
int32_t value = record.doc[afield.name].get<int32_t>();
|
||||
trie->insert(value, seq_id);
|
||||
});
|
||||
}
|
||||
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
@ -899,13 +920,19 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
|
||||
|
||||
// all other numerical arrays
|
||||
auto num_tree = numerical_index.at(afield.name);
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
|
||||
auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr;
|
||||
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie]
|
||||
(const index_record& record, uint32_t seq_id) {
|
||||
for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) {
|
||||
const auto& arr_value = record.doc[afield.name][arr_i];
|
||||
|
||||
if(afield.type == field_types::INT32_ARRAY) {
|
||||
const int32_t value = arr_value;
|
||||
|
||||
if (afield.range_index) {
|
||||
trie->insert(value, seq_id);
|
||||
}
|
||||
|
||||
num_tree->insert(value, seq_id);
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,35 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <collection_manager.h>
|
||||
#include "collection.h"
|
||||
#include "numeric_range_trie_test.h"
|
||||
|
||||
class NumericRangeTrieTest : public ::testing::Test {
|
||||
protected:
|
||||
Store *store;
|
||||
CollectionManager & collectionManager = CollectionManager::get_instance();
|
||||
std::atomic<bool> quit = false;
|
||||
|
||||
virtual void SetUp() {}
|
||||
std::vector<std::string> query_fields;
|
||||
std::vector<sort_by> sort_fields;
|
||||
|
||||
virtual void TearDown() {}
|
||||
void setupCollection() {
|
||||
std::string state_dir_path = "/tmp/typesense_test/collection_filtering";
|
||||
LOG(INFO) << "Truncating and creating: " << state_dir_path;
|
||||
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
|
||||
|
||||
store = new Store(state_dir_path);
|
||||
collectionManager.init(store, 1.0, "auth_key", quit);
|
||||
collectionManager.load(8, 1000);
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
setupCollection();
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
collectionManager.dispose();
|
||||
delete store;
|
||||
}
|
||||
};
|
||||
|
||||
void reset(uint32_t*& ids, uint32_t& ids_length) {
|
||||
@ -570,3 +593,70 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) {
|
||||
|
||||
ASSERT_EQ(0, ids_length);
|
||||
}
|
||||
|
||||
TEST_F(NumericRangeTrieTest, Integration) {
|
||||
Collection *coll_array_fields;
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
|
||||
std::vector<field> fields = {
|
||||
field("name", field_types::STRING, false),
|
||||
field("rating", field_types::FLOAT, false),
|
||||
field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(),
|
||||
true), // Setting range index true.
|
||||
field("years", field_types::INT32_ARRAY, false),
|
||||
field("timestamps", field_types::INT64_ARRAY, false),
|
||||
field("tags", field_types::STRING_ARRAY, true)
|
||||
};
|
||||
|
||||
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
|
||||
|
||||
coll_array_fields = collectionManager.get_collection("coll_array_fields").get();
|
||||
if(coll_array_fields == nullptr) {
|
||||
// ensure that default_sorting_field is a non-array numerical field
|
||||
auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years");
|
||||
ASSERT_EQ(false, coll_op.ok());
|
||||
ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str());
|
||||
|
||||
// let's try again properly
|
||||
coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age");
|
||||
coll_array_fields = coll_op.get();
|
||||
}
|
||||
|
||||
std::string json_line;
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
auto add_op = coll_array_fields->add(json_line);
|
||||
LOG(INFO) << add_op.error();
|
||||
ASSERT_TRUE(add_op.ok());
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
// Plain search with no filters - results should be sorted by rank fields
|
||||
query_fields = {"name"};
|
||||
std::vector<std::string> facets;
|
||||
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"3", "1", "4", "0", "2"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// Searching on an int32 field
|
||||
results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
|
||||
ASSERT_EQ(3, results["hits"].size());
|
||||
|
||||
ids = {"3", "1", "4"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user