Add range_index property.

This commit is contained in:
Harpreet Sangar 2023-06-01 16:46:04 +05:30
parent 1f4643cba0
commit b4a70682c6
6 changed files with 200 additions and 23 deletions

View File

@ -54,6 +54,7 @@ namespace fields {
static const std::string from = "from";
static const std::string embed_from = "embed_from";
static const std::string model_name = "model_name";
static const std::string range_index = "range_index";
// Some models require additional parameters to be passed to the model during indexing/querying
// For e.g. e5-small model requires prefix "passage:" for indexing and "query:" for querying
@ -93,13 +94,17 @@ struct field {
std::string reference; // Foo.bar (reference to bar field in Foo collection).
bool range_index;
field() {}
field(const std::string &name, const std::string &type, const bool facet, const bool optional = false,
bool index = true, std::string locale = "", int sort = -1, int infix = -1, bool nested = false,
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine, std::string reference = "", const nlohmann::json& embed = nlohmann::json()) :
int nested_array = 0, size_t num_dim = 0, vector_distance_type_t vec_dist = cosine,
std::string reference = "", const nlohmann::json& embed = nlohmann::json(), const bool range_index = false) :
name(name), type(type), facet(facet), optional(optional), index(index), locale(locale),
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference), embed(embed) {
nested(nested), nested_array(nested_array), num_dim(num_dim), vec_dist(vec_dist), reference(reference),
embed(embed), range_index(range_index) {
set_computed_defaults(sort, infix);
}

View File

@ -30,6 +30,7 @@
#include "vector_query_ops.h"
#include "hnswlib/hnswlib.h"
#include "filter.h"
#include "numeric_range_trie_test.h"
static constexpr size_t ARRAY_FACET_DIM = 4;
using facet_map_t = spp::sparse_hash_map<uint32_t, facet_hash_values_t>;
@ -305,6 +306,8 @@ private:
spp::sparse_hash_map<std::string, num_tree_t*> numerical_index;
spp::sparse_hash_map<std::string, NumericTrie*> range_index;
spp::sparse_hash_map<std::string, spp::sparse_hash_map<std::string, std::vector<uint32_t>>*> geopoint_index;
// geo_array_field => (seq_id => values) used for exact filtering of geo array records

View File

@ -75,6 +75,23 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
field_json[fields::reference] = "";
}
if (field_json.count(fields::range_index) != 0) {
if (!field_json.at(fields::range_index).is_boolean()) {
return Option<bool>(400, std::string("The `range_index` property of the field `") +
field_json[fields::name].get<std::string>() +
std::string("` should be a boolean."));
}
auto const& type = field_json["type"];
if (type != field_types::INT32 && type != field_types::INT32_ARRAY &&
type != field_types::INT64 && type != field_types::INT64_ARRAY &&
type != field_types::FLOAT && type != field_types::FLOAT_ARRAY) {
return Option<bool>(400, std::string("The `range_index` property is only allowed for the numerical fields`"));
}
} else {
field_json[fields::range_index] = false;
}
if(field_json["name"] == ".*") {
if(field_json.count(fields::facet) == 0) {
field_json[fields::facet] = false;
@ -297,7 +314,7 @@ Option<bool> field::json_field_to_field(bool enable_nested_fields, nlohmann::jso
field_json[fields::optional], field_json[fields::index], field_json[fields::locale],
field_json[fields::sort], field_json[fields::infix], field_json[fields::nested],
field_json[fields::nested_array], field_json[fields::num_dim], vec_dist,
field_json[fields::reference], field_json[fields::embed])
field_json[fields::reference], field_json[fields::embed], field_json[fields::range_index])
);
if (!field_json[fields::reference].get<std::string>().empty()) {

View File

@ -646,27 +646,62 @@ void filter_result_iterator_t::init() {
field f = index->search_schema.at(a_filter.field_name);
if (f.is_integer()) {
auto num_tree = index->numerical_index.at(a_filter.field_name);
if (f.is_int32() && f.range_index) {
auto const& trie = index->range_index.at(a_filter.field_name);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
int64_t value = (int64_t)std::stol(filter_value);
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
auto const& value = (int32_t)std::stoi(filter_value);
size_t result_size = filter_result.count;
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = (int64_t)std::stol(next_filter_value);
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value,
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
filter_result.docs, result_size);
} else {
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const& range_end_value = (int32_t)std::stoi(next_filter_value);
trie->search_range(value, true, range_end_value, true, filter_result.docs, filter_result.count);
fi++;
} else if (a_filter.comparators[fi] == EQUALS) {
trie->search_equal_to(value, filter_result.docs, filter_result.count);
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
uint32_t* to_exclude_ids = nullptr;
uint32_t to_exclude_ids_len = 0;
trie->search_equal_to(value, to_exclude_ids, to_exclude_ids_len);
auto all_ids = index->seq_ids->uncompress();
filter_result.count = ArrayUtils::exclude_scalar(all_ids, index->seq_ids->num_ids(),
to_exclude_ids, to_exclude_ids_len, &filter_result.docs);
delete[] all_ids;
delete[] to_exclude_ids;
} else if (a_filter.comparators[fi] == GREATER_THAN || a_filter.comparators[fi] == GREATER_THAN_EQUALS) {
trie->search_greater_than(value, a_filter.comparators[fi] == GREATER_THAN_EQUALS,
filter_result.docs, filter_result.count);
} else if (a_filter.comparators[fi] == LESS_THAN || a_filter.comparators[fi] == LESS_THAN_EQUALS) {
trie->search_less_than(value, a_filter.comparators[fi] == LESS_THAN_EQUALS,
filter_result.docs, filter_result.count);
}
}
} else {
auto num_tree = index->numerical_index.at(a_filter.field_name);
filter_result.count = result_size;
for (size_t fi = 0; fi < a_filter.values.size(); fi++) {
const std::string& filter_value = a_filter.values[fi];
int64_t value = (int64_t)std::stol(filter_value);
size_t result_size = filter_result.count;
if (a_filter.comparators[fi] == RANGE_INCLUSIVE && fi+1 < a_filter.values.size()) {
const std::string& next_filter_value = a_filter.values[fi + 1];
auto const range_end_value = (int64_t)std::stol(next_filter_value);
num_tree->range_inclusive_search(value, range_end_value, &filter_result.docs, result_size);
fi++;
} else if (a_filter.comparators[fi] == NOT_EQUALS) {
numeric_not_equals_filter(num_tree, value,
index->seq_ids->uncompress(), index->seq_ids->num_ids(),
filter_result.docs, result_size);
} else {
num_tree->search(a_filter.comparators[fi], value, &filter_result.docs, result_size);
}
filter_result.count = result_size;
}
}
if (a_filter.apply_not_equals) {

View File

@ -88,6 +88,11 @@ Index::Index(const std::string& name, const uint32_t collection_id, const Store*
} else {
num_tree_t* num_tree = new num_tree_t;
numerical_index.emplace(a_field.name, num_tree);
if (a_field.range_index) {
auto trie = new NumericTrie();
range_index.emplace(a_field.name, trie);
}
}
if(a_field.sort) {
@ -161,6 +166,13 @@ Index::~Index() {
numerical_index.clear();
for(auto & name_tree: range_index) {
delete name_tree.second;
name_tree.second = nullptr;
}
range_index.clear();
for(auto & name_map: sort_index) {
delete name_map.second;
name_map.second = nullptr;
@ -737,6 +749,15 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
if(!afield.is_string()) {
if (afield.type == field_types::INT32) {
if (afield.range_index) {
auto const& trie = range_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, trie]
(const index_record& record, uint32_t seq_id) {
int32_t value = record.doc[afield.name].get<int32_t>();
trie->insert(value, seq_id);
});
}
auto num_tree = numerical_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
(const index_record& record, uint32_t seq_id) {
@ -899,13 +920,19 @@ void Index::index_field_in_memory(const field& afield, std::vector<index_record>
// all other numerical arrays
auto num_tree = numerical_index.at(afield.name);
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree]
auto trie = range_index.count(afield.name) > 0 ? range_index.at(afield.name) : nullptr;
iterate_and_index_numerical_field(iter_batch, afield, [&afield, num_tree, trie]
(const index_record& record, uint32_t seq_id) {
for(size_t arr_i = 0; arr_i < record.doc[afield.name].size(); arr_i++) {
const auto& arr_value = record.doc[afield.name][arr_i];
if(afield.type == field_types::INT32_ARRAY) {
const int32_t value = arr_value;
if (afield.range_index) {
trie->insert(value, seq_id);
}
num_tree->insert(value, seq_id);
}

View File

@ -1,12 +1,35 @@
#include <gtest/gtest.h>
#include <collection_manager.h>
#include "collection.h"
#include "numeric_range_trie_test.h"
class NumericRangeTrieTest : public ::testing::Test {
protected:
Store *store;
CollectionManager & collectionManager = CollectionManager::get_instance();
std::atomic<bool> quit = false;
virtual void SetUp() {}
std::vector<std::string> query_fields;
std::vector<sort_by> sort_fields;
virtual void TearDown() {}
void setupCollection() {
std::string state_dir_path = "/tmp/typesense_test/collection_filtering";
LOG(INFO) << "Truncating and creating: " << state_dir_path;
system(("rm -rf "+state_dir_path+" && mkdir -p "+state_dir_path).c_str());
store = new Store(state_dir_path);
collectionManager.init(store, 1.0, "auth_key", quit);
collectionManager.load(8, 1000);
}
virtual void SetUp() {
setupCollection();
}
virtual void TearDown() {
collectionManager.dispose();
delete store;
}
};
void reset(uint32_t*& ids, uint32_t& ids_length) {
@ -570,3 +593,70 @@ TEST_F(NumericRangeTrieTest, EmptyTrieOperations) {
ASSERT_EQ(0, ids_length);
}
TEST_F(NumericRangeTrieTest, Integration) {
Collection *coll_array_fields;
std::ifstream infile(std::string(ROOT_DIR)+"test/numeric_array_documents.jsonl");
std::vector<field> fields = {
field("name", field_types::STRING, false),
field("rating", field_types::FLOAT, false),
field("age", field_types::INT32, false, false, true, "", -1, -1, false, 0, 0, cosine, "", nlohmann::json(),
true), // Setting range index true.
field("years", field_types::INT32_ARRAY, false),
field("timestamps", field_types::INT64_ARRAY, false),
field("tags", field_types::STRING_ARRAY, true)
};
std::vector<sort_by> sort_fields = { sort_by("age", "DESC") };
coll_array_fields = collectionManager.get_collection("coll_array_fields").get();
if(coll_array_fields == nullptr) {
// ensure that default_sorting_field is a non-array numerical field
auto coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "years");
ASSERT_EQ(false, coll_op.ok());
ASSERT_STREQ("Default sorting field `years` is not a sortable type.", coll_op.error().c_str());
// let's try again properly
coll_op = collectionManager.create_collection("coll_array_fields", 4, fields, "age");
coll_array_fields = coll_op.get();
}
std::string json_line;
while (std::getline(infile, json_line)) {
auto add_op = coll_array_fields->add(json_line);
LOG(INFO) << add_op.error();
ASSERT_TRUE(add_op.ok());
}
infile.close();
// Plain search with no filters - results should be sorted by rank fields
query_fields = {"name"};
std::vector<std::string> facets;
nlohmann::json results = coll_array_fields->search("Jeremy", query_fields, "", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(5, results["hits"].size());
std::vector<std::string> ids = {"3", "1", "4", "0", "2"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// Searching on an int32 field
results = coll_array_fields->search("Jeremy", query_fields, "age:>24", facets, sort_fields, {0}, 10, 1, FREQUENCY, {false}).get();
ASSERT_EQ(3, results["hits"].size());
ids = {"3", "1", "4"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
}