diff --git a/include/field.h b/include/field.h index cbd62af2..a0eca2af 100644 --- a/include/field.h +++ b/include/field.h @@ -54,6 +54,7 @@ namespace fields { static const std::string from = "from"; static const std::string embed_from = "embed_from"; static const std::string model_name = "model_name"; + static const std::string range_index = "range_index"; // Some models require additional parameters to be passed to the model during indexing/querying // For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying @@ -93,13 +94,17 @@ struct field { std::string reference; // Foo.bar (reference to bar field in Foo collection). + bool range_index; + field() {} field(const std::string &name, const std::string &type, const bool facet, const bool optional = false, bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false, - int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) : + int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, + std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) : name(name), type(type), facet(facet), optional(optional), index(index), locale(locale), - nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) { + nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), + embed(embed), range_index(range_index) { set_computed_defaults(sort, infix); } diff --git a/include/index.h b/include/index.h index 4570b77c..0c89463d 100644 --- a/include/index.h +++ b/include/index.h @@ -30,6 +30,7 @@ #include "vector_query_ops.h" #include "hnswlib/hnswlib.h" #include "filter.h" +#include "numeric_range_trie_test.h" static constexpr size_t ARRAY_FACET_DIM = 4; using facet_map_t = spp::sparse_hash_map; @@ -305,6 +306,8 @@ private: spp::sparse_hash_map numerical_index; + spp::sparse_hash_map range_index; + spp::sparse_hash_map>*> geopoint_index; // geo_array_field => (seq_id => values) used for exact filtering of geo array records diff --git a/src/field.cpp b/src/field.cpp index f48ecf50..3882b45f 100644 --- a/src/field.cpp +++ b/src/field.cpp @@ -75,6 +75,23 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::reference] = ""; } + if (field_json.count(fields::range_index) != 0) { + if (!field_json.at(fields::range_index).is_boolean()) { + return Option(400, std::string("The `range_index` property of the field `") + + field_json[fields::name].get() + + std::string("` should be a boolean.")); + } + + auto const& type = field_json["type"]; + if (type != field_types::INT32 && type != field_types::INT32_ARRAY && + type != field_types::INT64 && type != field_types::INT64_ARRAY && + type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) { + return Option(400, std::string("The `range_index` property is only allowed for the numerical fields`")); + } + } else { + field_json[fields::range_index] = false; + } + if(field_json["name"] == ".*") { if(field_json.count(fields::facet) == 0) { field_json[fields::facet] = false; @@ -297,7 +314,7 @@ Option field::json_field_to_field(bool enable_nested_fields, nlohmann::jso field_json[fields::optional], field_json[fields::index], field_json[fields::locale], field_json[fields::sort], field_json[fields::infix], field_json[fields::nested], field_json[fields::nested_array], field_json[fields::num_dim], vec_dist, - field_json[fields::reference], field_json[fields::embed]) + field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index]) ); if (!field_json[fields::reference].get().empty()) { diff --git a/src/filter_result_iterator.cpp b/src/filter_result_iterator.cpp index a6b61134..8a3c6d89 100644 --- a/src/filter_result_iterator.cpp +++ b/src/filter_result_iterator.cpp @@ -646,27 +646,62 @@ void filter_result_iterator_t::init() { field f = index->search_schema.at(a_filter.field_name); if (f.is_integer()) { - auto num_tree = index->numerical_index.at(a_filter.field_name); + if (f.is_int32() && f.range_index) { + auto const& trie = index->range_index.at(a_filter.field_name); - for (size_t fi = 0; fi < a_filter.values.size(); fi++) { - const std::string& filter_value = a_filter.values[fi]; - int64_t value = (int64_t)std::stol(filter_value); + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + auto const& value = (int32_t)std::stoi(filter_value); - size_t result_size = filter_result.count; - if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { - const std::string& next_filter_value = a_filter.values[fi + 1]; - auto const range_end_value = (int64_t)std::stol(next_filter_value); - num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); - fi++; - } else if (a_filter.comparators[fi] == NOT_EQUALS) { - numeric_not_equals_filter(num_tree, value, - index->seq_ids->uncompress(), index->seq_ids->num_ids(), - filter_result.docs, result_size); - } else { - num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const& range_end_value = (int32_t)std::stoi(next_filter_value); + trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count); + fi++; + } else if (a_filter.comparators[fi] == EQUALS) { + trie->search_equal_to(value, filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + uint32_t* to_exclude_ids = nullptr; + uint32_t to_exclude_ids_len = 0; + trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len); + + auto all_ids = index->seq_ids->uncompress(); + filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(), + to_exclude_ids, to_exclude_ids_len, &filter_result.docs); + + delete[] all_ids; + delete[] to_exclude_ids; + } else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) { + trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS, + filter_result.docs, filter_result.count); + } else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) { + trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS, + filter_result.docs, filter_result.count); + } } + } else { + auto num_tree = index->numerical_index.at(a_filter.field_name); - filter_result.count = result_size; + for (size_t fi = 0; fi < a_filter.values.size(); fi++) { + const std::string& filter_value = a_filter.values[fi]; + int64_t value = (int64_t)std::stol(filter_value); + + size_t result_size = filter_result.count; + if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) { + const std::string& next_filter_value = a_filter.values[fi + 1]; + auto const range_end_value = (int64_t)std::stol(next_filter_value); + num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size); + fi++; + } else if (a_filter.comparators[fi] == NOT_EQUALS) { + numeric_not_equals_filter(num_tree, value, + index->seq_ids->uncompress(), index->seq_ids->num_ids(), + filter_result.docs, result_size); + } else { + num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size); + } + + filter_result.count = result_size; + } } if (a_filter.apply_not_equals) { diff --git a/src/index.cpp b/src/index.cpp index a55646bc..cc2510ec 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -88,6 +88,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store* } else { num_tree_t* num_tree = new num_tree_t; numerical_index.emplace(a_field.name, num_tree); + + if (a_field.range_index) { + auto trie = new NumericTrie(); + range_index.emplace(a_field.name, trie); + } } if(a_field.sort) { @@ -161,6 +166,13 @@ Index::~Index() { numerical_index.clear(); + for(auto & name_tree: range_index) { + delete name_tree.second; + name_tree.second = nullptr; + } + + range_index.clear(); + for(auto & name_map: sort_index) { delete name_map.second; name_map.second = nullptr; @@ -737,6 +749,15 @@ void Index::index_field_in_memory(const field& afield, std::vector if(!afield.is_string()) { if (afield.type == field_types::INT32) { + if (afield.range_index) { + auto const& trie = range_index.at(afield.name); + iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie] + (const index_record& record, uint32_t seq_id) { + int32_t value = record.doc[afield.name].get(); + trie->insert(value, seq_id); + }); + } + auto num_tree = numerical_index.at(afield.name); iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] (const index_record& record, uint32_t seq_id) { @@ -899,13 +920,19 @@ void Index::index_field_in_memory(const field& afield, std::vector // all other numerical arrays auto num_tree = numerical_index.at(afield.name); - iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree] + auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr; + iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie] (const index_record& record, uint32_t seq_id) { for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) { const auto& arr_value = record.doc[afield.name][arr_i]; if(afield.type == field_types::INT32_ARRAY) { const int32_t value = arr_value; + + if (afield.range_index) { + trie->insert(value, seq_id); + } + num_tree->insert(value, seq_id); } diff --git a/test/numeric_range_trie_test.cpp b/test/numeric_range_trie_test.cpp index b8ba9186..5d9cca7d 100644 --- a/test/numeric_range_trie_test.cpp +++ b/test/numeric_range_trie_test.cpp @@ -1,12 +1,35 @@ #include +#include +#include "collection.h" #include "numeric_range_trie_test.h" class NumericRangeTrieTest : public ::testing::Test { protected: + Store *store; + CollectionManager & collectionManager = CollectionManager::get_instance(); + std::atomic quit = false; - virtual void SetUp() {} + std::vector query_fields; + std::vector sort_fields; - virtual void TearDown() {} + void setupCollection() { + std::string state_dir_path = "/tmp/typesense_test/collection_filtering"; + LOG(INFO) << "Truncating and creating: " << state_dir_path; + system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str()); + + store = new Store(state_dir_path); + collectionManager.init(store, 1.0, "auth_key", quit); + collectionManager.load(8, 1000); + } + + virtual void SetUp() { + setupCollection(); + } + + virtual void TearDown() { + collectionManager.dispose(); + delete store; + } }; void reset(uint32_t*& ids, uint32_t& ids_length) { @@ -570,3 +593,70 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) { ASSERT_EQ(0, ids_length); } + +TEST_F(NumericRangeTrieTest, Integration) { + Collection *coll_array_fields; + + std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl"); + std::vector fields = { + field("name", field_types::STRING, false), + field("rating", field_types::FLOAT, false), + field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(), + true), // Setting range index true. + field("years", field_types::INT32_ARRAY, false), + field("timestamps", field_types::INT64_ARRAY, false), + field("tags", field_types::STRING_ARRAY, true) + }; + + std::vector sort_fields = { sort_by("age", "DESC") }; + + coll_array_fields = collectionManager.get_collection("coll_array_fields").get(); + if(coll_array_fields == nullptr) { + // ensure that default_sorting_field is a non-array numerical field + auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years"); + ASSERT_EQ(false, coll_op.ok()); + ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str()); + + // let's try again properly + coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age"); + coll_array_fields = coll_op.get(); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + auto add_op = coll_array_fields->add(json_line); + LOG(INFO) << add_op.error(); + ASSERT_TRUE(add_op.ok()); + } + + infile.close(); + + // Plain search with no filters - results should be sorted by rank fields + query_fields = {"name"}; + std::vector facets; + nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(5, results["hits"].size()); + + std::vector ids = {"3", "1", "4", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // Searching on an int32 field + results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get(); + ASSERT_EQ(3, results["hits"].size()); + + ids = {"3", "1", "4"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } +}