From 0c186481a980710a6469cf76962d5acad16f4a3b Mon Sep 17 00:00:00 2001 From: kishorenc Date: Wed, 12 Aug 2020 18:16:27 +0530 Subject: [PATCH] Allow int64 to be used as a default sorting field. --- TODO.md | 4 +-- include/art.h | 2 -- src/art.cpp | 33 ----------------- src/collection_manager.cpp | 3 +- test/collection_sorting_test.cpp | 61 ++++++++++++++++++++++++++++---- test/collection_test.cpp | 2 +- 6 files changed, 58 insertions(+), 47 deletions(-) diff --git a/TODO.md b/TODO.md index 0c5afdb8..d4553c67 100644 --- a/TODO.md +++ b/TODO.md @@ -1,10 +1,8 @@ # Typesense: TODO -## Pre-alpha - a) ~~Fix memory ratio (decreasing with indexing)~~ b) ~~Speed up wildcard searches further~~ -c) Allow int64 in default sorting field +c) ~~Allow int64 in default sorting field~~ d) Use connection timeout for CURL rather than request timeout e) Update role to set max memory ration at 0.80 f) Async import diff --git a/include/art.h b/include/art.h index 12068ec9..1d3a286d 100644 --- a/include/art.h +++ b/include/art.h @@ -44,8 +44,6 @@ typedef struct { uint8_t num_children; uint8_t partial_len; unsigned char partial[MAX_PREFIX_LEN]; - int32_t max_score; - uint32_t max_token_count; } art_node; /** diff --git a/src/art.cpp b/src/art.cpp index 1da84cf1..0cd2eed3 100644 --- a/src/art.cpp +++ b/src/art.cpp @@ -53,15 +53,11 @@ bool compare_art_node_frequency(const art_node *a, const art_node *b) { if(IS_LEAF(a)) { art_leaf* al = (art_leaf *) LEAF_RAW(a); a_value = al->values->ids.getLength(); - } else { - a_value = a->max_token_count; } if(IS_LEAF(b)) { art_leaf* bl = (art_leaf *) LEAF_RAW(b); b_value = bl->values->ids.getLength(); - } else { - b_value = b->max_token_count; } return a_value > b_value; @@ -73,15 +69,11 @@ bool compare_art_node_score(const art_node* a, const art_node* b) { if(IS_LEAF(a)) { art_leaf* al = (art_leaf *) LEAF_RAW(a); a_value = al->max_score; - } else { - a_value = a->max_score; } if(IS_LEAF(b)) { art_leaf* bl = (art_leaf *) LEAF_RAW(b); b_value = bl->max_score; - } else { - b_value = b->max_score; } return a_value > b_value; @@ -118,8 +110,6 @@ static art_node* alloc_node(uint8_t type) { abort(); } n->type = type; - n->max_score = 0; - n->max_token_count = 0; return n; } @@ -448,8 +438,6 @@ static uint32_t longest_common_prefix(art_leaf *l1, art_leaf *l2, int depth) { } static void copy_header(art_node *dest, art_node *src) { - dest->max_score = src->max_score; - dest->max_token_count = src->max_token_count; dest->num_children = src->num_children; dest->partial_len = src->partial_len; memcpy(dest->partial, src->partial, min(MAX_PREFIX_LEN, src->partial_len)); @@ -457,8 +445,6 @@ static void copy_header(art_node *dest, art_node *src) { static void add_child256(art_node256 *n, art_node **ref, unsigned char c, void *child) { (void)ref; - n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score); - n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength()); n->n.num_children++; n->children[c] = (art_node *) child; } @@ -467,8 +453,6 @@ static void add_child48(art_node48 *n, art_node **ref, unsigned char c, void *ch if (n->n.num_children < 48) { int pos = 0; while (n->children[pos]) pos++; - n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score); - n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength()); n->children[pos] = (art_node *) child; n->keys[c] = pos + 1; n->n.num_children++; @@ -509,8 +493,6 @@ static void add_child16(art_node16 *n, art_node **ref, unsigned char c, void *ch idx = n->n.num_children; // Set the child - n->n.max_score = MAX(n->n.max_score, ((art_leaf *) LEAF_RAW(child))->max_score); - n->n.max_token_count = MAX(n->n.max_token_count, ((art_leaf *) LEAF_RAW(child))->values->ids.getLength()); n->keys[idx] = c; n->children[idx] = (art_node *) child; n->n.num_children++; @@ -543,12 +525,6 @@ static void add_child4(art_node4 *n, art_node **ref, unsigned char c, void *chil memmove(n->children+idx+1, n->children+idx, (n->n.num_children - idx)*sizeof(void*)); - int32_t child_max_score = IS_LEAF(child) ? ((art_leaf *) LEAF_RAW(child))->max_score : ((art_node *) child)->max_score; - uint32_t child_token_count = IS_LEAF(child) ? ((art_leaf *) LEAF_RAW(child))->values->ids.getLength() : ((art_node *) child)->max_token_count; - - n->n.max_score = MAX(n->n.max_score, child_max_score); - n->n.max_token_count = MAX(n->n.max_token_count, child_token_count); - n->keys[idx] = c; n->children[idx] = (art_node *) child; n->n.num_children++; @@ -648,9 +624,6 @@ static void* recursive_insert(art_node *n, art_node **ref, const unsigned char * return NULL; } - n->max_score = MAX(n->max_score, document->score); - n->max_token_count = MAX(n->max_token_count, num_hits); - // Check if given node has a prefix if (n->partial_len) { // Determine if the prefixes differ, since we need to split @@ -1384,12 +1357,6 @@ int art_fuzzy_search(art_tree *t, const unsigned char *term, const int term_len, art_fuzzy_recurse(0, 0, t->root, -1, term, term_len, irow, jrow, min_cost, max_cost, prefix, nodes); } - if(token_order == FREQUENCY) { - std::sort(nodes.begin(), nodes.end(), compare_art_node_frequency); - } else { - std::sort(nodes.begin(), nodes.end(), compare_art_node_score); - } - //long long int time_micro = microseconds(std::chrono::high_resolution_clock::now() - begin).count(); //!LOG(INFO) << "Time taken for fuzz: " << time_micro << "us, size of nodes: " << nodes.size(); diff --git a/src/collection_manager.cpp b/src/collection_manager.cpp index d929671f..33a27878 100644 --- a/src/collection_manager.cpp +++ b/src/collection_manager.cpp @@ -246,9 +246,10 @@ Option CollectionManager::create_collection(const std::string name, fields_json.push_back(field_val); if(field.name == default_sorting_field && !(field.type == field_types::INT32 || + field.type == field_types::INT64 || field.type == field_types::FLOAT)) { return Option(400, "Default sorting field `" + default_sorting_field + - "` must be of type int32 or float."); + "` must be a single valued numerical field."); } if(field.name == default_sorting_field) { diff --git a/test/collection_sorting_test.cpp b/test/collection_sorting_test.cpp index 900612f1..1a5901a5 100644 --- a/test/collection_sorting_test.cpp +++ b/test/collection_sorting_test.cpp @@ -119,7 +119,7 @@ TEST_F(CollectionSortingTest, SortingOrder) { } TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) { - // Default sorting field must be int32 or float + // Default sorting field must be a numerical field std::vector fields = {field("name", field_types::STRING, false), field("tags", field_types::STRING_ARRAY, true), field("age", field_types::INT32, false), @@ -129,16 +129,11 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) { Option collection_op = collectionManager.create_collection("sample_collection", fields, "name"); EXPECT_FALSE(collection_op.ok()); - EXPECT_EQ("Default sorting field `name` must be of type int32 or float.", collection_op.error()); + EXPECT_EQ("Default sorting field `name` must be a single valued numerical field.", collection_op.error()); collectionManager.drop_collection("sample_collection"); // Default sorting field must exist as a field in schema - fields = {field("name", field_types::STRING, false), - field("tags", field_types::STRING_ARRAY, true), - field("age", field_types::INT32, false), - field("average", field_types::INT32, false) }; - sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") }; collection_op = collectionManager.create_collection("sample_collection", fields, "NOT-DEFINED"); EXPECT_FALSE(collection_op.ok()); @@ -146,6 +141,58 @@ TEST_F(CollectionSortingTest, DefaultSortingFieldValidations) { collectionManager.drop_collection("sample_collection"); } +TEST_F(CollectionSortingTest, Int64AsDefaultSortingField) { + Collection *coll_mul_fields; + + std::ifstream infile(std::string(ROOT_DIR)+"test/multi_field_documents.jsonl"); + std::vector fields = {field("title", field_types::STRING, false), + field("starring", field_types::STRING, false), + field("points", field_types::INT64, false), + field("cast", field_types::STRING_ARRAY, false)}; + + coll_mul_fields = collectionManager.get_collection("coll_mul_fields"); + if(coll_mul_fields == nullptr) { + coll_mul_fields = collectionManager.create_collection("coll_mul_fields", fields, "points").get(); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + coll_mul_fields->add(json_line); + } + + infile.close(); + + query_fields = {"title"}; + std::vector facets; + sort_fields = { sort_by("points", "ASC") }; + nlohmann::json results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get(); + ASSERT_EQ(10, results["hits"].size()); + + std::vector ids = {"17", "13", "10", "4", "0", "1", "8", "6", "16", "11"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } + + // limiting results to just 5, "ASC" keyword must be case insensitive + sort_fields = { sort_by("points", "asc") }; + results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 5, 1, FREQUENCY, false).get(); + ASSERT_EQ(5, results["hits"].size()); + + ids = {"17", "13", "10", "4", "0"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["document"]["id"]; + std::string id = ids.at(i); + ASSERT_STREQ(id.c_str(), result_id.c_str()); + } +} + TEST_F(CollectionSortingTest, SortOnFloatFields) { Collection *coll_float_fields; diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 29ecd996..6ef3b73f 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -986,7 +986,7 @@ TEST_F(CollectionTest, FilterOnNumericFields) { // ensure that default_sorting_field is a non-array numerical field auto coll_op = collectionManager.create_collection("coll_array_fields", fields, "years"); ASSERT_EQ(false, coll_op.ok()); - ASSERT_STREQ("Default sorting field `years` must be of type int32 or float.", coll_op.error().c_str()); + ASSERT_STREQ("Default sorting field `years` must be a single valued numerical field.", coll_op.error().c_str()); // let's try again properly coll_op = collectionManager.create_collection("coll_array_fields", fields, "age");