Allow int64 to be used as a default sorting field.

This commit is contained in:
kishorenc 2020-08-12 18:16:27 +05:30
parent 14faa3af4e
commit 0c186481a9
6 changed files with 58 additions and 47 deletions

View File

@ -1,10 +1,8 @@
# Typesense: TODO
## Pre-alpha
a) ~~Fix memory ratio (decreasing with indexing)~~
b) ~~Speed up wildcard searches further~~
c) Allow int64 in default sorting field
c) ~~Allow int64 in default sorting field~~
d) Use connection timeout for CURL rather than request timeout
e) Update role to set max memory ration at 0.80
f) Async import

View File

@ -44,8 +44,6 @@ typedef struct {
uint8_t num_children;
uint8_t partial_len;
unsigned char partial[MAX_PREFIX_LEN];
int32_t max_score;
uint32_t max_token_count;
} art_node;
/**

View File

@ -53,15 +53,11 @@ bool compare_art_node_frequency(const art_node *a, const art_node *b) {
if(IS_LEAF(a)) {
art_leaf* al = (art_leaf *) LEAF_RAW(a);
a_value = al->values->ids.getLength();
} else {
a_value = a->max_token_count;
}
if(IS_LEAF(b)) {
art_leaf* bl = (art_leaf *) LEAF_RAW(b);
b_value = bl->values->ids.getLength();
} else {
b_value = b->max_token_count;
}
return a_value > b_value;
@ -73,15 +69,11 @@ bool compare_art_node_score(const art_node* a, const art_node* b) {
if(IS_LEAF(a)) {
art_leaf* al = (art_leaf *) LEAF_RAW(a);
a_value = al->max_score;
} else {
a_value = a->max_score;
}
if(IS_LEAF(b)) {
art_leaf* bl = (art_leaf *) LEAF_RAW(b);
b_value = bl->max_score;
} else {
b_value = b->max_score;
}
return a_value > b_value;
@ -118,8 +110,6 @@ static art_node* alloc_node(uint8_t type) {
abort();
}
n->type = type;
n->max_score = 0;
n->max_token_count = 0;
return n;
}
@ -448,8 +438,6 @@ static uint32_t longest_common_prefix(art_leaf *l1, art_leaf *l2, int depth) {
}
static void copy_header(art_node *dest, art_node *src) {
dest->max_score = src->max_score;
dest->max_token_count = src->max_token_count;
dest->num_children = src->num_children;
dest->partial_len = src->partial_len;
memcpy(dest->partial, src->partial, min(MAX_PREFIX_LEN, src->partial_len));
@ -457,8 +445,6 @@ static void copy_header(art_node *dest, art_node *src) {
static void add_child256(art_node256 *n, art_node **ref, unsigned char c, void *child) {
(void)ref;
n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score);
n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength());
n->n.num_children++;
n->children[c] = (art_node *) child;
}
@ -467,8 +453,6 @@ static void add_child48(art_node48 *n, art_node **ref, unsigned char c, void *ch
if (n->n.num_children < 48) {
int pos = 0;
while (n->children[pos]) pos++;
n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score);
n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength());
n->children[pos] = (art_node *) child;
n->keys[c] = pos + 1;
n->n.num_children++;
@ -509,8 +493,6 @@ static void add_child16(art_node16 *n, art_node **ref, unsigned char c, void *ch
idx = n->n.num_children;
// Set the child
n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score);
n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength());
n->keys[idx] = c;
n->children[idx] = (art_node *) child;
n->n.num_children++;
@ -543,12 +525,6 @@ static void add_child4(art_node4 *n, art_node **ref, unsigned char c, void *chil
memmove(n->children+idx+1, n->children+idx,
(n->n.num_children - idx)*sizeof(void*));
int32_t child_max_score = IS_LEAF(child) ? ((art_leaf *) LEAF_RAW(child))->max_score : ((art_node *) child)->max_score;
uint32_t child_token_count = IS_LEAF(child) ? ((art_leaf *) LEAF_RAW(child))->values->ids.getLength() : ((art_node *) child)->max_token_count;
n->n.max_score = MAX(n->n.max_score, child_max_score);
n->n.max_token_count = MAX(n->n.max_token_count, child_token_count);
n->keys[idx] = c;
n->children[idx] = (art_node *) child;
n->n.num_children++;
@ -648,9 +624,6 @@ static void* recursive_insert(art_node *n, art_node **ref, const unsigned char *
return NULL;
}
n->max_score = MAX(n->max_score, document->score);
n->max_token_count = MAX(n->max_token_count, num_hits);
// Check if given node has a prefix
if (n->partial_len) {
// Determine if the prefixes differ, since we need to split
@ -1384,12 +1357,6 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len,
art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes);
}
if(token_order == FREQUENCY) {
std::sort(nodes.begin(), nodes.end(), compare_art_node_frequency);
} else {
std::sort(nodes.begin(), nodes.end(), compare_art_node_score);
}
//long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count();
//!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size();

View File

@ -246,9 +246,10 @@ Option<Collection*> CollectionManager::create_collection(const std::string name,
fields_json.push_back(field_val);
if(field.name == default_sorting_field && !(field.type == field_types::INT32 ||
field.type == field_types::INT64 ||
field.type == field_types::FLOAT)) {
return Option<Collection*>(400, "Default sorting field `" + default_sorting_field +
"` must be of type int32 or float.");
"` must be a single valued numerical field.");
}
if(field.name == default_sorting_field) {

View File

@ -119,7 +119,7 @@ TEST_F(CollectionSortingTest, SortingOrder) {
}
TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
// Default sorting field must be int32 or float
// Default sorting field must be a numerical field
std::vector<field> fields = {field("name", field_types::STRING, false),
field("tags", field_types::STRING_ARRAY, true),
field("age", field_types::INT32, false),
@ -129,16 +129,11 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
Option<Collection*> collection_op = collectionManager.create_collection("sample_collection", fields, "name");
EXPECT_FALSE(collection_op.ok());
EXPECT_EQ("Default sorting field `name` must be of type int32 or float.", collection_op.error());
EXPECT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error());
collectionManager.drop_collection("sample_collection");
// Default sorting field must exist as a field in schema
fields = {field("name", field_types::STRING, false),
field("tags", field_types::STRING_ARRAY, true),
field("age", field_types::INT32, false),
field("average", field_types::INT32, false) };
sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") };
collection_op = collectionManager.create_collection("sample_collection", fields, "NOT-DEFINED");
EXPECT_FALSE(collection_op.ok());
@ -146,6 +141,58 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) {
collectionManager.drop_collection("sample_collection");
}
TEST_F(CollectionSortingTest, Int64AsDefaultSortingField) {
Collection *coll_mul_fields;
std::ifstream infile(std::string(ROOT_DIR)+"test/multi_field_documents.jsonl");
std::vector<field> fields = {field("title", field_types::STRING, false),
field("starring", field_types::STRING, false),
field("points", field_types::INT64, false),
field("cast", field_types::STRING_ARRAY, false)};
coll_mul_fields = collectionManager.get_collection("coll_mul_fields");
if(coll_mul_fields == nullptr) {
coll_mul_fields = collectionManager.create_collection("coll_mul_fields", fields, "points").get();
}
std::string json_line;
while (std::getline(infile, json_line)) {
coll_mul_fields->add(json_line);
}
infile.close();
query_fields = {"title"};
std::vector<std::string> facets;
sort_fields = { sort_by("points", "ASC") };
nlohmann::json results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get();
ASSERT_EQ(10, results["hits"].size());
std::vector<std::string> ids = {"17", "13", "10", "4", "0", "1", "8", "6", "16", "11"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
// limiting results to just 5, "ASC" keyword must be case insensitive
sort_fields = { sort_by("points", "asc") };
results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 5, 1, FREQUENCY, false).get();
ASSERT_EQ(5, results["hits"].size());
ids = {"17", "13", "10", "4", "0"};
for(size_t i = 0; i < results["hits"].size(); i++) {
nlohmann::json result = results["hits"].at(i);
std::string result_id = result["document"]["id"];
std::string id = ids.at(i);
ASSERT_STREQ(id.c_str(), result_id.c_str());
}
}
TEST_F(CollectionSortingTest, SortOnFloatFields) {
Collection *coll_float_fields;

View File

@ -986,7 +986,7 @@ TEST_F(CollectionTest, FilterOnNumericFields) {
// ensure that default_sorting_field is a non-array numerical field
auto coll_op = collectionManager.create_collection("coll_array_fields", fields, "years");
ASSERT_EQ(false, coll_op.ok());
ASSERT_STREQ("Default sorting field `years` must be of type int32 or float.", coll_op.error().c_str());
ASSERT_STREQ("Default sorting field `years` must be a single valued numerical field.", coll_op.error().c_str());
// let's try again properly
coll_op = collectionManager.create_collection("coll_array_fields", fields, "age");