mirror of
https://github.com/typesense/typesense.git
synced 2025-05-18 04:32:38 +08:00
Allow int64 to be used as a default sorting field.
This commit is contained in:
parent
14faa3af4e
commit
0c186481a9
4
TODO.md
4
TODO.md
@ -1,10 +1,8 @@
|
||||
# Typesense: TODO
|
||||
|
||||
## Pre-alpha
|
||||
|
||||
a) ~~Fix memory ratio (decreasing with indexing)~~
|
||||
b) ~~Speed up wildcard searches further~~
|
||||
c) Allow int64 in default sorting field
|
||||
c) ~~Allow int64 in default sorting field~~
|
||||
d) Use connection timeout for CURL rather than request timeout
|
||||
e) Update role to set max memory ration at 0.80
|
||||
f) Async import
|
||||
|
@ -44,8 +44,6 @@ typedef struct {
|
||||
uint8_t num_children;
|
||||
uint8_t partial_len;
|
||||
unsigned char partial[MAX_PREFIX_LEN];
|
||||
int32_t max_score;
|
||||
uint32_t max_token_count;
|
||||
} art_node;
|
||||
|
||||
/**
|
||||
|
33
src/art.cpp
33
src/art.cpp
@ -53,15 +53,11 @@ bool compare_art_node_frequency(const art_node *a, const art_node *b) {
|
||||
if(IS_LEAF(a)) {
|
||||
art_leaf* al = (art_leaf *) LEAF_RAW(a);
|
||||
a_value = al->values->ids.getLength();
|
||||
} else {
|
||||
a_value = a->max_token_count;
|
||||
}
|
||||
|
||||
if(IS_LEAF(b)) {
|
||||
art_leaf* bl = (art_leaf *) LEAF_RAW(b);
|
||||
b_value = bl->values->ids.getLength();
|
||||
} else {
|
||||
b_value = b->max_token_count;
|
||||
}
|
||||
|
||||
return a_value > b_value;
|
||||
@ -73,15 +69,11 @@ bool compare_art_node_score(const art_node* a, const art_node* b) {
|
||||
if(IS_LEAF(a)) {
|
||||
art_leaf* al = (art_leaf *) LEAF_RAW(a);
|
||||
a_value = al->max_score;
|
||||
} else {
|
||||
a_value = a->max_score;
|
||||
}
|
||||
|
||||
if(IS_LEAF(b)) {
|
||||
art_leaf* bl = (art_leaf *) LEAF_RAW(b);
|
||||
b_value = bl->max_score;
|
||||
} else {
|
||||
b_value = b->max_score;
|
||||
}
|
||||
|
||||
return a_value > b_value;
|
||||
@ -118,8 +110,6 @@ static art_node* alloc_node(uint8_t type) {
|
||||
abort();
|
||||
}
|
||||
n->type = type;
|
||||
n->max_score = 0;
|
||||
n->max_token_count = 0;
|
||||
return n;
|
||||
}
|
||||
|
||||
@ -448,8 +438,6 @@ static uint32_t longest_common_prefix(art_leaf *l1, art_leaf *l2, int depth) {
|
||||
}
|
||||
|
||||
static void copy_header(art_node *dest, art_node *src) {
|
||||
dest->max_score = src->max_score;
|
||||
dest->max_token_count = src->max_token_count;
|
||||
dest->num_children = src->num_children;
|
||||
dest->partial_len = src->partial_len;
|
||||
memcpy(dest->partial, src->partial, min(MAX_PREFIX_LEN, src->partial_len));
|
||||
@ -457,8 +445,6 @@ static void copy_header(art_node *dest, art_node *src) {
|
||||
|
||||
static void add_child256(art_node256 *n, art_node **ref, unsigned char c, void *child) {
|
||||
(void)ref;
|
||||
n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score);
|
||||
n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength());
|
||||
n->n.num_children++;
|
||||
n->children[c] = (art_node *) child;
|
||||
}
|
||||
@ -467,8 +453,6 @@ static void add_child48(art_node48 *n, art_node **ref, unsigned char c, void *ch
|
||||
if (n->n.num_children < 48) {
|
||||
int pos = 0;
|
||||
while (n->children[pos]) pos++;
|
||||
n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score);
|
||||
n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength());
|
||||
n->children[pos] = (art_node *) child;
|
||||
n->keys[c] = pos + 1;
|
||||
n->n.num_children++;
|
||||
@ -509,8 +493,6 @@ static void add_child16(art_node16 *n, art_node **ref, unsigned char c, void *ch
|
||||
idx = n->n.num_children;
|
||||
|
||||
// Set the child
|
||||
n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score);
|
||||
n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength());
|
||||
n->keys[idx] = c;
|
||||
n->children[idx] = (art_node *) child;
|
||||
n->n.num_children++;
|
||||
@ -543,12 +525,6 @@ static void add_child4(art_node4 *n, art_node **ref, unsigned char c, void *chil
|
||||
memmove(n->children+idx+1, n->children+idx,
|
||||
(n->n.num_children - idx)*sizeof(void*));
|
||||
|
||||
int32_t child_max_score = IS_LEAF(child) ? ((art_leaf *) LEAF_RAW(child))->max_score : ((art_node *) child)->max_score;
|
||||
uint32_t child_token_count = IS_LEAF(child) ? ((art_leaf *) LEAF_RAW(child))->values->ids.getLength() : ((art_node *) child)->max_token_count;
|
||||
|
||||
n->n.max_score = MAX(n->n.max_score, child_max_score);
|
||||
n->n.max_token_count = MAX(n->n.max_token_count, child_token_count);
|
||||
|
||||
n->keys[idx] = c;
|
||||
n->children[idx] = (art_node *) child;
|
||||
n->n.num_children++;
|
||||
@ -648,9 +624,6 @@ static void* recursive_insert(art_node *n, art_node **ref, const unsigned char *
|
||||
return NULL;
|
||||
}
|
||||
|
||||
n->max_score = MAX(n->max_score, document->score);
|
||||
n->max_token_count = MAX(n->max_token_count, num_hits);
|
||||
|
||||
// Check if given node has a prefix
|
||||
if (n->partial_len) {
|
||||
// Determine if the prefixes differ, since we need to split
|
||||
@ -1384,12 +1357,6 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len,
|
||||
art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes);
|
||||
}
|
||||
|
||||
if(token_order == FREQUENCY) {
|
||||
std::sort(nodes.begin(), nodes.end(), compare_art_node_frequency);
|
||||
} else {
|
||||
std::sort(nodes.begin(), nodes.end(), compare_art_node_score);
|
||||
}
|
||||
|
||||
//long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count();
|
||||
//!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size();
|
||||
|
||||
|
@ -246,9 +246,10 @@ Option<Collection*> CollectionManager::create_collection(const std::string name,
|
||||
fields_json.push_back(field_val);
|
||||
|
||||
if(field.name == default_sorting_field && !(field.type == field_types::INT32 ||
|
||||
field.type == field_types::INT64 ||
|
||||
field.type == field_types::FLOAT)) {
|
||||
return Option<Collection*>(400, "Default sorting field `" + default_sorting_field +
|
||||
"` must be of type int32 or float.");
|
||||
"` must be a single valued numerical field.");
|
||||
}
|
||||
|
||||
if(field.name == default_sorting_field) {
|
||||
|
@ -119,7 +119,7 @@ TEST_F(CollectionSortingTest, SortingOrder) {
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
|
||||
// Default sorting field must be int32 or float
|
||||
// Default sorting field must be a numerical field
|
||||
std::vector<field> fields = {field("name", field_types::STRING, false),
|
||||
field("tags", field_types::STRING_ARRAY, true),
|
||||
field("age", field_types::INT32, false),
|
||||
@ -129,16 +129,11 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
|
||||
|
||||
Option<Collection*> collection_op = collectionManager.create_collection("sample_collection", fields, "name");
|
||||
EXPECT_FALSE(collection_op.ok());
|
||||
EXPECT_EQ("Default sorting field `name` must be of type int32 or float.", collection_op.error());
|
||||
EXPECT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error());
|
||||
collectionManager.drop_collection("sample_collection");
|
||||
|
||||
// Default sorting field must exist as a field in schema
|
||||
|
||||
fields = {field("name", field_types::STRING, false),
|
||||
field("tags", field_types::STRING_ARRAY, true),
|
||||
field("age", field_types::INT32, false),
|
||||
field("average", field_types::INT32, false) };
|
||||
|
||||
sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
|
||||
collection_op = collectionManager.create_collection("sample_collection", fields, "NOT-DEFINED");
|
||||
EXPECT_FALSE(collection_op.ok());
|
||||
@ -146,6 +141,58 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
|
||||
collectionManager.drop_collection("sample_collection");
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, Int64AsDefaultSortingField) {
|
||||
Collection *coll_mul_fields;
|
||||
|
||||
std::ifstream infile(std::string(ROOT_DIR)+"test/multi_field_documents.jsonl");
|
||||
std::vector<field> fields = {field("title", field_types::STRING, false),
|
||||
field("starring", field_types::STRING, false),
|
||||
field("points", field_types::INT64, false),
|
||||
field("cast", field_types::STRING_ARRAY, false)};
|
||||
|
||||
coll_mul_fields = collectionManager.get_collection("coll_mul_fields");
|
||||
if(coll_mul_fields == nullptr) {
|
||||
coll_mul_fields = collectionManager.create_collection("coll_mul_fields", fields, "points").get();
|
||||
}
|
||||
|
||||
std::string json_line;
|
||||
|
||||
while (std::getline(infile, json_line)) {
|
||||
coll_mul_fields->add(json_line);
|
||||
}
|
||||
|
||||
infile.close();
|
||||
|
||||
query_fields = {"title"};
|
||||
std::vector<std::string> facets;
|
||||
sort_fields = { sort_by("points", "ASC") };
|
||||
nlohmann::json results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(10, results["hits"].size());
|
||||
|
||||
std::vector<std::string> ids = {"17", "13", "10", "4", "0", "1", "8", "6", "16", "11"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
|
||||
// limiting results to just 5, "ASC" keyword must be case insensitive
|
||||
sort_fields = { sort_by("points", "asc") };
|
||||
results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 5, 1, FREQUENCY, false).get();
|
||||
ASSERT_EQ(5, results["hits"].size());
|
||||
|
||||
ids = {"17", "13", "10", "4", "0"};
|
||||
|
||||
for(size_t i = 0; i < results["hits"].size(); i++) {
|
||||
nlohmann::json result = results["hits"].at(i);
|
||||
std::string result_id = result["document"]["id"];
|
||||
std::string id = ids.at(i);
|
||||
ASSERT_STREQ(id.c_str(), result_id.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(CollectionSortingTest, SortOnFloatFields) {
|
||||
Collection *coll_float_fields;
|
||||
|
||||
|
@ -986,7 +986,7 @@ TEST_F(CollectionTest, FilterOnNumericFields) {
|
||||
// ensure that default_sorting_field is a non-array numerical field
|
||||
auto coll_op = collectionManager.create_collection("coll_array_fields", fields, "years");
|
||||
ASSERT_EQ(false, coll_op.ok());
|
||||
ASSERT_STREQ("Default sorting field `years` must be of type int32 or float.", coll_op.error().c_str());
|
||||
ASSERT_STREQ("Default sorting field `years` must be a single valued numerical field.", coll_op.error().c_str());
|
||||
|
||||
// let's try again properly
|
||||
coll_op = collectionManager.create_collection("coll_array_fields", fields, "age");
|
||||
|
Loading…
x
Reference in New Issue
Block a user