From d351523655b0de261084393cc89525907ef291f5 Mon Sep 17 00:00:00 2001 From: Kishore Nallan Date: Sun, 20 Aug 2017 21:06:49 +0530 Subject: [PATCH] Allow results to be sorted on a float field. --- include/collection.h | 13 ++-- include/field.h | 6 +- include/number.h | 80 ++++++++++++++++++++++ include/person.h | 85 ++++++++++++++++++++++++ include/topster.h | 45 +------------ src/api.cpp | 4 +- src/collection.cpp | 81 ++++++++++++++++------- test/collection_manager_test.cpp | 4 +- test/collection_test.cpp | 110 +++++++++++++++++++++++++------ test/float_documents.jsonl | 7 ++ test/match_score_test.cpp | 6 ++ test/topster_test.cpp | 24 +++---- 12 files changed, 351 insertions(+), 114 deletions(-) create mode 100644 include/number.h create mode 100644 include/person.h create mode 100644 test/float_documents.jsonl diff --git a/include/collection.h b/include/collection.h index 0c56d6a5..fc2a5425 100644 --- a/include/collection.h +++ b/include/collection.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -52,7 +53,7 @@ private: spp::sparse_hash_map facet_schema; - std::vector sort_fields; + spp::sparse_hash_map sort_schema; Store* store; @@ -60,7 +61,7 @@ private: spp::sparse_hash_map facet_index; - spp::sparse_hash_map*> sort_index; + spp::sparse_hash_map*> sort_index; std::string token_ranking_field; @@ -84,14 +85,14 @@ private: size_t result_index, std::vector> &token_positions) const; void search_field(std::string & query, const std::string & field, uint32_t *filter_ids, size_t filter_ids_length, - std::vector & facets, const std::vector & sort_fields, + std::vector & facets, const std::vector & sort_fields, const int num_typos, const size_t num_results, std::vector> & searched_queries, int & searched_queries_index, Topster<100> & topster, uint32_t** all_result_ids, size_t & all_result_ids_len, const token_ordering token_order = FREQUENCY, const bool prefix = false); void search_candidates(uint32_t* filter_ids, size_t filter_ids_length, std::vector & facets, - const std::vector & sort_fields, int & candidate_rank, + const std::vector & sort_fields, int & candidate_rank, std::vector> & token_to_candidates, std::vector> & searched_queries, Topster<100> & topster, size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len, @@ -155,7 +156,7 @@ public: Option search(std::string query, const std::vector search_fields, const std::string & simple_filter_query, const std::vector & facet_fields, - const std::vector & sort_fields, const int num_typos, + const std::vector & sort_fields, const int num_typos, const size_t per_page = 10, const size_t page = 1, const token_ordering token_order = FREQUENCY, const bool prefix = false); @@ -163,7 +164,7 @@ public: Option remove(const std::string & id); - void score_results(const std::vector & sort_fields, const int & query_index, const int & candidate_rank, + void score_results(const std::vector & sort_fields, const int & query_index, const int & candidate_rank, Topster<100> &topster, const std::vector & query_suggestion, const uint32_t *result_ids, const size_t result_size) const; diff --git a/include/field.h b/include/field.h index 0b0e4bd3..7036eebc 100644 --- a/include/field.h +++ b/include/field.h @@ -81,15 +81,15 @@ namespace sort_field_const { static const std::string desc = "DESC"; } -struct sort_field { +struct sort_by { std::string name; std::string order; - sort_field(const std::string & name, const std::string & order): name(name), order(order) { + sort_by(const std::string & name, const std::string & order): name(name), order(order) { } - sort_field& operator=(sort_field other) { + sort_by& operator=(sort_by other) { name = other.name; order = other.order; return *this; diff --git a/include/number.h b/include/number.h new file mode 100644 index 00000000..af930a51 --- /dev/null +++ b/include/number.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +struct number_t { + bool is_float; + union { + float floatval; + int64_t intval; + }; + + number_t(): intval(0), is_float(false) { + + } + + number_t(bool is_float, float floatval): floatval(floatval), is_float(is_float) { + + } + + number_t(bool is_float, int64_t intval): intval(intval), is_float(is_float) { + + } + + + number_t(float val): floatval(val), is_float(true) { + + } + + number_t(int64_t val): intval(val), is_float(false) { + + } + + inline void operator = (const float & val) { + floatval = val; + is_float = true; + } + + inline void operator = (const int64_t & val) { + intval = val; + is_float = false; + } + + inline bool operator == (const number_t & rhs) const { + if(is_float) { + return floatval == rhs.floatval; + } + return intval == rhs.intval; + } + + inline bool operator < (const number_t & rhs) const { + if(is_float) { + return floatval < rhs.floatval; + } + return intval < rhs.intval; + } + + inline bool operator > (const number_t & rhs) const { + if(is_float) { + return floatval > rhs.floatval; + } + return intval > rhs.intval; + } + + inline number_t operator * (const number_t & rhs) const { + if(is_float) { + return number_t((float)(floatval * rhs.floatval)); + } + return number_t((int64_t)(intval * rhs.intval)); + } + + inline number_t operator-() { + if(is_float) { + floatval = -floatval; + } else { + intval = -intval; + } + + return *this; + } +}; \ No newline at end of file diff --git a/include/person.h b/include/person.h new file mode 100644 index 00000000..bd87e3c8 --- /dev/null +++ b/include/person.h @@ -0,0 +1,85 @@ +#pragma once + +struct person { + bool is_float; + union { + int64_t intval; + float floatval; + }; + + person(): intval(0), is_float(false) { + + } + + person(bool is_float, float floatval): floatval(floatval), is_float(is_float) { + + } + + person(bool is_float, int64_t intval): intval(intval), is_float(is_float) { + + } + + + person(float val): floatval(val), is_float(true) { + + } + + person(int64_t val): intval(val), is_float(false) { + + } + + inline void operator = (const float & val) { + floatval = val; + is_float = true; + } + + inline void operator = (const int64_t & val) { + intval = val; + is_float = false; + } + + inline bool operator == (const person & rhs) const { + if(is_float) { + return floatval == rhs.floatval; + } + return intval == rhs.intval; + } + + inline bool operator < (const person & rhs) const { + if(is_float) { + return floatval < rhs.floatval; + } + return intval < rhs.intval; + } + + inline bool operator > (const person & rhs) const { + if(is_float) { + return floatval > rhs.floatval; + } + return intval > rhs.intval; + } + + inline person operator * (const person & rhs) const { + if(is_float) { + return person(floatval * rhs.floatval); + } + return person(intval * rhs.intval); + } +}; + +namespace std +{ +// inject specialization of std::hash for Person into namespace std +// ---------------------------------------------------------------- + template<> + struct hash + { + std::size_t operator()(person const &p) const + { + std::size_t seed = 0; + spp::hash_combine(seed, p.is_float); + spp::hash_combine(seed, p.intval); + return seed; + } + }; +} \ No newline at end of file diff --git a/include/topster.h b/include/topster.h index 36f6df3b..347d1ca4 100644 --- a/include/topster.h +++ b/include/topster.h @@ -6,56 +6,13 @@ #include #include #include +#include /* * Remembers the max-K elements seen so far using a min-heap */ template struct Topster { - struct number_t { - bool is_float; - union { - float floatval; - int64_t intval; - }; - - number_t(): intval(0), is_float(false) { - - } - - number_t(float val): floatval(val), is_float(true) { - - } - - number_t(int64_t val): intval(val), is_float(false) { - - } - - inline void operator = (const float & val) { - floatval = val; - is_float = true; - } - - inline void operator = (const int64_t & val) { - intval = val; - is_float = false; - } - - inline bool operator < (const number_t & rhs) const { - if(is_float) { - return floatval < rhs.floatval; - } - return intval < rhs.intval; - } - - inline bool operator > (const number_t & rhs) const { - if(is_float) { - return floatval > rhs.floatval; - } - return intval > rhs.intval; - } - }; - struct KV { uint16_t query_index; uint64_t key; diff --git a/src/api.cpp b/src/api.cpp index 4d62c6e2..695c042a 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -188,7 +188,7 @@ void get_search(http_req & req, http_res & res) { std::vector facet_fields; StringUtils::split(req.params[FACET_BY], facet_fields, "&&"); - std::vector sort_fields; + std::vector sort_fields; if(req.params.count(SORT_BY) != 0) { std::vector sort_field_strs; StringUtils::split(req.params[SORT_BY], sort_field_strs, ","); @@ -206,7 +206,7 @@ void get_search(http_req & req, http_res & res) { } StringUtils::toupper(expression_parts[1]); - sort_fields.push_back(sort_field(expression_parts[0], expression_parts[1])); + sort_fields.push_back(sort_by(expression_parts[0], expression_parts[1])); } } diff --git a/src/collection.cpp b/src/collection.cpp index 1d6a5840..b5aa7330 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -11,7 +11,7 @@ Collection::Collection(const std::string name, const uint32_t collection_id, con const std::vector &search_fields, const std::vector & facet_fields, const std::vector & sort_fields, const std::string token_ranking_field): name(name), collection_id(collection_id), next_seq_id(next_seq_id), store(store), - sort_fields(sort_fields), token_ranking_field(token_ranking_field) { + token_ranking_field(token_ranking_field) { for(const field& field: search_fields) { art_tree *t = new art_tree; @@ -27,8 +27,9 @@ Collection::Collection(const std::string name, const uint32_t collection_id, con } for(const field & sort_field: sort_fields) { - spp::sparse_hash_map * doc_to_score = new spp::sparse_hash_map(); + spp::sparse_hash_map * doc_to_score = new spp::sparse_hash_map(); sort_index.emplace(sort_field.name, doc_to_score); + sort_schema.emplace(sort_field.name, sort_field); } num_documents = 0; @@ -218,7 +219,9 @@ Option Collection::index_in_memory(const nlohmann::json &document, uin } } - for(const field & sort_field: sort_fields) { + for(const std::pair & field_pair: sort_schema) { + const field & sort_field = field_pair.second; + if(document.count(sort_field.name) == 0) { return Option<>(400, "Field `" + sort_field.name + "` has been declared as a sort field in the schema, " "but is not found in the document."); @@ -228,8 +231,13 @@ Option Collection::index_in_memory(const nlohmann::json &document, uin return Option<>(400, "Sort field `" + sort_field.name + "` must be a number."); } - spp::sparse_hash_map *doc_to_score = sort_index.at(sort_field.name); - doc_to_score->emplace(seq_id, document[sort_field.name].get()); + spp::sparse_hash_map *doc_to_score = sort_index.at(sort_field.name); + + if(document[sort_field.name].is_number_integer()) { + doc_to_score->emplace(seq_id, document[sort_field.name].get()); + } else { + doc_to_score->emplace(seq_id, document[sort_field.name].get()); + } } num_documents += 1; @@ -401,7 +409,7 @@ void Collection::do_facets(std::vector & facets, uint32_t* result_ids, si } void Collection::search_candidates(uint32_t* filter_ids, size_t filter_ids_length, std::vector & facets, - const std::vector & sort_fields, int & candidate_rank, + const std::vector & sort_fields, int & candidate_rank, std::vector> & token_to_candidates, std::vector> & searched_queries, Topster<100> & topster, size_t & total_results, uint32_t** all_result_ids, size_t & all_result_ids_len, @@ -649,7 +657,7 @@ Option Collection::do_filtering(uint32_t** filter_ids_out, const std:: Option Collection::search(std::string query, const std::vector search_fields, const std::string & simple_filter_query, const std::vector & facet_fields, - const std::vector & sort_fields, const int num_typos, + const std::vector & sort_fields, const int num_typos, const size_t per_page, const size_t page, const token_ordering token_order, const bool prefix) { nlohmann::json result = nlohmann::json::object(); @@ -680,9 +688,9 @@ Option Collection::search(std::string query, const std::vector sort_fields_std; + std::vector sort_fields_std; - for(const sort_field & _sort_field: sort_fields) { + for(const sort_by & _sort_field: sort_fields) { if(sort_index.count(_sort_field.name) == 0) { std::string error = "Could not find a sort field named `" + _sort_field.name + "` in the schema."; return Option(400, error); @@ -888,7 +896,7 @@ Option Collection::search(std::string query, const std::vector & facets, const std::vector & sort_fields, const int num_typos, + std::vector & facets, const std::vector & sort_fields, const int num_typos, const size_t num_results, std::vector> & searched_queries, int & searched_queries_index, Topster<100> &topster, uint32_t** all_result_ids, size_t & all_result_ids_len, const token_ordering token_order, const bool prefix) { @@ -1048,7 +1056,7 @@ void Collection::log_leaves(const int cost, const std::string &token, const std: } } -void Collection::score_results(const std::vector & sort_fields, const int & query_index, const int & candidate_rank, +void Collection::score_results(const std::vector & sort_fields, const int & query_index, const int & candidate_rank, Topster<100> & topster, const std::vector &query_suggestion, const uint32_t *result_ids, const size_t result_size) const { @@ -1061,25 +1069,43 @@ void Collection::score_results(const std::vector & sort_fields, cons leaf_to_indices.emplace(token_leaf, indices); } - spp::sparse_hash_map * primary_rank_scores = nullptr; - spp::sparse_hash_map * secondary_rank_scores = nullptr; + spp::sparse_hash_map * primary_rank_scores = nullptr; + spp::sparse_hash_map * secondary_rank_scores = nullptr; // Used for asc/desc ordering. NOTE: Topster keeps biggest keys (i.e. it's desc in nature) - int64_t primary_rank_factor = 1; - int64_t secondary_rank_factor = 1; + number_t primary_rank_factor; + number_t secondary_rank_factor; if(sort_fields.size() > 0) { // assumed that rank field exists in the index - checked earlier in the chain primary_rank_scores = sort_index.at(sort_fields[0].name); + + // initialize primary_rank_factor + field sort_field = sort_schema.at(sort_fields[0].name); + if(sort_field.is_integer()) { + primary_rank_factor = ((int64_t) 1); + } else { + primary_rank_factor = ((float) 1); + } + if(sort_fields[0].order == sort_field_const::asc) { - primary_rank_factor = -1; + primary_rank_factor = -primary_rank_factor; } } if(sort_fields.size() > 1) { secondary_rank_scores = sort_index.at(sort_fields[1].name); + + // initialize secondary_rank_factor + field sort_field = sort_schema.at(sort_fields[1].name); + if(sort_field.is_integer()) { + secondary_rank_factor = ((int64_t) 1); + } else { + secondary_rank_factor = ((float) 1); + } + if(sort_fields[1].order == sort_field_const::asc) { - secondary_rank_factor = -1; + secondary_rank_factor = -secondary_rank_factor; } } @@ -1105,13 +1131,15 @@ void Collection::score_results(const std::vector & sort_fields, cons (candidate_rank_score << 8) + (MAX_SEARCH_TOKENS - mscore.distance); - int64_t primary_rank_score = (primary_rank_scores && primary_rank_scores->count(seq_id) > 0) ? - primary_rank_scores->at(seq_id) : 0; - int64_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ? - secondary_rank_scores->at(seq_id) : 0; + const int64_t default_score = 0; + number_t primary_rank_score = (primary_rank_scores && primary_rank_scores->count(seq_id) > 0) ? + primary_rank_scores->at(seq_id) : default_score; + number_t secondary_rank_score = (secondary_rank_scores && secondary_rank_scores->count(seq_id) > 0) ? + secondary_rank_scores->at(seq_id) : default_score; - topster.add(seq_id, query_index, match_score, primary_rank_factor * primary_rank_score, - secondary_rank_factor * secondary_rank_score); + const number_t & primary_rank_value = primary_rank_score * primary_rank_factor; + const number_t & secondary_rank_value = secondary_rank_score * secondary_rank_factor; + topster.add(seq_id, query_index, match_score, primary_rank_value, secondary_rank_value); /*std::cout << "candidate_rank_score: " << candidate_rank_score << ", words_present: " << mscore.words_present << ", match_score: " << match_score << ", primary_rank_score: " << primary_rank_score @@ -1399,7 +1427,12 @@ std::vector Collection::get_facet_fields() { } std::vector Collection::get_sort_fields() { - return sort_fields; + std::vector sort_fields_copy; + for(auto it = sort_schema.begin(); it != sort_schema.end(); ++it) { + sort_fields_copy.push_back(it->second); + } + + return sort_fields_copy; } spp::sparse_hash_map Collection::get_schema() { diff --git a/test/collection_manager_test.cpp b/test/collection_manager_test.cpp index e1a9be7f..777bb592 100644 --- a/test/collection_manager_test.cpp +++ b/test/collection_manager_test.cpp @@ -14,7 +14,7 @@ protected: std::vector facet_fields; std::vector sort_fields_index; - std::vector sort_fields; + std::vector sort_fields; void setupCollection() { std::string state_dir_path = "/tmp/typesense_test/coll_manager_test_db"; @@ -26,7 +26,7 @@ protected: search_fields = {field("title", field_types::STRING), field("starring", field_types::STRING)}; facet_fields = {field("starring", field_types::STRING)}; - sort_fields = { sort_field("points", "DESC") }; + sort_fields = { sort_by("points", "DESC") }; sort_fields_index = { field("points", "INT32") }; collection1 = collectionManager.create_collection("collection1", search_fields, facet_fields, diff --git a/test/collection_test.cpp b/test/collection_test.cpp index 7bfd82ba..70d28449 100644 --- a/test/collection_test.cpp +++ b/test/collection_test.cpp @@ -5,6 +5,8 @@ #include #include #include "collection.h" +#include "person.h" +#include "number.h" class CollectionTest : public ::testing::Test { protected: @@ -14,7 +16,7 @@ protected: CollectionManager & collectionManager = CollectionManager::get_instance(); std::vector facet_fields; std::vector sort_fields_index; - std::vector sort_fields; + std::vector sort_fields; void setupCollection() { std::string state_dir_path = "/tmp/typesense_test/collection"; @@ -29,7 +31,7 @@ protected: query_fields = {"title"}; facet_fields = { }; - sort_fields = { sort_field("points", "DESC") }; + sort_fields = { sort_by("points", "DESC") }; sort_fields_index = { field("points", "INT32") }; collection = collectionManager.get_collection("collection"); @@ -94,7 +96,7 @@ TEST_F(CollectionTest, ExactSearchShouldBeStable) { } // check ASC sorting - std::vector sort_fields_asc = { sort_field("points", "ASC") }; + std::vector sort_fields_asc = { sort_by("points", "ASC") }; results = collection->search("the", query_fields, "", facets, sort_fields_asc, 0, 10).get(); ASSERT_EQ(7, results["hits"].size()); @@ -135,7 +137,7 @@ TEST_F(CollectionTest, ExactPhraseSearch) { } // Check ASC sort order - std::vector sort_fields_asc = { sort_field("points", "ASC") }; + std::vector sort_fields_asc = { sort_by("points", "ASC") }; results = collection->search("rocket launch", query_fields, "", facets, sort_fields_asc, 0, 10).get(); ASSERT_EQ(5, results["hits"].size()); ASSERT_EQ(5, results["found"].get()); @@ -495,7 +497,7 @@ TEST_F(CollectionTest, FilterOnNumericFields) { std::vector fields = {field("name", field_types::STRING), field("age", field_types::INT32), field("years", field_types::INT32_ARRAY), field("timestamps", field_types::INT64_ARRAY)}; - std::vector sort_fields = { sort_field("age", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC") }; std::vector sort_fields_index = { field("age", "INT32") }; coll_array_fields = collectionManager.get_collection("coll_array_fields"); @@ -632,8 +634,8 @@ TEST_F(CollectionTest, FilterOnFloatFields) { field("top_3", field_types::FLOAT_ARRAY), field("rating", field_types::FLOAT)}; std::vector sort_fields_index = { field("rating", "FLOAT") }; - std::vector sort_fields_desc = { sort_field("rating", "DESC") }; - std::vector sort_fields_asc = { sort_field("rating", "ASC") }; + std::vector sort_fields_desc = { sort_by("rating", "DESC") }; + std::vector sort_fields_asc = { sort_by("rating", "ASC") }; coll_array_fields = collectionManager.get_collection("coll_array_fields"); if(coll_array_fields == nullptr) { @@ -673,7 +675,7 @@ TEST_F(CollectionTest, FilterOnFloatFields) { nlohmann::json result = results["hits"].at(i); std::string result_id = result["id"]; std::string id = ids.at(i); - ASSERT_STREQ(id.c_str(), result_id.c_str()); + ASSERT_STREQ(id.c_str(), result_id.c_str()); //? } // Searching on a float field, sorted desc by rating @@ -761,6 +763,72 @@ TEST_F(CollectionTest, FilterOnFloatFields) { collectionManager.drop_collection("coll_array_fields"); } +TEST_F(CollectionTest, SortOnFloatFields) { + Collection *coll_float_fields; + + std::ifstream infile(std::string(ROOT_DIR)+"test/float_documents.jsonl"); + std::vector fields = {field("title", field_types::STRING), field("score", field_types::FLOAT)}; + std::vector sort_fields_index = { field("score", "FLOAT"), field("average", "FLOAT") }; + std::vector sort_fields_desc = { sort_by("score", "DESC"), sort_by("average", "DESC") }; + + coll_float_fields = collectionManager.get_collection("coll_float_fields"); + if(coll_float_fields == nullptr) { + coll_float_fields = collectionManager.create_collection("coll_float_fields", fields, facet_fields, sort_fields_index); + } + + std::string json_line; + + while (std::getline(infile, json_line)) { + coll_float_fields->add(json_line); + } + + infile.close(); + + query_fields = {"title"}; + std::vector facets; + nlohmann::json results = coll_float_fields->search("Jeremy", query_fields, "", facets, sort_fields_desc, 0, 10, 1, FREQUENCY, false).get(); + ASSERT_EQ(7, results["hits"].size()); + + std::vector ids = {"2", "0", "3", "1", "5", "4", "6"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["id"]; + std::string id = ids.at(i); + EXPECT_STREQ(id.c_str(), result_id.c_str()); + } + + std::vector sort_fields_asc = { sort_by("score", "ASC"), sort_by("average", "ASC") }; + results = coll_float_fields->search("Jeremy", query_fields, "", facets, sort_fields_asc, 0, 10, 1, FREQUENCY, false).get(); + ASSERT_EQ(7, results["hits"].size()); + + ids = {"6", "4", "5", "1", "3", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["id"]; + std::string id = ids.at(i); + EXPECT_STREQ(id.c_str(), result_id.c_str()); + } + + // second field by desc + + std::vector sort_fields_asc_desc = { sort_by("score", "ASC"), sort_by("average", "DESC") }; + results = coll_float_fields->search("Jeremy", query_fields, "", facets, sort_fields_asc_desc, 0, 10, 1, FREQUENCY, false).get(); + ASSERT_EQ(7, results["hits"].size()); + + ids = {"5", "4", "6", "1", "3", "0", "2"}; + + for(size_t i = 0; i < results["hits"].size(); i++) { + nlohmann::json result = results["hits"].at(i); + std::string result_id = result["id"]; + std::string id = ids.at(i); + EXPECT_STREQ(id.c_str(), result_id.c_str()); + } + + collectionManager.drop_collection("coll_float_fields"); +} + TEST_F(CollectionTest, FilterOnTextFields) { Collection *coll_array_fields; @@ -770,7 +838,7 @@ TEST_F(CollectionTest, FilterOnTextFields) { field("tags", field_types::STRING_ARRAY)}; std::vector sort_fields_index = { field("age", "INT32") }; - std::vector sort_fields = { sort_field("age", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC") }; coll_array_fields = collectionManager.get_collection("coll_array_fields"); if(coll_array_fields == nullptr) { @@ -842,7 +910,7 @@ TEST_F(CollectionTest, HandleBadlyFormedFilterQuery) { field("tags", field_types::STRING_ARRAY)}; std::vector sort_fields_index = { field("age", "INT32") }; - std::vector sort_fields = { sort_field("age", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC") }; coll_array_fields = collectionManager.get_collection("coll_array_fields"); if(coll_array_fields == nullptr) { @@ -898,7 +966,7 @@ TEST_F(CollectionTest, FacetCounts) { facet_fields = {field("tags", field_types::STRING_ARRAY), field("name", field_types::STRING)}; std::vector sort_fields_index = { field("age", "DESC") }; - std::vector sort_fields = { sort_field("age", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC") }; coll_array_fields = collectionManager.get_collection("coll_array_fields"); if(coll_array_fields == nullptr) { @@ -991,7 +1059,7 @@ TEST_F(CollectionTest, SortingOrder) { query_fields = {"title"}; std::vector facets; - sort_fields = { sort_field("points", "ASC") }; + sort_fields = { sort_by("points", "ASC") }; nlohmann::json results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get(); ASSERT_EQ(10, results["hits"].size()); @@ -1005,7 +1073,7 @@ TEST_F(CollectionTest, SortingOrder) { } // limiting results to just 5, "ASC" keyword must be case insensitive - sort_fields = { sort_field("points", "asc") }; + sort_fields = { sort_by("points", "asc") }; results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 5, 1, FREQUENCY, false).get(); ASSERT_EQ(5, results["hits"].size()); @@ -1020,7 +1088,7 @@ TEST_F(CollectionTest, SortingOrder) { // desc - sort_fields = { sort_field("points", "dEsc") }; + sort_fields = { sort_by("points", "dEsc") }; results = coll_mul_fields->search("the", query_fields, "", facets, sort_fields, 0, 15, 1, FREQUENCY, false).get(); ASSERT_EQ(10, results["hits"].size()); @@ -1062,7 +1130,7 @@ TEST_F(CollectionTest, SearchingWithMissingFields) { field("tags", field_types::STRING_ARRAY)}; facet_fields = {field("tags", field_types::STRING_ARRAY), field("name", field_types::STRING)}; std::vector sort_fields_index = { field("age", "DESC") }; - std::vector sort_fields = { sort_field("age", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC") }; coll_array_fields = collectionManager.get_collection("coll_array_fields"); if(coll_array_fields == nullptr) { @@ -1097,11 +1165,11 @@ TEST_F(CollectionTest, SearchingWithMissingFields) { ASSERT_STREQ("Could not find a facet field named `timestamps` in the schema.", res_op.error().c_str()); // when a rank field is not defined in the schema - res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_field("timestamps", "ASC") }, 0, 10); + res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_by("timestamps", "ASC") }, 0, 10); ASSERT_EQ(400, res_op.code()); ASSERT_STREQ("Could not find a sort field named `timestamps` in the schema.", res_op.error().c_str()); - res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_field("_rank", "ASC") }, 0, 10); + res_op = coll_array_fields->search("the", {"name"}, "", {}, { sort_by("_rank", "ASC") }, 0, 10); ASSERT_EQ(400, res_op.code()); ASSERT_STREQ("Could not find a sort field named `_rank` in the schema.", res_op.error().c_str()); @@ -1116,7 +1184,7 @@ TEST_F(CollectionTest, IndexingWithBadData) { facet_fields = {field("tags", field_types::STRING_ARRAY)}; std::vector sort_fields_index = { field("age", "INT32"), field("average", "INT32") }; - std::vector sort_fields = { sort_field("age", "DESC"), sort_field("average", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") }; sample_collection = collectionManager.get_collection("sample_collection"); if(sample_collection == nullptr) { @@ -1193,7 +1261,7 @@ TEST_F(CollectionTest, EmptyIndexShouldNotCrash) { facet_fields = {field("tags", field_types::STRING_ARRAY)}; std::vector sort_fields_index = { field("age", "INT32"), field("average", "INT32") }; - std::vector sort_fields = { sort_field("age", "DESC"), sort_field("average", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") }; empty_coll = collectionManager.get_collection("empty_coll"); if(empty_coll == nullptr) { @@ -1212,7 +1280,7 @@ TEST_F(CollectionTest, IdFieldShouldBeAString) { facet_fields = {field("tags", field_types::STRING_ARRAY)}; std::vector sort_fields_index = { field("age", "INT32"), field("average", "INT32") }; - std::vector sort_fields = { sort_field("age", "DESC"), sort_field("average", "DESC") }; + std::vector sort_fields = { sort_by("age", "DESC"), sort_by("average", "DESC") }; coll1 = collectionManager.get_collection("coll1"); if(coll1 == nullptr) { @@ -1241,7 +1309,7 @@ TEST_F(CollectionTest, DeletionOfADocument) { std::vector search_fields = {field("title", field_types::STRING)}; std::vector query_fields = {"title"}; std::vector facet_fields = { }; - std::vector sort_fields = { sort_field("points", "DESC") }; + std::vector sort_fields = { sort_by("points", "DESC") }; std::vector sort_fields_index = { field("points", "INT32") }; Collection *collection_for_del; diff --git a/test/float_documents.jsonl b/test/float_documents.jsonl new file mode 100644 index 00000000..65de0262 --- /dev/null +++ b/test/float_documents.jsonl @@ -0,0 +1,7 @@ +{"title": "Jeremy Howard", "score": 1.09, "average": 1.45} +{"title": "Jeremy Howard", "score": -9.998, "average": -2.408 } +{"title": "Jeremy Howard", "score": 7.812, "average": 0.001 } +{"title": "Jeremy Howard", "score": 0.0, "average": 11.533 } +{"title": "Jeremy Howard", "score": -9.999, "average": -11.38 } +{"title": "Jeremy Howard", "score": -9.999, "average": 19.38 } +{"title": "Jeremy Howard", "score": -9.999, "average": -21.38 } \ No newline at end of file diff --git a/test/match_score_test.cpp b/test/match_score_test.cpp index 12642cb9..bc527559 100644 --- a/test/match_score_test.cpp +++ b/test/match_score_test.cpp @@ -19,4 +19,10 @@ TEST(MatchScoreTest, ShouldPackTokenOffsets) { ASSERT_EQ(0, offset_diffs[1]); ASSERT_EQ(1, offset_diffs[2]); ASSERT_EQ(2, offset_diffs[3]); + + uint16_t min_token_offset3[1] = {123}; + MatchScore::pack_token_offsets(min_token_offset3, 1, 0, offset_diffs); + + ASSERT_EQ(1, offset_diffs[0]); + ASSERT_EQ(0, offset_diffs[1]); } \ No newline at end of file diff --git a/test/topster_test.cpp b/test/topster_test.cpp index 169d4ba5..8f4ed1d3 100644 --- a/test/topster_test.cpp +++ b/test/topster_test.cpp @@ -52,18 +52,18 @@ TEST(TopsterTest, StoreMaxFloatValuesWithoutRepetition) { float primary_attr; int64_t secondary_attr; } data[12] = { - {0, 1, 11, 20.04, 30}, - {0, 2, 4, 20, 30}, - {2, 3, 7, 20, 30}, - {0, 4, 11, 20.05, 30}, - {0, 4, 11, 20.05, 30}, - {1, 5, 9, 24.50, 34}, - {0, 6, 6, 20, 30}, - {2, 7, 6, 22, 30}, - {1, 8, 9, 24.50, 30}, - {1, 8, 9, 24.50, 30}, - {0, 9, 8, 24.50, 30}, - {3, 10, 5, 20, 30}, + {0, 1, 11, 1.09, 30}, + {0, 2, 11, -20, 30}, + {2, 3, 11, -20, 30}, + {0, 4, 11, 7.812, 30}, + {0, 4, 11, 7.812, 30}, + {1, 5, 11, 0.0, 34}, + {0, 6, 11, -22, 30}, + {2, 7, 11, -22, 30}, + {1, 8, 11, -9.998, 30}, + {1, 8, 11, -9.998, 30}, + {0, 9, 11, -9.999, 30}, + {3, 10, 11, -20, 30}, }; for(int i = 0; i < 12; i++) {